## ----message=FALSE, warning=FALSE, class.source = 'fold-hide'-----------------

## Source required libraries
library(data.table)
library(tidyverse)
library(ggthemes)
library(ggrepel)
library(harmony)
library(patchwork)
library(tidyr)

## Useful util functions

cosine_normalize <- function(X, margin) {
    if (margin == 1) {
        res <- sweep(as.matrix(X), 1, sqrt(rowSums(X ^ 2)), '/')
        row.names(res) <- row.names(X)
        colnames(res) <- colnames(X)        
    } else {
        res <- sweep(as.matrix(X), 2, sqrt(colSums(X ^ 2)), '/')
        row.names(res) <- row.names(X)
        colnames(res) <- colnames(X)
    }
    return(res)
}

onehot <- function(vals) {
    t(model.matrix(~0 + as.factor(vals)))
}


colors_use <- c(`jurkat` = rgb(129, 15, 124, maxColorValue=255),
                `t293` = rgb(208, 158, 45, maxColorValue=255),
                `half` = rgb(0, 109, 44, maxColorValue=255))


do_scatter <- function(umap_use, meta_data, label_name, no_guides = TRUE, do_labels = TRUE, nice_names, 
                       palette_use = colors_use,
                       pt_size = 4, point_size = .5, base_size = 10, do_points = TRUE, do_density = FALSE, h = 4, w = 8) {
    umap_use <- umap_use[, 1:2]
    colnames(umap_use) <- c('X1', 'X2')
    plt_df <- umap_use %>% data.frame() %>% 
        cbind(meta_data) %>% 
        dplyr::sample_frac(1L) 
    plt_df$given_name <- plt_df[[label_name]]
    
    if (!missing(nice_names)) {
        plt_df %<>%
            dplyr::inner_join(nice_names, by = "given_name") %>% 
            subset(nice_name != "" & !is.na(nice_name))

        plt_df[[label_name]] <- plt_df$nice_name        
    }
        
    plt <- plt_df %>% 
        ggplot(aes(X1, X2, colour = .data[[label_name]], fill = .data[[label_name]])) + 
        theme_tufte(base_size = base_size) + 
        theme(panel.background = element_rect(fill = NA, color = "black")) + 
        guides(color = guide_legend(override.aes = list(stroke = 1, alpha = 1, shape = 16, size = 4)), alpha = "none") +
        scale_color_manual(values = palette_use) + 
        scale_fill_manual(values = palette_use) +    
        theme(plot.title = element_text(hjust = .5)) + 
        labs(x = "UMAP 1", y = "UMAP 2") 
    
    if (do_points) 
        plt <- plt + geom_point(size = 0.2)
    if (do_density) 
        plt <- plt + geom_density_2d()    
        

    if (no_guides)
        plt <- plt + theme(legend.position = "none")
    
    if (do_labels) 
        plt <- plt + geom_label_repel(data = data.table(plt_df)[, .(X1 = mean(X1), X2 = mean(X2)), by = label_name], label.size = NA,
                                      aes(label = .data[[label_name]]), color = "white", size = pt_size, alpha = 1, segment.size = 0) + 
        guides(col = "none", fill = "none")
    return(plt)
}


## -----------------------------------------------------------------------------
data(cell_lines)
V <- cell_lines$scaled_pcs
V_cos <- cosine_normalize(V, 1)
meta_data <- cell_lines$meta_data

## ----warning=FALSE, fig.width=5, fig.height=3, fig.align="center"-------------
do_scatter(V, meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) + 
    labs(title = 'Colored by dataset', x = 'PC1', y = 'PC2') +
do_scatter(V, meta_data, 'cell_type', no_guides = TRUE, do_labels = TRUE) + 
    labs(title = 'Colored by cell type', x = 'PC1', y = 'PC2') +
NULL

## -----------------------------------------------------------------------------

set.seed(1)
harmonyObj <- harmony::RunHarmony(
    data_mat = V, ## PCA embedding matrix of cells
    meta_data = meta_data, ## dataframe with cell labels
    theta = 1, ## cluster diversity enforcement
    vars_use = 'dataset', ## variable to integrate out
    nclust = 5, ## number of clusters in Harmony model
    max_iter = 0, ## stop after initialization
    return_object = TRUE ## return the full Harmony model object
)



## -----------------------------------------------------------------------------
## Cache embeddings as local R variables (cells × dims, matching input V).
## Z_orig / Z_corr are not directly accessible as fields; use getter methods.
Z_orig <- t(harmonyObj$getZorig())        # N × d
Z_cos  <- cosine_normalize(Z_orig, 1)     # L2-normalised rows, N × d

