---
title: "Dataset API"
output: 
  rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Dataset API}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
type: docs
repo: https://github.com/rstudio/tfestimators
menu:
  main:
    name: "Dataset API"
    identifier: "tfestimators-dataset-api"
    parent: "tfestimators-advanced"
    weight: 40
---


```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
knitr::opts_chunk$set(eval = FALSE)
```


## Overview

We can access the TensorFlow Dataset API via the [tfdatasets](https://tensorflow.rstudio.com/tools/tfdatasets/) package, which enables us to create scalable input pipelines that can be used with **tfestimators**. In this vignette, we demonstrate the capability to stream datasets stored on disk for training by building a classifier on the `iris` dataset.

## Dataset Preparation

Let's assume we're given a dataset (which could be arbitrarily large) split into training and validation, and a small sample of the dataset. To simulate this scenario, we'll create a few CSV files as follows:

```{r}
set.seed(123)
train_idx <- sample(nrow(iris), nrow(iris) * 2/3)

iris_train <- iris[train_idx,]
iris_validation <- iris[-train_idx,]
iris_sample <- iris_train %>%
  head(10)

write.csv(iris_train, "iris_train.csv", row.names = FALSE)
write.csv(iris_validation, "iris_validation.csv", row.names = FALSE)
write.csv(iris_sample, "iris_sample.csv", row.names = FALSE)
```

## Estimator Construction

We construct the classifier as usual -- see [Estimator Basics](https://github.com/rstudio/tfestimators/blob/main/vignettes/estimator_basics.Rmd) for details on feature columns and creating estimators.

```{r}
library(tfestimators)
response <- "Species"
features <- setdiff(names(iris), response)
feature_columns <- feature_columns(
  column_numeric(features)
)

classifier <- dnn_classifier(
  feature_columns = feature_columns,
  hidden_units = c(16, 32, 16),
  n_classes = 3,
  label_vocabulary = c("setosa", "virginica", "versicolor")
)
```

## Input Function

The creation of the input function is similar to the [in-memory case](https://github.com/rstudio/tfestimators/blob/main/vignettes/estimator_basics.Rmd). However, instead of passing data frames or matrices to `iris_input_fn()`, we pass TensorFlow dataset objects which are internally iterators of the dataset files.

```{r}
iris_input_fn <- function(data) {
  input_fn(data, features = features, response = response)
}

iris_spec <- csv_record_spec("iris_sample.csv")
iris_train <- text_line_dataset(
  "iris_train.csv", record_spec = iris_spec) %>%
  dataset_batch(10) %>% 
  dataset_repeat(10)
iris_validation <- text_line_dataset(
  "iris_validation.csv", record_spec = iris_spec) %>%
  dataset_batch(10) %>%
  dataset_repeat(1)
```

The `csv_record_spec()` function is a helper function that creates a specification from a sample file; the returned specification is required by the `text_line_dataset()` function to parse the files. There are many [transformations](https://tensorflow.rstudio.com/guide/tfdatasets/introduction/#transformations) available for dataset objects, but here we just demonstrate `dataset_batch()` and `dataset_repeat()` which control the batch size and how many times we iterate through the dataset files, respectively.

## Training and Evaluation

Once the input functions and datasets are defined, the training and evaluation interface is exactly the same as in the in-memory case.

```{r}
history <- train(classifier, input_fn = iris_input_fn(iris_train))
plot(history)
predictions <- predict(classifier, input_fn = iris_input_fn(iris_validation))
predictions
evaluation <- evaluate(classifier, input_fn = iris_input_fn(iris_validation))
evaluation
```

## Learning More

See the documetnation for the [tfdatasets](https://tensorflow.rstudio.com/tools/tfdatasets/) package for additional details on using TensorFlow datasets.



