## ----setup2, message = FALSE, warning = FALSE, results = 'hide'---------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
# Load packages
library(baselinenowcast)
library(dplyr)
library(tidyr)
library(stringr)
library(lubridate)
library(ggplot2)
library(purrr)

## -----------------------------------------------------------------------------
syn_nssp_line_list

## -----------------------------------------------------------------------------
diagnoses_codes_defn <- c("A22.1", "A221", "A37", "A48.1", "A481", "B25.0", "B250", "B34.2", "B34.9", "B342", "B349", "B44.0", "B44.9", "B440", "B449", "B44.81", "B4481", "B97.2", "B97.4", "B972", "B974", "J00", "J01", "J02", "J03", "J04", "J05", "J06", "J09", "J10", "J11", "J12", "J13", "J14", "J15", "J16", "J17", "J18", "J20", "J21", "J22", "J39.8", "J398", "J40", "J47.9", "J479", "J80", "J85.1", "J851", "J95.821", "J95821", "J96.0", "J96.00", "J9600", "J96.01", "J9601", "J96.02", "J9602", "J96.2", "J960", "J962", "J96.20", "J9620", "J96.21", "J9621", "J9622", "J96.22", "J96.91", "J9691", "J98.8", "J988", "R05", "R06.03", "R0603", "R09.02", "R0902", "R09.2", "R092", "R43.0", "R43.1", "R43.2", "R430", "R431", "R432", "U07.1", "U07.2", "U071", "U072", "022.1", "0221", "034.0", "0340", "041.5", "0415", "041.81", "04181", "079.1", "079.2", "079.3", "079.6", "0791", "0792", "0793", "0796", "079.82", "079.89", "07982", "07989", "079.99", "07999", "117.3", "1173", "460", "461", "462", "463", "464", "465", "466", "461.", "461", "461.", "464.", "465.", "466.", "461", "464", "465", "466", "478.9", "4789", "480.", "482.", "483.", "484.", "487.", "488.", "480", "481", "482", "483", "484", "485", "486", "487", "488", "490", "494.1", "4941", "517.1", "5171", "518.51", "518.53", "51851", "51853", "518.6", "5186", "518.81", "518.82", "518.84", "51881", "51882", "51884", "519.8", "5198", "073.0", "0730", "781.1", "7811", "786.2", "7862", "799.02", "79902", "799.1", "7991", "033", "033.", "033", "780.60", "78060") # nolint

## -----------------------------------------------------------------------------
expand_events <- function(line_list, event_col_name) {
  wide_line_list <- separate_wider_delim(line_list,
    {{ event_col_name }},
    delim = "{", names_sep = "", too_few = "align_start"
  )
  return(wide_line_list)
}

## -----------------------------------------------------------------------------
syn_nssp_time_stamps_wide <- expand_events(
  line_list = syn_nssp_line_list,
  event_col_name = "DischargeDiagnosisMDTUpdates"
) |>
  select(-DischargeDiagnosisUpdates)

syn_nssp_diagnoses_wide <- expand_events(
  line_list = syn_nssp_line_list,
  event_col_name = "DischargeDiagnosisUpdates"
) |>
  select(-DischargeDiagnosisMDTUpdates)

## -----------------------------------------------------------------------------
wide_to_long <- function(wide_line_list,
                         event_col_name,
                         values_to,
                         names_to,
                         id_col_name) {
  long_data <- wide_line_list |>
    pivot_longer(
      cols = starts_with({{ event_col_name }}),
      names_to = {{ names_to }},
      values_to = {{ values_to }},
      values_drop_na = FALSE
    ) |>
    mutate(
      event_id = paste(
        .data[[id_col_name]],
        as.numeric(str_extract(as.character(.data[[names_to]]), "[0-9.]+"))
      )
    )
  return(long_data)
}

## -----------------------------------------------------------------------------
syn_nssp_time_stamps_long <- wide_to_long(
  wide_line_list = syn_nssp_time_stamps_wide,
  event_col_name = "DischargeDiagnosisMDTUpdates",
  values_to = "time_stamp",
  names_to = "column_name",
  id_col_name = "C_Processed_BioSense_ID"
)

