best approach for fitting final model after tuning

I am creating a regression model for sale prices from the ames dataset, utilizing a xgboost and a grid search for parameter tuning.
After selecting the best model parameters, I can use two similar approaches to create a finalized, fitted workflow for predictions on new data, AFAIU

  1. fit the model on the whole training set (as described in Chapter 13.3 of TMWR)
  2. Extract the fitted workflow from a last_fit() object, as described in the Tidymodels tutorial under β€œTune Model Parameters > The Last Fit”. (I slightly prefer this, since I am using the last_fit object anyway for plotting, metrics, vip, …)

However, I am seeing slight differences between the two approaches. Where do they come from which approach is preferable/more accurate?

Here is an example:

library(tidymodels, warn.conflicts = FALSE)
data(ames)

ames <-
  ames |>
  select(
    Sale_Price,
    Neighborhood,
    Gr_Liv_Area,
    Year_Built,
    Bldg_Type,
    Latitude,
    Longitude
  ) |> 
  mutate(Sale_Price = log10(Sale_Price))

ames_split <- initial_split(ames, prop = 0.80)
ames_train <- training(ames_split)
ames_test  <- testing(ames_split)

# recipe, spec and workflow from usemodels::use_xgboost()

xgboost_recipe <- 
  recipe(formula = Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + 
           Latitude + Longitude, data = ames_train) %>% 
  step_novel(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% 
  step_zv(all_predictors()) 

xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
             loss_reduction = tune(), sample_size = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

xgboost_control <-
  control_grid(
    save_pred = TRUE,
    save_workflow = TRUE
  )

xgboost_folds <- vfold_cv(ames_train)

set.seed(46914)
xgboost_tune <-
  tune_grid(xgboost_workflow, resamples = xgboost_folds, control = xgboost_control)
#> ! Fold01: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold02: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold03: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold04: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold05: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold06: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold07: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold08: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold09: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold10: internal: A correlation computation is required, but `estimate` is constant and ha...

show_best(xgboost_tune, metric = "rmse")
#> # A tibble: 5 Γ— 12
#>   trees min_n tree_depth learn_rate loss_…¹ sampl…² .metric .esti…³   mean     n
#>   <int> <int>      <int>      <dbl>   <dbl>   <dbl> <chr>   <chr>    <dbl> <int>
#> 1  1007    13         13    0.00823 6.52e-4   0.805 rmse    standa… 0.0722    10
#> 2  1574    37          6    0.0585  1.38e-9   0.825 rmse    standa… 0.0754    10
#> 3   897    24          3    0.0384  1.84e-2   0.382 rmse    standa… 0.0760    10
#> 4  1375    28          5    0.0180  1.33e-6   0.133 rmse    standa… 0.0790    10
#> 5   458     3         14    0.229   3.20e-5   0.536 rmse    standa… 0.0801    10
#> # … with 2 more variables: std_err <dbl>, .config <chr>, and abbreviated
#> #   variable names ¹​loss_reduction, ²​sample_size, ³​.estimator
#> # β„Ή Use `colnames()` to see all variable names

xgboost_params <- select_best(xgboost_tune, metric = "rmse")

xgboost_final_workflow <- 
  xgboost_workflow %>% 
  finalize_workflow(xgboost_params)

Now we need to fit the final workflow.
Either by fitting the final workflow on the whole training set.

xgboost_direct_fit <- fit(xgboost_final_workflow, ames_train)

Or we can run last_fit() and extract_workflow()

xgboost_final_fit <- 
  xgboost_final_workflow %>%
  last_fit(ames_split) 

xgboost_final_tree <- extract_workflow(xgboost_final_fit)

However, the resulting objects are not the same

all.equal(xgboost_direct_fit, xgboost_final_tree)
#> [1] "Component \"fit\": Component \"fit\": Component \"fit\": Component \"raw\": Lengths (2955088, 2997136) differ (comparison on first 2955088 components)"
#> [2] "Component \"fit\": Component \"fit\": Component \"fit\": Component \"raw\": 2314610 element mismatches"                                                
#> [3] "Component \"fit\": Component \"fit\": Component \"fit\": Component \"evaluation_log\": Column 'training_rmse': Mean relative difference: 0.0002653307"

and predict slightly different values (only the first)

options(pillar.sigfig = 4)
predict(xgboost_direct_fit, head(ames))
#> # A tibble: 6 Γ— 1
#>   .pred
#>   <dbl>
#> 1 5.251
#> 2 5.104
#> 3 5.202
#> 4 5.350
#> 5 5.274
#> 6 5.271
predict(xgboost_final_tree, head(ames))
#> # A tibble: 6 Γ— 1
#>   .pred
#>   <dbl>
#> 1 5.249
#> 2 5.104
#> 3 5.196
#> 4 5.349
#> 5 5.273
#> 6 5.269
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.1 (2022-06-23 ucrt)
#>  os       Windows 10 x64 (build 19044)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language en
#>  collate  German_Germany.utf8
#>  ctype    German_Germany.utf8
#>  tz       Europe/Berlin
#>  date     2022-08-17
#>  pandoc   2.18 @ C:/Program Files/RStudio/bin/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.2.0)
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  broom        * 1.0.0      2022-07-01 [1] CRAN (R 4.2.1)
#>  class          7.3-20     2022-01-16 [2] CRAN (R 4.2.1)
#>  cli            3.3.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  codetools      0.2-18     2020-11-04 [2] CRAN (R 4.2.1)
#>  colorspace     2.0-3      2022-02-21 [1] CRAN (R 4.2.0)
#>  crayon         1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>  data.table     1.14.2     2021-09-27 [1] CRAN (R 4.2.0)
#>  DBI            1.1.3      2022-06-18 [1] CRAN (R 4.2.0)
#>  dials        * 1.0.0      2022-06-14 [1] CRAN (R 4.2.0)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.2.0)
#>  digest         0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>  dplyr        * 1.0.9      2022-04-28 [1] CRAN (R 4.2.0)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>  evaluate       0.16       2022-08-09 [1] CRAN (R 4.2.1)
#>  fansi          1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>  fastmap        1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.2.0)
#>  fs             1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>  furrr          0.3.0      2022-05-04 [1] CRAN (R 4.2.0)
#>  future         1.27.0     2022-07-22 [1] CRAN (R 4.2.1)
#>  future.apply   1.9.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.2.1)
#>  ggplot2      * 3.3.6      2022-05-03 [1] CRAN (R 4.2.0)
#>  globals        0.16.0     2022-08-05 [1] CRAN (R 4.2.1)
#>  glue           1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  gower          1.0.0      2022-02-03 [1] CRAN (R 4.2.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.2.0)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.2.0)
#>  hardhat        1.2.0      2022-06-30 [1] CRAN (R 4.2.1)
#>  highr          0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>  htmltools      0.5.3      2022-07-18 [1] CRAN (R 4.2.1)
#>  infer        * 1.0.2      2022-06-10 [1] CRAN (R 4.2.0)
#>  ipred          0.9-13     2022-06-02 [1] CRAN (R 4.2.0)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.2.0)
#>  jsonlite       1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>  knitr          1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>  lattice        0.20-45    2021-09-22 [2] CRAN (R 4.2.1)
#>  lava           1.6.10     2021-09-02 [1] CRAN (R 4.2.0)
#>  lhs            1.1.5      2022-03-22 [1] CRAN (R 4.2.0)
#>  lifecycle      1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>  lubridate      1.8.0      2021-10-07 [1] CRAN (R 4.2.0)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  MASS           7.3-58.1   2022-08-03 [1] CRAN (R 4.2.1)
#>  Matrix         1.4-1      2022-03-23 [2] CRAN (R 4.2.1)
#>  modeldata    * 1.0.0      2022-07-01 [1] CRAN (R 4.2.1)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.2.0)
#>  nnet           7.3-17     2022-01-16 [2] CRAN (R 4.2.1)
#>  parallelly     1.32.1     2022-07-21 [1] CRAN (R 4.2.1)
#>  parsnip      * 1.0.0      2022-06-16 [1] CRAN (R 4.2.0)
#>  pillar         1.8.0      2022-07-18 [1] CRAN (R 4.2.1)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.2.0)
#>  purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>  R.cache        0.16.0     2022-07-21 [1] CRAN (R 4.2.1)
#>  R.methodsS3    1.8.2      2022-06-13 [1] CRAN (R 4.2.0)
#>  R.oo           1.25.0     2022-06-12 [1] CRAN (R 4.2.0)
#>  R.utils        2.12.0     2022-06-28 [1] CRAN (R 4.2.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp           1.0.9      2022-07-08 [1] CRAN (R 4.2.1)
#>  recipes      * 1.0.1      2022-07-07 [1] CRAN (R 4.2.1)
#>  reprex         2.0.1      2021-08-05 [1] CRAN (R 4.2.0)
#>  rlang          1.0.4      2022-07-12 [1] CRAN (R 4.2.1)
#>  rmarkdown      2.14       2022-04-25 [1] CRAN (R 4.2.0)
#>  rpart          4.1.16     2022-01-24 [2] CRAN (R 4.2.1)
#>  rsample      * 1.1.0      2022-08-08 [1] CRAN (R 4.2.1)
#>  rstudioapi     0.13       2020-11-12 [1] CRAN (R 4.2.0)
#>  scales       * 1.2.0      2022-04-13 [1] CRAN (R 4.2.0)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi        1.7.8      2022-07-11 [1] CRAN (R 4.2.1)
#>  stringr        1.4.0      2019-02-10 [1] CRAN (R 4.2.0)
#>  styler         1.7.0      2022-03-13 [1] CRAN (R 4.2.0)
#>  survival       3.4-0      2022-08-09 [1] CRAN (R 4.2.1)
#>  tibble       * 3.1.8      2022-07-22 [1] CRAN (R 4.2.1)
#>  tidymodels   * 1.0.0      2022-07-13 [1] CRAN (R 4.2.1)
#>  tidyr        * 1.2.0      2022-02-01 [1] CRAN (R 4.2.0)
#>  tidyselect     1.1.2      2022-02-21 [1] CRAN (R 4.2.0)
#>  timeDate       4021.104   2022-07-19 [1] CRAN (R 4.2.1)
#>  tune         * 1.0.0      2022-07-07 [1] CRAN (R 4.2.1)
#>  utf8           1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs          0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>  withr          2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>  workflows    * 1.0.0      2022-07-05 [1] CRAN (R 4.2.1)
#>  workflowsets * 1.0.0      2022-07-12 [1] CRAN (R 4.2.1)
#>  xfun           0.32       2022-08-10 [1] CRAN (R 4.2.1)
#>  xgboost      * 1.6.0.1    2022-04-16 [1] CRAN (R 4.2.0)
#>  yaml           2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#>  yardstick    * 1.0.0      2022-06-06 [1] CRAN (R 4.2.0)
#> 
#>  [1] C:/Users/Daniel.AK-HAMBURG/AppData/Local/R/win-library/4.2
#>  [2] C:/Program Files/R/R-4.2.1/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

The model fit will use random numbers. Try setting the seed to the same value before each method of producing the final fit.

1 Like

Thanks @Max. Apparently my mental model of set.seed() is wrong. I thought it sets the seed for the whole session. So, for reproducibility, I need to set.seed() before every command that uses random numbers?

Back to the initial question: I gather that both approaches to get the final fit are fine, esp. extract_workflow(last_fit(my_not_fitted_workflow, my_split)).

Your understanding about set.seed() is good. However, initial_split() uses random numbers so setting the seed after that won't control the randomness of the split (so everything downstream will be different).

1 Like

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.