how can i reduce the size of a tuned workflow_set object?

I have a use case where I run a tune-models.R script as an RStudio Job overnight, which at the end writes the results to a folder on my laptop. The script structure is something like:

prep data -> create recipe -> create workflowset -> tune using workflow_map() -> package the results together with the splits -> use save.rds to save the results

by "package the results together with the splits ", i mean create a nested dataframe to save, like:

tune_results <- workflow_map(...)

workflowset_experiment <-
  tibble(workflows = list(tune_results),
         splits = list(splits),
         date_created = Sys.time()
)

saveRDS(workflowset_experiment, path)

I do this so i can have all i need to compare, evaluate and finalize the best model sometime later on in a new R session.

Since I usually only want to compare the best types of models to each other (i.e. plot the accuracy of best XGB, the best RF, the best KNN...), there are a lot of model/preprocessor combos in tune_results that i dont need. I have been trying to figure out how to remove these, while still keeping it a workflow_set object since all the nice helper functions only work with a workflow_set object.

my approach so far has been to search for an existing way to do this, but i haven't found a way so now I am trying to hack something together with no luck. Its very likley I am missing something that already does this, and/or taking the wrong approach to this, but I have a reprex below.

Is there tidymodels functionality that already does this? does anyone have advice on if this is posible and how to approach doing something like this? thanks!

reprex:

library(tidymodels)
library(stringr)
library(tidyverse)

data(parabolic)

# prep 
set.seed(1)
split <- initial_split(parabolic)
train_set <- training(split)
test_set <- testing(split)

set.seed(2)
train_resamples <- vfold_cv(train_set, v = 10)

logistic_reg_spec <- 
  logistic_reg(penalty = tune(),
               mixture = tune()) %>% 
  set_engine("glmnet")

rec <-
  recipe(class ~ ., data = train_set)

rec_norm <-
  recipe(class ~ ., data = train_set) %>%
  step_normalize(all_numeric_predictors())

workflow <- 
  workflow_set(
    preproc = list(rec = rec, rec_norm = rec_norm),
    models = list(lm = logistic_reg_spec), 
    cross = TRUE
  )

grid_ctrl <-
  control_grid(
    save_pred = TRUE,
    parallel_over = "everything",
    save_workflow = TRUE
  )

tune_results <-
  workflow %>%
  workflow_map(
    seed = 3566,
    verbose = FALSE,
    resamples = train_resamples,
    control = grid_ctrl,
    fn = "tune_grid", 
    grid = 10,
    metrics = yardstick::metric_set(accuracy)
  )

tune_results
#> # A workflow set/tibble: 2 × 4
#>   wflow_id    info             option    result   
#>   <chr>       <list>           <list>    <list>   
#> 1 rec_lm      <tibble [1 × 4]> <opts[4]> <tune[+]>
#> 2 rec_norm_lm <tibble [1 × 4]> <opts[4]> <tune[+]>

# these helper functions all work with the output of workflow_map():
tune_results %>% collect_metrics() # 20 rows (2 models * 10 folds * 1 metric):
#> # A tibble: 20 × 9
#>    wflow_id    .config      preproc model .metric .estimator  mean     n std_err
#>    <chr>       <chr>        <chr>   <chr> <chr>   <chr>      <dbl> <int>   <dbl>
#>  1 rec_lm      Preprocesso… recipe  logi… accura… binary     0.731    10  0.0188
#>  2 rec_lm      Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#>  3 rec_lm      Preprocesso… recipe  logi… accura… binary     0.440    10  0.0211
#>  4 rec_lm      Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#>  5 rec_lm      Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#>  6 rec_lm      Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#>  7 rec_lm      Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#>  8 rec_lm      Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#>  9 rec_lm      Preprocesso… recipe  logi… accura… binary     0.736    10  0.0171
#> 10 rec_lm      Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#> 11 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.731    10  0.0188
#> 12 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#> 13 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.440    10  0.0211
#> 14 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#> 15 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#> 16 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#> 17 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#> 18 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153
#> 19 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.736    10  0.0171
#> 20 rec_norm_lm Preprocesso… recipe  logi… accura… binary     0.736    10  0.0153


