---
title: "Using flashlight"
bibliography: "biblio.bib"
link-citations: true
output: 
  rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Using flashlight}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  warning = FALSE,
  message = FALSE,
  fig.width = 5.5,
  fig.height = 4.5
)
```

## Overview

**No black-box model without XAI.** This is where packages like 

- [{DALEX}](https://CRAN.R-project.org/package=DALEX), 
- [{iml}](https://CRAN.R-project.org/package=iml), and 
- [{flashlight}](https://CRAN.R-project.org/package=flashlight) enter the stage.

{flashlight} offers the following XAI methods:

- `light_performance()`: Performance metrics like RMSE and/or $R^2$
- `light_importance()`: Permutation variable importance [@fisher]
- `light_ice()`: Individual conditional expectation (ICE) profiles [@goldstein] (centered or uncentered)
- `light_profile()`: Partial dependence [@friedman2001], accumulated local effects (ALE) [@apley], average predicted/observed/residual
- `light_profile2d()`: Two-dimensional version of `light_profile()`
- `light_effects()`: Combines partial dependence, ALE, response and prediction profiles
- `light_interaction()`: Different variants of Friedman's H statistics [@friedman2008]
- `light_breakdown()`: Variable contribution breakdown (approximate SHAP) for single observations [@gosiewska]
- `light_global_surrogate()`: Global surrogate trees [@molnar]

Good to know:

- Each method acts on an explainer object called `flashlight` (see examples and Section "flashlights"). 
- Multiple models can be compared via `multiflashlight()`.
- Calling `plot()` visualizes the results via {ggplot2}.
- Methods support case weights.
- Methods support a grouping variable.

## Installation

```r
# From CRAN
install.packages("flashlight")

# Development version
devtools::install_github("mayer79/flashlight")
```

## Usage

Let's start with an iris example. For simplicity, we do not split the data into training and testing/validation sets.

```{r}
library(ggplot2)
library(MetricsWeighted)
library(flashlight)

fit_lm <- lm(Sepal.Length ~ ., data = iris)

# Make explainer object
fl_lm <- flashlight(
  model = fit_lm, 
  data = iris, 
  y = "Sepal.Length", 
  label = "lm",               
  metrics = list(RMSE = rmse, `R-squared` = r_squared)
)
```

### Performance

```{r}
fl_lm |> 
  light_performance() |> 
  plot(fill = "darkred") +
  labs(x = NULL, title = "Performance on training data")

fl_lm |> 
  light_performance(by = "Species") |> 
  plot(fill = "darkred") +
  ggtitle("Performance split by Species")
```

### Permutation importance regarding first metric

Error bars represent standard errors, i.e., the uncertainty of the estimated importance.

```{r}
fl_lm |>
  light_importance(m_repetitions = 4) |> 
  plot(fill = "darkred") +
  labs(title = "Permutation importance", y = "Increase in RMSE")
```

### ICE curves for `Petal.Width`

```{r}
fl_lm |> 
  light_ice("Sepal.Width", n_max = 200) |> 
  plot(alpha = 0.3, color = "chartreuse4") +
  labs(title = "ICE curves for 'Sepal.Width'", y = "Prediction")

fl_lm |> 
  light_ice("Sepal.Width", n_max = 200, center = "middle") |> 
  plot(alpha = 0.3, color = "chartreuse4") +
  labs(title = "c-ICE curves for 'Sepal.Width'", y = "Prediction (centered)")
```
### PDPs

```{r}
fl_lm |> 
  light_profile("Sepal.Width", n_bins = 40) |> 
  plot() +
  ggtitle("PDP for 'Sepal.Width'")

fl_lm |> 
  light_profile("Sepal.Width", n_bins = 40, by = "Species") |> 
  plot() +
  ggtitle("Same grouped by 'Species'")
```

### 2D PDP

```{r}
fl_lm |> 
  light_profile2d(c("Petal.Width", "Petal.Length")) |> 
  plot()
```

### ALE

```{r}
fl_lm |> 
  light_profile("Sepal.Width", type = "ale") |> 
  plot() +
  ggtitle("ALE plot for 'Sepal.Width'")
```

### Different profile plots in one

```{r}
fl_lm |> 
  light_effects("Sepal.Width") |> 
  plot(use = "all") +
  ggtitle("Different types of profiles for 'Sepal.Width'")
```

### Variable contribution breakdown for single observation

```{r}
fl_lm |> 
  light_breakdown(new_obs = iris[1, ]) |> 
  plot()
```

### Global surrogate tree

```{r}
fl_lm |> 
  light_global_surrogate() |> 
  plot()
```
### Multiple models

Multiple flashlights can be combined to a multiflashlight.

```{r}
library(rpart)

fit_tree <- rpart(
  Sepal.Length ~ ., 
  data = iris, 
  control = list(cp = 0, xval = 0, maxdepth = 5)
)

# Make explainer object
fl_tree <- flashlight(
  model = fit_tree, 
  data = iris, 
  y = "Sepal.Length", 
  label = "tree",               
  metrics = list(RMSE = rmse, `R-squared` = r_squared)
)

# Combine with other explainer
fls <- multiflashlight(list(fl_tree, fl_lm))

fls |> 
  light_performance() |> 
  plot(fill = "chartreuse4") +
  labs(x = "Model", title = "Performance")

fls |> 
  light_importance() |> 
  plot(fill = "chartreuse4") +
  labs(y = "Increase in RMSE", title = "Permutation importance")

fls |> 
  light_profile("Petal.Length", n_bins = 40) |> 
  plot() +
  ggtitle("PDP")

fls |> 
  light_profile("Petal.Length", n_bins = 40, by = "Species") |> 
  plot() +
  ggtitle("PDP by Species")
```

### flashlights

The "flashlight" explainer expects the following information:

- `model`: Fitted model. Currently, this argument must be named.
- `data`: Reference data used to calculate things, often part of the validation data.
- `y`: Column name in `data` corresponding to the **numeric** response.
- `predict_function`: function of the same signature as `stats::predict()`. It takes a `model` and a data.frame `data`, and provides numeric predictions, see below for more details.
- `linkinv`: Optional function applied to the output of `predict_function()`. *Should actually be called "trafo".*
- `w`: Optional column name in `data` corresponding to case weights.
- `by`: Optional column name in `data` used to group the results. Must be discrete.
- `metrics`: List of metrics, by default `list(rmse = MetricsWeighted::rmse)`. For binary (probabilistic) classification, good candidate metrics would be `MetricsWeighted::logLoss`.
- `label`: Mandatory name of the model.

#### Typical `predict_function`s (a selection)

The default `stats::predict()` works for models of class

- `lm()`, 
- `glm()` (for predictions on link scale), and
- `rpart()`.

It also works for meta-learner models like

- {caret}, and
- {mlr3}.
  
Manual prediction functions are, e.g., required for 

- {ranger}: Use `function(m, X) predict(m, X)$predictions` for regression, and 
  `function(m, X) predict(m, X)$predictions[, 2]` for probabilistic binary classification
- `glm()`: Use `function(m, X) predict(m, X, type = "response")` to get GLM predictions at the response scale

A bit more complicated are models whose native predict function do not work on data.frames:

- {xgboost} and {lightgbm}: They digest numeric matrices only, so the prediction function also needs to deal with the mapping from data.frame to matrix. 
- {keras}: It might accept data.frame inputs, but we need to take care of scalings.

**Example (XGBoost):**

This works when non-numeric features are all factors (not categoricals):

```r
x <- vector of features
predict_function = function(m, df) predict(m, data.matrix(df[x]))
```

## References