## Reconstruct the design matrix (B × N) from meta_data, matching what
## harmony built internally
phi <- Matrix::t(Matrix::sparse.model.matrix(~0 + as.factor(meta_data$dataset)))
## phi_moe = intercept row prepended to phi  (matches Phi_moe inside harmony)
phi_moe <- rbind(Matrix::Matrix(1, nrow = 1, ncol = ncol(phi), sparse = TRUE), phi)

## ----fig.width=5, fig.height=3, fig.align="center"----------------------------
do_scatter(Z_orig, meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) +
    labs(title = 'Z_orig', subtitle = 'Euclidean distance', x = 'PC1', y = 'PC2') +
do_scatter(Z_cos, meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) +
    labs(title = 'Z_cos', subtitle = 'Induced Cosine distance', x = 'PC1', y = 'PC2')


## ----fig.width=8, fig.height=3, out.width="100%"------------------------------

Z_cos %>% data.frame() %>%
    cbind(meta_data) %>%
    tidyr::gather(key, val, X1:X20) %>%
    ggplot(aes(reorder(gsub('X', 'PC', key), as.integer(gsub('X', '', key))), val)) + 
        geom_boxplot(aes(color = dataset)) + 
        scale_color_manual(values = colors_use) + 
        labs(x = 'PC number', y = 'PC embedding value', title = 'Z_cos (unit scaled PCA embeddings) for all 20 PCs') + 
        theme_tufte(base_size = 10) + geom_rangeframe() + 
        theme(axis.text.x = element_text(angle = 45, hjust = 1))

## ----fig.width=4, fig.height=3, fig.align="center"----------------------------

cluster_centroids <- harmonyObj$Y

do_scatter(Z_cos, meta_data, 'dataset', no_guides = FALSE, do_labels = FALSE) +
    labs(title = 'Initial kmeans cluster centroids', subtitle = '', x = 'PC1', y = 'PC2') +
    geom_point(
        data = data.frame(t(cluster_centroids)),
        color = 'black', fill = 'black', alpha = .8,
        shape = 21, size = 6
    ) +
NULL


## -----------------------------------------------------------------------------
cluster_assignment_matrix <- harmonyObj$R


## ----fig.height=5, fig.width=5------------------------------------------------
Z_cos %>% data.frame() %>%
    cbind(meta_data) %>%
    tibble::rowid_to_column('id') %>%
    dplyr::inner_join(
        cluster_assignment_matrix %>% t() %>% data.table() %>%
            tibble::rowid_to_column('id') %>%
            tidyr::gather(cluster, r, -id) %>%
            dplyr::mutate(cluster = gsub('V', 'Cluster ', cluster)),
        by = 'id'
    ) %>%
    dplyr::sample_frac(1L) %>%
    ggplot(aes(X1, X2, color = r)) +
        geom_point(size=0.2) +
        theme_tufte(base_size = 10) + theme(panel.background = element_rect()) +
        facet_grid(cluster ~ dataset) +
        scale_color_gradient(low = 'lightgrey', breaks = seq(0, 1, .1)) +
        labs(x = 'Scaled PC1', y = 'Scaled PC2', title = 'Initial probabilistic cluster assignments')

## -----------------------------------------------------------------------------
observed_counts <- harmonyObj$R %*% t(as.matrix(phi))
round(observed_counts)



## -----------------------------------------------------------------------------
## observed counts
round(harmonyObj$O)

## observed counts
round(harmonyObj$E)


## -----------------------------------------------------------------------------
phi_celltype <- onehot(meta_data$cell_type) 
observed_cell_counts <- harmonyObj$R %*% t(phi_celltype)
round(observed_cell_counts)


## -----------------------------------------------------------------------------
harmonyObj$max_iter_kmeans

## -----------------------------------------------------------------------------
## we can specify how many rounds of clustering to do
harmonyObj$max_iter_kmeans <- 10
harmonyObj$cluster_cpp()

## -----------------------------------------------------------------------------
round(harmonyObj$O)

## ----fig.height=5, fig.width=5------------------------------------------------
new_cluster_assignment_matrix <- harmonyObj$R

