# Metropolis–Hastings for a Markov chain transition matrix P (KxK)
# Prior: each row P[i,] ~ Dirichlet(alpha[i,])
# Likelihood: transition counts N[i,j]
# Stationary start: X0 ~ pi(P), contributes log(pi[x0])

# ---------- data ----------
N <- matrix(
  c(18, 6, 2, 4,
    5, 15, 6, 4,
    1, 3, 0, 1,
    6, 4, 3, 17),
  nrow = 4, byrow = TRUE
)

K <- nrow(N)
alpha <- matrix(1, K, K)   # Dirichlet(1,...,1) per row (uniform); change if desired
x0 <- 1                    # starting state index in {1,...,K}
set.seed(1)

# ---------- helpers ----------
softmax <- function(z) {
  z <- z - max(z)
  ez <- exp(z)
  ez / sum(ez)
}

# theta: K x (K-1) real numbers -> P: K x K probabilities (rows sum to 1)
theta_to_P <- function(theta) {
  P <- matrix(NA_real_, K, K)
  for (i in 1:K) P[i, ] <- softmax(c(theta[i, ], 0))  # last category baseline
  P
}

# stationary distribution via power iteration
stationary_dist <- function(P, tol = 1e-12, maxit = 10000) {
  pi <- rep(1 / K, K)
  for (t in 1:maxit) {
    pi_new <- as.numeric(pi %*% P)
    if (max(abs(pi_new - pi)) < tol) return(pi_new / sum(pi_new))
    pi <- pi_new
  }
  pi / sum(pi)
}

logpost_theta <- function(theta) {
  P <- theta_to_P(theta)
  
  # log prior: product of Dirichlet rows
  lp <- 0
  for (i in 1:K) {
    a <- alpha[i, ]
    lp <- lp + (lgamma(sum(a)) - sum(lgamma(a))) + sum((a - 1) * log(P[i, ]))
  }
  
  # log likelihood: sum_{i,j} N[i,j] log P[i,j]  (constants dropped)
  lp <- lp + sum(N * log(P))
  
  # stationary start contribution
  pi <- stationary_dist(P)
  lp <- lp + log(pi[x0])
  
  lp
}

# ---------- MCMC settings ----------
n_iter <- 30000
burn   <- 5000
thin   <- 10
step   <- 0.04            # random-walk size; tune to get ~0.2–0.4 accept rate

n_keep <- floor((n_iter - burn) / thin)
P_draws <- array(NA_real_, dim = c(n_keep, K, K))

# ---------- initialize ----------
# start at row-wise posterior mean (ignoring stationary-start term)
P0 <- (N + alpha) / rowSums(N + alpha)

theta <- matrix(0, K, K - 1)
for (i in 1:K) theta[i, ] <- log(P0[i, 1:(K - 1)] / P0[i, K])  # baseline category K

lp <- logpost_theta(theta)

# ---------- MH loop ----------
accept <- 0
keep <- 0
plot_every <- 2000   # update plots every this many iterations

oldpar <- par(no.readonly = TRUE)
on.exit(par(oldpar), add = TRUE)

for (it in 1:n_iter) {
  # propose in unconstrained space (theta = logit-like for rows)
  theta_prop <- theta + matrix(rnorm(K * (K - 1), sd = step), K, K - 1)
  lp_prop <- logpost_theta(theta_prop)
  
  # MH accept/reject
  if (log(runif(1)) < (lp_prop - lp)) {
    theta <- theta_prop
    lp <- lp_prop
    accept <- accept + 1
  }
  
  # save draw
  if (it > burn && ((it - burn) %% thin == 0)) {
    keep <- keep + 1
    P_draws[keep, , ] <- theta_to_P(theta)
  }
  
  # ---- live plot: whole P matrix traces (row panels; each column overlaid) ----
  if (it %% plot_every == 0 && keep >= 2) {
    idx <- 1:keep
    par(mfrow = c(2, 2), mar = c(3.2, 3.2, 2.2, 1))
    
    for (i in 1:K) {
      plot(
        idx, P_draws[idx, i, 1], type = "l", ylim = c(0, 1),
        xlab = "saved draw index",
        ylab = sprintf("row %d probs", i),
        main = sprintf("Row %d traces", i)
      )
      for (j in 2:K) lines(idx, P_draws[idx, i, j])
      
      # optional legend: comment out if too busy
      # legend("topright", legend = paste0("col ", 1:K), lty = 1, bty = "n", cex = 0.8)
    }
    
    if (interactive()) {
      dev.flush()
      dev.hold(FALSE)
    }
  }
}

accept_rate <- accept / n_iter
accept_rate

# ---------- posterior summaries ----------
P_mean <- apply(P_draws, c(2, 3), mean)
P_mean

# 95% pointwise credible intervals for each entry
P_q025 <- apply(P_draws, c(2, 3), quantile, probs = 0.025)
P_q975 <- apply(P_draws, c(2, 3), quantile, probs = 0.975)

P_q025
P_q975

# --- compare posterior mean to MLE (and show deltas) ---
P_mle <- sweep(N, 1, rowSums(N), "/")

P_mean
P_mle

deltas <- P_mean - P_mle
deltas

sqrt(sum(deltas^2))
