## ----include = FALSE----------------------------------------------------------
options(rmarkdown.html_vignette.check_title = FALSE)
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----rfae---------------------------------------------------------------------
# Load libraries
library(RFAE)
library(data.table)
library(ggplot2)
library(arf)
library(ranger)
set.seed(42)

# Train-test split
trn <- sample(1:nrow(iris), 100)
tst <- setdiff(1:nrow(iris), trn)

# Train a RF and an ARF
rf <- ranger(Species ~., data = iris[trn, ], num.trees=50)
arf <- adversarial_rf(iris[trn, ], num_trees = 50, parallel = FALSE)

## ----par, eval=FALSE----------------------------------------------------------
# # Register cores - Unix
# library(doParallel)
# registerDoParallel(cores = 2)

## ----par2, eval=FALSE---------------------------------------------------------
# # Register cores - Windows
# library(doParallel)
# cl <- makeCluster(2)
# registerDoParallel(cl)

## ----rfae3--------------------------------------------------------------------
# Rerun in parallel
arf <- adversarial_rf(iris[trn, ], num_trees=50)
rf <- ranger(Species ~., iris[trn, ], num.trees=50)

## ----encoding-----------------------------------------------------------------
# One encoding for each type of RF
# We choose k=2 to allow visualisation
emap_arf <- encode(arf, iris[trn, ], k=2)
emap_rf <- encode(rf, iris[trn, ], k=2)

# Print out first five rows of embeddings
# The first five rows of x
iris[trn, ][1:5, ]
# The first five embedded samples for ARF
emap_arf$Z[1:5, ]
# The first five embedded samples for RF
emap_rf$Z[1:5, ]

## ----encoding2, fig.height=5, fig.width=7-------------------------------------
# Plot the embedded training data
tmp <- data.frame(
  dim1 = emap_arf$Z[, 1],
  dim2 = emap_arf$Z[, 2],
  class = iris[trn, ]$Species
)
ggplot(tmp, aes(x = dim1, y = dim2, color = class)) +
  geom_point(size = 2, alpha = 0.8) +
  theme_minimal() +
  labs(
    x = "Diffusion Component 1",
    y = "Diffusion Component 2",
    color = "Species"
  )

## ----encoding4----------------------------------------------------------------
A <- emap_arf$A
A[1:5, 1:5]

## ----encoding5, fig.height=5, fig.width=7-------------------------------------
# Project testing data
emb <- predict(emap_arf, arf, iris[tst, ])

# Plot test embeddings
tmp <- data.frame(
  dim1 = emb[, 1],
  dim2 = emb[, 2],
  class = iris[tst, ]$Species
)

ggplot(tmp, aes(x = dim1, y = dim2, color = class)) +
  geom_point(size = 2, alpha = 0.8) +
  theme_minimal() +
  labs(
    x = "Diffusion Component 1",
    y = "Diffusion Component 2",
    color = "Species"
  )

## ----decoding1----------------------------------------------------------------
# Decode data
out <- decode_knn(arf, emap_arf, emb)
# Reconstructed testing data
out$x_hat[1:5, ]
# Original testing data
iris[tst, ][1:5, ]

## ----errors-------------------------------------------------------------------
errors <- reconstruction_error(out$x_hat, iris[tst, ])
# Error in numerical features
errors$num_error
# Error in categorical features
errors$cat_error
# Average numerical error
errors$num_avg
# Average categorical error
errors$cat_avg
# Overall error
errors$ovr_error

## ----errors2, fig.height=5, fig.width=7---------------------------------------

# Plotting the errors by each feature
error_df <- data.frame(
  Variable = c(names(errors$num_error), names(errors$cat_error)),
  Error = c(unlist(errors$num_error), unlist(errors$cat_error)),
  Type = c(rep("Numeric", length(errors$num_error)), 
           rep("Categorical", length(errors$cat_error)))
)

ggplot(error_df, aes(x = reorder(Variable, Error), y = Error, fill = Type)) +
  geom_bar(stat = "identity", width = 0.7) +
  geom_hline(aes(yintercept = errors$ovr_error), 
             linetype = "dashed", color = "red") +
  annotate("text", x = 1.5, y = errors$ovr_error + 0.02, 
           label = paste("Avg Error:", round(errors$ovr_error, 3))) +
  theme_minimal() +
  labs(
    title = "Reconstruction Error by Feature",
    x = NULL,
    y = "Distortion"
  )

