---
title: Prediction model class (`learner`)
author: Klaus Kähler Holst, Benedikt Sommer
date: "`r Sys.Date()`"
format:
  html:
    toc: true
    html-math-method: mathjax
execute:
  echo: true
  warning: false
  collapse: true
  comment: "#>"
vignette: >
  %\VignetteIndexEntry{Prediction models}
  %\VignetteEngine{quarto::html}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
library("targeted")
```

# Introduction

Methods for targeted and semiparametric inference rely on fitting nuisance
models to observed data when estimating the target parameter of interest. The
`{targeted}` package implements the [R6 reference class](https://r6.r-lib.org/)
`learner` to harmonize common statistical and machine learning models for the
usage as nuisance models across the various implemented estimators. Commonly
used models are constructed as `learner` class objects through the `learner_`
functions. These functions are wrappers for statistical and machine learning
models that have been implemented in other R packages. This is conceptually
similar to packages such as `{caret}` and `{mlr3}`. Besides implementing a large
number of prediction models, these packages provide extensive functionalities
for data preprocessing, model selection and diagnostics, and hyper-parameter
optimization. The design of the `learner` class and accompanied functions is
much leaner, as our use-cases often only require a standardized interface for
estimating models and generating predictions, instead of the full predictive
modelling pipeline.

This vignette uses the Mayo Clinic Primary Biliary Cholangitis data
(`?survival::pbc`) to exemplify the usage of the `learner` class objects. Our
interest lies in the prediction of the composite event of death or transplant
before 2 years.
```{r pbcdata}
data(pbc, package="survival")
pbc <- transform(pbc, y = (time < 730) * (status > 0))
```

A logistic regression model with a single `age` covariate is defined and
estimated via

```{r logistic1}
lr <- learner_glm(y ~ age, family = binomial())
lr$estimate(pbc)
```

and predictions for the event `y = 1` (class 2 probabilities) for new data are
obtained by

```{r logistic2}
lr$predict(newdata = data.frame(age = c(20, 40, 60, 80)))
```

The remainder of this vignette is structured as follows. We first provide more
details about the built-in learners that wrap statistical and machine
learning models from other R packages. Once the usage of the returned
`learner` class objects has been introduced, we provide details to implement
new learners.

# Built-in learners

The various constructors for commonly used statistical and machine learning
models are listed as part of the `learner` class documentation
```{r, eval = FALSE}
?learner # help(learner)
```
which contains all essential information for the usage of the `learner` class
objects. With the exception of `learner_sl`, all constructors require a
`formula` argument that specify the response vector and design matrix for the
underlying model. This way the construction of a learner is separated from the
estimation of the model, as shown in the above logistic regression example. A
result of this design is the convenient specification of ensemble learners
(superlearner), where individual learners operate on different subsets of
features.

```{r}
lr_sl <- learner_sl(
  learners = list(
    glm1 = learner_glm(y ~ age * bili, family = "binomial"),
    glm2 = learner_glm(y ~ age, family = "binomial"),
    gam = learner_gam(y ~ s(age) + s(bili), family = "binomial")
  )
)
lr_sl$estimate(pbc, nfolds = 10)
lr_sl
```

Most constructors have additional arguments that impact the resulting model fit,
ranging from the specification of a link function for generalized linear models
to hyper hyperparameters of machine learning models. Arguments naturally vary
between the constructors and are documented for each function (e.g.
`?learner_gam`). The implemented constructors only provide arguments for the
most relevant parameters of the underlying fit function (`mgcv::gam` in case of
`learner_gam`). As described in the documentation, the ellipsis argument (`...`)
can be used to pass additional arguments to the fit function. However, in most
situations this is rarely required.

It is often necessary to consider a range of models with different
hyper-parameters and to facilitate this we can use the `learner_expand_grid`
function
```{r learner_expand_grid}
lrs <- learner_expand_grid( learner_xgboost,
                            list(formula = Sepal.Length ~ .,
                                eta = c(0.2, 0.5, 0.3)) )
