## ----include = FALSE----------------------------------------------------------
options(rmarkdown.html_vignette.check_title = FALSE)
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(xhaz)

## ----eval=FALSE, echo=TRUE----------------------------------------------------
# install.packages("xhaz")
# 

## ----eval=FALSE---------------------------------------------------------------
# install.packages("survexp.fr")

## -----------------------------------------------------------------------------
library(survival)
library(xhaz)

## -----------------------------------------------------------------------------
library(survexp.fr)
data("breast", package = "xhaz")

head(breast)
dim(breast)

# armt: numbers and percentages
breast$armt2 <- factor(breast$armt, levels = c("0","1"))
table(breast$armt2)

# The life table to be used is survexp.us. Note that SEX is coded 2 instead of female in survexp.us.
breast$sexe <- "female"

fit.haz <- exphaz(
               formula = Surv(temps, statut) ~ 1,
               data = breast, ratetable = survexp.us,
               only_ehazard = FALSE,
               rmap = list(age = 'age', sex = 'sexe',
                           year = 'date'))



breast$expected <- fit.haz$ehazard
breast$expectedCum <- fit.haz$ehazardInt


qknots <- quantile(breast[breast$statut==1,]$temps, probs=c(1:2/3))
mod.bs <- mexhazLT(
  formula = Surv(temps, statut) ~ agecr + armt2,
  data = breast, degree = 3,
  knots = qknots, expected = "expected", 
  expectedCum = "expectedCum",
  base = "exp.bs", pophaz = "classic")
                  
mod.bs

## -----------------------------------------------------------------------------
mod.bs2 <- mexhazLT(
  formula = Surv(temps, statut) ~ agecr + armt2,
  data = breast, degree = 3,
  knots = qknots, expected = "expected",
  expectedCum = "expectedCum",
  base = "exp.bs", pophaz = "rescaled")
                  
mod.bs2

## -----------------------------------------------------------------------------

mod.bs3 <- mexhazLT(
  formula = Surv(temps, statut) ~ agecr + armt2,
  data = breast, degree = 3,
  knots = qknots, expected = "expected",
  expectedCum = "expectedCum",
  base = "exp.bs", pophaz = "classic", random = "hosp")
                  
mod.bs3

## -----------------------------------------------------------------------------

mod.bs4 <- mexhazLT(
  formula = Surv(temps, statut) ~ agecr + armt2,
  data = breast, degree = 3,
  knots = qknots, expected = "expected", 
  expectedCum = "expectedCum",
  base = "exp.bs", pophaz = "rescaled", random = "hosp")
                  
mod.bs4

## -----------------------------------------------------------------------------
compared_models <- list(mod.bs,mod.bs2, mod.bs3, mod.bs4)
names(compared_models) <- c("mod.bs","mod.bs2", "mod.bs3", "mod.bs4")

sapply(compared_models, function(i) {
  AIC(i)
})


## -----------------------------------------------------------------------------
anova(mod.bs,mod.bs2)
anova(mod.bs3,mod.bs4)

## -----------------------------------------------------------------------------
# hazard and survival in the two arms
predict_mod_amr0 <- predict(mod.bs,
                          time.pts=seq(0.1,10,by=0.1),
                          data.val=data.frame(agecr = 0,
                                              armt2 = "0"))

predict_mod_amr1 <- predict(mod.bs,
                          time.pts=seq(0.1,10,by=0.1),
                          data.val=data.frame(agecr = 0,
                                              armt2 = "1"))





predict_mod2_arm0 <- predict(mod.bs2,
                           time.pts=seq(0.1,10,by=0.1),
                           data.val=data.frame(agecr = 0,
                                               armt2 ="0"))
predict_mod2_arm1 <- predict(mod.bs2,
                           time.pts=seq(0.1,10,by=0.1),
                           data.val=data.frame(agecr = 0,
                                               armt2 ="1"))



predict_mod3_arm0 <- predict(mod.bs3,
                        time.pts=seq(0.1,10,by=0.1),
                        data.val=data.frame(agecr = 0,
                                            armt2 =  "0"),
                        marginal = TRUE)

predict_mod3_arm1 <- predict(mod.bs3,
                           time.pts=seq(0.1,10,by=0.1),
                           data.val=data.frame(agecr = 0,
                                               armt2 = "1"),
                           marginal = TRUE)



