---
author: "Metodiev Martin"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteEngine{knitr::knitr}
  %\VignetteIndexEntry{univmixtures}
  %\usepackage[UTF-8]{inputenc}
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = FALSE)
library(gor)
library(igraph)
library(parallel)
library(Rfast)
library(quadprog)
library(multimode)
```

## Preparing the data

Load the galaxy dataset:

```{r load data}
library(multimode)
# load in galaxy data
y <- multimode::galaxy/1000
n <- length(y)

hist(y)
rug(y)
```

## Computing the sample from MCMC output using bayesmix

Define the logposterior as a function of G and y:

```{r logposterior}

R <- diff(range(y))
m <- mean(range(y))

# likelihood
loglik_gmm <- function(y,sims,G){
  mus = sims[,,1]
  sigma_squs = sims[,,2]
  pis = sims[,,3]
  
  log_single_y = Vectorize(function(x) 
    log(rowSums(sapply(1:G, 
                       function(g) pis[,g]*dnorm(x,mus[,g],sqrt(sigma_squs[,g]))))
    )
  )
  res = suppressWarnings(rowSums(log_single_y(y)))
  return(rowSums(log_single_y(y)))
}

# prior
logprior_gmm_marginal <- function(sims,G) {
  mus = sims[,,1]
  sigma_squs = sims[,,2]
  pis = sims[,,3]
  
  l_mus <- rowSums(sapply(1:G, function(g) dnorm(mus[,g], mean = m, sd = R, log = TRUE)))
  l_pis <- LaplacesDemon::ddirichlet(1:G/G, rep(1,G),log=TRUE) # constant wrt pis
  l_sigma_squs <- lgamma(2*G+0.2) - lgamma(0.2) +
    0.2*log(10/R^2) - (2*G+0.2) * log(rowSums(sigma_squs^(-1))+10/R^2) - 3*rowSums(log(sigma_squs))
  return(l_mus + l_pis + l_sigma_squs)
}


# unnormalized log-posterior density
logposty = function(y,sims){
  G = dim(sims)[2]
  mus = sims[,1:G,1]
  # apply exp transform
  sims[,1:G,2] = sims[,1:G,2]
  sigma_squs = sims[,1:G,2]
  pis = sims[,1:G,3]
  
  # set to 0 outside of support
  if(G>2){
    mask = (((pis > 0) & (rowSums(pis[,1:(G-1)])<=1)) & (sigma_squs>0))
  }else{
    mask = (((pis > 0) & (pis[,1]<=1)) & (sigma_squs>0))
  }
    
  l_total = suppressWarnings(loglik_gmm(y,sims,G)+logprior_gmm_marginal(sims,G))
  l_total[exp(rowSums(log(mask)))==0] = -Inf
  # browser()
  return(l_total)
}
```

Calculating the MCMC output for 2 components:

```{r mcmc output for 2 components}
library(bayesmix)
library(label.switching)
n_samples = 10000
burn.in = 2000
seed = 2024
R <- diff(range(y))
m <- mean(range(y))

# simulating for G=2
model.g2 <- BMMmodel(y, k = 2,  priors = list(kind = "independence",
                                              hierarchical = "tau"),
                    initialValues = list(S0 = 2))
  
control <- JAGScontrol(variables = c("mu", "tau", "eta", "S"),
                         burn.in = burn.in, n.iter = n_samples, seed=seed)
  
## add the same prior than Model Selection for Mixture Models-Perspectives and
#Strategies
# Gilles Celeux, Sylvia Frühwirth-Schnatter, Christian Robert
R2=R^2
model.g2$data$B0inv <-1/R2
model.g2$data$b0 <- m
model.g2$data$nu0Half <- 2
model.g2$data$g0Half <- 0.2
model.g2$data$g0G0Half <- 20 / R2

z <- JAGSrun(y, model = model.g2, control = control)

results=array(c(z$results[,c((n+2+1):(n+3*2),(n+1):(n+2))]),
              dim=c(n_samples,2,3))
mcmc.g2 = results

