## Practice problem 4 from midterm: wind directions
## confidence/credible interval/region examples
source("https://ruizt.quarto.pub/stat545/scripts/ci-helpers.R")
set.seed(1)

# ----- constants -----
states <- c("N","E","S","W")
K <- length(states)

# ----- data -----
counts <- matrix(
  c(18, 6, 2, 4,
    5, 15, 6, 4,
    1, 3, 0, 1,
    6, 4, 3, 17),
  nrow = K, byrow = TRUE
)
dimnames(counts) <- list(from = states, to = states)

# Pairs for 2D panels (used later in plotting)
pairs_2d <- list(
  c("N","E"), c("N","S"), c("N","W"),
  c("E","S"), c("E","W"), c("S","W")
)

# ----- MLE -----
P_mle <- sweep(counts, 1, rowSums(counts), "/")
dimnames(P_mle) <- list(from = states, to = states)
P_mle

# ----- pointwise MLE intervals -----
row_idx <- 1L                 # 1 = N
level <- 0.95
zcrit <- qnorm(1 - (1 - level) / 2)

x <- counts[row_idx, ]
n <- sum(x)
p_hat <- as.numeric(x / n)
names(p_hat) <- states

ci_binom <- cbind(
  lwr.binom = qbeta((1 - level) / 2, x,     n - x + 1),
  upr.binom = qbeta(1 - (1 - level) / 2, x + 1, n - x)
)

se <- sqrt(p_hat * (1 - p_hat) / n)
ci_wald <- cbind(
  lwr.wald = p_hat - zcrit * se,
  upr.wald = p_hat + zcrit * se
)
ci_wald[, "lwr.wald"] <- pmax(0, ci_wald[, "lwr.wald"])
ci_wald[, "upr.wald"] <- pmin(1, ci_wald[, "upr.wald"])

rownames(ci_binom) <- rownames(ci_wald) <- states
cbind(p_hat = p_hat, ci_binom, ci_wald)

# ----- Wald joint region for entire row + 2D contour projections -----
M_wald <- 50000
drop_idx <- 4L              # 4 = W
keep_idx <- setdiff(seq_len(K), drop_idx)

keep_names <- states[keep_idx]
drop_name  <- states[drop_idx]

Sigma <- diag(p_hat) - tcrossprod(p_hat)
dimnames(Sigma) <- list(states, states)
Sigma_sub <- Sigma[keep_idx, keep_idx, drop = FALSE]

c2 <- qchisq(level, df = length(keep_idx))
L <- chol(Sigma_sub)
U <- matrix(rnorm(M_wald * length(keep_idx)), nrow = M_wald)
U <- U / sqrt(rowSums(U^2))
rad <- runif(M_wald)^(1 / length(keep_idx))
D <- (sqrt(c2 / n) * rad) * (U %*% t(L))

P_sub <- sweep(D, 2, p_hat[keep_idx], "+")
colnames(P_sub) <- keep_names
P_wald <- cbind(P_sub, 1 - rowSums(P_sub))
colnames(P_wald) <- c(keep_names, drop_name)
P_wald <- P_wald[, states, drop = FALSE]
P_wald <- P_wald[apply(P_wald, 1, min) >= 0, , drop = FALSE]

plot_hdr_contour_panel(
  P_wald, pairs_2d,
  mass = 0.95, nbins = 160, blur = 5,
  xlab_prefix = "p_", ylab_prefix = "p_",
  mark = p_hat
)

# ----- Dirichlet smoothing for row S + pointwise credible intervals -----
alpha0 <- 10
rowS_idx <- 3L  # 3 = S
pool_rows_idx <- c(1L, 2L, 4L)  # N, E, W

pooled <- colSums(counts[pool_rows_idx, , drop = FALSE])
prior_mean <- pooled / sum(pooled)

alpha <- alpha0 * prior_mean
names(alpha) <- states

xS <- counts[rowS_idx, ]
nS <- sum(xS)

P_dir <- P_mle
P_dir[rowS_idx, ] <- (xS + alpha) / (nS + alpha0)
dimnames(P_dir) <- list(from = states, to = states)
P_dir

ci_S <- cbind(
  lower = qbeta((1 - level) / 2, xS + alpha,     (nS - xS) + (alpha0 - alpha)),
  mean  = (xS + alpha) / (nS + alpha0),
  upper = qbeta(1 - (1 - level) / 2, xS + alpha, (nS - xS) + (alpha0 - alpha))
)
rownames(ci_S) <- states
ci_S

# ----- Beta marginals for Dirichlet-smoothed row S -----

par(mfrow = c(2, 2), mar = c(3.2, 3.2, 2.2, 1))

for (j in 1:K) {
  a_post <- xS[j] + alpha[j]
  b_post <- (nS - xS[j]) + (alpha0 - alpha[j])
  
  q <- qbeta(c(0.025, 0.975), a_post, b_post)
  
  xx <- seq(0, 1, length.out = 600)
  yy <- dbeta(xx, a_post, b_post)
  
  plot(
    xx, yy, type = "l",
    xlab = sprintf("p(%s -> %s)", states[rowS_idx], states[j]),
    ylab = "density",
    main = sprintf("Beta(%.1f, %.1f)", a_post, b_post)
  )
  
  # posterior mean for this component (matches ci_S[,"mean"])
  abline(v = (xS[j] + alpha[j]) / (nS + alpha0), lty = 2)
  
  # 95% credible interval (matches ci_S[,"lower"/"upper"])
  abline(v = q[1], lty = 3)
  abline(v = q[2], lty = 3)
}

# ----- HPD set for row S + 2D contour projections -----
M_hpd <- 2e6
prob_hpd <- 0.95

a <- xS + alpha
names(a) <- states

G <- sapply(a, function(shape) rgamma(M_hpd, shape = shape, rate = 1))
P_hpd <- G / rowSums(G)
colnames(P_hpd) <- states

log_u <- rowSums(sweep(log(P_hpd), 2, a - 1, "*"))
P_hpd_hpd <- P_hpd[log_u >= quantile(log_u, 1 - prob_hpd), , drop = FALSE]

plot_hdr_contour_panel(
  P_hpd_hpd, pairs_2d,
  mass = 0.95, nbins = 160, blur = 4,
  xlab_prefix = "p_S", ylab_prefix = "p_S",
  mark = a / sum(a)
)