tune_results %>% rank_results(rank_metric = "accuracy", select_best = TRUE)
#> # A tibble: 2 × 9
#>   wflow_id    .config       .metric  mean std_err     n preprocessor model  rank
#>   <chr>       <chr>         <chr>   <dbl>   <dbl> <int> <chr>        <chr> <int>
#> 1 rec_lm      Preprocessor… accura… 0.736  0.0171    10 recipe       logi…     1
#> 2 rec_norm_lm Preprocessor… accura… 0.736  0.0171    10 recipe       logi…     2


# trying to keep only the best instance of each model: 
# use `rank_results()` to identify the best performing preprocesser/model combo 
top_performers <- 
  tune_results %>% 
  rank_results(rank_metric = "accuracy", select_best = TRUE) %>% 
  select(wflow_id, .config) %>% 
  distinct() 

# unnest tune_results, filter out 18 worst performers, put back together
best_models <- 
  tune_results %>% 
  unnest(result) %>%
  unnest(.metrics) %>%
  inner_join(., top_performers, by = c("wflow_id", ".config")) %>% 
  nest(metrics = id:.config) %>%
  nest(result = splits:metrics)

#looks good?
best_models
#> # A tibble: 2 × 4
#>   wflow_id    info             option    result           
#>   <chr>       <list>           <list>    <list>           
#> 1 rec_lm      <tibble [1 × 4]> <opts[4]> <tibble [10 × 4]>
#> 2 rec_norm_lm <tibble [1 × 4]> <opts[4]> <tibble [10 × 4]>

# but these helper functions no longer work
best_models %>% collect_metrics() 
#> Error in `collect_metrics()`:
#> ! No `collect_metric()` exists for this type of object.
#> Backtrace:
#>     ▆
#>  1. ├─best_models %>% collect_metrics()
#>  2. ├─tune::collect_metrics(.)
#>  3. └─tune:::collect_metrics.default(.)
#>  4.   └─rlang::abort("No `collect_metric()` exists for this type of object.")


best_models %>% rank_results(rank_metric = "accuracy", select_best = TRUE)
#> Error in `collect_metrics()`:
#> ! No `collect_metric()` exists for this type of object.
#> Backtrace:
#>     ▆
#>  1. ├─best_models %>% ...
#>  2. └─workflowsets::rank_results(., rank_metric = "accuracy", select_best = TRUE)
#>  3.   └─workflowsets:::pick_metric(x, rank_metric)
#>  4.     ├─tune::collect_metrics(x)
#>  5.     └─tune:::collect_metrics.default(x)
#>  6.       └─rlang::abort("No `collect_metric()` exists for this type of object.")

# best_models is a regular old tibble instead of a workflowset object 
tune_results
#> # A workflow set/tibble: 2 × 4
#>   wflow_id    info             option    result   
#>   <chr>       <list>           <list>    <list>   
#> 1 rec_lm      <tibble [1 × 4]> <opts[4]> <tune[+]>
#> 2 rec_norm_lm <tibble [1 × 4]> <opts[4]> <tune[+]>

best_models
#> # A tibble: 2 × 4
#>   wflow_id    info             option    result           
#>   <chr>       <list>           <list>    <list>           
#> 1 rec_lm      <tibble [1 × 4]> <opts[4]> <tibble [10 × 4]>
#> 2 rec_norm_lm <tibble [1 × 4]> <opts[4]> <tibble [10 × 4]>
Created on 2023-11-17 with reprex v2.0.2
1 Like

one more request for any ideas on this :sweat_smile:

Am i missing something obvious? is this just a bad approach?

For any future readers, I figure out a solution to this. it needs to be improved in order to scale, but its a start.

Again my goal is to reduce the size of the tuned workflowset object in order to save it for later evaluation, while still retaining all thats needed to evaluate and finalize a model in a new R session.

The reprex below strips out all but the best performing iterations of each workflow. In other words if you tuned 2 model specs you can keep the best performing configuration of those 2 models, instead of getting (number models * number of folds * grid size) number of results.

