## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 6,
  fig.height = 4
)

## ----load-package-------------------------------------------------------------
library(PSsurvival)

## ----data---------------------------------------------------------------------
data(simdata_bin)   # Binary treatment (A vs B)
data(simdata_multi) # Four treatment groups (A, B, C, D)

# Binary treatment data
head(simdata_bin)
table(simdata_bin$Z)

# Multiple treatment data
table(simdata_multi$Z)

## ----surveff-basic------------------------------------------------------------
result_bin <- surveff(
  data = simdata_bin,
  ps_formula = Z ~ X1 + X2 + X3 + B1 + B2,
  censoring_formula = survival::Surv(time, event) ~ X1 + B1,
  weight_method = "OW",
  censoring_method = "weibull"
)

print(result_bin, max.len = 3)

## ----surveff-multi------------------------------------------------------------
# Define pairwise comparisons
contrast_mat <- matrix(
  c(1, -1, 0, 0,   # A vs B
    1, 0, -1, 0,   # A vs C
    1, 0, 0, -1),  # A vs D
  nrow = 3, byrow = TRUE
)
colnames(contrast_mat) <- c("A", "B", "C", "D")  # Must match treatment levels
rownames(contrast_mat) <- c("A vs B", "A vs C", "A vs D")

result_multi <- surveff(
  data = simdata_multi,
  ps_formula = Z ~ X1 + X2 + X3 + B1 + B2,
  censoring_formula = survival::Surv(time, event) ~ X1 + B1,
  weight_method = "IPW",
  contrast_matrix = contrast_mat,
  censoring_method = "cox",
  variance_method = "bootstrap",
  B = 50,
  seed = 123
)

print(result_multi, max.len = 3)

## ----surveff-multi-trimmed----------------------------------------------------
# with symmetric trimming
result_multi_trimmed <- surveff(
  data = simdata_multi,
  ps_formula = Z ~ X1 + X2 + X3 + B1 + B2,
  censoring_formula = survival::Surv(time, event) ~ X1 + B1,
  weight_method = "IPW",
  trim = TRUE,
  delta = 0.15,
  contrast_matrix = contrast_mat,
  censoring_method = "cox",
  variance_method = "bootstrap",
  B = 50,
  seed = 123
)

print(result_multi_trimmed, max.len = 3)

## ----weight-methods, eval=FALSE-----------------------------------------------
# # IPW (targets ATE)
# surveff(..., weight_method = "IPW")
# 
# # ATT targeting group B
# surveff(..., weight_method = "ATT", att_group = "B")
# 
# # Overlap weights
# surveff(..., weight_method = "OW")

## ----trimming, eval=FALSE-----------------------------------------------------
# # Symmetric trimming with default delta (automatic selection)
# surveff(..., weight_method = "IPW", trim = TRUE)
# 
# # Symmetric trimming with custom delta
# surveff(..., weight_method = "IPW", trim = TRUE, delta = 0.1)

## ----censoring-methods, eval=FALSE--------------------------------------------
# # Weibull censoring model
# surveff(..., censoring_method = "weibull")
# 
# # Cox censoring model (requires bootstrap)
# surveff(..., censoring_method = "cox", variance_method = "bootstrap", B = 200)

## ----variance-methods, eval=FALSE---------------------------------------------
# # Analytical variance (binary + weibull only)
# surveff(..., censoring_method = "weibull", variance_method = "analytical")
# 
# # Bootstrap variance
# surveff(..., variance_method = "bootstrap", B = 200)

## ----boot-level, eval=FALSE---------------------------------------------------
# # Stratified bootstrap
# surveff(..., variance_method = "bootstrap", boot_level = "strata", B = 200)

## ----summary-surveff----------------------------------------------------------
summary(result_bin, conf_level = 0.95, max.len = 5)

## ----summary-returns----------------------------------------------------------
summ <- summary(result_bin, style = "returns")
names(summ)
head(summ$survival_summary$A)
head(summ$difference_summary$`B vs A`)

## ----plot-surv, fig.width=7, fig.height=5-------------------------------------
# Survival curves for all groups
plot(result_multi, type = "surv")

## ----plot-survdiff, fig.width=7, fig.height=5---------------------------------
# Treatment effect curves
plot(result_multi, type = "survdiff")

## ----plot-subset-surv, fig.width=7, fig.height=5------------------------------
# Only groups A and C
plot(result_multi, type = "surv",
     strata_to_plot = c("A", "C"))

## ----plot-subset-diff, fig.width=7, fig.height=5------------------------------
# Only specific contrasts (names must match contrast_matrix rownames exactly)
plot(result_multi, type = "survdiff",
     strata_to_plot = c("A vs B", "A vs D"))