predict_mod4_arm0 <- predict(mod.bs4,
                        time.pts=seq(0.1,10,by=0.1),
                        data.val=data.frame(agecr = 0,
                                            armt2 =  "0"),
                        marginal = TRUE)

predict_mod4_arm1 <- predict(mod.bs4,
                        time.pts=seq(0.1,10,by=0.1),
                        data.val=data.frame(agecr = 0,
                                            armt2 =  "1"),
                        marginal = TRUE)


## ----fig.width=10, fig.height=10----------------------------------------------
old.par <- par(no.readonly = TRUE)
on.exit({ layout(1); par(old.par) })

## ----- 1) Make a 2-row layout: top row = legend, bottom row = 2 plots -----
layout(
  matrix(c(1,1,  2,3), nrow = 2, byrow = TRUE),   # top spans both columns
  heights = c(1.2, 8)                              # adjust top strip height
)

## ----- 2) TOP STRIP: shared model legend (outside the plots) -----
par(mar = c(0, 0, 0, 0))
plot.new()
legend("center", bty = "n", horiz = TRUE, title = "Model",
       legend = c("mod.bs", "mod.bs2", "mod.bs3", "mod.bs4"),
       lty = 1, lwd = 2, col = c("black","blue","red","green"), cex = 1)

## ----- 3) LEFT PANEL -----
par(mar = c(4, 4, 2, 1))
plot(predict_mod_amr0$results$time.pts, predict_mod_amr0$results$hazard,
     type = "l", lwd = 2, xlab = "Time (years)", ylab = "Excess hazard",
     ylim = c(0, 0.5))
lines(predict_mod_amr1$results$time.pts, predict_mod_amr1$results$hazard, lwd = 2, lty = 2)
lines(predict_mod2_arm0$results$time.pts, predict_mod2_arm0$results$hazard, lwd = 2, col = "blue",  lty = 1)
lines(predict_mod2_arm1$results$time.pts, predict_mod2_arm1$results$hazard, lwd = 2, col = "blue",  lty = 2)
lines(predict_mod3_arm0$results$time.pts, predict_mod3_arm0$results$hazard, lwd = 2, col = "red",   lty = 1)
lines(predict_mod3_arm1$results$time.pts, predict_mod3_arm1$results$hazard, lwd = 2, col = "red",   lty = 2)
lines(predict_mod4_arm0$results$time.pts, predict_mod4_arm0$results$hazard, lwd = 2, col = "green", lty = 1)
lines(predict_mod4_arm1$results$time.pts, predict_mod4_arm1$results$hazard, lwd = 2, col = "green", lty = 2)
grid()
legend("topright", bty = "n", title = "Treatment arm",
       legend = c("No immunotherapy","Immunotherapy"),
       lty = c(1,2), lwd = 2, col = "black", cex = 0.9)

## ----- 4) RIGHT PANEL -----
par(mar = c(4, 4, 2, 1))
plot(predict_mod_amr0$results$time.pts, predict_mod_amr0$results$surv,
     type = "l", lwd = 2, xlab = "Time (years)", ylab = "Net survival",
     ylim = c(0, 1))
lines(predict_mod_amr1$results$time.pts, predict_mod_amr1$results$surv, lwd = 2, lty = 2)
lines(predict_mod2_arm0$results$time.pts, predict_mod2_arm0$results$surv, lwd = 2, col = "blue",  lty = 1)
lines(predict_mod2_arm1$results$time.pts, predict_mod2_arm1$results$surv, lwd = 2, col = "blue",  lty = 2)
lines(predict_mod3_arm0$results$time.pts, predict_mod3_arm0$results$surv, lwd = 2, col = "red",   lty = 1)
lines(predict_mod3_arm1$results$time.pts, predict_mod3_arm1$results$surv, lwd = 2, col = "red",   lty = 2)
lines(predict_mod4_arm0$results$time.pts, predict_mod4_arm0$results$surv, lwd = 2, col = "green", lty = 1)
lines(predict_mod4_arm1$results$time.pts, predict_mod4_arm1$results$surv, lwd = 2, col = "green", lty = 2)
grid()
legend("topright", bty = "n", title = "Treatment arm",
       legend = c("No immunotherapy","Immunotherapy"),
       lty = c(1,2), lwd = 2, col = "black", cex = 0.9)

## -----------------------------------------------------------------------------
sessionInfo()