library(tidymodels)
library(stringr)
#> 
#> Attaching package: 'stringr'
#> The following object is masked from 'package:recipes':
#> 
#>     fixed
library(tidyverse)

data(parabolic)

# prep 
set.seed(1)
split <- initial_split(parabolic)
train_set <- training(split)
test_set <- testing(split)

set.seed(2)
train_resamples <- vfold_cv(train_set, v = 10)

logistic_reg_spec <- 
  logistic_reg(penalty = tune(),
               mixture = tune()) %>% 
  set_engine("glmnet")

rec <-
  recipe(class ~ ., data = train_set)

rec_norm <-
  recipe(class ~ ., data = train_set) %>%
  step_normalize(all_numeric_predictors())

workflow <- 
  workflow_set(
    preproc = list(rec = rec, rec_norm = rec_norm),
    models = list(lm = logistic_reg_spec), 
    cross = TRUE
  )

grid_ctrl <-
  control_grid(
    save_pred = TRUE,
    parallel_over = "everything",
    save_workflow = TRUE
  )

tune_results <-
  workflow %>%
  workflow_map(
    seed = 3566,
    verbose = FALSE,
    resamples = train_resamples,
    control = grid_ctrl,
    fn = "tune_grid", 
    grid = 10,
    metrics = yardstick::metric_set(accuracy)
  )

tune_results
#> # A workflow set/tibble: 2 × 4
#>   wflow_id    info             option    result   
#>   <chr>       <list>           <list>    <list>   
#> 1 rec_lm      <tibble [1 × 4]> <opts[4]> <tune[+]>
#> 2 rec_norm_lm <tibble [1 × 4]> <opts[4]> <tune[+]>

# trying to keep only the best instance of each model: 
# use `rank_results()` to identify the best performing preprocesser/model combo 
top_performers <- 
  tune_results %>% 
  rank_results(rank_metric = "accuracy", select_best = TRUE) %>% 
  select(wflow_id, .config) %>% 
  distinct() 

# turn the results in a list, each element of which is of class `tune_results`
# (see "https://workflowsets.tidymodels.org/reference/as_workflow_set.html#ref-examples" for the inspiration of this approach)
best_models <- 
  tune_results %>% 
  pluck("result")

names(best_models) <- tune_results$wflow_id

# get into each model result and only keep the best iteration as identified in the top_performers object
best_lm <- 
  best_models$rec_lm %>%
  hoist(".metrics") %>%
  unnest(.predictions) %>%
  filter(.config == "Preprocessor1_Model09") %>%
  nest(.predictions = .pred_class:.config) %>%
  unnest(.metrics) %>%
  filter(.config == "Preprocessor1_Model09") %>%
  nest(.metrics = penalty:.config)

best_norm_lm <- 
  best_models$rec_norm_lm %>%
  hoist(".metrics") %>%
  unnest(.predictions) %>%
  filter(.config == "Preprocessor1_Model09") %>%
  nest(.predictions = .pred_class:.config) %>%
  unnest(.metrics) %>%
  filter(.config == "Preprocessor1_Model09") %>%
  nest(.metrics = penalty:.config)

# combine the reduced .metrics and .predictions columns into a nested tibble for each model
lm_list <- 
  tibble(.metrics = best_lm$.metrics,
         .predictions = best_lm$.predictions)

norm_lm_list <- 
  tibble(.metrics = best_norm_lm$.metrics,
         .predictions = best_norm_lm$.predictions)

# overwrite the original .metrics and .predicitons columns with the nested tibbles containing only the best models
best_models$rec_lm$.metrics <- lm_list$.metrics
best_models$rec_lm$.predictions <- lm_list$.predictions

best_models$rec_norm_lm$.metrics <- norm_lm_list$.metrics
best_models$rec_norm_lm$.predictions <- norm_lm_list$.predictions

