---
title: "Using Arena with multiclass classifiers"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Using Arena with classifiers}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  eval = FALSE
)
```

## Preface
Multiclass classifiers are not supported by default in Arena. To overcome that
ArenaR splits them into binary classifiers for each class.
Result of this example if available at [link](https://arena.drwhy.ai/branch/dev/?data=https://gist.githubusercontent.com/piotrpiatyszek/9e24710b794c377fa4740ee4651bf1b7/raw/b6d6036ccd8ea317920ebe977fa4c3ef9638dba6/data.json)

## Load data & libraries
```{r setup, message=FALSE, warning=FALSE}
library(arenar)
library(dplyr)
library(DALEX)
library(MASS)
library(gbm)

#data set
HR <- DALEX::HR

# Get random 10 samples to explain it
observations <- HR[sample(1:nrow(HR), size=10), ]
# Name observations
rownames(observations) <- paste0(
  toupper(substr(observations$gender, 0, 1)),
  substr(observations$gender, 2, 100),
  " ",
  round(observations$age),
  "yr",
  " Grade: ",
  observations$evaluation
)
```

## Models
```{r}
model_gbm <- gbm(status ~ ., data=HR, n.trees=100, interaction.depth = 3)
model_lda <- lda(status ~ ., data=HR)
```

## Create Explainers & Arena
```{r}
# Create explainers
explainer_gbm <- DALEX::explain(model_gbm, data=HR, y=HR$status)
# For LDA we need to set model_info manualy
explainer_lda <- DALEX::explain(
  model_lda,
  data=HR,
  y=HR$status,
  model_info=list(package="MASS", ver="", type="multiclass"),
  predict_function = function(m, x) predict(m, x)$posterior
)

# Create new arena and add prepared observations and explainers
arena <- 
  create_arena() %>%
  push_observations(observations) %>%
  push_model(explainer_gbm) %>%
  push_model(explainer_lda)

# Now you can see that each explainer was splited into three
print(arena)
# ===== Static Arena Summary =====
# Models: gbm [fired], gbm [ok], gbm [promoted], lda [fired], lda [ok], lda [promoted] 
# Observations: Male 31yr Grade: 2, Female 24yr Grade: 4, Female 21yr Grade: 3, Female 25yr Grade: 3, Male 43yr Grade: 3, Female 32yr Grade: 5, Female 54yr Grade: 2, Female 32yr Grade: 2, Male 21yr Grade: 2, Male 41yr Grade: 2 
# Variables: gender, age, hours, evaluation, salary 
# Plots count: 510 
# NULL

# Upload arena
if (interactive()) upload_arena(arena)
```

## Make explainers manualy
Sometimes is usefull to make this process manualy.
```{r}
# Create new arena and add prepared observations
arena <- create_arena() %>% push_observations(observations)

# Levels of target variable
levels(HR$status)
# [1] "fired"    "ok"       "promoted"

# For each target level create explainers
for (status in levels(HR$status)) {
  # Explainer for gbm
  explainer_gbm <- explain(
    model_gbm,
    # Target variable as 0,1 for each level
    y = as.numeric(HR$status == status),
    data = HR[, -6], # Remove target variable
    label = paste0("GBM [", status, "]"),
    # In predict function we need to extract class probability
    predict_function = function(m, x) predict(m, x, n.trees=100, type="response")[,status,]
  )
  # Explainer for lda
  explainer_lda <- explain(
    model_lda,
    # Target variable as 0,1 for each level
    y = as.numeric(HR$status == status),
    data = HR[, -6], # Remove target variable
    label = paste0("LDA [", status, "]"),
    # In predict function we need to extract class probability
    predict_function = function(m, x) predict(m, x)$posterior[,status]
  )
  # Add explainers
  arena <- push_model(arena, explainer_gbm)
  arena <- push_model(arena, explainer_lda)

  # Upload arena
  if (interactive()) upload_arena(arena)
}
```