Z_cos %>% data.frame() %>%
    cbind(meta_data) %>%
    tibble::rowid_to_column('id') %>%
    dplyr::inner_join(
        new_cluster_assignment_matrix %>% t() %>% data.table() %>%
            tibble::rowid_to_column('id') %>%
            tidyr::gather(cluster, r, -id) %>%
            dplyr::mutate(cluster = gsub('V', 'Cluster ', cluster)),
        by = 'id'
    ) %>%
    dplyr::sample_frac(1L) %>%
    ggplot(aes(X1, X2, color = r)) +
        geom_point(shape = '.') +
        theme_tufte(base_size = 10) + theme(panel.background = element_rect()) +
        facet_grid(cluster ~ dataset) +
        scale_color_gradient(low = 'lightgrey', breaks = seq(0, 1, .1)) +
        labs(x = 'Scaled PC1', y = 'Scaled PC2', title = 'New probabilistic cluster assignments')

## -----------------------------------------------------------------------------
phi_celltype <- onehot(meta_data$cell_type)
observed_cell_counts <- harmonyObj$R %*% t(phi_celltype)
round(observed_cell_counts)

## -----------------------------------------------------------------------------
round(apply(prop.table(observed_cell_counts, 1), 1, min) * 100, 3)

## -----------------------------------------------------------------------------

with(harmonyObj, {
    distance_matrix <- 2 * (1 - t(Y) %*% t(Z_cos))
    distance_score <- exp(-distance_matrix / as.numeric(sigma))
    diversity_score <- sweep(E / O, 2, theta, '/') %*% as.matrix(phi)
    ## new assignments are based on distance and diversity
    R_new <- distance_score * diversity_score
    ## normalize R so each cell sums to 1
    R_new <- prop.table(R_new, 2)
})


## -----------------------------------------------------------------------------
## with theta = 0
with(harmonyObj, {
    ((E+1) / (O+E+1)) ^ 0
})

## -----------------------------------------------------------------------------
## with theta = 1
with(harmonyObj, {
    round(((E+1) / (O+E+1)) ^ 1, 2)
})


## -----------------------------------------------------------------------------
## as theta approach infinity
with(harmonyObj, {
    round(((E+1) / (O+E+1)) ^ 1e6, 2)
})


## -----------------------------------------------------------------------------
Y_unscaled <- t(Z_cos) %*% t(harmonyObj$R)

## -----------------------------------------------------------------------------
Y_new <- cosine_normalize(Y_unscaled, 2)

## -----------------------------------------------------------------------------
harmonyObj$moe_correct_ridge_cpp()
Z_corr <- t(harmonyObj$getZcorr())    # N × d, corrected (not L2-normalised)

## ----fig.width=5, fig.height=3, fig.align="center"----------------------------

do_scatter(Z_cos, meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) +
    labs(title = 'Z_cos before MoE', x = 'PC1', y = 'PC2') +
do_scatter(cosine_normalize(Z_corr, 1), meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) +
    labs(title = 'Z_cos after MoE', x = 'PC1', y = 'PC2')

## ----fig.width=8, fig.height=3, fig.align="center", out.width="100%"----------

do_scatter(Z_orig, meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) +
    labs(title = 'Z_orig', subtitle = 'Original PCA embeddings', x = 'PC1', y = 'PC2') +
do_scatter(Z_corr, meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) +
    labs(title = 'Z_corr', subtitle = '= Z_orig - correction_factors', x = 'PC1', y = 'PC2') +
do_scatter(cosine_normalize(Z_corr, 1), meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) +
    labs(title = 'Z_cos', subtitle = '= Unit_scaled(Z_corr)', x = 'Scaled PC1', y = 'Scaled PC2') +
NULL

## ----fig.width=5, fig.height=3, fig.align="center"----------------------------

plt <- data.table(PC1_After = Z_corr[, 1], PC1_Before = Z_orig[, 1]) %>%
    cbind(meta_data) %>% 
    dplyr::sample_frac(1L) %>% 
    ggplot(aes(PC1_Before, PC1_After)) + 
        geom_abline(slope = 1, intercept = 0) + 
        theme_tufte(base_size = 10) + geom_rangeframe() + 
        scale_color_tableau() + 
        guides(color = guide_legend(override.aes = list(stroke = 1, alpha = 1, shape = 16, size = 4))) + 
        NULL

plt + geom_point(shape = '.', aes(color = dataset)) + 
        labs(x = 'PC1 before correction', y = 'PC1 after correction', 
             title = 'PC1 correction for each cell', subtitle = 'Colored by Dataset') + 
plt + geom_point(shape = '.', aes(color = cell_type)) + 
        labs(x = 'PC1 before correction', y = 'PC1 after correction', 
             title = 'PC1 correction for each cell', subtitle = 'Colored by Cell Type') + 
NULL


## ----echo=TRUE----------------------------------------------------------------

