I'm trying to use race_tune
with workflowsets but am having an issue using save_pred = TRUE
.
reprex below shows that i use save_pred = TRUE
within control_race()
, but when i run collect_predictions()
i get an error: The '.predictions' column does not exist. Refit with the control argument 'save_pred = TRUE'
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.1.1
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
#> Warning: package 'dials' was built under R version 4.1.1
#> Warning: package 'ggplot2' was built under R version 4.1.1
#> Warning: package 'infer' was built under R version 4.1.1
#> Warning: package 'parsnip' was built under R version 4.1.1
#> Warning: package 'tune' was built under R version 4.1.1
#> Warning: package 'workflows' was built under R version 4.1.1
#> Warning: package 'workflowsets' was built under R version 4.1.1
#> Warning: package 'yardstick' was built under R version 4.1.1
library(discrim)
#> Warning: package 'discrim' was built under R version 4.1.1
#>
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#>
#> smoothness
library(workflowsets)
library(finetune)
#> Warning: package 'finetune' was built under R version 4.1.1
data(parabolic)
set.seed(1)
split <- initial_split(parabolic)
train_set <- training(split)
test_set <- testing(split)
set.seed(2)
train_resamples <- bootstraps(train_set, times = 5)
mars_disc_spec <-
discrim_flexible(prod_degree = tune()) %>%
set_engine("earth")
reg_disc_sepc <-
discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>%
set_engine("klaR")
cart_spec <-
decision_tree(cost_complexity = tune(), min_n = tune()) %>%
set_engine("rpart") %>%
set_mode("classification")
all_workflows <-
workflow_set(
preproc = list("formula" = class ~ .),
models = list(regularized = reg_disc_sepc, mars = mars_disc_spec, cart = cart_spec)
)
class_metrics <-
metric_set(roc_auc, accuracy, sensitivity, specificity)
race_ctrl <-
control_race(
verbose = TRUE,
allow_par = TRUE,
save_pred = TRUE,
parallel_over = "everything",
save_workflow = TRUE
)
doParallel::registerDoParallel()
wf_res <-
all_workflows %>%
workflow_map(fn = "tune_race_anova",
resamples = train_resamples,
grid = 10,
metrics = class_metrics,
ctrl = race_ctrl
)
#> Warning: The `...` are not used in this function but one or more objects were
#> passed: 'ctrl'
#> Warning: The `...` are not used in this function but one or more objects were
#> passed: 'ctrl'
#> Warning: The `...` are not used in this function but one or more objects were
#> passed: 'ctrl'
workflowsets::collect_predictions(wf_res)
#> Error: Problem with `mutate()` column `predictions`.
#> i `predictions = purrr::map(...)`.
#> x The `.predictions` column does not exist. Refit with the control argument `save_pred = TRUE` to save predictions.
Created on 2021-10-21 by the reprex package (v2.0.1)