## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(echo = TRUE)

## ----simulate-data, eval=F----------------------------------------------------
# library(refundBayes)
# 
# set.seed(42)
# 
# # --- Dimensions ---
# n  <- 200   # number of subjects
# L  <- 30    # number of predictor-domain grid points
# M  <- 30    # number of response-domain grid points
# 
# sindex <- seq(0, 1, length.out = L)   # predictor domain grid
# tindex <- seq(0, 1, length.out = M)   # response domain grid
# 
# # --- Functional predictor X(s): smooth random curves ---
# X_func <- matrix(0, nrow = n, ncol = L)
# for (i in 1:n) {
#   X_func[i, ] <- rnorm(1) * sin(2 * pi * sindex) +
#                  rnorm(1) * cos(2 * pi * sindex) +
#                  rnorm(1) * sin(4 * pi * sindex) +
#                  rnorm(1, sd = 0.3)
# }
# 
# # --- Scalar predictor ---
# age <- rnorm(n)
# 
# # --- True coefficient functions ---
# # Bivariate coefficient: beta(s, t) = sin(2*pi*s) * cos(2*pi*t)
# beta_true <- outer(sin(2 * pi * sindex), cos(2 * pi * tindex))
# 
# # Scalar coefficient function: alpha(t) = 0.5 * sin(pi*t)
# alpha_true <- 0.5 * sin(pi * tindex)
# 
# # --- Generate functional response ---
# # Y_i(t) = age_i * alpha(t) + integral X_i(s) beta(s,t) ds + epsilon_i(t)
# signal_scalar <- outer(age, alpha_true)                    # n x M
# signal_func   <- (X_func %*% beta_true) / L               # n x M  (Riemann sum)
# epsilon        <- matrix(rnorm(n * M, sd = 0.3), nrow = n) # n x M
# 
# Y_mat <- signal_scalar + signal_func + epsilon
# 
# # --- Organize data ---
# dat <- data.frame(age = age)
# dat$Y_mat  <- Y_mat
# dat$X_func <- X_func
# dat$sindex <- matrix(rep(sindex, n), nrow = n, byrow = TRUE)

## ----fit-model, eval=F--------------------------------------------------------
# fit_fofr <- fofr_bayes(
#   formula     = Y_mat ~ age + s(sindex, by = X_func, bs = "cr", k = 10),
#   data        = dat,
#   spline_type = "bs",
#   spline_df   = 10,
#   niter       = 2000,
#   nwarmup     = 1000,
#   nchain      = 3,
#   ncores      = 3
# )

## ----plot-bivar, eval=F-------------------------------------------------------
# # Posterior mean of the bivariate coefficient
# beta_est  <- apply(fit_fofr$bivar_func_coef[[1]], c(2, 3), mean)
# 
# # Pointwise 95% credible interval bounds
# beta_lower <- apply(fit_fofr$bivar_func_coef[[1]], c(2, 3),
#                     function(x) quantile(x, 0.025))
# beta_upper <- apply(fit_fofr$bivar_func_coef[[1]], c(2, 3),
#                     function(x) quantile(x, 0.975))
# 
# # Side-by-side heatmaps: true vs estimated vs difference
# par(mfrow = c(1, 3), mar = c(4, 4, 2, 1))
# image(sindex, tindex, beta_true,
#       xlab = "s (predictor domain)", ylab = "t (response domain)",
#       main = expression("True " * beta(s, t)),
#       col = hcl.colors(64, "Blue-Red 3"))
# image(sindex, tindex, beta_est,
#       xlab = "s (predictor domain)", ylab = "t (response domain)",
#       main = expression("Estimated " * hat(beta)(s, t)),
#       col = hcl.colors(64, "Blue-Red 3"))
# image(sindex, tindex, beta_est - beta_true,
#       xlab = "s (predictor domain)", ylab = "t (response domain)",
#       main = "Difference (Est - True)",
#       col = hcl.colors(64, "Blue-Red 3"))

## ----plot-scalar, eval=F------------------------------------------------------
# alpha_est   <- apply(fit_fofr$scalar_func_coef[, 1, ], 2, mean)
# alpha_lower <- apply(fit_fofr$scalar_func_coef[, 1, ], 2,
#                      function(x) quantile(x, 0.025))
# alpha_upper <- apply(fit_fofr$scalar_func_coef[, 1, ], 2,
#                      function(x) quantile(x, 0.975))
# 
# par(mfrow = c(1, 1))
# plot(tindex, alpha_true, type = "l", lwd = 2, col = "black",
#      ylim = range(c(alpha_lower, alpha_upper)),
#      xlab = "t (response domain)", ylab = expression(alpha(t)),
#      main = "Scalar coefficient function: age")
# lines(tindex, alpha_est, col = "blue", lwd = 2)
# polygon(c(tindex, rev(tindex)),
#         c(alpha_lower, rev(alpha_upper)),
#         col = rgb(0, 0, 1, 0.2), border = NA)
# legend("topright",
#        legend = c("Truth", "Posterior mean", "95% CI"),
#        col = c("black", "blue", rgb(0, 0, 1, 0.2)),
#        lwd = c(2, 2, 10), bty = "n")

## ----plot-slices, eval=F------------------------------------------------------
# # Fix s at the midpoint of the predictor domain and plot beta(s_mid, t)
# s_mid_idx <- which.min(abs(sindex - 0.5))
# 
# beta_slice_est   <- apply(fit_fofr$bivar_func_coef[[1]][, s_mid_idx, ], 2, mean)
# beta_slice_lower <- apply(fit_fofr$bivar_func_coef[[1]][, s_mid_idx, ], 2,
#                           function(x) quantile(x, 0.025))
# beta_slice_upper <- apply(fit_fofr$bivar_func_coef[[1]][, s_mid_idx, ], 2,
#                           function(x) quantile(x, 0.975))
# beta_slice_true  <- beta_true[s_mid_idx, ]
# 
# plot(tindex, beta_slice_true, type = "l", lwd = 2, col = "black",
#      ylim = range(c(beta_slice_lower, beta_slice_upper)),
#      xlab = "t (response domain)",
#      ylab = expression(beta(s[mid], t)),
#      main = paste0("Slice at s = ", round(sindex[s_mid_idx], 2)))
# lines(tindex, beta_slice_est, col = "red", lwd = 2)
# polygon(c(tindex, rev(tindex)),
#         c(beta_slice_lower, rev(beta_slice_upper)),
#         col = rgb(1, 0, 0, 0.2), border = NA)
# legend("topright",
#        legend = c("Truth", "Posterior mean", "95% CI"),
#        col = c("black", "red", rgb(1, 0, 0, 0.2)),
#        lwd = c(2, 2, 10), bty = "n")

## ----summary, eval=F----------------------------------------------------------
# # RMSE of the bivariate coefficient surface
# cat("RMSE of beta(s,t):", sqrt(mean((beta_est - beta_true)^2)), "\n")
# 
# # RMSE of the scalar coefficient function
# cat("RMSE of alpha(t): ", sqrt(mean((alpha_est - alpha_true)^2)), "\n")

## ----inspect-code, eval=F-----------------------------------------------------
# # Generate Stan code without running the sampler
# fofr_code <- fofr_bayes(
#   formula     = Y_mat ~ age + s(sindex, by = X_func, bs = "cr", k = 10),
#   data        = dat,
#   spline_type = "bs",
#   spline_df   = 10,
#   runStan     = FALSE
# )
# 
# # Print the generated Stan code
# cat(fofr_code$stancode)