lrs
```

The list of models can then serve as inputs for `predictor_sl` or for assessing
the generalization error of the model across different hyper-parameters with the
`cv` method.

# Usage

The basic usage is to construct a new learner by providing the `formula`
argument and the set of additional parameters that control the model fitting
process and the task (binary classification in this example).
```{r}
lr_xgboost <- learner_xgboost(
  formula = y ~ age + sex + bili,
  eta = 0.3, nrounds = 5,  # hyperparameters
  objective = "binary:logistic" # learning task
)
lr_xgboost
```

The model is estimated via the `estimate` method
```{r}
lr_xgboost$estimate(data = pbc)
```
The default behavior is to assign the fitted model to the learner object, which
can be accessed via
```{r}
class(lr_xgboost$fit)
```

Once the model has been fitted, predictions are generated with
```{r}
lr_xgboost$predict(head(pbc))
```
where the fitted model is used inside the learner object to generate the
predictions.

S3 methods are implemented for the `learner` class objects as an alternative to
the R6 class syntax for fitting the model and making predictions.
```{r}
lr <- learner_glm(y ~ age, family = "binomial")
estimate(lr, pbc)
predict(lr, head(pbc))
```

## Update formula

The estimate and predict functions of `learner` class objects are by design
immutable. Both functions can be inspected as part of the return values of the
summary method
```{r}
lr_xgboost$summary()$estimate
```

Rare situations may arise where one wants to update the formula argument. This
supported and implemented via the `update` method
```{r}
lr_xgboost$update(y ~ age + sex)
```
The design matrix that results from by the specified formula can be inspected
via
```{r}
head(lr_xgboost$design(pbc)$x)
```
and the response vector via
```{r}
head(lr_xgboost$response(pbc))
```
See `?learner` for the differences between `learner$design` and
`learner$response` and `?design` for more details about the construction of the
design matrix and response vector.

## Cross-validation

The `{targeted}` package provides a generic implementation for repeated k-fold
cross-validation with `learner` class objects. Parallelization is supported via
the `{future}` and `{parallel}` packages (see `?cv` for more details).
```{r}
# future::plan("multicore")
lrs <- list(
  glm = learner_glm(y ~ age + age, family = "binomial"),
  gam = learner_gam(y ~ s(age) + s(bili), family = "binomial")
)
 # 2 times repeated 5-fold cross-validation
