## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----message=FALSE, warning=FALSE, fig.width=7, fig.height=6, fig.align='center'----
# Load the package
library(AddiVortes)

# --- Generate Training Data ---
set.seed(42) # for reproducibility

# Create a 5-column matrix of predictors
X <- matrix(runif(2500), ncol = 5)
X[,1] <- -10 - X[,1] * 10
X[,2] <- X[,2] * 100
X[,3] <- -9 + X[,3] * 10
X[,4] <- 8 + X[,4]
X[,5] <- X[,5] * 10

# Create the response 'Y' based on a rule and add noise
Y_underlying <- ifelse(-1 * X[,2] > 10 * X[,1] + 100, 10, 0)
Y <- Y_underlying + rnorm(length(Y_underlying))

# Visualise the relationships in the data
# The colours show the two underlying groups
pairs(X, 
      col = ifelse(Y_underlying == 10, "red", "blue"),
      pch = 19, cex = 0.5,
      main = "Structure of Predictor Variables")

## ----results='hide'-----------------------------------------------------------
# Fit the model
AModel <- AddiVortes(Y, X, m = 50, showProgress = FALSE)


## -----------------------------------------------------------------------------
# We can check the in-sample Root Mean Squared Error
cat("In-sample RMSE:", AModel$inSampleRmse, "\n")

## -----------------------------------------------------------------------------
# --- Generate Test Data ---
set.seed(101) # Use a different seed for the test set
testX <- matrix(runif(1000), ncol = 5)
testX[,1] <- -10 - testX[,1] * 10
testX[,2] <- testX[,2] * 100
testX[,3] <- -9 + testX[,3] * 10
testX[,4] <- 8 + testX[,4]
testX[,5] <- testX[,5] * 10

# Create the true test response values
testY_underlying <- ifelse(-1 * testX[,2] > 10 * testX[,1] + 100, 10, 0)
testY <- testY_underlying + rnorm(length(testY_underlying))

# --- Make Predictions ---
# Predict the mean response
preds <- predict(AModel, testX,
                 showProgress = FALSE)

# Predict the 90% credible interval (from 0.05 to 0.95 quantiles)
# By default, this uses interval = "credible" which only accounts for
# uncertainty in the mean function
preds_q <- predict(AModel, testX,
                   "quantile", c(0.05, 0.95), showProgress = FALSE)

# For prediction intervals that also include the model's error variance
# (similar to lm's prediction intervals), use interval = "prediction"
# preds_q_pred <- predict(AModel, testX,
#                         "quantile", c(0.05, 0.95), 
#                         interval = "prediction",
#                         showProgress = FALSE)

## ----fig.width=7, fig.height=6, fig.align='center'----------------------------
# Plot observed vs. predicted values
plot(testY,
     preds,
     xlab = "Observed Values",
     ylab = "Predicted Mean Values",
     main = "Out-of-Sample Prediction Performance",
     xlim = range(c(testY, preds_q)),
     ylim = range(c(testY, preds_q)),
     pch = 19, col = "darkblue"
)

# Add error lines for the 90% credible interval
for (i in 1:nrow(preds_q)) {
  segments(testY[i], preds_q[i, 1],
           testY[i], preds_q[i, 2],
           col = rgb(0, 0, 0.5, 0.5), lwd = 1.5
  )
}

# Add lines showing the true underlying means
lines(c(min(testY)-0.2, 3), c(0, 0), col = "pink", lwd = 3, lty = 2)
lines(c(7, max(testY)+0.2), c(10, 10), col = "pink", lwd = 3, lty = 2)

# Add a legend
legend("bottomright",
       legend = c("Predicted Mean & 90% Interval",
                  "True Underlying Mean"),
       col = c("darkblue", "pink"),
       pch = c(19, NA),
       lty = c(1, 2),
       lwd = c(1.5, 3),
       bty = "n"
)

## ----fig.width=7, fig.height=6, fig.align='center'----------------------------
# Use a subset of test data for clearer visualization
subset_indices <- 1:20
testX_subset <- testX[subset_indices, ]
testY_subset <- testY[subset_indices]

# Get credible intervals (uncertainty in mean only)
cred_intervals <- predict(AModel, testX_subset,
                         type = "quantile",
                         interval = "credible",
                         quantiles = c(0.025, 0.975),
                         showProgress = FALSE)

# Get prediction intervals (includes error variance)
pred_intervals <- predict(AModel, testX_subset,
                         type = "quantile",
                         interval = "prediction",
                         quantiles = c(0.025, 0.975),
                         showProgress = FALSE)

# Get mean predictions
mean_preds <- predict(AModel, testX_subset,
                     type = "response",
                     showProgress = FALSE)

# Create comparison plot
plot(1:length(testY_subset), testY_subset,
     xlab = "Observation Index",
     ylab = "Response Value",
     main = "Credible Intervals vs. Prediction Intervals",
     pch = 19, col = "black", cex = 1.2, ylim = c(-5,20))

# Add mean predictions
points(1:length(testY_subset), mean_preds, pch = 4, col = "blue", cex = 1)

# Add credible intervals (narrower, in blue)
for (i in 1:length(testY_subset)) {
  segments(i, cred_intervals[i, 1],
           i, cred_intervals[i, 2],
           col = "blue", lwd = 2)
}

# Add prediction intervals (wider, in red)
for (i in 1:length(testY_subset)) {
  segments(i - 0.1, pred_intervals[i, 1],
           i - 0.1, pred_intervals[i, 2],
           col = "red", lwd = 2, lty = 1)
}

# Add legend
legend("topright",
       legend = c("Observed Values", "Mean Prediction",
                  "95% Credible Interval", "95% Prediction Interval"),
       col = c("black", "blue", "blue", "red"),
       pch = c(19, 4, NA, NA),
       lty = c(NA, NA, 1, 1),
       lwd = c(NA, NA, 2, 2),
       bty = "n")

## -----------------------------------------------------------------------------
# Calculate average interval widths
cred_width <- mean(cred_intervals[, 2] - cred_intervals[, 1])
pred_width <- mean(pred_intervals[, 2] - pred_intervals[, 1])

cat("Average 95% credible interval width:", round(cred_width, 2), "\n")
cat("Average 95% prediction interval width:", round(pred_width, 2), "\n")
cat("Ratio (prediction/credible):", round(pred_width / cred_width, 2), "\n")

