---
title: "General introduction: Survival on the RMS Titanic"
author: "Przemyslaw Biecek"
date: "`r Sys.Date()`"
output: rmarkdown::html_document
vignette: >
  %\VignetteIndexEntry{General introduction: Survival on the RMS Titanic}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
knitr::opts_chunk$set(
  collapse = FALSE,
  comment = "#>",
  warning = FALSE,
  message = FALSE
)
```

# Data for Titanic survival

Let's see an example for `DALEX` package for classification models for the survival problem for Titanic dataset.
Here we are using a dataset `titanic_imputed` avaliable in the `DALEX` package. Note that this data was copied from the `stablelearner` package and changed for practicality.

```{r}
library("DALEX")
head(titanic_imputed)
```

# Model for Titanic survival

Ok, now it's time to create a model. Let's use the Random Forest model.

```{r}
# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
                           fare + sibsp + parch,
                           data = titanic_imputed, probability = TRUE)
model_titanic_rf
```

# Explainer for Titanic survival

The third step (it's optional but useful) is to create a `DALEX` explainer for random forest model.

```{r}
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
                              data = titanic_imputed[,-8],
                              y = titanic_imputed[,8],
                              label = "Random Forest")
```

# Model Level Feature Importance

Use the `feature_importance()` explainer to present importance of particular features. Note that `type = "difference"` normalizes dropouts, and now they all start in 0.

```{r}
library("ingredients")

fi_rf <- feature_importance(explain_titanic_rf)
head(fi_rf)
plot(fi_rf)
```

# Feature effects

As we see the most important feature is `gender`. Next three importnat features are `class`, `age` and `fare`. Let's see the link between model response and these features.

Such univariate relation can be calculated with `partial_dependence()`.

## age

Kids 5 years old and younger have much higher survival probability.

### Partial Dependence Profiles

```{r}
pp_age  <- partial_dependence(explain_titanic_rf, variables =  c("age", "fare"))
head(pp_age)
plot(pp_age)
```

### Conditional Dependence Profiles

```{r}
cp_age  <- conditional_dependence(explain_titanic_rf, variables =  c("age", "fare"))
plot(cp_age)
```

### Accumulated Local Effect Profiles

```{r}
ap_age  <- accumulated_dependence(explain_titanic_rf, variables =  c("age", "fare"))
plot(ap_age)
```

# Instance level explanations

Let's see break down explanation for model predictions for 8 years old male from 1st class that embarked from port C.

First Ceteris Paribus Profiles for numerical variables

```{r}
new_passanger <- data.frame(
  class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
  gender = factor("male", levels = c("female", "male")),
  age = 8,
  sibsp = 0,
  parch = 0,
  fare = 72,
  embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)

sp_rf <- ceteris_paribus(explain_titanic_rf, new_passanger)
plot(sp_rf) +
  show_observations(sp_rf)
```

And for selected categorical variables. Note, that sibsp is numerical but here is presented as a categorical variable.

```{r}
plot(sp_rf,
     variables = c("class", "embarked", "gender", "sibsp"),
     variable_type = "categorical")
```

It looks like the most important feature for this passenger is `age` and `sex`. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite of being a male.


# Profile clustering

```{r}
passangers <- select_sample(titanic, n = 100)

sp_rf <- ceteris_paribus(explain_titanic_rf, passangers)
clust_rf <- cluster_profiles(sp_rf, k = 3)
head(clust_rf)
plot(sp_rf, alpha = 0.1) +
  show_aggregated_profiles(clust_rf, color = "_label_", size = 2)
```


# Session info

```{r}
sessionInfo()
```
