Extracting Rules from an RPart Model

Hi,

I have a decision tree i have tuned below. I want to extract the rules into a data frame whith each rule having a row and a column for the rule and the classification. I have managed to tune it but when i extract the engine it wont allow me to convert it into rules

Can anyone help?

# SET UP METRICS AND CONTROL GRID -----------------------------------------
mset <- metric_set(recall, precision, f_meas, j_index)

grid_control <- control_grid(save_workflow = TRUE,
                             save_pred = TRUE,
                             extract = extract_model)

uv_split <- initial_time_split(mod_df)
train_data <- training(uv_split)
test_data <- testing(uv_split)

folds <- rsample::vfold_cv(train_data, v = 10, 
                           strata = cons_ipr_confirmed)

# DECISION TREE -----------------------------------------------------------
library(rpart)

dt_mod <-
  decision_tree(tree_depth = tune(), 
                min_n = tune(), 
                cost_complexity = tune()) %>%
  set_engine('rpart') %>%
  set_mode('classification')

tree_wf <- workflow() %>%
  add_model(dt_mod) %>%
  add_recipe(uv_rec, 
             blueprint = hardhat::default_recipe_blueprint(allow_novel_levels = TRUE))

tree_grid <- tree_wf %>% 
  extract_parameter_set_dials() %>% 
  grid_max_entropy(size = 20)

my_res <- tree_wf %>% 
  tune_grid(
    resamples = folds,
    grid = tree_grid,
    control = grid_control,
    metrics = metric_set(j_index)
  )


best_j_index <- select_best(my_res)
final_tree <- finalize_workflow(tree_wf, best_j_index)
final_res_tree <- last_fit(final_tree, uv_split, metrics = mset)
collect_metrics(final_res)
show_best(my_res)

tree_fit_rpart <- extract_fit_engine(final_res_tree)
# Print the rules
rpart.rules(test)

# Prints rules but is missing the label
# I have redacted these as its company data

Warning message:
Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
To silence this warning:
    Call rpart.rules with roundint=FALSE,
    or rebuild the rpart model with model=TRUE. 

What happens if you follow the suggestion made in the warning? Does that solve your problem?

library(parsnip)
library(rpart.plot)
#> Loading required package: rpart
data(two_class_dat, package = "modeldata")

set.seed(1)
tree_fit <- 
  decision_tree(tree_depth = 30) %>% 
  set_mode("classification") %>% 
  set_engine("rpart") %>% 
  fit(Class ~ ., data = two_class_dat)

tree_fit %>% 
  extract_fit_engine() %>% 
  rpart.rules()
#> Warning: Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
#> To silence this warning:
#>     Call rpart.rules with roundint=FALSE,
#>     or rebuild the rpart model with model=TRUE.
#>  Class                                
#>   0.15 when B <  1.5                  
#>   0.28 when B is 1.5 to 2.1 & A >= 2.6
#>   0.75 when B is 1.5 to 2.1 & A <  2.6
#>   0.88 when B >=        2.1

# suggestion 1 from the warning: set `roundint = FALSE`
tree_fit %>% 
  extract_fit_engine() %>% 
  rpart.rules(roundint = FALSE)
#>  Class                                
#>   0.15 when B <  1.5                  
#>   0.28 when B is 1.5 to 2.1 & A >= 2.6
#>   0.75 when B is 1.5 to 2.1 & A <  2.6
#>   0.88 when B >=        2.1

# suggestion 2: rebuild the model with `model = TRUE`
set.seed(1)
tree_fit <- 
  decision_tree(tree_depth = 30) %>% 
  set_mode("classification") %>% 
  set_engine("rpart", model = TRUE) %>% 
  fit(Class ~ ., data = two_class_dat)

tree_fit %>% 
  extract_fit_engine() %>% 
  rpart.rules()
#>  Class                                
#>   0.15 when B <  1.5                  
#>   0.28 when B is 1.5 to 2.1 & A >= 2.6
#>   0.75 when B is 1.5 to 2.1 & A <  2.6
#>   0.88 when B >=        2.1

Created on 2023-12-12 with reprex v2.0.2

1 Like

Hi @hannah

Thanks very much. I wasnt aware of where to update the model setup with the suggestions.
I am curious from the model rules you have shown me, it doesn't actually tell me how they were classified.

If i print it using your code above as a base, the classification is clear but not within the rules itself. Do you know how I can add what rules give me what classification

tree_fit %>% 
  extract_fit_engine() %>% 
  rpart.plot()

Thanks very much for your time

That seems more like a rpart.plot question than a tidymodels question and I'm not really familiar with that. You could derive the hard classification from the class probability but maybe the rpart.plot docs have a better suggestion!

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.