############################################################
# Gibbs sampler for:
#   X | n,p ~ Binomial(n, p)   (observe X = x)
#   n | lam ~ Poisson(lam)
#   p ~ Beta(alph, bet)
#
# Full conditionals:
#   p | n,x ~ Beta(alph + x, bet + n - x)
#   n | p,x ~ x + Poisson(lam * (1 - p))
############################################################
library(MASS)

# ----- inputs -----
alph  <- 1
bet   <- 1
lam   <- 30
x_obs <- 15

# # ----- plot priors -----
# curve(dbeta(x, alph, bet), from = 0, to = 1, xlab = "p", ylab = "f(p)")
# curve(dpois(x, lam), from = 0, to = 2 * lam, xlab = "n", ylab = "Pr(n)")

# ----- MCMC settings -----
n_iter <- 50000
burn   <- 2000
set.seed(1)

# ----- initialize (must have n >= x_obs, and p in (0,1)) -----
n_curr <- max(x_obs, lam)           # simple default
p_curr <- (x_obs + alph) / (n_curr + alph + bet)

# ----- storage -----
keep <- n_iter - burn
n_draws <- integer(keep)
p_draws <- numeric(keep)

# ----- Gibbs sampler -----
k <- 0
for (t in 1:n_iter) {
  # p | n, x_obs
  p_curr <- rbeta(1, alph + x_obs, bet + (n_curr - x_obs))
  
  # n | p, x_obs
  n_curr <- x_obs + rpois(1, lam * (1 - p_curr))
  
  # save
  if (t > burn) {
    k <- k + 1
    n_draws[k] <- n_curr
    p_draws[k] <- p_curr
  }
}

# ----- diagnostics -----
# trace plots
plot(p_draws, type = "l", xlab = "iter (post-burn)", ylab = "p")
plot(n_draws, type = "l", xlab = "iter (post-burn)", ylab = "n")

# scatterplot
plot(n_draws, p_draws, col = adjustcolor("black", alpha.f = 0.2), cex = 0.5)

# smooth to see 2D density
smoothScatter(n_draws, p_draws, xlab = "n", ylab = "p",
              ylim = c(0, 1), xlim = c(max(x_obs), 2*lam))

# notice "ridge"
curve(x_obs / x, add = TRUE, col = "red", lwd = 2)
text(x = 50, y = 0.35, srt = -10, labels = "p = x/n", col = "red")

# hpd region bends around ridge but is quite diffuse -- n, p not well identified
add_hpd_region <- function(x, y, prob = 0.95, ngrid = 500,
                           col = "red", lwd = 2, lty = 1,
                           hx = NULL, hy = NULL) {
  if (is.null(hx) && is.null(hy)) {
    kd <- kde2d(x, y, n = ngrid)
  } else {
    stopifnot(!is.null(hx), !is.null(hy))
    kd <- kde2d(x, y, n = ngrid, h = c(hx, hy))
  }
  
  z  <- kd$z
  dx <- kd$x[2] - kd$x[1]
  dy <- kd$y[2] - kd$y[1]
  
  z_sorted <- sort(as.vector(z), decreasing = TRUE)
  cum_mass <- cumsum(z_sorted) * dx * dy
  z_star   <- z_sorted[which(cum_mass >= prob)[1]]
  
  contour(kd$x, kd$y, kd$z,
          levels = z_star, add = TRUE, drawlabels = FALSE,
          col = col, lwd = lwd, lty = lty)
  
  invisible(list(kd = kd, level = z_star))
}

add_hpd_region(n_draws, p_draws, prob = 0.95, ngrid = 300, hx = 8, hy = 0.1)

# ----- estimates and marginal intervals -----
# note width <> high uncertainty
c(p_mean = mean(p_draws),
  p_q025 = unname(quantile(p_draws, 0.025)),
  p_q975 = unname(quantile(p_draws, 0.975)))

c(n_mean = mean(n_draws),
  n_q025 = unname(quantile(n_draws, 0.025)),
  n_q975 = unname(quantile(n_draws, 0.975)))

