using custom metric_sets in tidymodel workflows

Hi all -

I'm trying to define a custom metric_set for an ElasticNet regression. I'm working with the example copied straight from the tidymodels website for mse. When I run this code myself, the mse and mse_vec functions work totally fine - with both dataframes and grouped dataframes (see reprex below). However, when I try to use this metric_set within a workflow, I get the following issue:

Failed to compute `mse()`.
Caused by error in `UseMethod()`:
! no applicable method for 'mse' applied to an object of class "c('grouped_df', 'tbl_df', 'tbl', 'data.frame')"

Incredibly frustratingly, when I render the reprex (below), it seems to work fine. The same code that I copied to create the reprex does not work when I paste it into my console. Does anyone have any insight into how I can get a custom metric_set working within a tidymodels workflow? Thanks in advance!!

library(tidymodels)

packageVersion('tidymodels')
#> [1] '1.4.1'
packageVersion('yardstick')
#> [1] '1.3.2'
set.seed(6735)

# Set up data for tests
data("solubility_test")
test_data <- data.frame(truth = 1:1000, estimate = 1:1000 + rnorm(1000, 0, 0.1))
tr_te_split <- vfold_cv(mtcars, v=10)


## Define custom metric set 

mse_impl <- function(truth, estimate, case_weights = NULL) {
  mean((truth - estimate) ^ 2)
}

mse_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) {
  check_numeric_metric(truth, estimate, case_weights)
  
  if (na_rm) {
    result <- yardstick_remove_missing(truth, estimate, case_weights)
    
    truth <- result$truth
    estimate <- result$estimate
    case_weights <- result$case_weights
  } else if (yardstick_any_missing(truth, estimate, case_weights)) {
    return(NA_real_)
  }
  
  mse_impl(truth, estimate, case_weights = case_weights)
}

mse <- function(data, ...) {
  UseMethod("mse")
}

mse <- new_numeric_metric(mse, direction = "minimize")

mse.data.frame <- function(data, truth, estimate, na_rm = TRUE, case_weights = NULL, ...) {
  
  numeric_metric_summarizer(
    name = "mse",
    fn = mse_vec,
    data = data,
    truth = !!enquo(truth),
    estimate = !!enquo(estimate),
    na_rm = na_rm,
    case_weights = !!enquo(case_weights)
  )
}

## define test data 

mse_vec(truth = test_data$truth, estimate = test_data$estimate)
#> [1] 0.009884644
mse(test_data, truth = truth, estimate = estimate)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 mse     standard     0.00988

## workflow 
rec_spec <- recipe(mpg ~ ., data = mtcars) %>%
  step_normalize(all_predictors())

lin_mod <- mod <- linear_reg(penalty=tune(), 
                             mixture=tune()) %>% 
  set_mode("regression") %>%
  set_engine("glmnet") 

wf_spec <-
  workflow() %>%
  add_recipe(rec_spec) %>%
  add_model(lin_mod)

# works 
wf_spec %>%
  tune_grid(resamples = tr_te_split,
            grid = 10,
            control= control_resamples(save_pred=TRUE,
                                       save_workflow = TRUE),
            metrics = yardstick::metric_set(rmse) )
#> # Tuning results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 5
#>    splits         id     .metrics          .notes           .predictions     
#>    <list>         <chr>  <list>            <list>           <list>           
#>  1 <split [28/4]> Fold01 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [40 × 6]>
#>  2 <split [28/4]> Fold02 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [40 × 6]>
#>  3 <split [29/3]> Fold03 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  4 <split [29/3]> Fold04 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  5 <split [29/3]> Fold05 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  6 <split [29/3]> Fold06 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  7 <split [29/3]> Fold07 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  8 <split [29/3]> Fold08 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  9 <split [29/3]> Fold09 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 10 <split [29/3]> Fold10 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>

# does not work 
wf_spec %>%
  tune_grid(resamples = tr_te_split,
            grid = 10,
            control= control_resamples(save_pred=TRUE,
                                       save_workflow = TRUE),
            metrics = yardstick::metric_set(mse) )
#> # Tuning results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 5
#>    splits         id     .metrics          .notes           .predictions     
#>    <list>         <chr>  <list>            <list>           <list>           
#>  1 <split [28/4]> Fold01 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [40 × 6]>
#>  2 <split [28/4]> Fold02 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [40 × 6]>
#>  3 <split [29/3]> Fold03 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  4 <split [29/3]> Fold04 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  5 <split [29/3]> Fold05 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  6 <split [29/3]> Fold06 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  7 <split [29/3]> Fold07 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  8 <split [29/3]> Fold08 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#>  9 <split [29/3]> Fold09 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>
#> 10 <split [29/3]> Fold10 <tibble [10 × 6]> <tibble [0 × 4]> <tibble [30 × 6]>

R.version
#>                _                           
#> platform       aarch64-apple-darwin20      
#> arch           aarch64                     
#> os             darwin20                    
#> system         aarch64, darwin20           
#> status                                     
#> major          4                           
#> minor          5.0                         
#> year           2025                        
#> month          04                          
#> day            11                          
#> svn rev        88135                       
#> language       R                           
#> version.string R version 4.5.0 (2025-04-11)
#> nickname       How About a Twenty-Six

Created on 2025-10-27 with reprex v2.1.1

I didn't have issues with using your code (see reprex below). Could your results have been using parallel processing? It would also be good to know your version of the tune package.

library(tidymodels)

packageVersion('tidymodels')
#> [1] '1.4.1'

packageVersion('yardstick')
#> [1] '1.3.2'

