In the code below, I am trying to pass a vector of weights to xgboost using the tidymodels framework. My understanding is that arguments can be passed to the underlying model functions with parsnip::set_engine()
. However, I am unclear on how to appropriately pass the weights argument to xgboost::xgb.train()
. I've tried a few ideas and tinkered with the variable roles, but have not had success. Any help is appreciated!
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.0.2
#> -- Attaching packages -------------------------------------------------------------- tidymodels 0.1.1 --
#> v broom 0.7.0 v recipes 0.1.13
#> v dials 0.0.8 v rsample 0.0.7
#> v dplyr 1.0.0 v tibble 3.0.3
#> v ggplot2 3.3.2 v tidyr 1.1.0
#> v infer 0.5.3 v tune 0.1.1
#> v modeldata 0.0.2 v workflows 0.1.2
#> v parsnip 0.1.2 v yardstick 0.0.7
#> v purrr 0.3.4
#> Warning: package 'broom' was built under R version 4.0.2
#> Warning: package 'dials' was built under R version 4.0.2
#> Warning: package 'scales' was built under R version 4.0.2
#> Warning: package 'dplyr' was built under R version 4.0.2
#> Warning: package 'ggplot2' was built under R version 4.0.2
#> Warning: package 'infer' was built under R version 4.0.2
#> Warning: package 'modeldata' was built under R version 4.0.2
#> Warning: package 'parsnip' was built under R version 4.0.2
#> Warning: package 'purrr' was built under R version 4.0.2
#> Warning: package 'recipes' was built under R version 4.0.2
#> Warning: package 'rsample' was built under R version 4.0.2
#> Warning: package 'tibble' was built under R version 4.0.2
#> Warning: package 'tidyr' was built under R version 4.0.2
#> Warning: package 'tune' was built under R version 4.0.2
#> Warning: package 'workflows' was built under R version 4.0.2
#> Warning: package 'yardstick' was built under R version 4.0.2
#> -- Conflicts ----------------------------------------------------------------- tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
data <- tibble(outcome = rnorm(3000, 100, 15),
pred_1 = outcome + rnorm(3000, 0, .6),
pred_2 = sample(c("lev1", "lev2", "lev3"),
size = 3000,
replace = TRUE),
the_weights = round(runif(3000, 1, 7), 0))
data <- mutate_if(data, is.character, factor)
data_split <- initial_split(data,
prop = .75,
strata = outcome)
training <- training(data_split)
testing <- testing(data_split)
my_recipe <- recipe(outcome ~ ., data = training) %>%
step_nzv(all_nominal()) %>%
step_dummy(all_nominal(), one_hot = TRUE) %>%
update_role(the_weights, new_role = "weights") #do i need a new role here?
xgb_spec <- boost_tree(trees = 200,
tree_depth = tune(),
mtry = tune(),
learn_rate = tune()) %>%
set_engine("xgboost", params = list(weight = the_weights)) %>% #attempting to pass the weights to xgb.train()
set_mode("regression")
xgb_grid <- grid_latin_hypercube(
tree_depth(),
finalize(mtry(), training),
learn_rate(),
size = 6
)
xgb_wf <- workflow() %>%
add_recipe(my_recipe) %>%
add_model(xgb_spec)
xgb_folds <- vfold_cv(training, strata = outcome, v = 2)
xgb_res <- tune_grid(
object = xgb_wf,
grid = xgb_grid,
resamples = xgb_folds,
control = control_grid(save_pred = TRUE)
)
#> x Fold1: model 1/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 2/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 3/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 4/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 5/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 6/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 1/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 2/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 3/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 4/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 5/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 6/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> Warning: All models failed in tune_grid(). See the `.notes` column.
Created on 2020-08-05 by the reprex package (v0.3.0)