# now .metrics is 1x6 instead of 10x6 
best_models
#> $rec_lm
#> # Tuning results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 5
#>    splits           id     .metrics         .notes           .predictions     
#>    <list>           <chr>  <list>           <list>           <list>           
#>  1 <split [337/38]> Fold01 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  2 <split [337/38]> Fold02 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  3 <split [337/38]> Fold03 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  4 <split [337/38]> Fold04 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  5 <split [337/38]> Fold05 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  6 <split [338/37]> Fold06 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>
#>  7 <split [338/37]> Fold07 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>
#>  8 <split [338/37]> Fold08 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>
#>  9 <split [338/37]> Fold09 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>
#> 10 <split [338/37]> Fold10 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>
#> 
#> $rec_norm_lm
#> # Tuning results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 5
#>    splits           id     .metrics         .notes           .predictions     
#>    <list>           <chr>  <list>           <list>           <list>           
#>  1 <split [337/38]> Fold01 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  2 <split [337/38]> Fold02 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  3 <split [337/38]> Fold03 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  4 <split [337/38]> Fold04 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  5 <split [337/38]> Fold05 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [38 × 6]>
#>  6 <split [338/37]> Fold06 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>
#>  7 <split [338/37]> Fold07 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>
#>  8 <split [338/37]> Fold08 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>
#>  9 <split [338/37]> Fold09 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>
#> 10 <split [338/37]> Fold10 <tibble [1 × 6]> <tibble [0 × 3]> <tibble [37 × 6]>

#check that the lists are still of class "tune_results"
purrr::map_chr(best_models, ~ class(.x)[1])
#>         rec_lm    rec_norm_lm 
#> "tune_results" "tune_results"

# combine into a new workflow
best_workflows <- as_workflow_set(!!!best_models)

# note the class of the result column is <tune[+]>
best_workflows
#> # A workflow set/tibble: 2 × 4
#>   wflow_id    info             option    result   
#>   <chr>       <list>           <list>    <list>   
#> 1 rec_lm      <tibble [1 × 4]> <opts[0]> <tune[+]>
#> 2 rec_norm_lm <tibble [1 × 4]> <opts[0]> <tune[+]>

# helper functions work!
best_workflows %>% collect_metrics() 
#> # A tibble: 2 × 9
#>   wflow_id    .config       preproc model .metric .estimator  mean     n std_err
#>   <chr>       <chr>         <chr>   <chr> <chr>   <chr>      <dbl> <int>   <dbl>
#> 1 rec_lm      Preprocessor… recipe  logi… accura… binary     0.736    10  0.0171
#> 2 rec_norm_lm Preprocessor… recipe  logi… accura… binary     0.736    10  0.0171

# confirm that the results only contain the best models
best_workflows %>% rank_results(rank_metric = "accuracy", select_best = FALSE)
#> # A tibble: 2 × 9
#>   wflow_id    .config       .metric  mean std_err     n preprocessor model  rank
#>   <chr>       <chr>         <chr>   <dbl>   <dbl> <int> <chr>        <chr> <int>
#> 1 rec_lm      Preprocessor… accura… 0.736  0.0171    10 recipe       logi…     1
#> 2 rec_norm_lm Preprocessor… accura… 0.736  0.0171    10 recipe       logi…     2
best_workflows %>% rank_results(rank_metric = "accuracy", select_best = TRUE)
#> # A tibble: 2 × 9
#>   wflow_id    .config       .metric  mean std_err     n preprocessor model  rank
#>   <chr>       <chr>         <chr>   <dbl>   <dbl> <int> <chr>        <chr> <int>
#> 1 rec_lm      Preprocessor… accura… 0.736  0.0171    10 recipe       logi…     1
#> 2 rec_norm_lm Preprocessor… accura… 0.736  0.0171    10 recipe       logi…     2
Created on 2023-12-06 with reprex v2.0.2

Adding one more reprex for posterity.

Created a function called keep_best_workflows() which reduces the size of the tuned model object from 2.58 MB to 419.90 kB.

This improves the speed of saving the tuned model object (complete with full workflow and predictions), and also improves the speed when loading it into a new R session to perform model evaluation.

