## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(collapse = TRUE, comment = "#>")

if (!requireNamespace("bigmemory", quietly = TRUE)) {
  cat("This vignette requires the 'bigmemory' package.\n")
  knitr::knit_exit()
}

library(bigKNN)
library(bigmemory)

## ----helpers, include=FALSE---------------------------------------------------
knn_table <- function(result, query_ids, ref_ids) {
  do.call(rbind, lapply(seq_along(query_ids), function(i) {
    data.frame(
      query = query_ids[i],
      rank = seq_len(result$k),
      neighbor = ref_ids[result$index[i, ]],
      distance = signif(result$distance[i, ], 5),
      row.names = NULL
    )
  }))
}

## ----create-reference---------------------------------------------------------
scratch_dir <- file.path(tempdir(), "bigknn-prepared-search")
dir.create(scratch_dir, recursive = TRUE, showWarnings = FALSE)

reference_points <- data.frame(
  id = paste0("r", 1:8),
  x1 = c(1, 1, 2, 2, 3, 3, 4, 4),
  x2 = c(1, 2, 1, 2, 2, 3, 3, 4),
  x3 = c(0.5, 0.5, 1.0, 1.0, 1.5, 1.5, 2.0, 2.5)
)

reference <- filebacked.big.matrix(
  nrow = nrow(reference_points),
  ncol = 3,
  type = "double",
  backingfile = "reference.bin",
  descriptorfile = "reference.desc",
  backingpath = scratch_dir
)

reference[,] <- as.matrix(reference_points[c("x1", "x2", "x3")])

query_batch_a <- matrix(
  c(1.1, 1.2, 0.5,
    2.7, 2.2, 1.4),
  ncol = 3,
  byrow = TRUE
)

query_batch_b <- matrix(
  c(3.6, 3.1, 1.9,
    1.5, 1.8, 0.8),
  ncol = 3,
  byrow = TRUE
)

query_ids_a <- c("a1", "a2")
query_ids_b <- c("b1", "b2")

reference_points

## ----prepare-reference--------------------------------------------------------
prepared <- knn_prepare_bigmatrix(reference, metric = "cosine")
prepared

## ----prepared-summary---------------------------------------------------------
summary(prepared)
length(prepared$row_cache)
head(prepared$row_cache, 4)

## ----prepared-search----------------------------------------------------------
batch_a_result <- knn_search_prepared(
  prepared,
  query = query_batch_a,
  k = 2,
  exclude_self = FALSE
)

batch_b_result <- knn_search_prepared(
  prepared,
  query = query_batch_b,
  k = 2,
  exclude_self = FALSE
)

batch_a_result
knn_table(batch_a_result, query_ids = query_ids_a, ref_ids = reference_points$id)
knn_table(batch_b_result, query_ids = query_ids_b, ref_ids = reference_points$id)

## ----prepared-vs-direct-------------------------------------------------------
direct_batch_a <- knn_bigmatrix(
  reference,
  query = query_batch_a,
  k = 2,
  metric = "cosine",
  exclude_self = FALSE
)

identical(batch_a_result$index, direct_batch_a$index)
all.equal(batch_a_result$distance, direct_batch_a$distance)

## ----prepared-stream----------------------------------------------------------
index_store <- big.matrix(nrow(query_batch_b), 2, type = "integer")
distance_store <- big.matrix(nrow(query_batch_b), 2, type = "double")

streamed_batch_b <- knn_search_stream_prepared(
  prepared,
  query = query_batch_b,
  xpIndex = index_store,
  xpDistance = distance_store,
  k = 2,
  exclude_self = FALSE
)

bigmemory::as.matrix(streamed_batch_b$index)
round(bigmemory::as.matrix(streamed_batch_b$distance), 6)
all.equal(bigmemory::as.matrix(streamed_batch_b$distance), batch_b_result$distance)

## ----persist-prepared---------------------------------------------------------
cache_path <- file.path(scratch_dir, "prepared-cosine-cache.rds")

prepared_cached <- knn_prepare_bigmatrix(
  reference,
  metric = "cosine",
  cache_path = cache_path
)

prepared_cached
file.exists(cache_path)

## ----load-prepared------------------------------------------------------------
loaded <- knn_load_prepared(cache_path)
loaded

## ----validate-prepared--------------------------------------------------------
isTRUE(knn_validate_prepared(loaded))

## ----loaded-search------------------------------------------------------------
loaded_batch_b <- knn_search_prepared(
  loaded,
  query = query_batch_b,
  k = 2,
  exclude_self = FALSE
)

identical(loaded_batch_b$index, batch_b_result$index)
all.equal(loaded_batch_b$distance, batch_b_result$distance)