R.version
#>                _                           
#> platform       aarch64-apple-darwin20      
#> arch           aarch64                     
#> os             darwin20                    
#> system         aarch64, darwin20           
#> status                                     
#> major          4                           
#> minor          5.1                         
#> year           2025                        
#> month          06                          
#> day            13                          
#> svn rev        88306                       
#> language       R                           
#> version.string R version 4.5.1 (2025-06-13)
#> nickname       Great Square Root

set.seed(6735)

# Set up data for tests
data("solubility_test")
test_data <- data.frame(truth = 1:1000, estimate = 1:1000 + rnorm(1000, 0, 0.1))
tr_te_split <- vfold_cv(mtcars, v = 10)


## Define custom metric set

mse_impl <- function(truth, estimate, case_weights = NULL) {
  mean((truth - estimate)^2)
}

mse_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) {
  check_numeric_metric(truth, estimate, case_weights)

  if (na_rm) {
    result <- yardstick_remove_missing(truth, estimate, case_weights)

    truth <- result$truth
    estimate <- result$estimate
    case_weights <- result$case_weights
  } else if (yardstick_any_missing(truth, estimate, case_weights)) {
    return(NA_real_)
  }

  mse_impl(truth, estimate, case_weights = case_weights)
}

mse <- function(data, ...) {
  UseMethod("mse")
}

mse <- new_numeric_metric(mse, direction = "minimize")

mse.data.frame <- function(
  data,
  truth,
  estimate,
  na_rm = TRUE,
  case_weights = NULL,
  ...
) {
  numeric_metric_summarizer(
    name = "mse",
    fn = mse_vec,
    data = data,
    truth = !!enquo(truth),
    estimate = !!enquo(estimate),
    na_rm = na_rm,
    case_weights = !!enquo(case_weights)
  )
}

## define test data

mse_vec(truth = test_data$truth, estimate = test_data$estimate)
#> [1] 0.009884644

mse(test_data, truth = truth, estimate = estimate)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 mse     standard     0.00988

## workflow
rec_spec <- recipe(mpg ~ ., data = mtcars) %>%
  step_normalize(all_predictors())

lin_mod <- mod <- linear_reg(penalty = tune(), mixture = tune()) %>%
  set_mode("regression") %>%
  set_engine("glmnet")

wf_spec <-
  workflow() %>%
  add_recipe(rec_spec) %>%
  add_model(lin_mod)

# works
rmse_res <-
  wf_spec %>%
  tune_grid(
    resamples = tr_te_split,
    grid = 10,
    control = control_resamples(save_pred = TRUE, save_workflow = TRUE),
    metrics = yardstick::metric_set(rmse)
  )

# does not work
mse_res <-
  wf_spec %>%
  tune_grid(
    resamples = tr_te_split,
    grid = 10,
    control = control_resamples(save_pred = TRUE, save_workflow = TRUE),
    metrics = yardstick::metric_set(mse)
  )

collect_metrics(mse_res)
#> # A tibble: 10 × 8
#>          penalty mixture .metric .estimator  mean     n std_err .config         
#>            <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>           
#>  1 0.0000000001    0.367 mse     standard   13.2     10    2.94 pre0_mod01_post0
#>  2 0.00000000129   0.789 mse     standard   13.3     10    2.97 pre0_mod02_post0
#>  3 0.0000000167    0.05  mse     standard   13.1     10    2.92 pre0_mod03_post0
#>  4 0.000000215     0.472 mse     standard   13.2     10    2.95 pre0_mod04_post0
#>  5 0.00000278      0.894 mse     standard   13.2     10    2.96 pre0_mod05_post0
#>  6 0.0000359       0.156 mse     standard   13.2     10    2.94 pre0_mod06_post0
#>  7 0.000464        0.578 mse     standard   13.2     10    2.95 pre0_mod07_post0
#>  8 0.00599         1     mse     standard   13.0     10    2.87 pre0_mod08_post0
#>  9 0.0774          0.261 mse     standard   10.8     10    2.29 pre0_mod09_post0
#> 10 1               0.683 mse     standard    8.80    10    2.06 pre0_mod10_post0

Created on 2025-10-30 with reprex v2.1.1

Hi Max -

Was using tune 2.0.0.

I just tried my reprex without using a parallel backend and it does indeed work.

Do you have a good guide for setting up tune_grid with a parallel backend? I'm getting some weird errors following guidance from this page using the future package and not quite sure how to debug. Thanks!

Any global data from your current workspace has to be copied over to the worker processes so that the code, when run in parallel, has access to them.

(In the next version of tune, we can add an option to formally pass the data to the workers)

Using future

For background, tune_grid() calls future_lappy() with a set of globals that are from the package internals, not with data from the global environment. We need to send that data to the worker processes.

If you are using the future package, I thought that we could solve this by using the %<-% operator to send objects to the workers. For example:

library(future)
plan("multisession")

mse %<-% {mse}
mse_impl %<-% {mse_impl}
mse_vec %<-% {mse_vec}
mse.data.frame %<-% {mse.data.frame}

mse_par_res <-
  wf_spec %>%
  tune_grid(
    resamples = tr_te_split,
    grid = 10,
    control = control_resamples(save_pred = TRUE, save_workflow = TRUE),
    metrics = yardstick::metric_set(mse)
  )

However, I get the same error.

Perhaps @heinreich can drop in...

Using mirai

For background, tune_grid() calls mirai_map() with a set of globals that are from the package internals, not with data from the global environment. We need to send that data to the worker processes.

Honestly, I'm not sure at all. Maybe @gaoce can help out.

Using mirai, you can add those functions to the global environment of your parallel processes by:

library(mirai)
daemons(4)
everywhere({}, mse.data.frame = mse.data.frame, mse_vec = mse_vec, mse_impl = mse_impl)

You can find more details at: Evaluate Everywhere — everywhere • mirai