---
title: "Cross-validation with single algorithm"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Cross-validation with single algorithm}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.path = "../man/figures/README-"
  )

library(dplyr)

load("../data/star.rda")

# specifying the outcome
outcomes <- "g3tlangss"

# specifying the treatment
treatment <- "treatment"

# specifying the data (remove other outcomes)
star_data <- star %>% dplyr::select(-c(g3treadss,g3tmathss))

# specifying the formula
user_formula <- as.formula(
  "g3tlangss ~ treatment + gender + race + birthmonth + 
  birthyear + SCHLURBN + GRDRANGE + GKENRMNT + GKFRLNCH + 
  GKBUSED + GKWHITE ")
```


When users choose to estimate and evaluate ITR under cross-validation, the package implements Algorithm 1 from [Imai and Li
(2023)](https://arxiv.org/abs/1905.05389) to estimate and evaluate ITR. For more information about Algorithm 1, please refer to the [this page](../articles/paper_alg1.html).


Instead of specifying the `split_ratio` argument, 
we choose the number of folds (`n_folds`). We present an example of estimating ITR with 3 folds cross-validation. 
In practice, we recommend using 10 folds to get a more stable model performance. 


| Input                                                            | R package input                                               | Descriptions                                                                                                                                                         |
|:-----------------------------------------------------------------|:---------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Data $\mathbf{Z}=\left\{\mathbf{X}_i, T_i, Y_i\right\}_{i=1}^n$ | `treatment = treatment, form = user_formula, data = star_data` | `treatment` is a character string specifying the treatment variable in the `data`; `form` is a formula specifying the outcome and covariates; and a dataframe `data` |
| Machine learning algorithm $F$                                   | `algorithms = c("causal_forest")`                              | a character vector specifying the ML algorithms to be used                                                                                                           |
| Evaluation metric $\tau_f$                                       | PAPE, PAPD, AUPEC, GATE                                        | By default                                                                                                                                                           |
| Number of folds $K$                                              | `n_folds = 3`                                                  | `n_folds` is a numeric value indicating the number of folds used for cross-validation                                                                                |
| …                                                                | `budget = 0.2`                                                 | `budget` is a numeric value specifying the maximum percentage of population that can be treated under the budget constraint                                          |


```{r cv_estimate, message = FALSE, out.width = '60%'}
library(evalITR)
# estimate ITR 
set.seed(2021)
fit_cv <- estimate_itr(
               treatment = treatment,
               form = user_formula,
               data = star_data,
               algorithms = c("causal_forest"),
               budget = 0.2,
               n_folds = 3)

```

The output will be an object that
includes estimated evaluation metric $\hat{\tau}_F$ and the estimated
variance of $\hat{\tau}_F$ for different metrics (PAPE, PAPD, AUPEC).


```{r cv_eval, message = FALSE, out.width = '50%'}
# evaluate ITR 
est_cv <- evaluate_itr(fit_cv)

# summarize estimates
summary(est_cv)
```
