Hi!
I'm having trouble with resampling and collecting metrics (especially ROC-AUC) from a glm
model. When using yes
as the positive and first level for the outcome of a classification problem, the coefficients have the wrong sign (I think because glm
uses the first level as reference level), but the resampled and manually calculated metrics are the same and correct (see below).
When using no
as the first level and setting yardstick.event_first = FALSE
, the coefficients are correct, but the resampled metrics are wrong. The predictions and manually calculated metrics are correct, though.
Does anybody have a suggestion if I'm doing something wrong?
Regards, Pascal
Reproducible example
library(tidyverse)
library(tidymodels)
set.seed(123)
data("titanic_imputed", package = "DALEX")
options(yardstick.event_first = TRUE)
multi_metric <- metric_set(accuracy, sens, spec, roc_auc)
titanic_data_yesfirst <-
titanic_imputed %>%
as_tibble() %>%
mutate(survived = factor(survived, levels = c("1", "0"), labels = c("yes", "no")))
training_folds_yesfirst <-
vfold_cv(titanic_data_yesfirst, v = 10, repeats = 2, strata = survived)
model.glm <-
logistic_reg() %>%
set_engine("glm")
workflow.glm <-
workflow() %>%
add_formula(survived ~ .) %>%
add_model(model.glm)
fit.glm_yesfirst <-
workflow.glm %>%
fit(data = titanic_data_yesfirst)
resamples.glm_yesfirst <-
workflow.glm %>%
fit_resamples(
training_folds_yesfirst,
metrics = multi_metric,
control = control_resamples(
save_pred = TRUE
)
)
#> ! Fold01, Repeat1: internal: The `yardstick.event_first` option has been deprecated as of y...
summary(fit.glm_yesfirst %>%
pull_workflow_fit() %>% .$fit)
#>
#> Call:
#> stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
#>
#> Deviance Residuals:
#> Min 1Q Median 3Q Max
#> -2.6043 -0.6020 0.4934 0.7003 2.5663
#>
#> Coefficients:
#> Estimate Std. Error z value Pr(>|z|)
#> (Intercept) -3.3817176 0.4165414 -8.119 4.72e-16 ***
#> gendermale 2.7196959 0.1560095 17.433 < 2e-16 ***
#> age 0.0360196 0.0053290 6.759 1.39e-11 ***
#> class2nd 1.1424261 0.2475800 4.614 3.94e-06 ***
#> class3rd 2.1003085 0.2471237 8.499 < 2e-16 ***
#> `classdeck crew` -1.0803128 0.3466852 -3.116 0.001832 **
#> `classengineering crew` 1.0041907 0.2602157 3.859 0.000114 ***
#> `classrestaurant staff` 3.3041097 0.6540119 5.052 4.37e-07 ***
#> `classvictualling crew` 1.1217197 0.2564026 4.375 1.22e-05 ***
#> embarkedCherbourg -0.7606857 0.2822736 -2.695 0.007042 **
#> embarkedQueenstown -0.1714282 0.3391971 -0.505 0.613282
#> embarkedSouthampton -0.1972937 0.2115860 -0.932 0.351103
#> fare -0.0008322 0.0019043 -0.437 0.662085
#> sibsp 0.3603939 0.0918656 3.923 8.74e-05 ***
#> parch 0.0086057 0.0920058 0.094 0.925479
#> ---
#> Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#>
#> (Dispersion parameter for binomial family taken to be 1)
#>
#> Null deviance: 2774.1 on 2206 degrees of freedom
#> Residual deviance: 2039.3 on 2192 degrees of freedom
#> AIC: 2069.3
#>
#> Number of Fisher Scoring iterations: 5
collect_metrics(resamples.glm_yesfirst)
#> # A tibble: 4 x 5
#> .metric .estimator mean n std_err
#> <chr> <chr> <dbl> <int> <dbl>
#> 1 accuracy binary 0.799 20 0.00478
#> 2 roc_auc binary 0.807 20 0.00687
#> 3 sens binary 0.570 20 0.0112
#> 4 spec binary 0.908 20 0.00517
collect_predictions(resamples.glm_yesfirst) %>%
group_by(id, id2) %>%
multi_metric(truth = survived, .pred_yes, estimate = .pred_class, event_level = "first") %>%
group_by(.metric) %>%
summarise(.estimate = mean(.estimate))
#> `summarise()` ungrouping output (override with `.groups` argument)
#> # A tibble: 4 x 2
#> .metric .estimate
#> <chr> <dbl>
#> 1 accuracy 0.799
#> 2 roc_auc 0.807
#> 3 sens 0.570
#> 4 spec 0.908
options(yardstick.event_first = FALSE)
titanic_data_nofirst <-
titanic_imputed %>%
as_tibble() %>%
mutate(survived = factor(survived, levels = c("0", "1"), labels = c("no", "yes")))
training_folds_nofirst <-
vfold_cv(titanic_data_nofirst, v = 10, repeats = 2, strata = survived)
fit.glm_nofirst <-
workflow.glm %>%
fit(data = titanic_data_nofirst)
resamples.glm_nofirst <-
workflow.glm %>%
fit_resamples(
training_folds_nofirst,
metrics = multi_metric,
control = control_resamples(
save_pred = TRUE
)
)
summary(fit.glm_nofirst %>%
pull_workflow_fit() %>% .$fit)
#>
#> Call:
#> stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
#>
#> Deviance Residuals:
#> Min 1Q Median 3Q Max
#> -2.5663 -0.7003 -0.4934 0.6020 2.6043
#>
#> Coefficients:
#> Estimate Std. Error z value Pr(>|z|)
#> (Intercept) 3.3817176 0.4165414 8.119 4.72e-16 ***
#> gendermale -2.7196959 0.1560095 -17.433 < 2e-16 ***
#> age -0.0360196 0.0053290 -6.759 1.39e-11 ***
#> class2nd -1.1424261 0.2475800 -4.614 3.94e-06 ***
#> class3rd -2.1003085 0.2471237 -8.499 < 2e-16 ***
#> `classdeck crew` 1.0803128 0.3466852 3.116 0.001832 **
#> `classengineering crew` -1.0041907 0.2602157 -3.859 0.000114 ***
#> `classrestaurant staff` -3.3041097 0.6540119 -5.052 4.37e-07 ***
#> `classvictualling crew` -1.1217197 0.2564026 -4.375 1.22e-05 ***
#> embarkedCherbourg 0.7606857 0.2822736 2.695 0.007042 **
#> embarkedQueenstown 0.1714282 0.3391971 0.505 0.613282
#> embarkedSouthampton 0.1972937 0.2115860 0.932 0.351103
#> fare 0.0008322 0.0019043 0.437 0.662085
#> sibsp -0.3603939 0.0918656 -3.923 8.74e-05 ***
#> parch -0.0086057 0.0920058 -0.094 0.925479
#> ---
#> Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#>
#> (Dispersion parameter for binomial family taken to be 1)
#>
#> Null deviance: 2774.1 on 2206 degrees of freedom
#> Residual deviance: 2039.3 on 2192 degrees of freedom
#> AIC: 2069.3
#>
#> Number of Fisher Scoring iterations: 5
collect_metrics(resamples.glm_nofirst)
#> # A tibble: 4 x 5
#> .metric .estimator mean n std_err
#> <chr> <chr> <dbl> <int> <dbl>
#> 1 accuracy binary 0.798 20 0.00569
#> 2 roc_auc binary 0.194 20 0.00798
#> 3 sens binary 0.570 20 0.0134
#> 4 spec binary 0.907 20 0.00594
collect_predictions(resamples.glm_nofirst) %>%
group_by(id, id2) %>%
multi_metric(truth = survived, .pred_yes, estimate = .pred_class, event_level = "second") %>%
group_by(.metric) %>%
summarise(.estimate = mean(.estimate))
#> `summarise()` ungrouping output (override with `.groups` argument)
#> # A tibble: 4 x 2
#> .metric .estimate
#> <chr> <dbl>
#> 1 accuracy 0.798
#> 2 roc_auc 0.806
#> 3 sens 0.570
#> 4 spec 0.907
Created on 2020-10-05 by the reprex package (v0.3.0)
Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 3.6.1 (2019-07-05)
#> os macOS Sierra 10.12.6
#> system x86_64, darwin15.6.0
#> ui X11
#> language (EN)
#> collate de_CH.UTF-8
#> ctype de_CH.UTF-8
#> tz Europe/Zurich
#> date 2020-10-05
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 3.6.0)
#> backports 1.1.8 2020-06-17 [1] CRAN (R 3.6.2)
#> blob 1.2.1 2020-01-20 [1] CRAN (R 3.6.0)
#> broom * 0.7.0 2020-07-09 [1] CRAN (R 3.6.2)
#> callr 3.4.3 2020-03-28 [1] CRAN (R 3.6.2)
#> cellranger 1.1.0 2016-07-27 [1] CRAN (R 3.6.0)
#> class 7.3-17 2020-04-26 [1] CRAN (R 3.6.2)
#> cli 2.0.2 2020-02-28 [1] CRAN (R 3.6.0)
#> codetools 0.2-16 2018-12-24 [1] CRAN (R 3.6.1)
#> colorspace 1.4-1 2019-03-18 [1] CRAN (R 3.6.0)
#> crayon 1.3.4 2017-09-16 [1] CRAN (R 3.6.0)
#> DBI 1.1.0 2019-12-15 [1] CRAN (R 3.6.0)
#> dbplyr 1.4.4 2020-05-27 [1] CRAN (R 3.6.2)
#> desc 1.2.0 2018-05-01 [1] CRAN (R 3.6.0)
#> devtools 2.3.0 2020-04-10 [1] CRAN (R 3.6.2)
#> dials * 0.0.9 2020-09-16 [1] CRAN (R 3.6.2)
#> DiceDesign 1.8-1 2019-07-31 [1] CRAN (R 3.6.0)
#> digest 0.6.25 2020-02-23 [1] CRAN (R 3.6.0)
#> dplyr * 1.0.0 2020-05-29 [1] CRAN (R 3.6.2)
#> ellipsis 0.3.1 2020-05-15 [1] CRAN (R 3.6.2)
#> evaluate 0.14 2019-05-28 [1] CRAN (R 3.6.0)
#> fansi 0.4.1 2020-01-08 [1] CRAN (R 3.6.0)
#> forcats * 0.5.0 2020-03-01 [1] CRAN (R 3.6.0)
#> foreach 1.5.0 2020-03-30 [1] CRAN (R 3.6.2)
#> fs 1.4.2 2020-06-30 [1] CRAN (R 3.6.2)
#> furrr 0.1.0 2018-05-16 [1] CRAN (R 3.6.0)
#> future 1.17.0 2020-04-18 [1] CRAN (R 3.6.2)
#> generics 0.0.2 2018-11-29 [1] CRAN (R 3.6.0)
#> ggplot2 * 3.3.2 2020-06-19 [1] CRAN (R 3.6.2)
#> globals 0.12.5 2019-12-07 [1] CRAN (R 3.6.0)
#> glue 1.4.1 2020-05-13 [1] CRAN (R 3.6.2)
#> gower 0.2.1 2019-05-14 [1] CRAN (R 3.6.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 3.6.0)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 3.6.0)
#> hardhat 0.1.4 2020-07-02 [1] CRAN (R 3.6.2)
#> haven 2.3.1 2020-06-01 [1] CRAN (R 3.6.2)
#> highr 0.8 2019-03-20 [1] CRAN (R 3.6.0)
#> hms 0.5.3 2020-01-08 [1] CRAN (R 3.6.0)
#> htmltools 0.5.0 2020-06-16 [1] CRAN (R 3.6.2)
#> httr 1.4.2 2020-07-20 [1] CRAN (R 3.6.2)
#> infer * 0.5.3 2020-07-14 [1] CRAN (R 3.6.2)
#> ipred 0.9-9 2019-04-28 [1] CRAN (R 3.6.0)
#> iterators 1.0.12 2019-07-26 [1] CRAN (R 3.6.0)
#> jsonlite 1.7.0 2020-06-25 [1] CRAN (R 3.6.2)
#> knitr 1.28 2020-02-06 [1] CRAN (R 3.6.0)
#> lattice 0.20-41 2020-04-02 [1] CRAN (R 3.6.2)
#> lava 1.6.7 2020-03-05 [1] CRAN (R 3.6.0)
#> lhs 1.0.2 2020-04-13 [1] CRAN (R 3.6.2)
#> lifecycle 0.2.0 2020-03-06 [1] CRAN (R 3.6.0)
#> listenv 0.8.0 2019-12-05 [1] CRAN (R 3.6.0)
#> lubridate 1.7.9 2020-06-08 [1] CRAN (R 3.6.2)
#> magrittr 1.5 2014-11-22 [1] CRAN (R 3.6.0)
#> MASS 7.3-51.6 2020-04-26 [1] CRAN (R 3.6.2)
#> Matrix 1.2-18 2019-11-27 [1] CRAN (R 3.6.0)
#> memoise 1.1.0 2017-04-21 [1] CRAN (R 3.6.0)
#> modeldata * 0.0.2 2020-06-22 [1] CRAN (R 3.6.2)
#> modelr 0.1.8 2020-05-19 [1] CRAN (R 3.6.2)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 3.6.0)
#> nnet 7.3-14 2020-04-26 [1] CRAN (R 3.6.2)
#> parsnip * 0.1.3 2020-08-04 [1] CRAN (R 3.6.2)
#> pillar 1.4.6 2020-07-10 [1] CRAN (R 3.6.2)
#> pkgbuild 1.0.8 2020-05-07 [1] CRAN (R 3.6.2)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 3.6.0)
#> pkgload 1.1.0 2020-05-29 [1] CRAN (R 3.6.2)
#> plyr 1.8.6 2020-03-03 [1] CRAN (R 3.6.0)
#> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 3.6.0)
#> pROC 1.16.2 2020-03-19 [1] CRAN (R 3.6.0)
#> processx 3.4.2 2020-02-09 [1] CRAN (R 3.6.0)
#> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 3.6.0)
#> ps 1.3.3 2020-05-08 [1] CRAN (R 3.6.2)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 3.6.2)
#> R6 2.4.1 2019-11-12 [1] CRAN (R 3.6.0)
#> Rcpp 1.0.4.6 2020-04-09 [1] CRAN (R 3.6.1)
#> readr * 1.3.1 2018-12-21 [1] CRAN (R 3.6.0)
#> readxl 1.3.1 2019-03-13 [1] CRAN (R 3.6.0)
#> recipes * 0.1.13 2020-06-23 [1] CRAN (R 3.6.2)
#> remotes 2.1.1 2020-02-15 [1] CRAN (R 3.6.0)
#> reprex 0.3.0 2019-05-16 [1] CRAN (R 3.6.0)
#> rlang 0.4.7 2020-07-09 [1] CRAN (R 3.6.2)
#> rmarkdown 2.3 2020-06-18 [1] CRAN (R 3.6.2)
#> rpart 4.1-15 2019-04-12 [1] CRAN (R 3.6.1)
#> rprojroot 1.3-2 2018-01-03 [1] CRAN (R 3.6.0)
#> rsample * 0.0.8 2020-09-23 [1] CRAN (R 3.6.2)
#> rstudioapi 0.11 2020-02-07 [1] CRAN (R 3.6.0)
#> rvest 0.3.6 2020-07-25 [1] CRAN (R 3.6.2)
#> scales * 1.1.1 2020-05-11 [1] CRAN (R 3.6.2)
#> sessioninfo 1.1.1 2018-11-05 [1] CRAN (R 3.6.0)
#> stringi 1.4.6 2020-02-17 [1] CRAN (R 3.6.0)
#> stringr * 1.4.0 2019-02-10 [1] CRAN (R 3.6.0)
#> survival 3.2-3 2020-06-13 [1] CRAN (R 3.6.2)
#> testthat 2.3.2 2020-03-02 [1] CRAN (R 3.6.0)
#> tibble * 3.0.3 2020-07-10 [1] CRAN (R 3.6.2)
#> tidymodels * 0.1.1 2020-07-14 [1] CRAN (R 3.6.2)
#> tidyr * 1.1.0 2020-05-20 [1] CRAN (R 3.6.2)
#> tidyselect 1.1.0 2020-05-11 [1] CRAN (R 3.6.2)
#> tidyverse * 1.3.0 2019-11-21 [1] CRAN (R 3.6.0)
#> timeDate 3043.102 2018-02-21 [1] CRAN (R 3.6.0)
#> tune * 0.1.1 2020-07-08 [1] CRAN (R 3.6.2)
#> usethis 1.6.1 2020-04-29 [1] CRAN (R 3.6.2)
#> utf8 1.1.4 2018-05-24 [1] CRAN (R 3.6.0)
#> vctrs 0.3.1 2020-06-05 [1] CRAN (R 3.6.2)
#> withr 2.2.0 2020-04-20 [1] CRAN (R 3.6.2)
#> workflows * 0.2.0 2020-09-15 [1] CRAN (R 3.6.2)
#> xfun 0.14 2020-05-20 [1] CRAN (R 3.6.2)
#> xml2 1.3.2 2020-04-23 [1] CRAN (R 3.6.2)
#> yaml 2.2.1 2020-02-01 [1] CRAN (R 3.6.0)
#> yardstick * 0.0.7 2020-07-13 [1] CRAN (R 3.6.2)
#>
#> [1] /Library/Frameworks/R.framework/Versions/3.6/Resources/library