## ----include=FALSE------------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 7,
  fig.height = 5,
  error = TRUE 
)

## -----------------------------------------------------------------------------
library(fairGNN)
library(dplyr)
library(readxl)

# For a reproducible vignette, we create a dummy dataframe.
# A user would load their own data here.
set.seed(123)
raw_data <- data.frame(
  subjectid = 1:430,
  hdremit.all = sample(0:1, 430, replace = TRUE),
  sex = sample(1:2, 430, replace = TRUE),
  madrs.total = rnorm(430, 25, 5),
  feature1 = rnorm(430),
  feature2 = rnorm(430)
)

## -----------------------------------------------------------------------------
# Raw coding: 2 = Male, 1 = Female
numeric_mappings_gender <- list('2' = 0, '1' = 1) 

cols_to_drop <- c(
  "subjectid", "Row.names", "bloodsampleid.x", "madrs.total", "hrsd.total",
  "bdi.total", "bdi14wk0", "bdi20wk0", "f61score0", "f62score0",
  "f64score0", "f65score0", "k30"
)

prepared_data_gender <- prepare_data(
  data = raw_data,
  outcome_var = "hdremit.all",
  group_var = "sex",
  group_mappings = numeric_mappings_gender,
  cols_to_remove = cols_to_drop
)

## -----------------------------------------------------------------------------
# In a real analysis, a user would run train_gnn() here.
# For the vignette, we load the results saved by the create_vignette_data.R script.
gnn_results <- readRDS("data/gnn_results.rds")
expert_analyses <- readRDS("data/expert_analyses.rds")

## -----------------------------------------------------------------------------
label_mappings_gender <- list('0' = "Male", '1' = "Female")

# Run basic analysis
basic_analyses <- analyse_gnn_results(
  gnn_results = gnn_results,
  prepared_data = prepared_data_gender,
  group_mappings = label_mappings_gender
)

# --- View all plots from the basic analysis ---
cat("## ROC Curve\n")
print(basic_analyses$roc_plot)

cat("\n## Calibration Plot\n")
print(basic_analyses$calibration_plot)

cat("\n## Gate Weight Distribution\n")
print(basic_analyses$gate_density_plot)

cat("\n## Gate Entropy Distribution\n")
print(basic_analyses$entropy_density_plot)

## -----------------------------------------------------------------------------
# --- View all results from the expert analysis ---
cat("\n## Feature Importance: Female vs. Male\n")
# This table shows the features with the biggest difference in importance between the two experts.
print(head(expert_analyses$pairwise_differences$Female_vs_Male))

cat("\n## Feature Importance Difference Plot\n")
# This plot visualises the differences from the table above.
print(expert_analyses$difference_plot)

## -----------------------------------------------------------------------------
# Generate and print the Sankey plot
sankey_diagram <- plot_sankey(
    raw_data = raw_data,
    gnn_results = gnn_results,
    expert_results = expert_analyses,
    group_mappings = label_mappings_gender,
    group_var = "sex"
)

print(sankey_diagram)