cv(lrs, data = pbc, rep = 2, nfolds = 5)
```

# Defining new learners

Statistical or machine learning models for which no constructors are provided
can be implemented with a few lines of code. In what follows we cover three
general cases where the input to the fit function differs.

## Fit function with formula and data arguments

The first general case covers fit functions which expect a formula and data
argument. Both arguments are used then by the fitting routine to construct the
design matrix and response vector. Statistical models are usually implement with
such an interface, with examples including `stats::glm`, `mgcv::gam` and
`earth::earth`. The `learner` R6 class supports these fitting routines by
checking if the provided estimate function has a `formula` and `data` argument.
If it does, then the `formula` and `data` argument are passed on to the estimate
function without further modifications. This is exemplified in the following for
`stats::glm`, where we define a new constructor to return a `learner` class
object that fits a generalized model
```{r}
new_glm <- function(formula, ...) {
  learner$new(
    formula = formula,
    estimate = stats::glm,
    predict = stats::predict,
    predict.args = list(type = "response"),
    estimate.args = list(...),
    info = "new glm learner" # optional
  )
}
lr <- new_glm(y ~ age, family = "binomial")
lr
```
It can be seen that the optional arguments of `new_glm` define the
`estimate.args` of a constructed learner object. When estimating the model with
```{r}
lr$estimate(pbc)
```
the parameters defined in `estimate.args` are passed on with the `formula` and
`data` to `stats::glm` (i.e. the function specified via the `estimate`
argument). Thus, the above is equivalent to
```{r}
fit <- glm(y ~ age, family = "binomial", data = pbc)
all(coef(fit) == coef(lr$fit))
```
The code further instructs to construct a learner object that uses
`stats::predict` to make predictions. The learner object similarly passes on the
`predict.args` to `stats::predict` for
```{r}
lr$predict(head(pbc))
```
which is equivalent to
```{r}
predict(fit, newdata = head(pbc), type = "response")
```
Indeed, the documentation of `?learner` reveals that the predict method always
requires an `object` and `newdata` argument. An important implementation detail
is that the `learner` R6 class allows to overrule the defined `estimate.args`
and `predict.args` in the method calls.
```{r}
lr$estimate(pbc, family = "poisson")
lr$predict(head(pbc), type = "link")
```

## Fit function with x and y arguments

Most fitting routines for machine learning models expect a design matrix and
response vector as inputs. The R6 `learner` class supports these fitting
functions by internally processing the `formula` argument via the
`targeted::design` function. Take the example of `grf::probability_forest`,
which expects a `X` (design matrix) and `Y` (response vector) as inputs.
```{r}
new_grf <- function(formula, ...) {
  learner$new(
    formula = formula,
    estimate = function(x, y, ...) grf::probability_forest(X = x, Y = y, ...),
    predict = function(object, newdata) {
      predict(object, newdata = newdata)$predictions
    },
    estimate.args = list(...),
    info = "grf::probability_forest"
  )
}
lr <- new_grf(as.factor(y) ~ age + bili, num.trees = 100)
lr$estimate(pbc)
```
Compared to the previous case, the `pbc` data object is not directly passed on
to the fit function. Instead,`targeted::design` constructs the design matrix and
response vector from the defined `formula` argument. As shown previously,
```{r}
dsgn <- lr$design(pbc)
```
can be used to inspect the design object. The `x` and `y` attributes of the
returned `design` object are then passed on to the fit function.

## Fit function with single data argument

To support ensemble/meta learners, it is also possible to construct a learner
without providing a formula argument.
```{r}
new_sl <- function(learners, ...) {
  learner$new(
    info = "new superlearner",
    estimate = superlearner,
    predict = targeted:::predict.superlearner,
    estimate.args = c(list(learners = learners), list(...))
  )
}
lrs <- list(
  glm = learner_glm(y ~ age, family = "binomial"),
  gam = learner_gam(y ~ s(age), family = "binomial")
)
lr <- new_sl(lrs, nfolds = 2)
lr$estimate(pbc)
lr
```
In this case, all arguments provided to `lr$estimate` are joined together with
the specified `estimate.args` and passed on to the defined estimate function (i.e.
`superlearner`).

## Specials in formula

Certain learner functions allow for specifying special terms in the formula. For
example, let’s consider an aggregated dataset that includes only the response
variable, treatment, and sex:
```{r naivebayes_aggregate_data}
library("data.table")
dd <- data.table(pbc)[!is.na(trt), .(.N), by=.(y,trt,sex)]
print(dd)
```

Next, we fit a Naive Bayes classifier using the `weights` special term to specify
the frequency weights for the estimation
```{r naivevbayes}
lr <- learner_naivebayes(y ~ trt + sex + weights(N))
lr$estimate(dd)
lr$predict(dd)
```

Here `weights.numeric` is simply the identity function
```{r}
targeted:::weights.numeric
```

To illustrate how to define a custom learner that utilizes special
terms, consider the following example where we introduce a `strata` term. In
this case, we introduce an estimation method that fits a linear regression model for each
value of the strata variable, as well as a corresponding prediction method
```{r eststrata}
est <- function(formula, data, strata, ...)
  lapply(levels(strata), \(s) lm(formula, data[which(strata==s),]))

pred <- function(object, newdata, strata, ...) {
  res <- numeric(length(strata))
  for (i in seq_along(levels(strata))) {
    idx <- which(strata == levels(strata)[i])
    res[idx] <- predict(object[[i]], newdata[idx, ], ...)
  }
  return(res)
}
```

The new learner is now defined by including the argument `specials
= "strata", which ensures that the strata variable is correctly passed to both
the estimate and predict functions
```{r lr_strata}
lr <- learner$new(y ~ sex + strata(trt),
                  estimate=est, predict=pred, specials = "strata")
des <- lr$design(head(pbc))
des
des$strata
```

```{r lr_strata_est}
lr$estimate(pbc)
lr
lr$predict(head(pbc))
```
