library(arrow)
library(tidyverse)
library(sf)
library(markovchain)

## DATA ------------------------------------------------------------------------

# read in taxi trip data for january 2025
# downloaded from https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page
trips <- read_parquet('_data/yellow_tripdata_2025-01.parquet')
str(trips)

# map of zones
zones <- st_read('_data/taxi_zones/taxi_zones.shp', quiet = TRUE) 
ggplot(zones) +
  geom_sf(aes(fill = borough), linewidth = 0.1, color = 'white') +
  coord_sf(datum = NA) +
  theme_minimal() +
  guides(fill = guide_none())

# 261 states (pickup location zones)
states <- as.character(trips$PULocationID) |> unique() |> sort()

# extract transitions for estimation
mc_dat <- trips %>%
  transmute(
    from = as.character(PULocationID),
    to   = as.character(DOLocationID)
  ) %>%
  filter(!is.na(from), !is.na(to))

# transition counts
N <- with(mc_dat,
          table(factor(from, levels = states),
                factor(to,   levels = states)))

## CONDITIONAL MLE -------------------------------------------------------------

# conditional mle
phat_mle <- prop.table(N, 1) |> unclass()

# manually create markovchain object
mcfit_mle <- new("markovchain",
          states = states,
          byrow = TRUE,
          transitionMatrix = unclass(phat_mle))

# but get two degenerate (all mass at one point) stationary distributions
steadyStates(mcfit_mle)

# ... corresponding to the two recurrent classes
recurrentClasses(mcfit_mle)

# ... and everything else is transient
transientClasses(mcfit_mle) |> unlist() |> length()

# why? first note many zones with very low effective sample size
rowSums(N) |> sort() |> head(20)

# then note high row sparsity (many "impossible" destinations)
zeros_by_row <- rowSums(N == 0)
frac_zero_by_row <- zeros_by_row / ncol(N)
sort(frac_zero_by_row, decreasing = TRUE) |>
  head(20)

# in extreme cases this gives a degenerate mle (only one possible destination)
which(rowSums(N > 0) == 1) |> head()

# and in two of those cases all observed transitions are self loops
phat_mle[c("44", "84"), c("44", "84")]

## FIX: BAYESIAN ESTIMATION WITH DIRICHLET PRIOR -------------------------------

# symmetric dirichlet prior
alpha <- 2  # tune: 0.1, 0.5, 1
K <- nrow(N)
row_tot <- rowSums(N)
phat_bayes <- unclass((N + alpha) / (row_tot + K * alpha))

# markovchain object
mcfit_bayes <- new("markovchain",
              states = rownames(N),
              byrow = TRUE,
              transitionMatrix = phat_bayes)

# stable stationary distribution and no transient states
steadyStates(mcfit_bayes)
recurrentClasses(mcfit_bayes)
transientClasses(mcfit_bayes)

# predict where a cab is going (most likely states)
origin <- "89"
conditionalDistribution(mcfit_bayes, state = origin) |> 
  sort(decreasing = T) |> 
  head()

# plot it on the map
origin <- "89"
p <- conditionalDistribution(mcfit_bayes, state = origin)
zones |> 
  mutate(prob = as.numeric(p[as.character(LocationID)])) |>
  ggplot() +
  geom_sf(aes(fill = prob), 
          color = "lightgrey", 
          linewidth = 0.1) +
  geom_sf(data = filter(zones, 
                        as.character(LocationID) == origin),
          fill = NA, 
          color = "black", 
          linewidth = 0.5) +
  scale_fill_gradient(low = "white", 
                      high = "red",
                      trans = 'sqrt') +
  coord_sf(datum = NA) +
  theme_minimal() +
  guides(fill = 'none')


## EMPIRICAL BAYES -------------------------------------------------------------

# concentration (equivalent sample size)
alpha <- 50

# estimate prior from data
g <- colSums(N)
g <- g / sum(g)

# compute posterior
row_tot <- rowSums(N)
phat_ebayes <- unclass(sweep(N, 2, alpha * g, `+`) / (row_tot + alpha))

# create markovchain object
mcfit_ebayes <- new("markovchain",
                 states = rownames(N),
                 byrow = TRUE,
                 transitionMatrix = phat_ebayes)

# stable stationary distribution
steadyStates(mcfit_ebayes)

# interestingly, you do get two transient classes
recurrentClasses(mcfit_ebayes)
transientClasses(mcfit_ebayes)

# prediction
origin <- "89"
conditionalDistribution(mcfit_ebayes, state = origin) |> 
  sort(decreasing = T) |> 
  head()

# plot it on the map
origin <- "89"
p <- conditionalDistribution(mcfit_ebayes, state = origin)
zones |> 
  mutate(prob = as.numeric(p[as.character(LocationID)])) |>
  ggplot() +
  geom_sf(aes(fill = prob), 
          color = "lightgrey", 
          linewidth = 0.1) +
  geom_sf(data = filter(zones, 
                        as.character(LocationID) == origin),
          fill = NA, 
          color = "black", 
          linewidth = 0.5) +
  scale_fill_gradient(low = "white", 
                      high = "red",
                      trans = 'sqrt') +
  coord_sf(datum = NA) +
  theme_minimal() +
  guides(fill = 'none')

# expected return times (how many trips does it take the cab to get back?)
pi <- steadyStates(mcfit_ebayes)
sort(1/pi[, recurrentStates(mcfit_ebayes)]) |> head()