syn_nssp_diagnoses_long <- wide_to_long(
  wide_line_list = syn_nssp_diagnoses_wide,
  event_col_name = "DischargeDiagnosisUpdates",
  values_to = "diagnoses_codes",
  names_to = "column_name",
  id_col_name = "C_Processed_BioSense_ID"
)

## -----------------------------------------------------------------------------
syn_nssp_time_stamps <-
  syn_nssp_time_stamps_long |>
  mutate(
    time_stamp = as.POSIXct(
      str_remove_all(
        str_remove(time_stamp, ".*\\}"),
        "[|;]+"
      ),
      format = "%Y-%m-%d %H:%M:%S",
      tz = "UTC"
    ),
    C_Visit_Date_Time = as.POSIXct(C_Visit_Date_Time)
  ) |>
  drop_na(time_stamp)

## -----------------------------------------------------------------------------
syn_nssp_diagnoses <-
  syn_nssp_diagnoses_long |>
  mutate(diagnoses_codes = str_remove(diagnoses_codes, ".*\\}")) |>
  filter(nzchar(diagnoses_codes)) |>
  drop_na() |>
  select(event_id, diagnoses_codes)

## -----------------------------------------------------------------------------
nssp_merged <- merge(syn_nssp_time_stamps,
  syn_nssp_diagnoses,
  by = "event_id"
) |>
  filter(diagnoses_codes != ";;|")

## -----------------------------------------------------------------------------
nssp_updates <- nssp_merged |>
  mutate(arrival_to_update_delay = as.numeric(difftime(
    time_stamp, C_Visit_Date_Time,
    units = "days"
  )))

## -----------------------------------------------------------------------------
bar_updates <- nssp_updates |>
  filter(map_lgl(diagnoses_codes, ~ any(str_detect(.x, diagnoses_codes_defn))))

## -----------------------------------------------------------------------------
first_bar_diagnosis <- bar_updates |>
  arrange(arrival_to_update_delay) |>
  group_by(C_Processed_BioSense_ID) |>
  slice(1)

## -----------------------------------------------------------------------------
clean_line_list <- first_bar_diagnosis |>
  mutate(
    reference_date = as.Date(C_Visit_Date_Time),
    report_date = as.Date(time_stamp)
  ) |>
  ungroup()
head(clean_line_list)

## -----------------------------------------------------------------------------
count_df_raw <- clean_line_list |>
  group_by(reference_date, report_date) |>
  summarise(count = n()) |>
  mutate(delay = as.integer(report_date - reference_date))

## -----------------------------------------------------------------------------
count_df <- filter(count_df_raw, delay >= 0)
head(count_df)

## -----------------------------------------------------------------------------
syn_nssp_df

## -----------------------------------------------------------------------------
long_df <- syn_nssp_df |>
  mutate(delay = as.integer(report_date - reference_date))

delay_df_t <- long_df |>
  group_by(reference_date) |>
  summarise(mean_delay = sum(count * delay) / sum(count))

delay_summary <- long_df |>
  mutate(mean_delay_overall = sum(count * delay) / sum(count))

avg_delays <- long_df |>
  group_by(delay) |>
  summarise(pmf = sum(count) / sum(long_df$count)) |>
  mutate(cdf = cumsum(pmf))

delay_t <- ggplot(delay_df_t) +
  geom_line(aes(
    x = reference_date,
    y = mean_delay
  )) +
  geom_line(
    data = delay_summary,
    aes(
      x = reference_date,
      y = mean_delay_overall
    ),
    linetype = "dashed"
  ) +
  xlab("") +
  ylab("Mean delay") +
  theme_bw()

cdf_delay <- ggplot(avg_delays) +
  geom_line(aes(x = delay, y = cdf)) +
  geom_hline(aes(yintercept = 0.95), linetype = "dashed") +
  theme_bw()

## -----------------------------------------------------------------------------
cdf_delay
delay_t

## -----------------------------------------------------------------------------
max_delay <- 25
nowcast_date <- max(long_df$reference_date) - days(30)

## -----------------------------------------------------------------------------
training_df <- filter(
  long_df,
  report_date <= nowcast_date
)