library(tidymodels)
library(stringr)
library(tidyverse)
library(finetune)
library(lobstr)

data(parabolic)

# prep ----
set.seed(1)
split <- initial_split(parabolic)
train_set <- training(split)
test_set <- testing(split)

set.seed(2)
train_resamples <- vfold_cv(train_set, v = 10)

# recipes ----
rec <-
  recipe(class ~ ., data = train_set)

rec_norm <-
  recipe(class ~ ., data = train_set) %>%
  step_normalize(all_numeric_predictors())

# model speficications ----
logistic_reg_spec <- 
  logistic_reg(penalty = tune(),
               mixture = tune()) %>% 
  set_engine("glmnet")

xgb_spec <- 
  boost_tree(
    tree_depth = tune(),
    learn_rate = tune(),
    loss_reduction = tune(), 
    min_n = tune(),
    mtry = tune(),
    sample_size = tune(),
    trees = tune()
  ) %>% 
  set_engine("xgboost", counts = FALSE) %>% 
  set_mode("classification")

# worflows ----
update_mtry <-
  function(workflowset, workflow){
    
    workflowset %>%
      hardhat::extract_parameter_set_dials({{workflow}}) %>%
      update(mtry = mtry_prop())
  }

workflow <- 
  workflow_set(
    preproc = list(rec = rec, rec_norm = rec_norm),
    models = list(lm = logistic_reg_spec, xgb = xgb_spec), 
    cross = TRUE
  ) %>%
  option_add(param_info = update_mtry(workflowset = ., workflow = "rec_xgb"), id = "rec_xgb") %>%
  option_add(param_info = update_mtry(workflowset = ., workflow = "rec_norm_xgb"), id = "rec_norm_xgb")


# tune worflows ----
sim_anneal_ctrl <- 
  control_sim_anneal(
    save_pred = TRUE,
    parallel_over = "everything",
    save_workflow = TRUE,
    restart = 5L,
    verbose = FALSE
  ) 

tune_results <-
  workflow %>%
  workflow_map(
    seed = 3566,
    verbose = FALSE,
    resamples = train_resamples,
    control = sim_anneal_ctrl,
    fn = "tune_sim_anneal", 
    iter = 10,
    metrics = yardstick::metric_set(accuracy, roc_auc)
  )
#> Optimizing accuracy
#> Initial best: 0.71465
#> 1 ♥ new best           accuracy=0.72248 (+/-0.01668)
#> 2 ♥ new best           accuracy=0.73321 (+/-0.0174)
#> 3 ─ discard suboptimal accuracy=0.69068 (+/-0.01908)
#> 4 ♥ new best           accuracy=0.73585 (+/-0.01527)
#> 5 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> 6 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> 7 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> 8 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> 9 ✖ restart from best  accuracy=0.73585 (+/-0.01527)
#> 10 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> Optimizing accuracy
#> Initial best: 0.42696
#> 1 ◯ accept suboptimal  accuracy=0.42696 (+/-0.0157)
#> 2 ♥ new best           accuracy=0.44047 (+/-0.02111)
#> 3 ◯ accept suboptimal  accuracy=0.44047 (+/-0.02111)
#> 4 ◯ accept suboptimal  accuracy=0.42696 (+/-0.0157)
#> 5 ♥ new best           accuracy=0.45882 (+/-0.02723)
#> 6 ♥ new best           accuracy=0.56017 (+/-0.0219)
#> 7 ♥ new best           accuracy=0.65057 (+/-0.01422)
#> 8 ♥ new best           accuracy=0.66913 (+/-0.01195)
#> 9 ♥ new best           accuracy=0.7995 (+/-0.02756)
#> 10 ♥ new best           accuracy=0.87198 (+/-0.01713)
#> Optimizing accuracy
#> Initial best: 0.71465
#> 1 ♥ new best           accuracy=0.72248 (+/-0.01668)
#> 2 ♥ new best           accuracy=0.73321 (+/-0.0174)
#> 3 ─ discard suboptimal accuracy=0.69068 (+/-0.01908)
#> 4 ♥ new best           accuracy=0.73585 (+/-0.01527)
#> 5 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> 6 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> 7 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> 8 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> 9 ✖ restart from best  accuracy=0.73585 (+/-0.01527)
#> 10 ◯ accept suboptimal  accuracy=0.73585 (+/-0.01527)
#> Optimizing accuracy
#> Initial best: 0.42696
#> 1 ◯ accept suboptimal  accuracy=0.42696 (+/-0.0157)
#> 2 ♥ new best           accuracy=0.44047 (+/-0.02111)
#> 3 ◯ accept suboptimal  accuracy=0.44047 (+/-0.02111)
#> 4 ◯ accept suboptimal  accuracy=0.42696 (+/-0.0157)
#> 5 ♥ new best           accuracy=0.45882 (+/-0.02723)
#> 6 ♥ new best           accuracy=0.56017 (+/-0.0219)
#> 7 ♥ new best           accuracy=0.65057 (+/-0.01422)
#> 8 ♥ new best           accuracy=0.66913 (+/-0.01195)
#> 9 ♥ new best           accuracy=0.7995 (+/-0.02756)
#> 10 ♥ new best           accuracy=0.87198 (+/-0.01713)

