---
title: "Multivariate Gaussian marginal likelihood via THAMES"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{multivariate-gaussian}
  %\VignetteEngine{knitr::knitr}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
options(rmarkdown.html_vignette.check_title = FALSE)
```

```{r setup}
set.seed(2023)
library(thames)
```

To use the function `thames()` we only need two things 
  1. A $T\times d$ matrix of parameters drawn from the posterior distribution. The columns are the parameters (dimension $d$) and the rows are the $T$ different posterior draws.
  2. The vector of unnormalized log posterior values of length $T$ (sum of the log prior and the log likelihood for each drawn parameter).

To illustrate the use of the `thames()` function we give a toy example on multivariate Gaussian data. For more details on the method, see the paper
[Metodiev M, Perrot-Dockès M, Ouadah S, Irons N. J., Raftery A. E. (2023), Easily Computed Marginal Likelihoods from Posterior Simulation Using the THAMES Estimator.](https://arxiv.org/abs/2305.08952)

Here the data $y_i, i=1,\ldots,n$ are drawn independently
from a multivariate normal distribution:
\begin{eqnarray*}
y_i|\mu & \stackrel{\rm iid}{\sim} & {\rm MVN}_d(\mu,  I_d), \;\; i=1,\ldots, n,
\end{eqnarray*}
along with a prior distribution on the mean vector $\mu$:
\begin{equation*}
    p(\mu)= {\rm MVN}_d(\mu; 0_d, s_0 I_d),
\end{equation*}
with $s_0 > 0$. It can be shown that the posterior distribution of the mean vector $\mu$ given the data $D=\{y_1, \ldots, y_n\}$ is given by:
\begin{equation}\label{eq:postMultiGauss}
    p(\mu|D) = {\rm MVN}_d(\mu; m_n,s_n I_d),
\end{equation}
where $m_n=n\bar{y}/(n+1/s_0)$, $\bar{y}=(1/n)\sum_{i=1}^n y_i$, and $s_n=1/(n +1/s_0)$. 

### Toy example: d = 1, n = 20

We fix $s_0$ (the variance of the prior) and $\mu$ to 1 (note that the results are similar for some other values).

```{r}
s0      <- 1
mu_star <- 1
n       <- 20
d       <- 1
```

We simulate values of Y and calculate the associated $s_n$ and $m_n$

```{r}
library(mvtnorm)

Y =  rmvnorm(n, mu_star, diag(d))
sn = 1/(n+1/s0)
mn = sum(Y)/(n + 1/s0)
```


To use `thames()` we need  a sample drawn from the posterior of the parameters. In this toy example the only parameter is $\mu$, and we can draw from the posterior exactly. (More generally, MCMC can be used to obtain approximate posterior samples.) Here we take $2000$ samples.

```{r}
mu_sample = rmvnorm(2000, mean=mn,  sn*diag(d))
```

Now we calculate the unnormalized log posterior for each $\mu^{(i)}$.

We first calculate the prior on each sample: 

```{r}
reg_log_prior <- function(param,sig2) {
  d <- length(param)
  p <- dmvnorm(param, rep(0,d), (sig2)*diag(d), log = TRUE)
  return(p)
}

log_prior <- apply(mu_sample, 1, reg_log_prior, s0)
```

and the likelihood of the data for each sampled parameter: 

```{r}
reg_log_likelihood <- function(param, X) {
  n <- nrow(X)
  d <- length(param)
  sum(dmvnorm(X, param, diag(d),log = TRUE))
}

log_likelihood <- apply(mu_sample, 1, reg_log_likelihood, Y)
```


and then sum the two to get the log posterior:

```{r}
log_post <- log_prior + log_likelihood
```

We can now estimate the marginal likelihood using THAMES:

```{r}
result <- thames(log_post,mu_sample)
```


The THAMES estimate of the log marginal likelihood is then

```{r}
-result$log_zhat_inv
```

The upper and lower bounds of a confidence interval based on asymptotic normality of the estimator (a 95\% interval by default) are

```{r}
-result$log_zhat_inv_L 
-result$log_zhat_inv_U
```

If we instead want a 90\% confidence interval, we specify a lower quantile of $0.05$: 

```{r}
result_90 <- thames(log_post,mu_sample, p = 0.05)
-result_90$log_zhat_inv_L 
-result_90$log_zhat_inv_U
```

To check our estimate, we can calculate the Gaussian log marginal likelihood analytically as
$$
\ell(y) = -\frac{nd}{2}\log(2\pi)-\frac{d}{2}\log(s_0n+1)-\frac{1}{2}\sum_{i=1}^n\|y_i\|^2+ \frac{n^2}{2(n+1/s_0)}\|\bar{y}\|^2.
$$

```{r}
- n*d*log(2*pi)/2 - d*log(s0*n+1)/2 - sum(Y^2)/2 + sum(colSums(Y)^2)/(2*(n+1/s0))
```


### In higher dimensions

We can do exactly the same thing if $Y_i \in \mathbb{R}^d$ with $d>1$ 
```{r}
s0      <- 1
n       <- 20
d       <- 2
mu_star <- rep(1,d)
Y =  rmvnorm(n, mu_star, diag(d))
sn = 1/(n+1/s0)
mn = apply(Y,2,sum)/(n + 1/s0)
mu_sample = rmvnorm(2000, mean=mn,  sn*diag(d))

mvg_log_post <- function(param, X, sig2){
  n <- nrow(X)
  d <- length(param)
  l <- sum(dmvnorm(X, param, diag(d),log = TRUE))
  p <- dmvnorm(param, rep(0,d), (sig2)*diag(d), log = TRUE)
  return(p + l)
}
log_post <- apply(mu_sample, 1,mvg_log_post,Y,s0)
result <- thames(log_post,mu_sample)
-result$log_zhat_inv
```

To check our estimate, we again calculate the Gaussian marginal likelihood analytically:

```{r}
- n*d*log(2*pi)/2 - d*log(s0*n+1)/2 - sum(Y^2)/2 + sum(colSums(Y)^2)/(2*(n+1/s0))
```
