---
title: "RStanTVA Parameter Recovery Study"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{RStanTVA Parameter Recovery Study}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  warning = FALSE
)

```


In this vignette, we are reporting results for a parameter recovery of the `tva_recovery` data set. We need packages `tibble`, `tidyr` and `dplyr` for data wrangling, `ggplot2` for plotting, `RStanTVA` for the actual analyses, and `brms` for specifying priors.

```{r setup}
library(tibble)
library(dplyr)
library(tidyr)
library(ggplot2)
library(RStanTVA)
library(brms)
```

We want to make use of the parallelization built into Stan by activating the corresponding option in RStan:

```{r}
rstan_options(threads_per_chain = parallel::detectCores())
```

# Generative model

## Level 1: Trial

Each trial is simulated assuming partial report in TVA with a global processing capacity $C$, top-down selectivity $\alpha$, hemispheric bias ($r$), probit-distributed memory capacity $K$ (parameters $\mu_K$ and $\sigma_K$), and Gaussian sensory thresholds $t_0$ (parameters $\mu_0$ and $\sigma_0$).

For easier understanding, we will later on only look at the expected $K$ values instead of the distributional parameters $\mu_K$ and $\sigma_K$. The expected value (mean) can be derived from $(\mu_K,\sigma_K)$ samples like this:

```{r meanprobitk}
meanprobitK <- Vectorize(function(mK, sigmaK, nS) {
  p <- pnorm(1:nS, mK, sigmaK)
  lps <- c(0, p)
  ups <- c(p, 1)
  sum(0:nS * (ups - lps))
}, c("mK","sigmaK"))
```

The same is often done for $t_0$, where the expected value of a Gaussian $t_0$ is simply $\mu_0$.

*Note:* For a demonstration of simulated TVA trials, see the `ShinyTVA` app, which can be run as follows:

```r
shiny::runGitHub("mmrabe/ShinyTVA")
```

## Level 2: Experimental condition

Processing speed $C$ differs between the "low" and "high" conditions so that $C_\textrm{low}=`r exp(tva_recovery_true_params$b["C_Intercept"])`$ and $C_\textrm{high}=`r exp(tva_recovery_true_params$b["C_Intercept"]+tva_recovery_true_params$b["C_conditionhigh"])`$. There are no true differences between the conditions for any of the other model parameters. 

## Level 3: Subject-level effects (hyperparameters)

Each of the `r n_distinct(tva_recovery$subject)` subjects has their individual parameter configuration and condition-level effects, and all of those parameters were randomly sampled from a multivariate normal distribution so that:

$$
\begin{pmatrix}
a_{\log C}\\
b_{\log C}\\
a_{\log\alpha}\\
b_{\log\alpha}\\
a_{\mu_K}\\
b_{\mu_K}\\
a_{\log\sigma_K}\\
b_{\log\sigma_K}\\
a_{\log\sigma_0}\\
b_{\log\sigma_0}\\
a_{\mu_0}\\
b_{\mu_0}\\
a_{\textrm{logit}(r)}\\
b_{\textrm{logit}(r)}
\end{pmatrix}\sim\mathcal{N}_{`r length(tva_recovery_true_params$b)`}\left(\begin{pmatrix}
`r paste0(tva_recovery_true_params$b, collapse = "\\\\")`
\end{pmatrix},\begin{pmatrix}
`r (tva_recovery_true_params$r_subject * tcrossprod(tva_recovery_true_params$s_subject)) %>% apply(1, paste, collapse = "&") %>% paste(collapse="\\\\")`
\end{pmatrix}\right)
$$

## Simulations


The data frame `tva_recovery` contains simulated CombiTVA trials for `r n_distinct(tva_recovery$subject)` subjects (`r nrow(tva_recovery)/n_distinct(tva_recovery$subject)` trials per subject, =`r nrow(tva_recovery)` trials total). Half of the trials of each subject were performed in a "low" condition and the other half in a "high" condition:

```{r load-data}
data("tva_recovery")
head(tva_recovery)
```

The package also contains the true values used for the simulation (including the subject-level parameters):

```{r load-true-vals}
data("tva_recovery_true_params")
str(tva_recovery_true_params)
```

These can be conveniently converted to data frame for later comparison with fitted values:

```{r narrow-true-vals}
true_params <- tva_recovery_true_params$coef_subject

