I am using tidymodels
to optimize/evaluate models and have generated a custom metric that is used for virtual drug screening (enrichment factor at 1%). I've followed the vignettes/tutorials, and the metric is working after model fitting when used on the predictions from fitting to the hold out data, but I can't seem to get it to work during tuning/racing. Because it's class probability metric, I need to pass the class probability column in addition to the truth and estimate, but I can't get it to work in tune_grid()
or workflow_map()/control_race()
. Below is a reprex:
library(tidyverse)
library(tidymodels)
library(modeldata)
# metric ------------------------------------------------------------------
# get the correct event level
event_col <- function(xtab, event_level) {
if (identical(event_level, "first")) {
colnames(xtab)[[1]]
} else {
colnames(xtab)[[2]]
}
}
# control what type of data can be used
finalize_estimator_internal.ef1 <- function(metric_dispatcher, x, estimator) {
validate_estimator(estimator, estimator_override = "binary")
if(!is.null(estimator)) {
return(estimator)
}
lvls <- levels(x)
if(length(lvls) > 2) {
stop("A multiclass `truth` input was provided, but only `binary` is supported.")
}
"binary"
}
# vector implementation
ef1_vec <- function(truth, estimate, estimate_val, estimator = NULL, event_level = "first", na_rm = TRUE, ...) {
estimator <- finalize_estimator(truth, estimator, metric_class = "ef1")
ef1_impl <- function(truth, estimate, estimate_val) {
xtab <- table(estimate, truth)
col <- event_col(xtab, event_level)
N <- length(truth)
A <- sum(truth == col)
df <- bind_cols(truth, estimate, estimate_val) %>%
rename(truth = 1, estimate = 2, estimate_val = 3) %>%
arrange(-estimate_val) %>%
slice_head(prop = 0.1) %>%
group_by(truth) %>%
tally()
# there's probably a better way of getting the event name, but this works for now
a <- filter(df, truth == col)$n
n <- sum(df$n)
(a / n) / (A / N)
}
metric_vec_template(
metric_impl = ef1_impl,
truth = truth,
estimate = estimate,
estimate_val = estimate_val,
na_rm = na_rm,
cls = "factor",
estimator = estimator,
...
)
}
# data frame implementation
ef1 <- function(data, ...) {
UseMethod("ef1")
}
ef1 <- new_prob_metric(ef1, direction = "maximize")
ef1.data.frame <- function(data, truth, estimate, estimate_val,
estimator = NULL, na_rm = TRUE,
event_level = "first", ...) {
metric_summarizer(
metric_nm = "ef1",
metric_fn = ef1_vec,
data = data,
truth = !! enquo(truth),
estimate = !! enquo(estimate),
metric_fn_options = list(estimate_val = enquo(estimate_val)),
estimator = estimator,
na_rm = na_rm,
event_level = event_level,
...
)
}
# test metric -------------------------------------------------------------
data(two_class_example)
two_class_example %>%
ef1(truth, predicted, estimate_val = Class1)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 ef1 binary 1.94
# tune parameters ---------------------------------------------------------------
data(bivariate)
base_rec <-
recipe(Class ~ ., data = bivariate_val)
rand_forest_ranger_spec <-
rand_forest(mtry = tune(), min_n = tune(), trees = tune()) %>%
set_engine('ranger') %>%
set_mode('classification')
rf_grid <- grid_regular(finalize(mtry(), bivariate_train),
min_n(),
trees(),
levels = 3)
ef1_set <- metric_set(ef1)
rf_tune <- tune_grid(
workflow(base_rec, rand_forest_ranger_spec),
resamples = folds <- vfold_cv(bivariate_train,
v = 2,
strata = Class),
grid = rf_grid,
metrics = ef1_set(Class, .pred_class, estimate = .pred_class),
control = control_grid(
save_pred = TRUE,
verbose = TRUE
)
)
#> Error:
#> ! In metric: `ef1`
#> object 'Class' not found
Created on 2022-04-13 by the reprex package (v2.0.1)