############################################################
# Gibbs sampler for:
#   Xi | n,p ~ Binomial(n, p)   (observe Xi = x_i, i=1..m)
#   n | lam ~ Poisson(lam)
#   p ~ Beta(alph, bet)
#
# Full conditionals:
#   p | n, x_obs ~ Beta(alph + S, bet + m*n - S)
#   n | p, x_obs : MH step (m > 1), closed form only when m = 1
############################################################

# ----- inputs -----
alph  <- 1
bet   <- 1
lam   <- 30
x_obs <- c(15, 18, 6, 15, 20, 11)

# ----- sufficient stats -----
m <- length(x_obs)
s <- sum(x_obs)
l <- max(x_obs)  # lower bound for n

# # ----- plot priors -----
# curve(dbeta(x, alph, bet), from = 0, to = 1, xlab = "p", ylab = "f(p)")
# curve(dpois(x, lam), from = 0, to = 2 * lam, xlab = "n", ylab = "Pr(n)")

# ----- helpers -----
log_fullcond_n <- function(n, p, x_obs, lam, l) {
  if (n < l) return(-Inf)
  dpois(n, lam, log = TRUE) + sum(dbinom(x_obs, size = n, prob = p, log = TRUE))
}

rw_prop_n <- function(n, l, s_n) {
  n_raw <- as.integer(round(n + rnorm(1, 0, s_n)))
  l + abs(n_raw - l)  # reflect at l, keeps symmetry on {l, l+1, ...}
}

# ----- MCMC settings -----
n_iter <- 50000
burn   <- 2000
set.seed(1)

# ----- initialize -----
n_curr <- max(lam, l)
p_curr <- (s + alph) / (m * n_curr + alph + bet)

# ----- storage -----
keep <- n_iter - burn
n_draws <- integer(keep)
p_draws <- numeric(keep)

# ----- Gibbs sampler (MH-within-Gibbs for n when m>1) -----
s_n <- 10  # tune
k <- 0
for (t in 1:n_iter) {
  # p | n, x_obs
  p_curr <- rbeta(1, alph + s, bet + (m * n_curr - s))
  
  # n | p, x_obs
  if (m == 1L) {
    n_curr <- s + rpois(1, lam * (1 - p_curr))
  } else {
    n_prop <- rw_prop_n(n_curr, l, s_n)
    log_r  <- log_fullcond_n(n_prop, p_curr, x_obs, lam, l) -
      log_fullcond_n(n_curr, p_curr, x_obs, lam, l)
    if (log(runif(1)) < min(0, log_r)) n_curr <- n_prop
  }
  
  # save
  if (t > burn) {
    k <- k + 1
    n_draws[k] <- n_curr
    p_draws[k] <- p_curr
  }
}

# ----- diagnostics -----
# trace plots look good
plot(p_draws, type = "l", xlab = "iter (post-burn)", ylab = "p")
plot(n_draws, type = "l", xlab = "iter (post-burn)", ylab = "n")

# parameters are better identified -- posterior is less diffuse
smoothScatter(n_draws, p_draws, xlab = "n", ylab = "p", 
              ylim = c(0, 1), xlim = c(max(x_obs), 2*lam))

# ridge is still there though -- this is inherent to the model
curve(mean(x_obs) / x, add = TRUE, col = "red", lwd = 2)
text(x = 57, y = 0.32, srt = -5, 
     labels = expression(paste('p = ', bar(x)/n)), col = "red")

# hpd region more concentrated <> less uncertainty
add_hpd_region(n_draws, p_draws, hx = 5, hy = 0.1)

# mixing (univariate)
mcmcse::ess(n_draws)
mcmcse::ess(p_draws)

# mixing (multivariate)
x <- posterior::as_draws_matrix(cbind(n = log(n_draws), p = p_draws))
posterior::ess_bulk(x) # effective sample size for center
posterior::ess_tail(x) # effective sample size for tail probabilities

# ----- estimates -----
c(p_mean = mean(p_draws),
  p_q025 = unname(quantile(p_draws, 0.025)),
  p_q975 = unname(quantile(p_draws, 0.975)))

c(n_mean = mean(n_draws),
  n_q025 = unname(quantile(n_draws, 0.025)),
  n_q975 = unname(quantile(n_draws, 0.975)))