## -----------------------------------------------------------------------------
training_df_by_ref_date <- training_df |>
  filter(report_date <= nowcast_date) |>
  group_by(reference_date) |>
  summarise(initial_count = sum(count))

## -----------------------------------------------------------------------------
init_data <- training_df_by_ref_date |>
  filter(reference_date >= nowcast_date - days(60))

plot_inits <- ggplot(init_data) +
  geom_line(aes(x = reference_date, y = initial_count), color = "darkred") +
  theme_bw() +
  ylab("Initially reported BAR cases") +
  xlab("Date of ED visit")

## -----------------------------------------------------------------------------
plot_inits

## -----------------------------------------------------------------------------
rep_tri_full <- as_reporting_triangle(training_df)

## -----------------------------------------------------------------------------
rep_tri_full

## -----------------------------------------------------------------------------
summary(rep_tri_full)

## -----------------------------------------------------------------------------
rep_tri <- truncate_to_delay(rep_tri_full, max_delay = max_delay)

## -----------------------------------------------------------------------------
rep_tri

## -----------------------------------------------------------------------------
scale_factor <- 3
prop_delay <- 0.5

## -----------------------------------------------------------------------------
nowcast_draws_df <- baselinenowcast(rep_tri,
  scale_factor = scale_factor,
  prop_delay = prop_delay,
  draws = 1000
)

head(nowcast_draws_df)

## -----------------------------------------------------------------------------
nowcast_summary_df <-
  nowcast_draws_df |>
  group_by(reference_date) |>
  summarise(
    median = median(pred_count),
    q50th_lb = quantile(pred_count, 0.25),
    q50th_ub = quantile(pred_count, 0.75),
    q95th_lb = quantile(pred_count, 0.025),
    q95th_ub = quantile(pred_count, 0.975)
  )

## -----------------------------------------------------------------------------
eval_data <- long_df |>
  filter(
    delay <= max_delay,
    reference_date <= nowcast_date
  ) |>
  group_by(reference_date) |>
  summarise(final_count = sum(count))

## -----------------------------------------------------------------------------
nowcast_w_data <- nowcast_summary_df |>
  left_join(training_df_by_ref_date,
    by = "reference_date"
  ) |>
  left_join(eval_data,
    by = "reference_date"
  )
head(nowcast_w_data)

## -----------------------------------------------------------------------------
combined_data <- nowcast_w_data |>
  select(reference_date, initial_count, final_count) |>
  distinct() |>
  pivot_longer(
    cols = c(initial_count, final_count),
    names_to = "type",
    values_to = "count"
  ) |>
  mutate(type = case_when(
    type == "initial_count" ~ "Initially observed data",
    type == "final_count" ~ "Final observed data"
  )) |>
  filter(reference_date >= nowcast_date - days(60))

nowcast_data_recent <- nowcast_w_data |>
  filter(reference_date >= nowcast_date - days(60))

plot_prob_nowcast <- ggplot(nowcast_data_recent) +
  geom_line(
    aes(
      x = reference_date, y = median
    ),
    color = "gray"
  ) +
  geom_ribbon(
    aes(
      x = reference_date,
      ymin = q50th_lb, ymax = q50th_ub
    ),
    alpha = 0.5,
    fill = "gray"
  ) +
  geom_ribbon(
    aes(
      x = reference_date,
      ymin = q95th_lb, ymax = q95th_ub
    ),
    alpha = 0.5,
    fill = "gray"
  ) +
  # Add observed data and final data once
  geom_line(
    data = combined_data,
    aes(
      x = reference_date,
      y = count,
      color = type
    )
  ) +
  theme_bw() +
  scale_color_manual(
    values = c(
      "Initially observed data" = "darkred",
      "Final observed data" = "black"
    ),
    name = ""
  ) +
  xlab("Date of ED visit") +
  ylab("Number of BAR cases") +
  theme(legend.position = "bottom") +
  ggtitle("Comparison of cases of BAR as of the nowcast date, later observed,\n and generated as a probabilistic nowcast") # nolint

## -----------------------------------------------------------------------------
plot_prob_nowcast

