## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment  = "#>",
  message  = FALSE
)

## ----setup, warning = FALSE, message = FALSE----------------------------------
library(baserater)
library(tidyverse)
library(knitr)

## -----------------------------------------------------------------------------
# Load the base-rate database
database <- download_data("database")

# Load the typicality validation ratings
ratings <- download_data("validation_ratings")

# Load the typicality matrices
gpt4_matrix <- download_data("typicality_matrix_gpt4")
llama3_3_matrix <- download_data("typicality_matrix_llama3.3")

# Load the group-adjective pairs
material <- download_data("material")

## ----message=FALSE, warning = FALSE, echo = FALSE-----------------------------
# Load pre-generated scores
new_scores <- readRDS(system.file("extdata", "new_typicality_scores_llama3.1_8B.rds", package = "baserater"))

## ----eval = FALSE-------------------------------------------------------------
# # Original prompt from the paper
# original_system_prompt_content <- "You are expert at accurately reproducing the stereotypical associations humans make, in order to annotate data for experiments. Your focus is to capture common societal perceptions and stereotypes, rather than factual attributes of the groups, even when they are negative or unfounded."
# 
# original_user_prompt_content_template <- "Rate how well the adjective '{description}' reflects the prototypical member of the group '{group}' on a scale from 0 ('Not at all') to 100 ('Extremely').
# 
# To clarify, consider the following examples:
# 
# 1. 'Rate how well the adjective FUNNY reflects the prototypical member of the group CLOWN on a scale from 0 (Not at all) to 100 (Extremely).' A high rating is expected because the adjective 'FUNNY' closely aligns with the typical characteristics of a 'CLOWN'.
# 
# 2. 'Rate how well the adjective FEARFUL reflects the prototypical member of the group FIREFIGHTER on a scale from 0 (Not at all) to 100 (Extremely).' A low rating is expected because the adjective 'FEARFUL' diverges significantly from the typical characteristics of a 'FIREFIGHTER'.
# 
# 3. 'Rate how well the adjective PATIENT reflects the prototypical member of the group ENGINEER on a scale from 0 (Not at all) to 100 (Extremely).' A mid-scale rating is expected because the adjective 'PATIENT' neither closely aligns nor diverges significantly from the typical characteristics of an 'ENGINEER'.
# 
# Your response should be a single score between 0 and 100, with no additional text, letters, or symbols included."
# 
# 
# # Example using the validation ratings
# groups <- ratings$group
# descriptions <- ratings$adjective
# 
# api_token <- Sys.getenv("PROVIDER_API_TOKEN")
# 
# new_scores <- generate_typicality(
#   groups                = groups,
#   descriptions          = descriptions,
#   api_url               = "https://api.together.xyz/v1/chat/completions", # example for 'Together AI' API
#   api_token             =  api_token,
#   model                 = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", #  model name on Together AI
#   n                     = 3,  # number of responses to generate
#   min_valid             = 2,  # minimum number of valid responses; mean of valid ones is used
#   max_tokens            = 3,  # numeric output between 0 and 100
#   retries               = 2,  # number of retries in case of API errors
#   matrix                = FALSE,
#   return_raw_scores     = TRUE,
#   return_full_responses = TRUE,
#   verbose               = TRUE)

## ----warning = FALSE----------------------------------------------------------
knitr::kable(head(new_scores))

## -----------------------------------------------------------------------------
# Distribution of new typicality ratings
ggplot(new_scores, aes(x = mean_score)) +
  geom_histogram(binwidth = 5, fill = "steelblue", color = "white") +
  labs(
    title = "Distribution of Typicality Ratings from 'LLaMA 3.1-8B-Instruct'",
    x = "Typicality Rating",
    y = "Count"
  ) +
  theme_classic()

## ----warning = FALSE----------------------------------------------------------
# Create a data frame with the same structure as the validation set
new_scores = new_scores %>% 
  mutate(adjective = description,
         rating = mean_score) %>%
  select(group, adjective, rating)

knitr::kable(head(new_scores))

## -----------------------------------------------------------------------------
# Join human and model scores
comparison_df <- left_join(
  ratings %>% select(group, adjective, human = mean_human_rating),
  new_scores,
  by = c("group", "adjective")
)

# Scatterplot
ggplot(comparison_df, aes(x = rating, y = human)) +
  geom_point(alpha = 0.6) +
  geom_smooth(method = "lm", se = FALSE, color = "darkred") +
  labs(
    title = "Scatterplot of 'LLaMA 3.1' and Human Typicality Ratings",
    y = "Average Human Rating",
    x = "Average 'LLaMA 3.1' Rating"
  ) +
  theme_classic()

## -----------------------------------------------------------------------------
# Print correlation summary with human ground truth and baselines
knitr::kable(evaluate_external_ratings(new_scores))
# Optionally store the output in a variable
results <- evaluate_external_ratings(new_scores)

## ----warning = FALSE----------------------------------------------------------
#' The typicality matrix from 'GPT-4' is a data frame with group–adjective pairs and their typicality scores
gpt4_matrix <- download_data("typicality_matrix_gpt4")
knitr::kable(head(gpt4_matrix))

## -----------------------------------------------------------------------------
#' Extract base-rate items from the typicality matrix
base_rate_items <- extract_base_rate_items(gpt4_matrix)

## ----warning = FALSE----------------------------------------------------------
# View top base-rate items by stereotype strength
knitr::kable(base_rate_items %>%
  arrange(desc(StereotypeStrength)) %>%
  head(10))

## -----------------------------------------------------------------------------
# Pick one adjective and extract group typicality scores
df <- gpt4_matrix %>%
  select(group, selfish) %>% 
  rename(score = selfish) %>%
  arrange(desc(score))  # sort by how typical the group is

# Save group names and their scores
group_order <- df$group
typ_values  <- df$score
names(typ_values) <- df$group

# Build all group pairs and compute log-ratios
res_df <- expand.grid(
  g1 = group_order,
  g2 = group_order,
  KEEP.OUT.ATTRS = FALSE
) %>%
  mutate(
    typ1 = typ_values[as.character(g1)],
    typ2 = typ_values[as.character(g2)],
    log_ratio = log(pmax(typ1, 1e-9) / pmax(typ2, 1e-9))
  ) %>%
  mutate(
    g1 = factor(g1, levels = group_order),
    g2 = factor(g2, levels = group_order)
  )

# Keep only pairs where g1 is ranked higher than g2 and log-ratio is positive
res_df <- res_df %>%
  filter(as.integer(g1) < as.integer(g2), log_ratio > 0)

# Add identity pairs (g1 == g2) with NA
diag_df <- tibble(
  g1 = factor(group_order, levels = group_order),
  g2 = factor(group_order, levels = group_order),
  log_ratio = NA_real_
)

# Combine with original filtered upper-triangle pairs
res_df <- bind_rows(res_df, diag_df)

# Plot the heatmap
ggplot(res_df, aes(x = g2, y = g1, fill = log_ratio)) +
  geom_tile(color = "white") +
  scale_fill_gradient2(
    low       = "steelblue",
    mid       = "white",
    high      = "firebrick",
    midpoint  = 0,
    na.value  = "grey90",
    name      = "Log(Group 1 / Group 2)",
  ) +
  labs(
    title = "Stereotype Strength for Adjective 'Selfish'",
    x = "Group 2",
    y = "Group 1"
  ) +
  theme_classic() +
  theme(
    axis.text.y = element_text(angle = 15, hjust = 1, size = 5),
    axis.text.x = element_text(angle = 45, hjust = 1, size = 5),
    panel.grid  = element_blank()
  )

