---
title: "Working with Distributions"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Working with Distributions}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
```

## Distributions

Distributions are a set of classes available in `{reservr}` to specify distribution families of random variables.
A Distribution inherits from the R6 Class `Distribution` and provides all functionality necessary for working with a
specific family.

A Distribution can be defined by calling one of the constructor functions, prefixed by `dist_` in the package.
All constructors accept parameters of the family as arguments.
If these arguments are specified, the corresponding parameter is considered _fixed_ in the sense that it need not be
specified when computing something for the distribution and it will be assumed fixed when calling `fit()` on the
distribution instance.

### Sample

For example, an unspecified normal distribution can be created by calling `dist_normal()` without arguments.
This means the parameters `mean` and `sd` are considered _placeholders_.
If we want to, e.g., sample from `norm`, we must specify these placeholders in the `with_params` argument:

```{r}
library(reservr)
set.seed(1L)
# Instantiate an unspecified normal distribution
norm <- dist_normal()
x <- norm$sample(n = 10L, with_params = list(mean = 3, sd = 1))

set.seed(1L)
norm2 <- dist_normal(sd = 1)
x2 <- norm2$sample(n = 10L, with_params = list(mean = 3))

# the same RVs are drawn because the distribution parameters and the seed were the same
stopifnot(identical(x, x2))
```

### Density

The `density()` function computes the density of the distribution with respect to its natural measure.
Use `is_discrete_at()` to check if a point has discrete mass or lebesgue density.

```{r}
norm$density(x, with_params = list(mean = 3, sd = 1))
dnorm(x, mean = 3, sd = 1)
norm$density(x, log = TRUE, with_params = list(mean = 3, sd = 1)) # log-density
norm$is_discrete_at(x, with_params = list(mean = 3, sd = 1))

# A discrete distribution with mass only at point = x[1].
dd <- dist_dirac(point = x[1])
dd$density(x)
dd$is_discrete_at(x)
```

`diff_density()` computes the gradient of the density with respect to each free parameter.
Setting `log = TRUE` computes the gradient of the log-density, i.e., the gradient of `log f(x, params)` instead.

```{r}
norm$diff_density(x, with_params = list(mean = 3, sd = 1))
```

### Probability

With `probability()`, the c.d.f., survival function, and their logarithms can be computed.
For discrete distributions, `dist$probability(x, lower.tail = TRUE)` returns $P(X \le x)$ and
`dist$probability(x, lower.tail = FALSE)` returns $P(X > x)$.

```{r}
norm$probability(x, with_params = list(mean = 3, sd = 1))
pnorm(x, mean = 3, sd = 1)

dd$probability(x)
dd$probability(x, lower.tail = FALSE, log.p = TRUE)
```

Gradients of the (log-)c.d.f. or survival function with respect to parameters can be computed using
`diff_probability()`.

```{r}
norm$diff_probability(x, with_params = list(mean = 3, sd = 1))
```

### Hazard

The hazard rate is defined by $h(x, \theta) = f(x, \theta) / S(x, \theta)$, i.e., the ratio of the density to the
survival function.

```{r}
norm$hazard(x, with_params = list(mean = 3, sd = 1))
norm$hazard(x, log = TRUE, with_params = list(mean = 3, sd = 1))
```

### Fitting

The `fit()` generic is defined for Distributions and will perform maximum likelihood estimation.
It accepts a weighted, censored and truncated sample of class `trunc_obs`, but can automatically convert uncensored,
untruncated observations without weight into the proper `trunc_obs` object.
```{r}
# Fit with mean, sd free
fit1 <- fit(norm, x)
# Fit with mean free
fit2 <- fit(norm2, x)
# Fit with sd free
fit3 <- fit(dist_normal(mean = 3), x)

# Fitted parameters
fit1$params
fit2$params
fit3$params

# log-Likelihoods can be computed on
AIC(fit1$logLik)
AIC(fit2$logLik)
AIC(fit3$logLik)

# Convergence checks
fit1$opt$message
fit2$opt$message
fit3$opt$message
```

### Fitting censored data

You can also fit interval-censored data.
```{r}
params <- list(mean = 30, sd = 10)
x <- norm$sample(100L, with_params = params)
xl <- floor(x)
xr <- ceiling(x)

cens_fit <- fit(norm, trunc_obs(xmin = xl, xmax = xr))
print(cens_fit)
```

### Fitting truncated data

It is possible to fit randomly truncated samples, i.e., samples where the truncation bound itself is also random and
differs for each observed observation.
```{r}
params <- list(mean = 30, sd = 10)
x <- norm$sample(100L, with_params = params)
tl <- runif(length(x), min = 0, max = 20)
tr <- runif(length(x), min = 0, max = 60) + tl

# truncate_obs() also truncates observations.
# if data is already truncated, use trunc_obs(x = ..., tmin = ..., tmax = ...) instead.
trunc_fit <- fit(norm, truncate_obs(x, tl, tr))
print(trunc_fit)

attr(trunc_fit$logLik, "nobs")
```

### Plotting

Visualising different distributions, or parametrizations, e.g., fits, can be done with `plot_distributions()`

```{r}
# Plot fitted densities
plot_distributions(
  true = norm,
  fit1 = norm,
  fit2 = norm2,
  fit3 = dist_normal(3),
  .x = seq(-2, 7, 0.01),
  with_params = list(
    true = list(mean = 3, sd = 1),
    fit1 = fit1$params,
    fit2 = fit2$params,
    fit3 = fit3$params
  ),
  plots = "density"
)

# Plot fitted densities, c.d.f.s and hazard rates
plot_distributions(
  true = norm,
  cens_fit = norm,
  trunc_fit = norm,
  .x = seq(0, 60, length.out = 101L),
  with_params = list(
    true = list(mean = 30, sd = 10),
    cens_fit = cens_fit$params,
    trunc_fit = trunc_fit$params
  )
)

# More complex distributions
plot_distributions(
  bdegp = dist_bdegp(2, 3, 10, 3),
  .x = c(seq(0, 12, length.out = 121), 1.5 - 1e-6),
  with_params = list(
    bdegp = list(
      dists = list(
        list(), list(), list(
          dists = list(
            list(
              dist = list(
                shapes = as.list(1:3),
                scale = 2.0,
                probs = list(0.2, 0.5, 0.3)
              )
            ),
            list(
              sigmau = 0.4,
              xi = 0.2
            )
          ),
          probs = list(0.7, 0.3)
        )
      ),
      probs = list(0.15, 0.1, 0.75)
    )
  )
)
```
