## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(echo = TRUE)
knitr::opts_chunk$set(eval = FALSE)

## -----------------------------------------------------------------------------
# set.seed(123)
# train_idx <- sample(nrow(iris), nrow(iris) * 2/3)
# 
# iris_train <- iris[train_idx,]
# iris_validation <- iris[-train_idx,]
# iris_sample <- iris_train %>%
#   head(10)
# 
# write.csv(iris_train, "iris_train.csv", row.names = FALSE)
# write.csv(iris_validation, "iris_validation.csv", row.names = FALSE)
# write.csv(iris_sample, "iris_sample.csv", row.names = FALSE)

## -----------------------------------------------------------------------------
# library(tfestimators)
# response <- "Species"
# features <- setdiff(names(iris), response)
# feature_columns <- feature_columns(
#   column_numeric(features)
# )
# 
# classifier <- dnn_classifier(
#   feature_columns = feature_columns,
#   hidden_units = c(16, 32, 16),
#   n_classes = 3,
#   label_vocabulary = c("setosa", "virginica", "versicolor")
# )

## -----------------------------------------------------------------------------
# iris_input_fn <- function(data) {
#   input_fn(data, features = features, response = response)
# }
# 
# iris_spec <- csv_record_spec("iris_sample.csv")
# iris_train <- text_line_dataset(
#   "iris_train.csv", record_spec = iris_spec) %>%
#   dataset_batch(10) %>%
#   dataset_repeat(10)
# iris_validation <- text_line_dataset(
#   "iris_validation.csv", record_spec = iris_spec) %>%
#   dataset_batch(10) %>%
#   dataset_repeat(1)

## -----------------------------------------------------------------------------
# history <- train(classifier, input_fn = iris_input_fn(iris_train))
# plot(history)
# predictions <- predict(classifier, input_fn = iris_input_fn(iris_validation))
# predictions
# evaluation <- evaluate(classifier, input_fn = iris_input_fn(iris_validation))
# evaluation

