What is workflow's select_best + finalize_workflow() |> fit() equivalent for workflowsets?

You can see here that I use collect_predictions(), and cbind() after some filtering + inner joining etc. (see 'PROBLEM STARTS HERE' comment)

For workflow, I can simply use something along the lines of best_model = select_best() and finalize_workflow(best_model) |> fit(df). However, I was unable to find anything on how to do the same thing for workflowsets. Is there a more elegant solution?

library(tidytable)
library(tidymodels)
library(parsnip)
library(readr)
library(lubridate)
library(workflowsets)
library(rsample)
library(recipes)

### Sample data + resampling
df = data.frame(yearr = sample(2015:2021, 2000, replace = TRUE),
                monthh = sample(1:12, 2000, replace = TRUE),
                dayy = sample(1:29, 2000, replace = TRUE)) |>
  mutate(datee = ymd(paste(yearr, monthh, dayy)),
         yy = sample(0:100, 2000, replace = TRUE) + (130 * yearr) + (2 * monthh)) |>
  filter(!is.na(datee)) |>
  arrange(-desc(datee)) |>
  mutate(ii = row_number())

cross_folds = df |>
  vfold_cv(times = 3, repeats = 2)

### Recipe specifications
rec_basic_formula = df |>
  recipe(yy ~ .) |>
  update_role(datee, new_role = 'date') |>
  step_zv(all_predictors())

rec_additional_formula = df |>
  recipe(yy ~ .) |>
  update_role(datee, new_role = 'date') |>
  step_normalize(all_predictors()) |>
  step_zv(all_predictors())

### Modelling algorithms
linear_reg_glmnet_spec = linear_reg(penalty = tune(), mixture = tune()) |>
  set_engine('glmnet')

svm_poly_kernlab_spec = svm_poly(cost = tune(), degree = tune(), scale_factor = tune(), margin = tune()) |>
  set_engine('kernlab') |>
  set_mode('regression')

boost_tree_xgboost_spec = boost_tree(tree_depth = tune(),
                                     trees = tune(),
                                     learn_rate = tune(),
                                     min_n = tune(),
                                     loss_reduction = tune(),
                                     sample_size = tune(),
                                     stop_iter = tune()) |>
  set_engine('xgboost') |>
  set_mode('regression')

#parsnip_addin() # Other available algorithms (GPU-usage is a different story)

### Set up
wfs_models = workflow_set(preproc = list(rec_basic_formula,
                                         rec_additional_formula),
                          models = list(glmnet = linear_reg_glmnet_spec,
                                        svm = svm_poly_kernlab_spec,
                                        boosted_tree = boost_tree_xgboost_spec),
                          cross = T)

### Run the models
tune_models = wfs_models |>
  workflow_map('tune_grid',
               resamples = cross_folds,
               grid = 3,
               metrics = metric_set(huber_loss),
               control = control_resamples(save_pred = T))

### Graphing and numbers
autoplot(tune_models)

best_model = rank_results(tune_models,
                          rank_metric = 'huber_loss',
                          select_best = T)

### Fits -- PROBLEM STARTS HERE
best_fits = collect_predictions(tune_models)

df_yy_hat = best_model |>
  filter.(rank == 1) |>
  inner_join.(best_fits) |>
  select.(.pred) |>
  rename.(yy_hat = .pred)

### Output
df_fitted = df |>
  cbind(df_yy_hat)

df_fitted |> head()

### Graphing the fit
df_fitted |>
  select.(datee, yy, yy_hat) |>
  ggplot() +
  geom_point(aes(y = yy,
                 x = datee)) +
  geom_point(aes(y = yy_hat,
                 x = datee),
             colour = 'red')

Thank you for the post! Max and I (workflowsets maintainers) just chatted about this briefly, and I've opened up an issue on the GitHub repository for further discussion. Your current approach seems solid, though we may eventually have a more, as you say, elegant API for this in the future. :slight_smile:

1 Like

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.