true_params_narrow <- true_params %>% mutate(`italic(w)[lambda]` = w[,2], `italic(K)` = meanprobitK(mK,sK,6))  %>% select(-w,-mK,-sK) %>% rename(`italic(t)[0]` = mu0, `sigma[0]` = sigma0, `italic(C)` = C) %>% pivot_longer(-c(subject,condition), names_to = "param", values_to = "true_value")
head(true_params_narrow)

true_params_narrow2 <- bind_rows(
  true_params_narrow %>% filter(condition == "low" & param != "italic(C)") %>% select(-condition),
  true_params_narrow %>% filter(param == "italic(C)") %>% transmute(subject, param = sprintf("%s[%s]", param, condition), true_value)
)
head(true_params_narrow2)
```


*Note:* The R script that generates `tva_recovery` is located in the package source at `data-raw/tva_recovery.R`.



# Model likelihood

For a single trial, given the effective exposure duration $\tau=t-t_0$ of the set of displayed items $S$, and predicted individual processing rates $\mathbf v$ for those items, the probability of a memory report $M$ is given as:

$$
P_{\mathrm{M}}\left(M\mid\tau,S,K,\mathbf{v}\right)
=\begin{cases}
1 & \textrm{if }M=\emptyset\textrm{ and }K=0\\
1 & \textrm{if }M=\emptyset\textrm{ and }\tau=0\\
\bar{F}\left(\tau;\sum_{z\in S}v_{z}\right) & \textrm{if }M=\emptyset\textrm{ and }K>0\textrm{ and }\tau>0\\
\prod_{y\in M}F\left(\tau;v_{y}\right)\prod_{z\in S\setminus M}\bar{F}\left(\tau;v_{z}\right) & \textrm{if }0<\left|M\right|<K\leq\left|S\right|\textrm{ and }\tau>0\\
\prod_{y\in M}F\left(\tau;v_{y}\right)\prod_{z\in S\setminus M}\bar{F}\left(\tau;v_{z}\right) & \textrm{if }0<\left|M\right|=\left|S\right|\leq K\textrm{ and }\tau>0\\
\sum_{x\in M}\int_{0}^{\tau}f\left(t;v_{x}\right)\prod_{y\in M\setminus\left\{ x\right\} }F\left(t;v_{y}\right)\prod_{z\in S\setminus M}\bar{F}\left(t;v_{z}\right)\mathrm{d}t & \textrm{if }0<\left|M\right|=K<\left|S\right|\textrm{ and }\tau>0\\
0 & \textrm{otherwise}
\end{cases}\
$$

For partial reports (i.e., when the display set contains distractors $S_{{\rm D}}$ and targets $S_{{\rm T}}$, we need to iterate over the set of potential subsets of $S_{{\rm D}}$, powerset $\mathcal{P}_{\leq K-\left|R_{{\rm T}}\right|}\left(S_{{\rm D}}\right)$, that could have made it into VSTM before $\tau$, in addition to the actually reported target items $R_\mathrm{T}$:

$$
P_{{\rm PR}}\left(R_{{\rm T}}\mid\tau,S_{{\rm T}},S_{{\rm D}},K,\mathbf{v}\right)
=\begin{cases}
\sum_{R_{{\rm D}}\in\mathcal{P}_{\leq K-\left|R_{{\rm T}}\right|}\left(S_{{\rm D}}\right)}P_{{\rm M}}\left(R_{{\rm T}}\cup R_{{\rm D}}\mid\tau,S_{{\rm T}}\cup S_{{\rm D}},K,\mathbf{v}\right) & \textrm{if }S_{{\rm D}}\neq\emptyset\textrm{ and }\tau\geq0\\
P_{{\rm M}}\left(R_{{\rm T}}\mid\tau,S_{{\rm T}},K,\mathbf{v}\right) & \textrm{if }S_{{\rm D}}=\emptyset\textrm{ and }\tau\geq0\\
0 & \textrm{otherwise}
\end{cases}
$$

Note that $P_{{\rm PR}}$ simplifies to $P_{\mathrm{M}}$ if there were no distractors ($S_{{\rm D}}=\emptyset$), i.e. when the trial was effectively a whole-report trial.

# Parameter recovery

Being based on RStan, RStanTVA can be used in different ways, i.e. using maximum-likelihood estimation or Bayesian inference of different hierarchical complexity (single-condition, single-participant/fixed-effects, fully-hierarchical/mixed-effects). Hence we can also recover model parameters in different ways.

## Maximum-likelihood estimation (MLE)

Using MLE, we can define four different inference models:

### RStanTVA-ML: Single participant, single condition

The most straightforward model uses no regularization by priors (`priors = FALSE`) and conducts single fits for all model parameters:

```r
m <- stantva_model(
  locations = 6,
  task = "pr",
  regions = list(left = 1:3, right = 4:6),
  w_mode = "regions",
  t0_mode = "gaussian",
  K_mode = "probit",
  save_log_lik = FALSE,
  priors = FALSE
)
```

### RStanTVA-MLN: Single participant, condition as fixed effect

By defining a regression model for $\log C$, which includes an intercept and a treatment effect of "high" vs. "low" (`log(C) ~ 1 + condition`), we can model both experimental conditions in a single model:

```r
m_nested <- stantva_model(
  formula = list(log(C) ~ 1 + condition),
  locations = 6,
  task = "pr",
  regions = list(left = 1:3, right = 4:6),
  w_mode = "regions",
  t0_mode = "gaussian",
  K_mode = "probit",
  save_log_lik = FALSE,
  priors = FALSE
)
```

### RStanTVA-MLR: Single participant, single condition, with regularization (priors)

Omitting the `priors = FALSE` argument falls back to the default `priors = NULL`, which uses vaguely informative default priors for all parameters. The model is otherwise identical to the baseline RStanTVA-ML model above:

```r
m_reg <- stantva_model(
  locations = 6,
  task = "pr",
  regions = list(left = 1:3, right = 4:6),
  w_mode = "regions",
  t0_mode = "gaussian",
  K_mode = "probit",
  save_log_lik = FALSE
)
```

### RStanTVA-MLRN: Single participant, condition as fixed effect, with regularization (priors)

When using regularizing priors, if a parameter is described by an equation such as $C$, we need to specify priors for intercepts and slopes:

```r
priors <-
  prior(normal(0,.5),dpar=C)+
  prior(normal(4.5,.6),dpar=C,coef=Intercept)

