Hi, I just want to confirm: will the importance weights be passed to the weight
option of xgb.train()
? I see no weight
in my result.
library(tidymodels)
library(hardhat)
boost_tree_spec <- function(engine = "xgboost", mode = "classification", ...) {
boost_tree(
trees = tune()
) %>%
set_mode(mode) %>%
set_engine(engine, ...)
}
xgboost_recipe <- function(data) {
data %>%
recipes::recipe(signal ~ ., data = data) %>%
update_role(symbol, new_role = "info") %>%
step_date(date, features = "dow", keep_original_cols = FALSE) %>%
step_dummy(all_nominal_predictors())
}
final_nthread <- tune_nthread <- 0
tune_tree_method <- "hist"
tune_max_bin <- 256
final_max_bin <- 512
final_tree_method <- "hist"
data <- tibble(
date = sample(
seq(as.Date("1999/01/01"), as.Date("2000/01/01"), by = "day"),
100
),
f1 = sample(0:1000, 100),
symbol = sample(c("s1", "s2", "s3", "s4"), 100, replace = TRUE),
importance = importance_weights(sample(0:100, 100)),
signal = sample(c("open", "close"), 100, replace = TRUE)
)
data %>%
glimpse()
#> Rows: 100
#> Columns: 5
#> $ date <date> 1999-01-25, 1999-03-06, 1999-06-24, 1999-10-23, 1999-08-19…
#> $ f1 <int> 314, 814, 983, 888, 620, 750, 913, 297, 439, 155, 460, 698,…
#> $ symbol <chr> "s1", "s4", "s3", "s2", "s4", "s3", "s2", "s2", "s2", "s2",…
#> $ importance <imp_wts> 50, 0, 89, 74, 98, 82, 65, 31, 75, 32, 100, 52, 35, 71,…
#> $ signal <chr> "open", "close", "open", "open", "open", "close", "close", …
split <- data %>%
initial_split(strata = signal)
tune_data <- split %>%
training() %>%
print()
#> # A tibble: 74 × 5
#> date f1 symbol importance signal
#> <date> <int> <chr> <imp_wts> <chr>
#> 1 1999-03-06 814 s4 0 close
#> 2 1999-08-02 750 s3 82 close
#> 3 1999-01-26 913 s2 65 close
#> 4 1999-11-02 297 s2 31 close
#> 5 1999-07-26 460 s1 100 close
#> 6 1999-01-23 338 s2 71 close
#> 7 1999-12-18 465 s3 39 close
#> 8 1999-02-17 294 s2 64 close
#> 9 1999-11-26 749 s4 7 close
#> 10 1999-09-24 333 s1 37 close
#> # … with 64 more rows
tune_result <- workflow() %>%
add_model(
boost_tree_spec(
nthread = tune_nthread,
tree_method = tune_tree_method,
max_bin = tune_max_bin
)
) %>%
add_recipe(xgboost_recipe(training(split))) %>%
add_case_weights(importance) %>%
tune_grid(
resamples = vfold_cv(tune_data, v = 2, strata = signal),
grid = 5,
control = control_grid(
verbose = TRUE
),
metrics = metric_set(roc_auc)
)
#> i Fold1: preprocessor 1/1
#> ✓ Fold1: preprocessor 1/1
#> i Fold1: preprocessor 1/1, model 1/1
#> ✓ Fold1: preprocessor 1/1, model 1/1
#> i Fold1: preprocessor 1/1, model 1/1 (predictions)
#> i Fold2: preprocessor 1/1
#> ✓ Fold2: preprocessor 1/1
#> i Fold2: preprocessor 1/1, model 1/1
#> ✓ Fold2: preprocessor 1/1, model 1/1
#> i Fold2: preprocessor 1/1, model 1/1 (predictions)
best_params <- tune_result %>%
select_best()
final_workflow <- workflow() %>%
add_model(
boost_tree_spec(
nthread = final_nthread,
tree_method = final_tree_method,
max_bin = final_max_bin
)
) %>%
add_recipe(xgboost_recipe(training(split))) %>%
add_case_weights(importance) %>%
finalize_workflow(best_params)
final_workflow %>%
last_fit(split) %>%
collect_metrics() %>%
print()
#> # A tibble: 2 × 4
#> .metric .estimator .estimate .config
#> <chr> <chr> <dbl> <chr>
#> 1 accuracy binary 0.5 Preprocessor1_Model1
#> 2 roc_auc binary 0.542 Preprocessor1_Model1
model <- final_workflow %>%
fit(training(split)) %>%
print()
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 2 Recipe Steps
#>
#> • step_date()
#> • step_dummy()
#>
#> ── Case Weights ────────────────────────────────────────────────────────────────
#> importance
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> ##### xgb.Booster
#> raw: 959.3 Kb
#> call:
#> xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0,
#> colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1,
#> subsample = 1, objective = "binary:logistic"), data = x$data,
#> nrounds = 1019L, watchlist = x$watchlist, verbose = 0, nthread = 0,
#> tree_method = "hist", max_bin = 512)
#> params (as set within xgb.train):
#> eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "binary:logistic", nthread = "0", tree_method = "hist", max_bin = "512", validate_parameters = "TRUE"
#> xgb.attributes:
#> niter
#> callbacks:
#> cb.evaluation.log()
#> # of features: 7
#> niter: 1019
#> nfeatures : 7
#> evaluation_log:
#> iter training_logloss
#> 1 0.564311269
#> 2 0.496895742
#> ---
#> 1018 0.001333734
#> 1019 0.001333456
Created on 2022-07-16 by the reprex package (v2.0.1)
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.2.1 (2022-06-23)
#> os macOS Monterey 12.4
#> system x86_64, darwin21.5.0
#> ui unknown
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Asia/Shanghai
#> date 2022-07-16
#> pandoc 2.18 @ /usr/local/bin/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.2.1)
#> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.1)
#> broom * 1.0.0 2022-07-01 [1] CRAN (R 4.2.1)
#> class 7.3-20 2022-01-16 [2] CRAN (R 4.2.1)
#> cli 3.3.0 2022-04-25 [1] CRAN (R 4.2.1)
#> codetools 0.2-18 2020-11-04 [2] CRAN (R 4.2.1)
#> colorspace 2.0-3 2022-02-21 [1] CRAN (R 4.2.1)
#> crayon 1.5.1 2022-03-26 [1] CRAN (R 4.2.1)
#> data.table 1.14.2 2021-09-27 [1] CRAN (R 4.2.1)
#> DBI 1.1.3 2022-06-18 [1] CRAN (R 4.2.1)
#> dials * 1.0.0 2022-06-14 [1] CRAN (R 4.2.1)
#> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.2.1)
#> digest 0.6.29 2021-12-01 [1] CRAN (R 4.2.1)
#> dplyr * 1.0.9 2022-04-28 [1] CRAN (R 4.2.1)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.1)
#> evaluate 0.15 2022-02-18 [1] CRAN (R 4.2.1)
#> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.1)
#> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.1)
#> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.2.1)
#> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.1)
#> furrr 0.3.0 2022-05-04 [1] CRAN (R 4.2.1)
#> future 1.26.1 2022-05-27 [1] CRAN (R 4.2.1)
#> future.apply 1.9.0 2022-04-25 [1] CRAN (R 4.2.1)
#> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.1)
#> ggplot2 * 3.3.6 2022-05-03 [1] CRAN (R 4.2.1)
#> globals 0.15.1 2022-06-24 [1] CRAN (R 4.2.1)
#> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.1)
#> gower 1.0.0 2022-02-03 [1] CRAN (R 4.2.1)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.2.1)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.2.1)
#> hardhat * 1.2.0 2022-06-30 [1] CRAN (R 4.2.1)
#> highr 0.9 2021-04-16 [1] CRAN (R 4.2.1)
#> htmltools 0.5.2 2021-08-25 [1] CRAN (R 4.2.1)
#> infer * 1.0.2 2022-06-10 [1] CRAN (R 4.2.1)
#> ipred 0.9-13 2022-06-02 [1] CRAN (R 4.2.1)
#> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.2.1)
#> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.2.1)
#> knitr 1.39 2022-04-26 [1] CRAN (R 4.2.1)
#> lattice 0.20-45 2021-09-22 [2] CRAN (R 4.2.1)
#> lava 1.6.10 2021-09-02 [1] CRAN (R 4.2.1)
#> lhs 1.1.5 2022-03-22 [1] CRAN (R 4.2.1)
#> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.2.1)
#> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.1)
#> lubridate 1.8.0 2021-10-07 [1] CRAN (R 4.2.1)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.1)
#> MASS 7.3-58 2022-07-14 [2] CRAN (R 4.2.1)
#> Matrix 1.4-1 2022-03-23 [2] CRAN (R 4.2.1)
#> modeldata * 1.0.0 2022-07-01 [1] CRAN (R 4.2.1)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.1)
#> nnet 7.3-17 2022-01-16 [2] CRAN (R 4.2.1)
#> parallelly 1.32.0 2022-06-07 [1] CRAN (R 4.2.1)
#> parsnip * 1.0.0 2022-06-16 [1] CRAN (R 4.2.1)
#> pillar 1.7.0 2022-02-01 [1] CRAN (R 4.2.1)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.1)
#> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.2.1)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.2.1)
#> R.cache 0.15.0 2021-04-30 [1] CRAN (R 4.2.1)
#> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.1)
#> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.1)
#> R.utils 2.12.0 2022-06-28 [1] CRAN (R 4.2.1)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.1)
#> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.1)
#> recipes * 1.0.1 2022-07-07 [1] CRAN (R 4.2.1)
#> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.2.1)
#> rlang 1.0.4 2022-07-12 [1] CRAN (R 4.2.1)
#> rmarkdown 2.14 2022-04-25 [1] CRAN (R 4.2.1)
#> rpart 4.1.16 2022-01-24 [2] CRAN (R 4.2.1)
#> rsample * 1.0.0 2022-06-24 [1] CRAN (R 4.2.1)
#> scales * 1.2.0 2022-04-13 [1] CRAN (R 4.2.1)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.1)
#> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.1)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.2.1)
#> styler 1.7.0 2022-03-13 [1] CRAN (R 4.2.1)
#> survival 3.3-1 2022-03-03 [2] CRAN (R 4.2.1)
#> tibble * 3.1.7 2022-05-03 [1] CRAN (R 4.2.1)
#> tidymodels * 1.0.0 2022-07-13 [1] CRAN (R 4.2.1)
#> tidyr * 1.2.0 2022-02-01 [1] CRAN (R 4.2.1)
#> tidyselect 1.1.2 2022-02-21 [1] CRAN (R 4.2.1)
#> timeDate 3043.102 2018-02-21 [1] CRAN (R 4.2.1)
#> tune * 1.0.0 2022-07-07 [1] CRAN (R 4.2.1)
#> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.1)
#> vctrs 0.4.1 2022-04-13 [1] CRAN (R 4.2.1)
#> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.1)
#> workflows * 1.0.0 2022-07-05 [1] CRAN (R 4.2.1)
#> workflowsets * 1.0.0 2022-07-12 [1] CRAN (R 4.2.1)
#> xfun 0.31 2022-05-10 [1] CRAN (R 4.2.1)
#> xgboost * 1.6.0.1 2022-04-16 [1] CRAN (R 4.2.1)
#> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.1)
#> yardstick * 1.0.0 2022-06-06 [1] CRAN (R 4.2.1)
#>
#> [1] /usr/local/lib/R/4.2/site-library
#> [2] /usr/local/Cellar/r/4.2.1/lib/R/library
#>
#> ──────────────────────────────────────────────────────────────────────────────