I am creating a regression model for sale prices from the ames dataset, utilizing a xgboost and a grid search for parameter tuning.
After selecting the best model parameters, I can use two similar approaches to create a finalized, fitted workflow for predictions on new data, AFAIU
- fit the model on the whole training set (as described in Chapter 13.3 of TMWR)
- Extract the fitted workflow from a
last_fit()
object, as described in the Tidymodels tutorial under βTune Model Parameters > The Last Fitβ. (I slightly prefer this, since I am using the last_fit object anyway for plotting, metrics, vip, β¦)
However, I am seeing slight differences between the two approaches. Where do they come from which approach is preferable/more accurate?
Here is an example:
library(tidymodels, warn.conflicts = FALSE)
data(ames)
ames <-
ames |>
select(
Sale_Price,
Neighborhood,
Gr_Liv_Area,
Year_Built,
Bldg_Type,
Latitude,
Longitude
) |>
mutate(Sale_Price = log10(Sale_Price))
ames_split <- initial_split(ames, prop = 0.80)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)
# recipe, spec and workflow from usemodels::use_xgboost()
xgboost_recipe <-
recipe(formula = Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type +
Latitude + Longitude, data = ames_train) %>%
step_novel(all_nominal_predictors()) %>%
step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
step_zv(all_predictors())
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(),
loss_reduction = tune(), sample_size = tune()) %>%
set_mode("regression") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_recipe) %>%
add_model(xgboost_spec)
xgboost_control <-
control_grid(
save_pred = TRUE,
save_workflow = TRUE
)
xgboost_folds <- vfold_cv(ames_train)
set.seed(46914)
xgboost_tune <-
tune_grid(xgboost_workflow, resamples = xgboost_folds, control = xgboost_control)
#> ! Fold01: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold02: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold03: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold04: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold05: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold06: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold07: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold08: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold09: internal: A correlation computation is required, but `estimate` is constant and ha...
#> ! Fold10: internal: A correlation computation is required, but `estimate` is constant and ha...
show_best(xgboost_tune, metric = "rmse")
#> # A tibble: 5 Γ 12
#> trees min_n tree_depth learn_rate loss_β¦ΒΉ samplβ¦Β² .metric .estiβ¦Β³ mean n
#> <int> <int> <int> <dbl> <dbl> <dbl> <chr> <chr> <dbl> <int>
#> 1 1007 13 13 0.00823 6.52e-4 0.805 rmse standa⦠0.0722 10
#> 2 1574 37 6 0.0585 1.38e-9 0.825 rmse standa⦠0.0754 10
#> 3 897 24 3 0.0384 1.84e-2 0.382 rmse standa⦠0.0760 10
#> 4 1375 28 5 0.0180 1.33e-6 0.133 rmse standa⦠0.0790 10
#> 5 458 3 14 0.229 3.20e-5 0.536 rmse standa⦠0.0801 10
#> # β¦ with 2 more variables: std_err <dbl>, .config <chr>, and abbreviated
#> # variable names ΒΉβloss_reduction, Β²βsample_size, Β³β.estimator
#> # βΉ Use `colnames()` to see all variable names
xgboost_params <- select_best(xgboost_tune, metric = "rmse")
xgboost_final_workflow <-
xgboost_workflow %>%
finalize_workflow(xgboost_params)
Now we need to fit the final workflow.
Either by fitting the final workflow on the whole training set.
xgboost_direct_fit <- fit(xgboost_final_workflow, ames_train)
Or we can run last_fit()
and extract_workflow()
xgboost_final_fit <-
xgboost_final_workflow %>%
last_fit(ames_split)
xgboost_final_tree <- extract_workflow(xgboost_final_fit)
However, the resulting objects are not the same
all.equal(xgboost_direct_fit, xgboost_final_tree)
#> [1] "Component \"fit\": Component \"fit\": Component \"fit\": Component \"raw\": Lengths (2955088, 2997136) differ (comparison on first 2955088 components)"
#> [2] "Component \"fit\": Component \"fit\": Component \"fit\": Component \"raw\": 2314610 element mismatches"
#> [3] "Component \"fit\": Component \"fit\": Component \"fit\": Component \"evaluation_log\": Column 'training_rmse': Mean relative difference: 0.0002653307"
and predict slightly different values (only the first)
options(pillar.sigfig = 4)
predict(xgboost_direct_fit, head(ames))
#> # A tibble: 6 Γ 1
#> .pred
#> <dbl>
#> 1 5.251
#> 2 5.104
#> 3 5.202
#> 4 5.350
#> 5 5.274
#> 6 5.271
predict(xgboost_final_tree, head(ames))
#> # A tibble: 6 Γ 1
#> .pred
#> <dbl>
#> 1 5.249
#> 2 5.104
#> 3 5.196
#> 4 5.349
#> 5 5.273
#> 6 5.269
Session info
sessioninfo::session_info()
#> β Session info βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#> setting value
#> version R version 4.2.1 (2022-06-23 ucrt)
#> os Windows 10 x64 (build 19044)
#> system x86_64, mingw32
#> ui RTerm
#> language en
#> collate German_Germany.utf8
#> ctype German_Germany.utf8
#> tz Europe/Berlin
#> date 2022-08-17
#> pandoc 2.18 @ C:/Program Files/RStudio/bin/quarto/bin/tools/ (via rmarkdown)
#>
#> β Packages βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#> package * version date (UTC) lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.2.0)
#> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0)
#> broom * 1.0.0 2022-07-01 [1] CRAN (R 4.2.1)
#> class 7.3-20 2022-01-16 [2] CRAN (R 4.2.1)
#> cli 3.3.0 2022-04-25 [1] CRAN (R 4.2.0)
#> codetools 0.2-18 2020-11-04 [2] CRAN (R 4.2.1)
#> colorspace 2.0-3 2022-02-21 [1] CRAN (R 4.2.0)
#> crayon 1.5.1 2022-03-26 [1] CRAN (R 4.2.0)
#> data.table 1.14.2 2021-09-27 [1] CRAN (R 4.2.0)
#> DBI 1.1.3 2022-06-18 [1] CRAN (R 4.2.0)
#> dials * 1.0.0 2022-06-14 [1] CRAN (R 4.2.0)
#> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.2.0)
#> digest 0.6.29 2021-12-01 [1] CRAN (R 4.2.0)
#> dplyr * 1.0.9 2022-04-28 [1] CRAN (R 4.2.0)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0)
#> evaluate 0.16 2022-08-09 [1] CRAN (R 4.2.1)
#> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0)
#> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0)
#> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.2.0)
#> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0)
#> furrr 0.3.0 2022-05-04 [1] CRAN (R 4.2.0)
#> future 1.27.0 2022-07-22 [1] CRAN (R 4.2.1)
#> future.apply 1.9.0 2022-04-25 [1] CRAN (R 4.2.0)
#> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.1)
#> ggplot2 * 3.3.6 2022-05-03 [1] CRAN (R 4.2.0)
#> globals 0.16.0 2022-08-05 [1] CRAN (R 4.2.1)
#> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0)
#> gower 1.0.0 2022-02-03 [1] CRAN (R 4.2.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.2.0)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.2.0)
#> hardhat 1.2.0 2022-06-30 [1] CRAN (R 4.2.1)
#> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0)
#> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.1)
#> infer * 1.0.2 2022-06-10 [1] CRAN (R 4.2.0)
#> ipred 0.9-13 2022-06-02 [1] CRAN (R 4.2.0)
#> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.2.0)
#> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.2.0)
#> knitr 1.39 2022-04-26 [1] CRAN (R 4.2.0)
#> lattice 0.20-45 2021-09-22 [2] CRAN (R 4.2.1)
#> lava 1.6.10 2021-09-02 [1] CRAN (R 4.2.0)
#> lhs 1.1.5 2022-03-22 [1] CRAN (R 4.2.0)
#> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.2.0)
#> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0)
#> lubridate 1.8.0 2021-10-07 [1] CRAN (R 4.2.0)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0)
#> MASS 7.3-58.1 2022-08-03 [1] CRAN (R 4.2.1)
#> Matrix 1.4-1 2022-03-23 [2] CRAN (R 4.2.1)
#> modeldata * 1.0.0 2022-07-01 [1] CRAN (R 4.2.1)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.0)
#> nnet 7.3-17 2022-01-16 [2] CRAN (R 4.2.1)
#> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.1)
#> parsnip * 1.0.0 2022-06-16 [1] CRAN (R 4.2.0)
#> pillar 1.8.0 2022-07-18 [1] CRAN (R 4.2.1)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0)
#> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.2.0)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.2.0)
#> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.1)
#> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0)
#> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0)
#> R.utils 2.12.0 2022-06-28 [1] CRAN (R 4.2.1)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0)
#> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.1)
#> recipes * 1.0.1 2022-07-07 [1] CRAN (R 4.2.1)
#> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.2.0)
#> rlang 1.0.4 2022-07-12 [1] CRAN (R 4.2.1)
#> rmarkdown 2.14 2022-04-25 [1] CRAN (R 4.2.0)
#> rpart 4.1.16 2022-01-24 [2] CRAN (R 4.2.1)
#> rsample * 1.1.0 2022-08-08 [1] CRAN (R 4.2.1)
#> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.2.0)
#> scales * 1.2.0 2022-04-13 [1] CRAN (R 4.2.0)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0)
#> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.1)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.2.0)
#> styler 1.7.0 2022-03-13 [1] CRAN (R 4.2.0)
#> survival 3.4-0 2022-08-09 [1] CRAN (R 4.2.1)
#> tibble * 3.1.8 2022-07-22 [1] CRAN (R 4.2.1)
#> tidymodels * 1.0.0 2022-07-13 [1] CRAN (R 4.2.1)
#> tidyr * 1.2.0 2022-02-01 [1] CRAN (R 4.2.0)
#> tidyselect 1.1.2 2022-02-21 [1] CRAN (R 4.2.0)
#> timeDate 4021.104 2022-07-19 [1] CRAN (R 4.2.1)
#> tune * 1.0.0 2022-07-07 [1] CRAN (R 4.2.1)
#> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0)
#> vctrs 0.4.1 2022-04-13 [1] CRAN (R 4.2.0)
#> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0)
#> workflows * 1.0.0 2022-07-05 [1] CRAN (R 4.2.1)
#> workflowsets * 1.0.0 2022-07-12 [1] CRAN (R 4.2.1)
#> xfun 0.32 2022-08-10 [1] CRAN (R 4.2.1)
#> xgboost * 1.6.0.1 2022-04-16 [1] CRAN (R 4.2.0)
#> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.0)
#> yardstick * 1.0.0 2022-06-06 [1] CRAN (R 4.2.0)
#>
#> [1] C:/Users/Daniel.AK-HAMBURG/AppData/Local/R/win-library/4.2
#> [2] C:/Program Files/R/R-4.2.1/library
#>
#> ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