Hi all -
I'm trying to define a custom metric_set for an ElasticNet regression. I'm working with the example copied straight from the tidymodels website for mse. When I run this code myself, the mse and mse_vec functions work totally fine - with both dataframes and grouped dataframes (see reprex below). However, when I try to use this metric_set within a workflow, I get the following issue:
Failed to compute `mse()`.
Caused by error in `UseMethod()`:
! no applicable method for 'mse' applied to an object of class "c('grouped_df', 'tbl_df', 'tbl', 'data.frame')"
Incredibly frustratingly, when I render the reprex (below), it seems to work fine. The same code that I copied to create the reprex does not work when I paste it into my console. Does anyone have any insight into how I can get a custom metric_set working within a tidymodels workflow? Thanks in advance!!
library(tidymodels)
packageVersion('tidymodels')
#> [1] '1.4.1'
packageVersion('yardstick')
#> [1] '1.3.2'
set.seed(6735)
# Set up data for tests
data("solubility_test")
test_data <- data.frame(truth = 1:1000, estimate = 1:1000 + rnorm(1000, 0, 0.1))
tr_te_split <- vfold_cv(mtcars, v=10)
## Define custom metric set
mse_impl <- function(truth, estimate, case_weights = NULL) {
mean((truth - estimate) ^ 2)
}
mse_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) {
check_numeric_metric(truth, estimate, case_weights)
if (na_rm) {
result <- yardstick_remove_missing(truth, estimate, case_weights)
truth <- result$truth
estimate <- result$estimate
case_weights <- result$case_weights
} else if (yardstick_any_missing(truth, estimate, case_weights)) {
return(NA_real_)
}
mse_impl(truth, estimate, case_weights = case_weights)
}
mse <- function(data, ...) {
UseMethod("mse")
}
mse <- new_numeric_metric(mse, direction = "minimize")
mse.data.frame <- function(data, truth, estimate, na_rm = TRUE, case_weights = NULL, ...) {
numeric_metric_summarizer(
name = "mse",
fn = mse_vec,
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm,
case_weights = !!enquo(case_weights)
)
}
## define test data
mse_vec(truth = test_data$truth, estimate = test_data$estimate)
#> [1] 0.009884644
mse(test_data, truth = truth, estimate = estimate)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 mse standard 0.00988
## workflow
rec_spec <- recipe(mpg ~ ., data = mtcars) %>%
step_normalize(all_predictors())
lin_mod <- mod <- linear_reg(penalty=tune(),
mixture=tune()) %>%
set_mode("regression") %>%
set_engine("glmnet")
wf_spec <-
workflow() %>%
add_recipe(rec_spec) %>%
add_model(lin_mod)
# works
wf_spec %>%
tune_grid(resamples = tr_te_split,
grid = 10,
control= control_resamples(save_pred=TRUE,
save_workflow = TRUE),
metrics = yardstick::metric_set(rmse) )
#> # Tuning results
#> # 10-fold cross-validation
#> # A tibble: 10 × 5
#> splits id .metrics .notes .predictions
#> <list> <chr> <list> <list> <list>
#> 1 <split [28/4]> Fold01 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [40 × 6]>
#> 2 <split [28/4]> Fold02 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [40 × 6]>
#> 3 <split [29/3]> Fold03 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 4 <split [29/3]> Fold04 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 5 <split [29/3]> Fold05 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 6 <split [29/3]> Fold06 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 7 <split [29/3]> Fold07 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 8 <split [29/3]> Fold08 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 9 <split [29/3]> Fold09 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 10 <split [29/3]> Fold10 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
# does not work
wf_spec %>%
tune_grid(resamples = tr_te_split,
grid = 10,
control= control_resamples(save_pred=TRUE,
save_workflow = TRUE),
metrics = yardstick::metric_set(mse) )
#> # Tuning results
#> # 10-fold cross-validation
#> # A tibble: 10 × 5
#> splits id .metrics .notes .predictions
#> <list> <chr> <list> <list> <list>
#> 1 <split [28/4]> Fold01 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [40 × 6]>
#> 2 <split [28/4]> Fold02 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [40 × 6]>
#> 3 <split [29/3]> Fold03 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 4 <split [29/3]> Fold04 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 5 <split [29/3]> Fold05 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 6 <split [29/3]> Fold06 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 7 <split [29/3]> Fold07 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 8 <split [29/3]> Fold08 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 9 <split [29/3]> Fold09 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 10 <split [29/3]> Fold10 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
R.version
#> _
#> platform aarch64-apple-darwin20
#> arch aarch64
#> os darwin20
#> system aarch64, darwin20
#> status
#> major 4
#> minor 5.0
#> year 2025
#> month 04
#> day 11
#> svn rev 88135
#> language R
#> version.string R version 4.5.0 (2025-04-11)
#> nickname How About a Twenty-Six
Created on 2025-10-27 with reprex v2.1.1