W <- list()
## Convert sparse data structures to dense matrix
Phi.moe <- as.matrix(phi_moe)
lambda_mat <- harmonyObj$getLambda()

## Get beta coefficients for all the clusters
for (k in 1:harmonyObj$K) {
    lambda <- diag(lambda_mat[k,])
    W[[k]] <- solve(Phi.moe %*% diag(harmonyObj$R[k, ]) %*% t(as.matrix(Phi.moe)) + lambda) %*%
              (Phi.moe %*% diag(harmonyObj$R[k, ])) %*% Z_orig
}



## ----fig.width=5, fig.height=5------------------------------------------------

cluster_assignment_matrix <- harmonyObj$R

Z_orig %>% data.frame() %>%
    cbind(meta_data) %>%
    tibble::rowid_to_column('id') %>%
    dplyr::inner_join(
        cluster_assignment_matrix %>% t() %>% data.table() %>%
            tibble::rowid_to_column('id') %>%
            tidyr::gather(cluster, r, -id) %>%
            dplyr::mutate(cluster = gsub('V', 'Cluster ', cluster)),
        by = 'id'
    ) %>%
    dplyr::sample_frac(1L) %>%
    ggplot(aes(X1, X2, color = r)) +
        geom_point(shape = 0.2) +
        theme_tufte(base_size = 10) + theme(panel.background = element_rect()) +
        facet_grid(cluster ~ dataset) +
        scale_color_gradient(low = 'grey', breaks = seq(0, 1, .2)) +
        labs(x = 'PC1', y = 'PC2', title = 'Cluster assigned in original PCA space (Z_orig)')


## -----------------------------------------------------------------------------
plt_list <- lapply(1:harmonyObj$K, function(k) {
    plt_df <- W[[k]] %>% data.frame() %>% 
        dplyr::select(X1, X2)
    ## Append n
    plt_df <- plt_df %>% 
        cbind(
            data.frame(t(matrix(unlist(c(c(0, 0), rep(plt_df[1, ], 3))), nrow = 2))) %>% 
                dplyr::rename(x0 = X1, y0 = X2) 
        ) %>%
        cbind(type = c('intercept', unique(meta_data$dataset)))
    plt <- plt_df %>% 
        ggplot() + 
            geom_point(aes(X1, X2),
                       data = Z_orig %>% data.frame(),
                       size = 0.5,
                       color = 'grey'
            ) +
            geom_segment(aes(x = x0, y = y0, xend = X1 + x0, yend = X2 + y0, color = type), linewidth=1) + 
            scale_color_manual(values = c('intercept' = 'black', colors_use)) + 
            theme_tufte(base_size = 10) + theme(panel.background = element_rect()) + 
            labs(x = 'PC 1', y = 'PC 2', title = sprintf('Cluster %d', k))
    plt <- plt + guides(color = guide_legend(override.aes = list(stroke = 1, alpha = 1, shape = 16)))    
    # if (k == harmonyObj$K) {
    # } else {
    #     plt <- plt + guides(color = FALSE)
    # }
    plt
})



## ----fig.height=6, fig.width=6------------------------------------------------
Reduce(`+`, plt_list) + 
  patchwork::plot_annotation(title = 'Mixture of experts beta terms before correction (Z_orig)') + 
  plot_layout(ncol = 2)

## ----fig.width=4, fig.height=3, fig.align="center"----------------------------

plt_list <- lapply(1:harmonyObj$K, function(k) {
    plt_df <- W[[k]] %>% data.frame() %>% 
        dplyr::select(X1, X2)

    plt_df <- plt_df %>% 
        cbind(
            data.frame(t(matrix(unlist(c(c(0, 0), rep(plt_df[1, ], 3))), nrow = 2))) %>% 
                dplyr::rename(x0 = X1, y0 = X2) 
        ) %>%
        cbind(type = c('intercept', unique(meta_data$dataset))) 

    plt <- plt_df %>% 
        ggplot() + 
            geom_point(aes(X1, X2),
                data = Z_corr %>% data.frame(),
                shape = '.',
                color = 'grey'
            ) +
            geom_segment(aes(x = x0, y = y0, xend = X1 + x0, yend = X2 + y0, color = type), linewidth=1) + 
            scale_color_manual(values = c('intercept' = 'black', colors_use)) + 
            theme_tufte(base_size = 10) + theme(panel.background = element_rect()) + 
            labs(x = 'PC 1', y = 'PC 2', title = sprintf('Cluster %d', k))
    plt <- plt + guides(color = guide_legend(override.aes = list(stroke = 1, alpha = 1, shape = 16)))
    plt
})



