I have a decision tree i have tuned below. I want to extract the rules into a data frame whith each rule having a row and a column for the rule and the classification. I have managed to tune it but when i extract the engine it wont allow me to convert it into rules
Can anyone help?
# SET UP METRICS AND CONTROL GRID -----------------------------------------
mset <- metric_set(recall, precision, f_meas, j_index)
grid_control <- control_grid(save_workflow = TRUE,
save_pred = TRUE,
extract = extract_model)
uv_split <- initial_time_split(mod_df)
train_data <- training(uv_split)
test_data <- testing(uv_split)
folds <- rsample::vfold_cv(train_data, v = 10,
strata = cons_ipr_confirmed)
# DECISION TREE -----------------------------------------------------------
library(rpart)
dt_mod <-
decision_tree(tree_depth = tune(),
min_n = tune(),
cost_complexity = tune()) %>%
set_engine('rpart') %>%
set_mode('classification')
tree_wf <- workflow() %>%
add_model(dt_mod) %>%
add_recipe(uv_rec,
blueprint = hardhat::default_recipe_blueprint(allow_novel_levels = TRUE))
tree_grid <- tree_wf %>%
extract_parameter_set_dials() %>%
grid_max_entropy(size = 20)
my_res <- tree_wf %>%
tune_grid(
resamples = folds,
grid = tree_grid,
control = grid_control,
metrics = metric_set(j_index)
)
best_j_index <- select_best(my_res)
final_tree <- finalize_workflow(tree_wf, best_j_index)
final_res_tree <- last_fit(final_tree, uv_split, metrics = mset)
collect_metrics(final_res)
show_best(my_res)
tree_fit_rpart <- extract_fit_engine(final_res_tree)
# Print the rules
rpart.rules(test)
# Prints rules but is missing the label
# I have redacted these as its company data
Warning message:
Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
To silence this warning:
Call rpart.rules with roundint=FALSE,
or rebuild the rpart model with model=TRUE.
What happens if you follow the suggestion made in the warning? Does that solve your problem?
library(parsnip)
library(rpart.plot)
#> Loading required package: rpart
data(two_class_dat, package = "modeldata")
set.seed(1)
tree_fit <-
decision_tree(tree_depth = 30) %>%
set_mode("classification") %>%
set_engine("rpart") %>%
fit(Class ~ ., data = two_class_dat)
tree_fit %>%
extract_fit_engine() %>%
rpart.rules()
#> Warning: Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
#> To silence this warning:
#> Call rpart.rules with roundint=FALSE,
#> or rebuild the rpart model with model=TRUE.
#> Class
#> 0.15 when B < 1.5
#> 0.28 when B is 1.5 to 2.1 & A >= 2.6
#> 0.75 when B is 1.5 to 2.1 & A < 2.6
#> 0.88 when B >= 2.1
# suggestion 1 from the warning: set `roundint = FALSE`
tree_fit %>%
extract_fit_engine() %>%
rpart.rules(roundint = FALSE)
#> Class
#> 0.15 when B < 1.5
#> 0.28 when B is 1.5 to 2.1 & A >= 2.6
#> 0.75 when B is 1.5 to 2.1 & A < 2.6
#> 0.88 when B >= 2.1
# suggestion 2: rebuild the model with `model = TRUE`
set.seed(1)
tree_fit <-
decision_tree(tree_depth = 30) %>%
set_mode("classification") %>%
set_engine("rpart", model = TRUE) %>%
fit(Class ~ ., data = two_class_dat)
tree_fit %>%
extract_fit_engine() %>%
rpart.rules()
#> Class
#> 0.15 when B < 1.5
#> 0.28 when B is 1.5 to 2.1 & A >= 2.6
#> 0.75 when B is 1.5 to 2.1 & A < 2.6
#> 0.88 when B >= 2.1
Thanks very much. I wasnt aware of where to update the model setup with the suggestions.
I am curious from the model rules you have shown me, it doesn't actually tell me how they were classified.
If i print it using your code above as a base, the classification is clear but not within the rules itself. Do you know how I can add what rules give me what classification
That seems more like a rpart.plot question than a tidymodels question and I'm not really familiar with that. You could derive the hard classification from the class probability but maybe the rpart.plot docs have a better suggestion!