# tuned results object size
lobstr::obj_size(tune_results)
#> 2.58 MB

# select best workflows reducing tuned results object size
top_performers <- 
  tune_results %>% 
  rank_results(rank_metric = "accuracy", select_best = TRUE) %>% 
  filter(.metric == 'accuracy') %>%
  select(wflow_id, .config) %>% 
  distinct() 

top_performers
#> # A tibble: 4 × 2
#>   wflow_id     .config
#>   <chr>        <chr>  
#> 1 rec_xgb      Iter10 
#> 2 rec_norm_xgb Iter10 
#> 3 rec_lm       Iter4  
#> 4 rec_norm_lm  Iter4

result_list <- 
  tune_results %>% 
  pluck("result")

# workflow_map returns 4 models, each with 110 rows
result_list
#> [[1]]
#> # Tuning results
#> # 10-fold cross-validation 
#> # A tibble: 110 × 6
#>    splits           id     .metrics         .notes           .predictions .iter
#>    <list>           <chr>  <list>           <list>           <list>       <int>
#>  1 <split [337/38]> Fold01 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  2 <split [337/38]> Fold02 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  3 <split [337/38]> Fold03 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  4 <split [337/38]> Fold04 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  5 <split [337/38]> Fold05 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  6 <split [338/37]> Fold06 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  7 <split [338/37]> Fold07 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  8 <split [338/37]> Fold08 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  9 <split [338/37]> Fold09 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#> 10 <split [338/37]> Fold10 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#> # ℹ 100 more rows
#> 
#> [[2]]
#> # Tuning results
#> # 10-fold cross-validation 
#> # A tibble: 110 × 6
#>    splits           id     .metrics          .notes           .predictions .iter
#>    <list>           <chr>  <list>            <list>           <list>       <int>
#>  1 <split [337/38]> Fold01 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  2 <split [337/38]> Fold02 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  3 <split [337/38]> Fold03 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  4 <split [337/38]> Fold04 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  5 <split [337/38]> Fold05 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  6 <split [338/37]> Fold06 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  7 <split [338/37]> Fold07 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  8 <split [338/37]> Fold08 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  9 <split [338/37]> Fold09 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#> 10 <split [338/37]> Fold10 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#> # ℹ 100 more rows
#> 
#> [[3]]
#> # Tuning results
#> # 10-fold cross-validation 
#> # A tibble: 110 × 6
#>    splits           id     .metrics         .notes           .predictions .iter
#>    <list>           <chr>  <list>           <list>           <list>       <int>
#>  1 <split [337/38]> Fold01 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  2 <split [337/38]> Fold02 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  3 <split [337/38]> Fold03 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  4 <split [337/38]> Fold04 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  5 <split [337/38]> Fold05 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  6 <split [338/37]> Fold06 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  7 <split [338/37]> Fold07 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  8 <split [338/37]> Fold08 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#>  9 <split [338/37]> Fold09 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#> 10 <split [338/37]> Fold10 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         0
#> # ℹ 100 more rows
#> 
#> [[4]]
#> # Tuning results
#> # 10-fold cross-validation 
#> # A tibble: 110 × 6
#>    splits           id     .metrics          .notes           .predictions .iter
#>    <list>           <chr>  <list>            <list>           <list>       <int>
#>  1 <split [337/38]> Fold01 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  2 <split [337/38]> Fold02 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  3 <split [337/38]> Fold03 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  4 <split [337/38]> Fold04 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  5 <split [337/38]> Fold05 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  6 <split [338/37]> Fold06 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  7 <split [338/37]> Fold07 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  8 <split [338/37]> Fold08 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#>  9 <split [338/37]> Fold09 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#> 10 <split [338/37]> Fold10 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>         0
#> # ℹ 100 more rows

