---
title: "How to use breakDown package for models created with caret"
author: "Przemyslaw Biecek"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{model agnostic breakDown plots for caret}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
```

This example demonstrates how to use the `breakDown` package for models created with the [caret](https://CRAN.R-project.org/package=caret) package. 

First we will generate some data.

```{r, warning=FALSE, message=FALSE}
library(caret)

set.seed(2)
training <- twoClassSim(50, linearVars = 2)
trainX <- training[, -ncol(training)]
trainY <- training$Class

head(training)
```

Now we are ready to train a model. Let's train a `glm` model with `caret`. 

```{r}
cctrl1 <- trainControl(method = "cv", number = 3, returnResamp = "all",
                       classProbs = TRUE, 
                       summaryFunction = twoClassSummary)

test_class_cv_model <- train(trainX, trainY, 
                             method = "glm", 
                             trControl = cctrl1,
                             metric = "ROC", 
                             preProc = c("center", "scale"))
test_class_cv_model
```

To use `breakDown` we need a function that will calculate scores/predictions for a single observation. By default the `predict()` function returns predicted class. 

So we are adding `type = "prob"` argument to get scores. And since there will be two scores for each observarion we need to extract one of them.

```{r}
predict.fun <- function(model, x) predict(model, x, type = "prob")[,1]
testing <- twoClassSim(10, linearVars = 2)
predict.fun(test_class_cv_model, testing[1,])
```

Now we are ready to call the `broken()` function.

```{r}
library("breakDown")
explain_2 <- broken(test_class_cv_model, testing[1,], data = trainX, predict.function = predict.fun)
explain_2
```

And plot it.

```{r, fig.width=7}
library(ggplot2)
plot(explain_2) + ggtitle("breakDown plot for caret/glm model")
```