m_nested_reg <- stantva_model(
  formula = list(log(C) ~ 1 + condition),
  locations = 6,
  task = "pr",
  regions = list(left = 1:3, right = 4:6),
  w_mode = "regions",
  t0_mode = "gaussian",
  K_mode = "probit",
  save_log_lik = FALSE,
  priors = priors
)
```

For brevity, we are only reporting results of the RStanTVA-MLRN fits here, since those have been found to be most reliable (see manuscript). The model can be fitted to the individual data sets (each participant) as follows:

```{r fit_mle_nested_reg, include=FALSE}
fit_mle_nested_reg <- readRDS("tva_recovery_cache/fit_mle_nested_reg.rds")
```

```r
fit_mle_nested_reg <- lapply(1:50, function(i) {
  d <- tva_recovery %>% filter(subject == i)
  repeat {
    p <- optimizing(m_nested_reg, d)
    if(p$return_code == 0L) break
  }
  tibble(subject = i, param = c("italic(C)[low]","italic(C)[high]","alpha","italic(w)[lambda]","italic(t)[0]","sigma[0]","italic(K)"), fitted_value = c(exp(p$par["C_Intercept"]), exp(p$par["C_Intercept"]+p$par["C_conditionhigh"]), p$par[c("alpha","r[2]","mu0","sigma0")], meanprobitK(p$par["mK"],p$par["sK"],6)), converged = p$return_code == 0L)
}) %>% bind_rows() %>% left_join(true_params_narrow2)
```

The critical function above is `optimizing(m_nested_reg, d)`, which uses MLE to fit `m_nested_reg` to the subject-level subset `d`.


## RStanTVA-NB: Bayesian non-hierarchical inference

We can, in principle, fit the exact same models as above to the data using Bayesian inference instead of MLE, simply by replacing the `optimizing(...)` with a `sampling(...)` function call. So the following code will fit the RStanTVA-MLRN (`m_nested_reg`) model to subject-level subsets `d` using HMC methods in Stan:

```{r fit_bayes_fixed, include=FALSE}
fit_bayesian_nested <- readRDS("tva_recovery_cache/fit_bayesian_nested.rds")
```


```r
fit_bayesian_nested <- lapply(1:50, function(i) {
  d <- tva_recovery %>% filter(subject == i)
  sf <- sampling(m_nested_reg, d)
  p1 <- predict(sf, data.frame(subject = i, condition = "low"), c("C","alpha", "r", "mu0","sigma0","mK","sK"))
  p2 <- predict(sf, data.frame(subject = i, condition = "high"), "C")
  tibble(
    `italic(C)[low]` = t(p1$C),
    `italic(C)[high]` = t(p2$C),
    `alpha` = t(p1$alpha),
    `italic(w)[lambda]` = t(p1$r[,2,1]),
    `italic(t)[0]` = t(p1$mu0),
    `sigma[0]` = t(p1$sigma0),
    `italic(K)` = t(meanprobitK(p1$mK,p1$sK,6))
  ) %>%
    pivot_longer(everything(), names_to = "param", values_to = "samples") %>%
    rowwise(param) %>%
    reframe(subject = i, posterior_sd = sd(samples), fitted_value = median(samples), cri = t(quantile(samples, c(.025,.975))), converged = Rhat(matrix(samples, ncol = sf@sim$chains)) < 1.1)
}) %>% 
  bind_rows() %>%
  left_join(true_params_narrow2)