names(result_list) <- tune_results$wflow_id

keep_best_workflows <- 
  function(result_list, workflow_id, top_performers){
    
    result_type <-
      top_performers %>% 
      filter(wflow_id == {{workflow_id}}) %>%
      pull(.config)
    
    best_results <- 
      
      if( str_detect(string = result_type, pattern = "Preprocessor" ) ){
        
        result_list[[{{workflow_id}}]]
        
      } else {
        
        best_iteration <-
          top_performers %>% 
          filter(wflow_id == {{workflow_id}}) %>%
          mutate(iter = str_replace(string = .config, pattern = "Iter", replacement = ""),
                 iter = as.numeric(iter)) %>%
          pull(iter)
        
        best_results <-
          result_list[[{{workflow_id}}]] %>%
          filter(.iter == best_iteration)
        
        best_results <- 
          tibble::new_tibble(best_results, class = "tune_results")
        
      }
    
    best_results
  }

best_result_list <- 
  map(tune_results$wflow_id, ~keep_best_workflows(result_list = result_list, 
                                                  top_performers = top_performers, 
                                                  workflow_id = .x)) 

# after reducing we also have 4 models, but now each has only 10 rows
best_result_list
#> [[1]]
#> # Tuning results
#> # A tibble: 10 × 6
#>    splits           id     .metrics         .notes           .predictions .iter
#>    <list>           <chr>  <list>           <list>           <list>       <int>
#>  1 <split [337/38]> Fold01 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  2 <split [337/38]> Fold02 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  3 <split [337/38]> Fold03 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  4 <split [337/38]> Fold04 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  5 <split [337/38]> Fold05 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  6 <split [338/37]> Fold06 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  7 <split [338/37]> Fold07 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  8 <split [338/37]> Fold08 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  9 <split [338/37]> Fold09 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#> 10 <split [338/37]> Fold10 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#> 
#> [[2]]
#> # Tuning results
#> # A tibble: 10 × 6
#>    splits           id     .metrics          .notes           .predictions .iter
#>    <list>           <chr>  <list>            <list>           <list>       <int>
#>  1 <split [337/38]> Fold01 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  2 <split [337/38]> Fold02 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  3 <split [337/38]> Fold03 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  4 <split [337/38]> Fold04 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  5 <split [337/38]> Fold05 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  6 <split [338/37]> Fold06 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  7 <split [338/37]> Fold07 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  8 <split [338/37]> Fold08 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  9 <split [338/37]> Fold09 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#> 10 <split [338/37]> Fold10 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#> 
#> [[3]]
#> # Tuning results
#> # A tibble: 10 × 6
#>    splits           id     .metrics         .notes           .predictions .iter
#>    <list>           <chr>  <list>           <list>           <list>       <int>
#>  1 <split [337/38]> Fold01 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  2 <split [337/38]> Fold02 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  3 <split [337/38]> Fold03 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  4 <split [337/38]> Fold04 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  5 <split [337/38]> Fold05 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  6 <split [338/37]> Fold06 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  7 <split [338/37]> Fold07 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  8 <split [338/37]> Fold08 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#>  9 <split [338/37]> Fold09 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#> 10 <split [338/37]> Fold10 <tibble [2 × 6]> <tibble [0 × 3]> <tibble>         4
#> 
#> [[4]]
#> # Tuning results
#> # A tibble: 10 × 6
#>    splits           id     .metrics          .notes           .predictions .iter
#>    <list>           <chr>  <list>            <list>           <list>       <int>
#>  1 <split [337/38]> Fold01 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  2 <split [337/38]> Fold02 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  3 <split [337/38]> Fold03 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  4 <split [337/38]> Fold04 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  5 <split [337/38]> Fold05 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  6 <split [338/37]> Fold06 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  7 <split [338/37]> Fold07 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  8 <split [338/37]> Fold08 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#>  9 <split [338/37]> Fold09 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10
#> 10 <split [338/37]> Fold10 <tibble [2 × 11]> <tibble [0 × 3]> <tibble>        10