# determine logposterior values
alloc_vec.g2 = z$results[,1:n]
lps.g2 = logposty(y,mcmc.g2)

# apply the ECR algorithm to relabel
zpivot.g2 = z$results[which.max(lps.g2),1:n]
relab = ecr(zpivot.g2,alloc_vec.g2,2)
for(i in 1:n_samples){
  mcmc.g2[i,,] = mcmc.g2[i,relab$permutations[i,],]
}

boxplot(mcmc.g2[,,1],main="mean parameters")
boxplot(mcmc.g2[,,2],main="variance parameters")
boxplot(mcmc.g2[,,3],main="proportions")
```


Same for 3 components:

```{r mcmc output for 3 components}
library(bayesmix)
library(label.switching)

n_samples = 10000
burn.in = 2000
seed = 2024
R <- diff(range(y))
m <- mean(range(y))

# simulating for G=2
model.g3 <- BMMmodel(y, k = 3,  priors = list(kind = "independence",
                                              hierarchical = "tau"),
                    initialValues = list(S0 = 2))
  
control <- JAGScontrol(variables = c("mu", "tau", "eta", "S"),
                         burn.in = burn.in, n.iter = n_samples, seed=seed)
  
## add the same prior than Model Selection for Mixture Models-Perspectives and
#Strategies
# Gilles Celeux, Sylvia Frühwirth-Schnatter, Christian Robert
R2=R^2
model.g3$data$B0inv <-1/R2
model.g3$data$b0 <- m
model.g3$data$nu0Half <- 2
model.g3$data$g0Half <- 0.2
model.g3$data$g0G0Half <- 20 / R2

z <- JAGSrun(y, model = model.g3, control = control)

results=array(c(z$results[,c((n+3+1):(n+3*3),(n+1):(n+3))]),
                dim=c(n_samples,3,3))
mcmc.g3 = results

# determine logposterior values
alloc_vec.g3 = z$results[,1:n]
lps.g3 = logposty(y,mcmc.g3)

# apply the ECR algorithm to relabel
zpivot.g3 = z$results[which.max(lps.g3),1:n]
relab.g3 = ecr(zpivot.g3,alloc_vec.g3,3)
for(i in 1:n_samples){
  mcmc.g3[i,,] = mcmc.g3[i,relab.g3$permutations[i,],]
}

boxplot(mcmc.g3[,,1],main="mean parameters")
boxplot(mcmc.g3[,,2],main="variance parameters")
boxplot(mcmc.g3[,,3],main="proportions")
```

## Computing the THAMES

Estimates of the logarithm of the marginal likelihood:

```{r compute thames g2 and g3}
thames.g2 = thamesmix::thames_mixtures(logpost=function(sims) logposty(y,sims),
                      sims=mcmc.g2,
                      seed=2025)
thames.g3 = thamesmix::thames_mixtures(logpost=function(sims) logposty(y,sims),
                     sims=mcmc.g3,
                     seed=2025)

# note that the package returns the negative log marginal likelihood
df_margliks = as.data.frame(cbind(-thames.g2$log_zhat_inv,-thames.g3$log_zhat_inv))
names(df_margliks) = c("G = 2","G = 3")
df_margliks
```
It is also possible to plot the overlap graphs and compute the criterion of overlap (co):

```{r plot overlap g2 and g3}
plot(thames.g2$graph)
plot(thames.g3$graph)
df_cos = as.data.frame(cbind(thames.g2$co,thames.g3$co))
names(df_cos) = c("G = 2","G = 3")
df_cos
```
This can also be done directly, without computing the THAMES:

```{r plot overlap g2 and g3 using overlapgraph function}
plot(thamesmix::overlapgraph(mcmc.g2)$graph)
plot(thamesmix::overlapgraph(mcmc.g3)$graph)
df_cos = as.data.frame(cbind(thamesmix::overlapgraph(mcmc.g2)$co,
                             thamesmix::overlapgraph(mcmc.g3)$co))
names(df_cos) = c("G = 2","G = 3")
df_cos
```
