What I'd like to do
I am trying to build a model in tidymodels
that will predict the efficacy of drugs on cell lines (like bacteria). The model will rank drugs by efficacy for a given cell line, so I want to use Spearman's correlation (ρ) as a metric. In the following example data set, each cell line (column Sample
) is represented by a letter, Q, R, S, ..., Z
, and each sample was treated with 50 drugs.
When I split the data for cross-validation, the training/test splits for each fold will have >1 cell line (e.g. Q, R
in the test split for fold 1), but in calculating the metric (ρ), I want to calculate it for each cell line individually and then take the average across all the cell lines in the test split, rather than for all the observations in aggregate. For example, if the test split for fold 1 consists of Q, R
, then I want to calculate ρ for the 50 drugs tested against Q
, then a separate ρ for the 50 drugs tested against R
, average these two ρ, and have that average be the metric calculated for fold 1.
What I've tried
I was thinking that I'd have to calculate the metric on a tibble/data.frame grouped by the Sample
column, but I can't figure out how to pass that variable into tune_grid()
. I don't think I can include the variable in add_formula()
when creating the workflow object, since I don't want it as a predictor variable. I just discovered tidymodels yesterday, so maybe there's a straightforward solution I'm unaware of, but I haven't been able to find anything on Google so far. The code below is what I've tried, but obviously it doesn't work. Thank you in advance for any advice you can give.
Error
i Resample1: preprocessor 1/1
✓ Resample1: preprocessor 1/1
i Resample1: preprocessor 1/1, model 1/20
✓ Resample1: preprocessor 1/1, model 1/20
i Resample1: preprocessor 1/1, model 1/20 (predictions)
x Resample1: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample2: preprocessor 1/1
✓ Resample2: preprocessor 1/1
i Resample2: preprocessor 1/1, model 1/20
✓ Resample2: preprocessor 1/1, model 1/20
i Resample2: preprocessor 1/1, model 1/20 (predictions)
x Resample2: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample3: preprocessor 1/1
✓ Resample3: preprocessor 1/1
i Resample3: preprocessor 1/1, model 1/20
✓ Resample3: preprocessor 1/1, model 1/20
i Resample3: preprocessor 1/1, model 1/20 (predictions)
x Resample3: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample4: preprocessor 1/1
✓ Resample4: preprocessor 1/1
i Resample4: preprocessor 1/1, model 1/20
✓ Resample4: preprocessor 1/1, model 1/20
i Resample4: preprocessor 1/1, model 1/20 (predictions)
x Resample4: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample5: preprocessor 1/1
✓ Resample5: preprocessor 1/1
i Resample5: preprocessor 1/1, model 1/20
✓ Resample5: preprocessor 1/1, model 1/20
i Resample5: preprocessor 1/1, model 1/20 (predictions)
x Resample5: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
Warning message:
All models failed. See the `.notes` column.
Upon running glmnet_tuning_results
:
Warning message:
This tuning result has notes. Example notes on model fitting include:
internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm = ~na_rm)
internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm = ~na_rm)
internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm = ~na_rm)
Code
Example data set
data = tibble(
Sample = rep(LETTERS[17:26], each = 50),
TargetVariable = rnorm(500, mean = 0, sd = 1),
PredictorVariable1 = rnorm(500, mean = 5, sd = 1),
PredictorVariable2 = rpois(500, lambda = 5)
)
Model
# Splitting for cross-validation.
set.seed(1026)
folds = group_vfold_cv(data, group = Sample, v = 5)
# Model specification.
glmnet_model = linear_reg(
mode = "regression",
penalty = tune(),
mixture = tune()
) %>%
set_engine("glmnet")
# Workflow.
glmnet_wf = workflow() %>%
add_model(glmnet_model) %>%
add_formula(TargetVariable ~ . - Sample)
# Grid specification.
glmnet_params = parameters(penalty(), mixture())
set.seed(1026)
glmnet_grid = grid_max_entropy(glmnet_params, size = 20)
# Hyperparameter tuning.
glmnet_tuning_results = tune_grid(
glmnet_wf,
resamples = folds,
grid = glmnet_grid,
metrics = metric_set(spearman_cor),
control = control_grid(verbose = TRUE)
)
glmnet_tuning_results %>% show_best(n = 10)
Custom metric
# Vector version.
spearman_cor_vec = function(truth, estimate, na_rm = TRUE) {
spearman_cor_impl = function(truth, estimate) {
cor(truth, estimate, method = "spearman")
}
metric_vec_template(
metric_impl = spearman_cor_impl,
truth = truth,
estimate = estimate,
na_rm = na_rm,
cls = "numeric"
)
}
# Data frame version.
spearman_cor = function(data) {
UseMethod("spearman_cor")
}
spearman_cor = new_numeric_metric(spearman_cor, direction = "maximize")
spearman_cor.data.frame = function(data, truth, estimate, na_rm = TRUE) {
data_grouped = data %>%
group_by(Sample)
metric_summarizer(
metric_nm = "spearman_cor",
metric_fn = spearman_cor_vec,
data = data_grouped,
truth = !! enquo(truth),
estimate = !! enquo(estimate),
na_rm = na_rm
)
}
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 3.6.3 (2020-02-29)
#> os macOS Catalina 10.15.7
#> system x86_64, darwin15.6.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz America/Chicago
#> date 2021-08-25
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib source
#> backports 1.1.6 2020-04-05 [1] CRAN (R 3.6.2)
#> cli 3.0.1 2021-07-17 [1] CRAN (R 3.6.2)
#> crayon 1.3.4 2017-09-16 [1] CRAN (R 3.6.0)
#> digest 0.6.25 2020-02-23 [1] CRAN (R 3.6.0)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 3.6.2)
#> evaluate 0.14 2019-05-28 [1] CRAN (R 3.6.0)
#> fansi 0.4.1 2020-01-08 [1] CRAN (R 3.6.0)
#> fs 1.3.1 2019-05-06 [1] CRAN (R 3.6.0)
#> glue 1.4.0 2020-04-03 [1] CRAN (R 3.6.2)
#> highr 0.8 2019-03-20 [1] CRAN (R 3.6.0)
#> htmltools 0.5.1.1 2021-01-22 [1] CRAN (R 3.6.2)
#> knitr 1.27 2020-01-16 [1] CRAN (R 3.6.0)
#> lifecycle 1.0.0 2021-02-15 [1] CRAN (R 3.6.2)
#> magrittr 2.0.1 2020-11-17 [1] CRAN (R 3.6.2)
#> pillar 1.6.2 2021-07-29 [1] CRAN (R 3.6.2)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 3.6.0)
#> purrr 0.3.4 2020-04-17 [1] CRAN (R 3.6.2)
#> Rcpp 1.0.4.6 2020-04-09 [1] CRAN (R 3.6.1)
#> reprex 2.0.1 2021-08-05 [1] CRAN (R 3.6.2)
#> rlang 0.4.10 2020-12-30 [1] CRAN (R 3.6.2)
#> rmarkdown 2.1 2020-01-20 [1] CRAN (R 3.6.0)
#> rstudioapi 0.13 2020-11-12 [1] CRAN (R 3.6.2)
#> sessioninfo 1.1.1 2018-11-05 [1] CRAN (R 3.6.0)
#> stringi 1.4.5 2020-01-11 [1] CRAN (R 3.6.0)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 3.6.0)
#> styler 1.5.1 2021-07-13 [1] CRAN (R 3.6.2)
#> tibble 3.1.3 2021-07-23 [1] CRAN (R 3.6.2)
#> utf8 1.1.4 2018-05-24 [1] CRAN (R 3.6.0)
#> vctrs 0.3.8 2021-04-29 [1] CRAN (R 3.6.2)
#> withr 2.4.2 2021-04-18 [1] CRAN (R 3.6.2)
#> xfun 0.12 2020-01-13 [1] CRAN (R 3.6.0)
#> yaml 2.2.0 2018-07-25 [1] CRAN (R 3.6.0)
#>
#> [1] /Library/Frameworks/R.framework/Versions/3.6/Resources/library