Become an expert in R — Interactive courses, Cheat Sheets, certificates and more!
Get Started for Free

ml-tuning

Spark ML – Tuning


Description

Perform hyper-parameter tuning using either K-fold cross validation or train-validation split.

Usage

ml_sub_models(model)

ml_validation_metrics(model)

ml_cross_validator(
  x,
  estimator = NULL,
  estimator_param_maps = NULL,
  evaluator = NULL,
  num_folds = 3,
  collect_sub_models = FALSE,
  parallelism = 1,
  seed = NULL,
  uid = random_string("cross_validator_"),
  ...
)

ml_train_validation_split(
  x,
  estimator = NULL,
  estimator_param_maps = NULL,
  evaluator = NULL,
  train_ratio = 0.75,
  collect_sub_models = FALSE,
  parallelism = 1,
  seed = NULL,
  uid = random_string("train_validation_split_"),
  ...
)

Arguments

model

A cross validation or train-validation-split model.

x

A spark_connection, ml_pipeline, or a tbl_spark.

estimator

A ml_estimator object.

estimator_param_maps

A named list of stages and hyper-parameter sets to tune. See details.

evaluator

A ml_evaluator object, see ml_evaluator.

num_folds

Number of folds for cross validation. Must be >= 2. Default: 3

collect_sub_models

Whether to collect a list of sub-models trained during tuning. If set to FALSE, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.

parallelism

The number of threads to use when running parallel algorithms. Default is 1 for serial execution.

seed

A random seed. Set this value if you need your results to be reproducible across repeated calls.

uid

A character string used to uniquely identify the ML estimator.

...

Optional arguments; currently unused.

train_ratio

Ratio between train and validation data. Must be between 0 and 1. Default: 0.75

Details

ml_cross_validator() performs k-fold cross validation while ml_train_validation_split() performs tuning on one pair of train and validation datasets.

Value

The object returned depends on the class of x.

  • spark_connection: When x is a spark_connection, the function returns an instance of a ml_cross_validator or ml_traing_validation_split object.

  • ml_pipeline: When x is a ml_pipeline, the function returns a ml_pipeline with the tuning estimator appended to the pipeline.

  • tbl_spark: When x is a tbl_spark, a tuning estimator is constructed then immediately fit with the input tbl_spark, returning a ml_cross_validation_model or a ml_train_validation_split_model object.

For cross validation, ml_sub_models() returns a nested list of models, where the first layer represents fold indices and the second layer represents param maps. For train-validation split, ml_sub_models() returns a list of models, corresponding to the order of the estimator param maps.

ml_validation_metrics() returns a data frame of performance metrics and hyperparameter combinations.

Examples

## Not run: 
sc <- spark_connect(master = "local")
iris_tbl <- sdf_copy_to(sc, iris, name = "iris_tbl", overwrite = TRUE)

# Create a pipeline
pipeline <- ml_pipeline(sc) %>%
  ft_r_formula(Species ~ .) %>%
  ml_random_forest_classifier()

# Specify hyperparameter grid
grid <- list(
  random_forest = list(
    num_trees = c(5, 10),
    max_depth = c(5, 10),
    impurity = c("entropy", "gini")
  )
)

# Create the cross validator object
cv <- ml_cross_validator(
  sc,
  estimator = pipeline, estimator_param_maps = grid,
  evaluator = ml_multiclass_classification_evaluator(sc),
  num_folds = 3,
  parallelism = 4
)

# Train the models
cv_model <- ml_fit(cv, iris_tbl)

# Print the metrics
ml_validation_metrics(cv_model)

## End(Not run)

sparklyr

R Interface to Apache Spark

v1.6.2
Apache License 2.0 | file LICENSE
Authors
Javier Luraschi [aut], Kevin Kuo [aut] (<https://orcid.org/0000-0001-7803-7901>), Kevin Ushey [aut], JJ Allaire [aut], Samuel Macedo [ctb], Hossein Falaki [aut], Lu Wang [aut], Andy Zhang [aut], Yitao Li [aut, cre] (<https://orcid.org/0000-0002-1261-905X>), Jozef Hajnala [ctb], Maciej Szymkiewicz [ctb] (<https://orcid.org/0000-0003-1469-9396>), Wil Davis [ctb], RStudio [cph], The Apache Software Foundation [aut, cph]
Initial release

We don't support your browser anymore

Please choose more modern alternatives, such as Google Chrome or Mozilla Firefox.