Hi Community
I followed this tutorial tidymodels last case study session and was stuck at the last code of plotting roc curve.
last_rf_fit %>%
collect_predictions() %>%
roc_curve(children, .pred_children) %>%
autoplot()
My CPU was running at 22% constantly after random forest on hotels data.
Random Forest model ran at CPU parallel 12-cores @100%.
I killed RStudio in task monitor windows machine x 3 times.
I rebooted my pc and ran again and stuck again.
I hope you will help me in this.
Thank you so much.
My repex will be pasted after the link to the tutorial.
The link to that tutorial tidymodels - A predictive modeling case study
5 a predictive modeling case study tidymodels - A predictive modeling case study
library(tidymodels)
library(readr)
hotels <-
read_csv("https://tidymodels.org/start/case-study/hotels.csv") %>%
mutate(across(where(is.character), as.factor))
dim(hotels)
#> [1] 50000 23
glimpse(hotels)
hotels %>%
count(children) %>%
mutate(prop = n/sum(n))
DATA SPLIT AND RESAMPLING
set.seed(123)
splits <- initial_split(hotels, strata = children)
hotel_other <- training(splits)
hotel_test <- testing(splits)
training set prop by children
hotel_other %>%
count(children) %>%
mutate(prop = n / sum(n))
test set prop by children
hotel_test %>%
count(children) %>%
mutate(prop = n / sum(n))
validation set
set.seed(234)
val_set <- validation_split(hotel_other,
strata = children,
prop = 0.80)
val_set
first model penalised logistic regress
lr_mod <-
logistic_reg(penalty = tune(), mixture = 1) %>%
set_engine('glmnet')
holidays <- c("AllSouls", "AshWednesday", "ChristmasEve", "Easter",
"ChristmasDay", "GoodFriday", "NewYearsDay", "PalmSunday")
lr_recipe <-
recipe(children ~ ., data = hotel_other) %>%
step_date(arrival_date) %>%
step_holiday(arrival_date, holidays = holidays) %>%
step_rm(arrival_date) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
create workflow
lr_workflow <-
workflow() %>%
add_model(lr_mod) %>%
add_recipe(lr_recipe)
create grid for tuning
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))
lr_reg_grid %>%
top_n(-5) # lowest penalty values
lr_reg_grid %>%
top_n(5) # highest penalty values
train and tune the model
lr_res <-
lr_workflow %>%
tune_grid(val_set,
grid = lr_reg_grid,
control = control_grid(save_pred = TRUE),
metrics = metric_set(roc_auc))
library(ggpubr)
library(scales)
lr_plot <-
lr_res %>%
collect_metrics() %>%
ggplot(aes(x = penalty, y = mean)) +
geom_point() +
geom_line() +
ylab('area under the roc curve') +
scale_x_log10(labels = scales::label_number())
lr_plot
top models
top_models <-
lr_res %>%
show_best('roc_auc', n = 15) %>%
arrange(penalty)
top_models
lr_best <-
lr_res %>%
collect_metrics() %>%
arrange(penalty) %>%
slice(12)
lr_best
lr_auc <-
lr_res %>%
collect_predictions(parameters = lr_best) %>%
roc_curve(children, .pred_children) %>%
mutate(model = 'Logistic Regression')
autoplot(lr_auc)
second model tree based ensemble
cores <- parallel::detectCores()
cores
parsnip model
rf_mod <-
rand_forest(mtry = tune(), min_n = tune(), trees = 1000) %>%
set_engine('ranger', num.threads = cores) %>%
set_mode('classification')
create recipe and workflow
rf_recipe <-
recipe(children ~ ., data = hotel_other) %>%
step_date(arrival_date) %>%
step_holiday(arrival_date) %>%
step_rm(arrival_date)
then workflow
rf_workflow <-
workflow() %>%
add_model(rf_mod) %>%
add_recipe(rf_recipe)
train and tune
rf_mod
extract_parameter_set_dials(rf_mod)
set.seed(345)
rf_res <-
rf_workflow %>%
tune_grid(val_set,
grid = 25,
control = control_grid(save_pred = TRUE),
metrics = metric_set(roc_auc))
rf_res %>%
show_best(metric = 'roc_auc')
autoplot(rf_res)
rf_best <-
rf_res %>%
select_best(metric = 'roc_auc')
rf_best
collect predictions
rf_res %>%
collect_predictions()
rf_auc <-
rf_res %>%
collect_predictions(parameters = rf_best) %>%
roc_curve(children, .pred_children) %>%
mutate(model = 'Random Forest')
Now, we can compare the validation set ROC curves for our top penalized logistic regression model and random forest model
bind_rows(rf_auc, lr_auc) %>%
ggplot(aes(x = 1 - specificity, y = sensitivity, color = model)) +
geom_path(lwd = 1.5, alpha = 0.8) +
geom_abline(lty = 3) +
coord_equal() +
scale_color_viridis_d()
last fit
last model
last_rf_mod <-
rand_forest(mtry = 8, min_n = 7, trees = 1000) %>%
set_engine('ranger', num.threads = cores, importance = 'impurity') %>%
set_mode('classification')
last workflow
last_rf_workflow <-
rf_workflow %>%
update_model(last_rf_mod)
last fit
set.seed(345)
last_rf_fit <-
last_rf_workflow %>%
last_fit(splits)
last_rf_fit
last_rf_fit %>%
collect_metrics()
library(vip)
last_rf_fit %>%
extract_fit_parsnip() %>%
vip(num_features = 20)
last_rf_fit %>%
collect_predictions() %>%
roc_curve(children, .pred_children) %>%
autoplot()