Tidymodels using tune_race_anova with a workflow

There were three problems:

  • There isn't a perf argument; I think you meant metrics. I didn't see that either when I looked at your code.
  • In the recipe, step_dummy(all_nominal()) was capturing the outcome. This happens a lot and the devel version of recipes has all_nominal_predictors(). Until then, use step_dummy(all_nominal(), -Status). However, the ranger package does not require dummy variables for predictors, so you can skip that if you want.
  • The grid code returns the parameters. You could pass this to the param_info argument or make the grid with one of the grid functions, such as
rf_grid <- dials::parameters(
   finalize(mtry(), select(credit_data, -Status)),
   trees(),
   min_n()) %>% 
   grid_random(5)

One other thing... this data set has some missing values so you might want to add one of the imputation steps to the recipe (otherwise ranger will error).

Here's my script:

library(tidyverse)
library(tidymodels)
library(finetune)

set.seed(4595)
data("credit_data")

credit_data <- credit_data %>% na.omit()

data_split <- initial_split(credit_data, strata = "Status", prop = 0.75)

train_explore <- training(data_split)
test_explore  <- testing(data_split)

# Generate resamples and repeat
report_resamples <- vfold_cv(train_explore, v = 10, repeats = 1, strata = Status)

# Set up the model definition
preprocess <- train_explore %>%
   recipe(Status ~ .) %>%
   themis::step_downsample(Status) 

# BUILD A RANDOM FOREST MODEL ---------------------------------------------
rf_mod <- rand_forest(
   mtry = tune(),
   trees = tune(),
   min_n = tune()) %>%
   set_mode("classification") %>%
   set_engine("ranger")

rf_grid <- dials::parameters(
   finalize(mtry(), select(credit_data, -Status)),
   trees(),
   min_n()) %>% 
   grid_random(5)

tune_wf <- workflow() %>%
   add_recipe(preprocess) %>%
   add_model(rf_mod)

set.seed(345)
tune_res_rf <- tune_race_anova(tune_wf,
                               resamples = report_resamples,
                               grid = rf_grid,
                               metrics = metric_set(roc_auc, sens, kap, accuracy)
)
2 Likes