## ----fig.height=6, fig.width=6------------------------------------------------
Reduce(`+`, plt_list) + 
  patchwork::plot_annotation(title = 'Mixture of experts beta terms after correction (Z_corr)') + 
  plot_layout(ncol = 2)

## ----echo=TRUE----------------------------------------------------------------

Z_i <- Z_orig[5, ]
Z_i_pred <- Reduce(`+`, lapply(1:harmonyObj$K, function(k) {
    W[[k]] * phi_moe[, 5] * harmonyObj$R[k, 5]
})) %>% colSums



## ----fig.width=4, fig.height=3, fig.align="center"----------------------------
data.table(obs = Z_i, pred = Z_i_pred) %>% 
    tibble::rowid_to_column('PC') %>% 
    ggplot(aes(obs, pred)) + 
        geom_point(shape = 21) + 
        geom_label_repel(aes(label = PC)) + 
        geom_abline(slope = 1, intercept = 0) + 
        theme_tufte() + geom_rangeframe() + 
        labs(x = 'Observed PC score', y = 'Predicted PC score', title = 'Observed and predicted values of PC scores\nfor cell 5') + 
        NULL        

## -----------------------------------------------------------------------------
delta <- Reduce(`+`, lapply(1:harmonyObj$K, function(k) {
    W[[k]][2:4, ] * phi[, 5] * harmonyObj$R[k, 5]
})) %>% colSums

Z_corrected <- Z_orig[5, ] - delta


## ----fig.width=3, fig.height=3, fig.align="center"----------------------------


Z_orig %>% data.frame() %>%
    ggplot(aes(X1, X2)) +
        geom_point(shape = '.') +
        geom_point(
            data = data.frame(Z_orig[5, , drop = FALSE]),
            color = 'red'
        ) +
        geom_segment(
            data = data.table(x0 = Z_orig[5, 1],
                              y0 = Z_orig[5, 2],
                              x1 = Z_corrected[1],
                              y1 = Z_corrected[2]), 
            aes(x = x0, y = y0, xend = x1, yend = y1),
            linewidth = 1,
            color = 'red', 
            arrow = arrow(length = unit(0.05, "npc"), type = 'closed')            
        ) + 
        theme_tufte(base_size = 10) + geom_rangeframe() + 
        labs(x = 'PC1', y = 'PC2', title = 'Correction of cell #5')


## -----------------------------------------------------------------------------

harmonyObj <- RunHarmony(
    data_mat = V, ## PCA embedding matrix of cells
    meta_data = meta_data, ## dataframe with cell labels
    theta = 1, ## cluster diversity enforcement
    vars_use = 'dataset', ## (list of) variable(s) we'd like to Harmonize out
    nclust = 50, ## number of clusters in Harmony model
    max_iter = 0, ## don't actually run Harmony, stop after initialization
    return_object = TRUE ## return the full Harmony model object, not just the corrected PCA matrix
)
Z_cos <- cosine_normalize(t(harmonyObj$getZorig()), 1)   # N × d, initial state


## ----message=FALSE, fig.width=5, fig.height=3, fig.align="center"-------------

i <- 0

do_scatter(Z_cos, meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) +
    labs(title = sprintf('Round %d', i), subtitle = 'Colored by dataset', x = 'Scaled PC1', y = 'Scaled PC2') +
do_scatter(Z_cos, meta_data, 'cell_type', no_guides = TRUE, do_labels = TRUE) +
    labs(title = sprintf('Round %d', i), subtitle = 'Colored by cell type', x = 'Scaled PC1', y = 'Scaled PC2') +
NULL

## ----fig.width=5, fig.height=3, fig.align="center", message=FALSE-------------

for (i in 1:5) {
    harmony:::harmonize(harmonyObj, 1)
    Z_cos <- cosine_normalize(t(harmonyObj$getZcorr()), 1)   # N × d, after round i
    plt <- do_scatter(Z_cos, meta_data, 'dataset', no_guides = TRUE, do_labels = TRUE) +
        labs(title = sprintf('Round %d', i), subtitle = 'Colored by dataset', x = 'Scaled PC1', y = 'Scaled PC2') +
    do_scatter(Z_cos, meta_data, 'cell_type', no_guides = TRUE, do_labels = TRUE) +
        labs(title = sprintf('Round %d', i), subtitle = 'Colored by cell type', x = 'Scaled PC1', y = 'Scaled PC2') +
    NULL
    plot(plt)
}
    

## -----------------------------------------------------------------------------
sessionInfo()

