Differences between last_fit and fit

Hi tidymodels team,

I tried to build a model with ranger engine for a classification task. And I've tuned the model and got the best mtry and min_n. So I used last_fit on the split object to build the model on the full train set with the best mtry and min_n:

# the last model
last_rf_mod <- 
  rand_forest(mtry =rf_best$mtry, min_n = rf_best$min_n, trees = 1000) %>% 
  set_engine("ranger", num.threads = cores, importance = "impurity") %>% 
  set_mode("classification")

last_rf_workflow <- 
  rf_workflow %>% 
  update_model(last_rf_mod)

set.seed(10086)
last_rf_fit <- 
  last_rf_workflow %>% 
  last_fit(., split = splits) 

Then I used collect_metrics() to get the model performance on the test set:

last_rf_fit %>% collect_metrics()

However, I was also trying fit function as well as predict function to evalue the performance on test data:

last_rf_mod <- 
  rand_forest(mtry =rf_best$mtry, min_n = rf_best$min_n, trees = 1000) %>% 
 set_engine("ranger", num.threads = cores, importance = "impurity") %>%  
set_mode("classification")

set.seed(10086)
rf_cls_fit <- last_rf_mod %>% fit(outcome ~ ., data = TrainSet)
rf_cls_fit

predict_res <- bind_cols(
    predict(rf_cls_fit, TestSet),
    predict(rf_cls_fit, TestSet, type = "prob")
)

And calculate the auc with the roc_auc function. But this gives different numbers from the auc in collect_metrics(). Am I doing something totally wrong? Look forward to your reply!

Best,

Ben

It's hard to tell without a reproducible example and some results. We don't know if the difference is large or within the noise.

I would suggest passing the seed option as an engine argument to both model fit calls. We use a seed based on the current random number stream. They should be the same but you never know.

Hi Max,

Thanks for your fast reponse. I attached a reproducible code with example data below:

library(tidymodels)
library(readr)

Read and split data

hotels <- 
  read_csv("https://tidymodels.org/start/case-study/hotels.csv") %>%
  mutate(across(where(is.character), as.factor))

set.seed(123)
splits      <- initial_split(hotels, strata = children)

hotel_other <- training(splits)
hotel_test  <- testing(splits)


set.seed(234)
val_set <- validation_split(hotel_other, 
                            strata = children, 
                            prop = 0.80)

Set workflow and tune

cores <- 2

rf_mod <- 
  rand_forest(mtry = tune(), min_n = tune(), trees = 2) %>% 
  set_engine("ranger", num.threads = cores) %>% 
  set_mode("classification")

rf_recipe <- 
  recipe(children ~ ., data = hotel_other) %>% 
  step_date(arrival_date) %>% 
  step_holiday(arrival_date) %>% 
  step_rm(arrival_date) 

rf_workflow <- 
  workflow() %>% 
  add_model(rf_mod) %>% 
  add_recipe(rf_recipe)

set.seed(345)
rf_res <- 
  rf_workflow %>% 
  tune_grid(val_set,
            grid = 1,
            control = control_grid(save_pred = TRUE),
            metrics = metric_set(roc_auc))

select the best hyperparameters and fit to the last model and workflow

rf_best <- 
  rf_res %>% 
  select_best(metric = "roc_auc")

last_rf_mod <- 
  rand_forest(mtry = rf_best$mtry, min_n = rf_best$min_n, trees = 2) %>% 
  set_engine("ranger", num.threads = cores, importance = "impurity") %>% 
  set_mode("classification")

# the last workflow
last_rf_workflow <- 
  rf_workflow %>% 
  update_model(last_rf_mod)

Try last_fit function

# the last fit
set.seed(345)
last_rf_fit <- 
  last_rf_workflow %>% 
  last_fit(splits)

last_rf_fit %>% 
  collect_metrics()

The auc from collect_metrics function after using last_fit is 0.76945944 for me.