new_workflowset <- 
  tune_results %>%
  select(-result) %>%
  bind_cols(., tibble(result = best_result_list)) %>%
  new_tibble(., class = "workflow_set")

new_workflowset
#> # A workflow set/tibble: 4 × 4
#>   wflow_id     info             option    result   
#>   <chr>        <list>           <list>    <list>   
#> 1 rec_lm       <tibble [1 × 4]> <opts[4]> <tune[+]>
#> 2 rec_xgb      <tibble [1 × 4]> <opts[5]> <tune[+]>
#> 3 rec_norm_lm  <tibble [1 × 4]> <opts[4]> <tune[+]>
#> 4 rec_norm_xgb <tibble [1 × 4]> <opts[5]> <tune[+]>

# reduced workflowset object size
lobstr::obj_size(new_workflowset)
#> 419.90 kB

# helper functions work!
new_workflowset %>% collect_metrics() 
#> # A tibble: 8 × 10
#>   wflow_id    .config .iter preproc model .metric .estimator  mean     n std_err
#>   <chr>       <chr>   <int> <chr>   <chr> <chr>   <chr>      <dbl> <int>   <dbl>
#> 1 rec_lm      Iter4       4 recipe  logi… accura… binary     0.736    10  0.0153
#> 2 rec_lm      Iter4       4 recipe  logi… roc_auc binary     0.804    10  0.0183
#> 3 rec_xgb     Iter10     10 recipe  boos… accura… binary     0.872    10  0.0171
#> 4 rec_xgb     Iter10     10 recipe  boos… roc_auc binary     0.916    10  0.0126
#> 5 rec_norm_lm Iter4       4 recipe  logi… accura… binary     0.736    10  0.0153
#> 6 rec_norm_lm Iter4       4 recipe  logi… roc_auc binary     0.804    10  0.0183
#> 7 rec_norm_x… Iter10     10 recipe  boos… accura… binary     0.872    10  0.0171
#> 8 rec_norm_x… Iter10     10 recipe  boos… roc_auc binary     0.916    10  0.0126

# rank_results() doesn work, but can easily be replicated (new_workflowset %>% rank_results(rank_metric = "accuracy", select_best = FALSE))
new_workflowset %>% 
  collect_metrics() %>%
  filter(.metric == 'accuracy') %>%
  arrange(desc(mean))
#> # A tibble: 4 × 10
#>   wflow_id    .config .iter preproc model .metric .estimator  mean     n std_err
#>   <chr>       <chr>   <int> <chr>   <chr> <chr>   <chr>      <dbl> <int>   <dbl>
#> 1 rec_xgb     Iter10     10 recipe  boos… accura… binary     0.872    10  0.0171
#> 2 rec_norm_x… Iter10     10 recipe  boos… accura… binary     0.872    10  0.0171
#> 3 rec_lm      Iter4       4 recipe  logi… accura… binary     0.736    10  0.0153
#> 4 rec_norm_lm Iter4       4 recipe  logi… accura… binary     0.736    10  0.0153
Created on 2023-12-27 with reprex v2.0.2
1 Like

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.