## ----plot-custom, fig.width=7, fig.height=5-----------------------------------
plot(result_multi,
     type = "surv",
     strata_to_plot = c("A", "B"),
     strata_colors = c("steelblue", "coral"),
     max_time = 15,
     include_CI = TRUE,
     legend_position = "bottom",
     plot_title = "Survival by Treatment Group")

## ----marCoxph-basic-----------------------------------------------------------
hr_result <- marCoxph(
  data = simdata_bin,
  ps_formula = Z ~ X1 + X2 + X3 + B1 + B2,
  time_var = "time",
  event_var = "event",
  reference_level = "A",
  weight_method = "OW"
)

print(hr_result)

## ----marCoxph-multi-----------------------------------------------------------
hr_multi <- marCoxph(
  data = simdata_multi,
  ps_formula = Z ~ X1 + X2 + X3 + B1 + B2,
  time_var = "time",
  event_var = "event",
  reference_level = "A",
  weight_method = "IPW",
  variance_method = "bootstrap",
  B = 50,
  seed = 456
)

print(hr_multi)

## ----marCoxph-multi-trimmed---------------------------------------------------
hr_multi_trimmed <- marCoxph(
  data = simdata_multi,
  ps_formula = Z ~ X1 + X2 + X3 + B1 + B2,
  time_var = "time",
  event_var = "event",
  reference_level = "A",
  weight_method = "IPW",
  trim = TRUE,
  delta = 0.15,
  variance_method = "bootstrap",
  B = 50,
  seed = 456
)

print(hr_multi_trimmed)

## ----marcoxph-weight-methods, eval=FALSE--------------------------------------
# # IPW (targets ATE)
# marCoxph(..., weight_method = "IPW")
# 
# # ATT targeting specific group
# marCoxph(..., weight_method = "ATT", att_group = "treated")
# 
# # Overlap weights
# marCoxph(..., weight_method = "OW")

## ----marcoxph-trimming, eval=FALSE--------------------------------------------
# # Symmetric trimming with default delta
# marCoxph(..., weight_method = "IPW", trim = TRUE)
# 
# # Symmetric trimming with custom delta
# marCoxph(..., weight_method = "IPW", trim = TRUE, delta = 0.1)

## ----marCoxph-variance, eval=FALSE--------------------------------------------
# # Bootstrap variance (default)
# marCoxph(..., variance_method = "bootstrap", B = 200)
# 
# # Robust sandwich variance (no bootstrap)
# marCoxph(..., variance_method = "robust")

## ----summary-marcoxph---------------------------------------------------------
summary(hr_multi, conf_level = 0.95)

## ----summary-marcoxph-returns-------------------------------------------------
summ_hr <- summary(hr_multi, style = "returns")
names(summ_hr)
summ_hr$HR  # Exponentiated hazard ratios

## ----weightedKM-basic---------------------------------------------------------
# Classical unweighted KM
km_result <- weightedKM(
  data = simdata_bin,
  treatment_var = "Z",
  time_var = "time",
  event_var = "event",
  weight_method = "none"
)

print(km_result)

## ----weightedKM-OW------------------------------------------------------------
# Overlap-weighted KM
km_ow <- weightedKM(
  data = simdata_bin,
  treatment_var = "Z",
  ps_formula = Z ~ X1 + X2 + X3 + B1 + B2,
  time_var = "time",
  event_var = "event",
  weight_method = "OW"
)

print(km_ow)

## ----weightedKM-IPW-----------------------------------------------------------
# IPW-weighted KM with trimming
km_ipw <- weightedKM(
  data = simdata_multi,
  treatment_var = "Z",
  ps_formula = Z ~ X1 + X2 + X3 + B1 + B2,
  time_var = "time",
  event_var = "event",
  weight_method = "IPW",
  trim = TRUE,
  delta = 0.1
)

print(km_ipw)

## ----plot-weightedKM, fig.width=7, fig.height=5-------------------------------
# Survival curves with log CI
plot(km_ow, type = "Kaplan-Meier", conf_type = "log")

## ----plot-weightedKM-CR, fig.width=7, fig.height=5----------------------------
# Cumulative risk curves with log-log CI
plot(km_ow, type = "CR", conf_type = "log-log")

## ----plot-weightedKM-risktable, fig.width=7, fig.height=6---------------------
plot(km_ipw,
     risk_table = TRUE,
     risk_table_breaks = c(0, 5, 10, 15),
     risk_table_height = 0.3,
     risk_table_stats = c("n.risk", "n.acc.event"))

## ----summary-weightedKM-------------------------------------------------------
summary(km_ow, type = "Kaplan-Meier", conf_type = "log-log", print.rows = 5)

