## ----setup, output=FALSE, warning=FALSE, message=FALSE------------------------
library(serosv)
library(dplyr)
library(magrittr)

## ----warning=FALSE------------------------------------------------------------
data <- parvob19_fi_1997_1998[order(parvob19_fi_1997_1998$age), ] %>% 
  rename(status = seropositive) 

aggregated <- transform_data(data, stratum_col = "age", status_col="status")

## ----message=FALSE------------------------------------------------------------
# ----- Return AIC, BIC ----- 
aic_bic_out <- compare_models(
  data = data,
  method = "AIC/BIC",
  griffith = ~polynomial_model(.x, k=2),
  penalized_spline = penalized_spline_model,
  farrington = ~farrington_model(.x, start=list(alpha=0.07,beta=0.1,gamma=0.03)),
  local_polynomial = lp_model # expect to not return any values
  ) %>% suppressWarnings()

# ----- Return Cross-validation metrics ----- 
cv_out <- compare_models(
    data,
    method = "CV",
    griffith = ~polynomial_model(.x, k=2),
    penalized_spline = penalized_spline_model,
    farrington = ~farrington_model(.x, start=list(alpha=0.07,beta=0.1,gamma=0.03)),
    local_polynomial = lp_model 
  ) %>% suppressWarnings()

aic_bic_out
cv_out

## ----message=FALSE------------------------------------------------------------
# ----- Return AIC, BIC ----- 
aic_bic_out <- compare_models(
  data = aggregated,
  method = "AIC/BIC",
  griffith = ~polynomial_model(.x, k=2),
  penalized_spline = penalized_spline_model,
  farrington = ~farrington_model(.x, start=list(alpha=0.07,beta=0.1,gamma=0.03)),
  local_polynomial = lp_model # expect to not return any values
  ) %>% suppressWarnings()

# ----- Return Cross-validation metrics (default with 4 folds) ----- 
cv_out <- compare_models(
    data = aggregated,
    method = "CV",
    griffith = ~polynomial_model(.x, k=2),
    penalized_spline = penalized_spline_model,
    farrington = ~farrington_model(.x, start=list(alpha=0.07,beta=0.1,gamma=0.03)),
    local_polynomial = lp_model 
  ) %>% suppressWarnings()

aic_bic_out
cv_out

## -----------------------------------------------------------------------------
generate_mae <- function(dat, model_func){
  n_train <- round(nrow(dat)*0.8)
  train <- dat[1:n_train,]
  test <- dat[n_train:nrow(dat),]
  
  fit <- model_func(dat)
  pred <- predict(fit, test[, 1, drop=FALSE])
  
  # handle error differently depending on datatype
  mae <- if(fit$datatype == "linelisting"){
    sum(abs(test$status - pred), na.rm=TRUE)/nrow(test)
  }else{
    sum(abs(test$pos/test$tot - pred), na.rm=TRUE)/nrow(test)
  }
  
  data.frame(
    mae = mae
  )
}

## -----------------------------------------------------------------------------
compare_models(
    data = aggregated,
    method = generate_mae,
    griffith = ~polynomial_model(.x, k=2),
    penalized_spline = penalized_spline_model,
    farrington = ~farrington_model(.x, start=list(alpha=0.07,beta=0.1,gamma=0.03)),
    local_polynomial = lp_model 
  ) %>% suppressWarnings()

compare_models(
    data = data,
    method = generate_mae,
    griffith = ~polynomial_model(.x, k=2),
    penalized_spline = penalized_spline_model,
    farrington = ~farrington_model(.x, start=list(alpha=0.07,beta=0.1,gamma=0.03)),
    local_polynomial = lp_model 
  ) %>% suppressWarnings()