```

Note that we must use the method `predict(fit, data, parameters)` to predict the TVA parameter values from the fitted parameters. This will return a posterior distribution for each row (subject) in `data` and each specified parameter. We are then calculating 95% credible intervals (CrIs), medians, $\hat R$ and posteriors SDs as measures of goodness-of-fit.


## RStanTVA-HB*: Bayesian hierarchical inference

For hierarchical inference, we define mixed-effects regressions for the model parameters $C$, $\log\alpha$, $\mu_K$, $\mu_0$, and $\textrm{logit}(r)$, while keeping $\sigma_0$ and $\sigma_K$ fixed across all subjects:

```r
m_hierarchical <- stantva_model(
  formula = list(
    C ~ 1 + condition + (1 + condition | subject), 
    log(alpha) ~ 1 + (1 | subject), 
    mK ~ 1 + (1 | subject), 
    mu0 ~ 1 + (1 | subject), 
    log(r) ~ 1 + (1 | subject)
  ),
  locations = 6,
  task = "pr",
  regions = list(left = 1:3, right = 4:6),
  C_mode = "equal",
  w_mode = "regions",
  t0_mode = "gaussian",
  K_mode = "probit",
  priors = 
    prior(normal(90, 20/sqrt2()), coef = Intercept, dpar = C) + 
    prior(normal(0, 10/sqrt2()), dpar = C) + 
    prior(gamma(2, 2/20 * sqrt2()), class = sd, coef = Intercept, dpar = C) + 
    prior(gamma(2, 2/10 * sqrt2()), class = sd, dpar = C) + 
    prior(normal(-0.4, 0.6/sqrt2()), coef = Intercept, dpar = alpha) + 
    prior(normal(0, 0.3/sqrt2()), dpar = alpha) + 
    prior(gamma(2, 2/0.6 * sqrt2()), class = sd, coef = Intercept, dpar = alpha) + 
    prior(gamma(2, 2/0.3 * sqrt2()), class = sd, dpar = alpha) + 
    prior(normal(0, 0.1/sqrt2()), coef = Intercept, dpar = r) + 
    prior(normal(0, 0.05/sqrt2()), dpar = r) + 
    prior(gamma(2, 2/0.1 * sqrt2()), class = sd, coef = Intercept, dpar = r) + 
    prior(gamma(2, 2/0.05 * sqrt2()), class = sd, dpar = r) + 
    prior(normal(3.5, 0.5/sqrt2()), coef = Intercept, dpar = mK) + 
    prior(normal(0, 0.25/sqrt2()), dpar = mK) + 
    prior(gamma(2, 2/0.5 * sqrt2()), class = sd, coef = Intercept, dpar = mK) + 
    prior(gamma(2, 2/0.25 * sqrt2()), class = sd, dpar = mK) + 
    prior(normal(20, 15/sqrt2()), coef = Intercept, dpar = mu0) +
    prior(normal(0, 7.5/sqrt2()), dpar = mu0) + 
    prior(gamma(2, 2/15 * sqrt2()), class = sd, coef = Intercept, dpar = mu0) + 
    prior(gamma(2, 2/7.5 * sqrt2()), class = sd, dpar = mu0) + 
    prior(normal(0.5, 0.05/sqrt2()), coef = Intercept, dpar = sK) + 
    prior(normal(0, 0.025/sqrt2()), dpar = sK) + 
    prior(gamma(2, 2/0.05 * sqrt2()), class = sd, coef = Intercept, dpar = sK) + 
    prior(gamma(2, 2/0.025 * sqrt2()), class = sd, dpar = sK) + 
    prior(normal(20, 15/sqrt2()), coef = Intercept, dpar = sigma0) + 
    prior(normal(0, 7.5/sqrt2()), dpar = sigma0) + 
    prior(gamma(2, 2/15 * sqrt2()), class = sd, coef = Intercept, dpar = sigma0) + 
    prior(gamma(2, 2/7.5 * sqrt2()), class = sd, dpar = sigma0)
)
```

We can then simply fit the hierarchical model `m_hierarchical` to the entire data set `tva_recovery` without subsetting:

```{r fit_bayesian_hierarchical_globals, include=FALSE}
fit_bayesian_hierarchical_globals <- readRDS("tva_recovery_cache/fit_bayesian_hierarchical_globals.rds")
```

```r
fg <- sampling(m_hierarchical, tva_recovery)

