I'm having some trouble using tidymodels workflows to fit a tuned xgboost model with cross-validation. When I check the "notes" column, I see quite a few errors. The reprex below mimics my data.
Am I missing something obvious?
library(doParallel)
#> Warning: package 'doParallel' was built under R version 4.0.2
#> Loading required package: foreach
#> Warning: package 'foreach' was built under R version 4.0.2
#> Loading required package: iterators
#> Warning: package 'iterators' was built under R version 4.0.2
#> Loading required package: parallel
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.0.2
#> -- Attaching packages ---------------------------------------------------------------------- tidymodels 0.1.1 --
#> v broom 0.7.0 v recipes 0.1.13
#> v dials 0.0.8 v rsample 0.0.7
#> v dplyr 1.0.0 v tibble 3.0.3
#> v ggplot2 3.3.2 v tidyr 1.1.0
#> v infer 0.5.3 v tune 0.1.1
#> v modeldata 0.0.2 v workflows 0.1.2
#> v parsnip 0.1.2 v yardstick 0.0.7
#> v purrr 0.3.4
#> Warning: package 'broom' was built under R version 4.0.2
#> Warning: package 'dials' was built under R version 4.0.2
#> Warning: package 'scales' was built under R version 4.0.2
#> Warning: package 'dplyr' was built under R version 4.0.2
#> Warning: package 'ggplot2' was built under R version 4.0.2
#> Warning: package 'infer' was built under R version 4.0.2
#> Warning: package 'modeldata' was built under R version 4.0.2
#> Warning: package 'parsnip' was built under R version 4.0.2
#> Warning: package 'purrr' was built under R version 4.0.2
#> Warning: package 'recipes' was built under R version 4.0.2
#> Warning: package 'rsample' was built under R version 4.0.2
#> Warning: package 'tibble' was built under R version 4.0.2
#> Warning: package 'tidyr' was built under R version 4.0.2
#> Warning: package 'tune' was built under R version 4.0.2
#> Warning: package 'workflows' was built under R version 4.0.2
#> Warning: package 'yardstick' was built under R version 4.0.2
#> -- Conflicts ------------------------------------------------------------------------- tidymodels_conflicts() --
#> x purrr::accumulate() masks foreach::accumulate()
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
#> x purrr::when() masks foreach::when()
library(tidyverse)
#> Warning: package 'tidyverse' was built under R version 4.0.2
#> Warning: package 'readr' was built under R version 4.0.2
#> Warning: package 'stringr' was built under R version 4.0.2
#> Warning: package 'forcats' was built under R version 4.0.2
set.seed(3434)
data <- tibble(outcome = rnorm(3000, 100, 15),
pred_1 = rnorm(3000, 20, 10),
pred_2 = sample(c("lev1", "lev2", "lev3"),
size = 3000,
replace = TRUE),
pred_3 = sample(c("lev1", "lev2", "lev3"),
size = 3000,
replace = TRUE),
pred_4 = sample(c("lev1", "lev2", "lev3"),
size = 3000,
replace = TRUE))
data <- mutate_if(data, is.character, factor)
data_split <- initial_split(data,
prop = .75,
strata = outcome)
training <- training(data_split)
testing <- testing(data_split)
my_recipe <- recipe(outcome ~ ., data = training) %>%
step_nzv(all_nominal()) #remove near-zero-variance predictors
xgb_spec <- boost_tree(trees = 200,
tree_depth = tune(), #number of splits
mtry = tune(), #introducing randomness
learn_rate = tune()) %>% #step size
set_engine("xgboost") %>%
set_mode("regression")
xgb_grid <- grid_latin_hypercube(
tree_depth(),
finalize(mtry(), training), #based on # of predictors
learn_rate(),
size = 6
)
xgb_wf <- workflow() %>%
add_recipe(my_recipe) %>%
add_model(xgb_spec)
xgb_folds <- vfold_cv(training, strata = outcome, v = 10)
registerDoParallel()
xgb_res <- tune_grid(
object = xgb_wf,
grid = xgb_grid,
resamples = xgb_folds,
control = control_grid(save_pred = TRUE)
)
#> Warning: All models failed in tune_grid(). See the `.notes` column.