But when I try fit function from the same last model:

rf_cls_fit <- last_rf_mod %>% fit(children ~ ., data = hotel_other)
rf_cls_fit

Predict on the test data and calculate the auc:

predict_res <- bind_cols(
    predict(rf_cls_fit, hotel_test),
    predict(rf_cls_fit, hotel_test, type = "prob")
)

predict_res <- predict_res %>% mutate(children = hotel_test$children)

predict_res %>% roc_auc(., children, .pred_children)

It gave me 0.7734074. And the difference in confusion matrix is more different.

Best,

Ben

I believe that the issue is that you've separated the model from the recipe; last_rf_mod is a parsnip model spec and doesn't use the recipe.

Here's some code that finalize the workflow then estimates the final model via last_fit() and just fit() but both using the workflow:

library(tidymodels)
library(readr)
#> 
#> Attaching package: 'readr'
#> The following object is masked from 'package:yardstick':
#> 
#>     spec
#> The following object is masked from 'package:scales':
#> 
#>     col_factor

hotels <- 
  read_csv("https://tidymodels.org/start/case-study/hotels.csv") %>%
  mutate(across(where(is.character), as.factor))
#> Rows: 50000 Columns: 23
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> chr  (11): hotel, children, meal, country, market_segment, distribution_chan...
#> dbl  (11): lead_time, stays_in_weekend_nights, stays_in_week_nights, adults,...
#> date  (1): arrival_date
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

set.seed(123)
splits      <- initial_split(hotels, strata = children)

hotel_other <- training(splits)
hotel_test  <- testing(splits)


set.seed(234)
val_set <- validation_split(hotel_other, 
                            strata = children, 
                            prop = 0.80)
#> Warning: `validation_split()` was deprecated in rsample 1.2.0.
#> ℹ Please use `initial_validation_split()` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.

cores <- 2

rf_mod <- 
  rand_forest(mtry = tune(), min_n = tune(), trees = 3) %>% 
  set_engine("ranger", num.threads = cores, importance = "impurity", seed = 24) %>% 
  set_mode("classification")

rf_recipe <- 
  recipe(children ~ ., data = hotel_other) %>% 
  step_date(arrival_date) %>% 
  step_holiday(arrival_date) %>% 
  step_rm(arrival_date) 

rf_workflow <- 
  workflow() %>% 
  add_model(rf_mod) %>% 
  add_recipe(rf_recipe)

set.seed(345)
rf_res <- 
  rf_workflow %>% 
  tune_grid(val_set,
            grid = 1,
            control = control_grid(save_pred = TRUE),
            metrics = metric_set(roc_auc))
#> i Creating pre-processing data to finalize unknown parameter: mtry
# Choose final model

rf_best <- 
  rf_res %>% 
  select_best(metric = "roc_auc")

# the last workflow
last_rf_workflow <- 
  rf_workflow %>% 
  finalize_workflow(rf_best)
# Last fit based off of `last_fit()`

set.seed(345)
last_fit_mod_fit <- 
  last_rf_workflow %>% 
  last_fit(splits)

last_fit_mod_fit %>% 
  collect_metrics()
#> # A tibble: 3 × 4
#>   .metric     .estimator .estimate .config             
#>   <chr>       <chr>          <dbl> <chr>               
#> 1 accuracy    binary        0.930  Preprocessor1_Model1
#> 2 roc_auc     binary        0.804  Preprocessor1_Model1
#> 3 brier_class binary        0.0580 Preprocessor1_Model1
# Directly use `fit()`

set.seed(345)
fit_mod_fit <- last_rf_workflow %>% fit(data = hotel_other)

fit_mod_fit %>% 
  augment(hotel_test) %>% 
  roc_auc(truth = children, .pred_children)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.804

Created on 2024-05-14 with reprex v2.0.2

1 Like

Thanks! That is much clearer! :100:

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.