fit_bayesian_hierarchical_globals <-
  {
    p1 <- predict(fg, tibble(subject = 1:50, condition = "low"), c("C","alpha", "r", "mu0","sigma0","mK","sK"))
    p2 <- predict(fg, tibble(subject = 1:50, condition = "high"), "C")
    tibble(
      subject = 1:50,
      `italic(C)[low]` = t(p1$C),
      `italic(C)[high]` = t(p2$C),
      `alpha` = t(p1$alpha),
      `italic(w)[lambda]` = t(p1$r[,2,]),
      `italic(t)[0]` = t(p1$mu0),
      `sigma[0]` = t(p1$sigma0),
      `italic(K)` = t(matrix(meanprobitK(p1$mK,p1$sK,6),ncol=50))
    )
  } %>%
  pivot_longer(-subject, names_to = "param", values_to = "samples") %>%
  rowwise(c(subject,param)) %>%
  reframe(posterior_sd = sd(samples), fitted_value = median(samples), cri = t(quantile(samples, c(.025,.975))), converged = Rhat(t(samples)) < 1.1) %>%
  left_join(true_params_narrow2)

```


## Comparison of MLE and Bayes

```{r plot-bayes, fig.width = 8, fig.height = 5, dpi = 100}
figdat <- bind_rows(
  fit_mle_nested_reg %>% add_column(method = "MLRN"),
  fit_bayesian_hierarchical_globals %>% add_column(method = "HB"),
  fit_bayesian_nested %>% add_column(method = "NB")
) %>%
  mutate(method = factor(method, levels = c("MLRN","NB","HB")))

fig <- ggplot(figdat) +
  theme_minimal() +
  theme(text = element_text(family = "sans", size = 10)) +
  labs(x = "True value", y = "Recovered value") +
  facet_wrap(method~param, ncol = 7, scales = "free", labeller = function(x) label_parsed(list(sprintf("%s~(\"%s\")", x$param, x$method)))) +
  geom_abline(linetype = "dashed") +
  geom_linerange(aes(x=true_value,ymin=cri[,1],ymax=cri[,2]), color = "gray50", linewidth = 0.2) +
  geom_point(aes(x=true_value,y=fitted_value), size = 0.5, color = "blue") +
  geom_point(aes(x=vx,y=vy), color = "transparent", figdat %>% group_by(param) %>% reframe(vx = range(true_value), vy = range(c(fitted_value,cri), na.rm = TRUE)))

print(fig)

```





