Stratified (group-wise) imputation with tidymodels::recipes

In my application, the data-generating process requires stratified handling, as the data was sampled within known strata (e.g., by country), and each stratum is assumed to follow a structurally different process. Therefore, each group should have its own imputation model, trained only on its own training data to avoid leakage across strata or between train/test splits. Is this approach correct for imputing by subgroup and then modelling? If not, is there a better way to handle this in tidymodels?

My approach

Using a subset of the diamonds dataset (where cut is treated as a stratum), I:

  1. Split the data within each cut group (strata) into training and test sets.
  2. Fit a bagged trees imputer (step_impute_bag) on the training set of each stratum.
  3. Bake the train data, then pool the strata back together.
  4. On the pooled, imputed dataset, I run cross-validated workflows to compare model specifications:
  • linear
  • linear interaction

I came across this example, which suggests:

step_impute_linear(var, impute_with = imp_vars(group))

This includes the group variable as a predictor in a single pooled imputation model. However, it does not isolate the imputation by group — it shares information across groups, which violates my data-generating assumptions.

Note: I also posted this question on Stack Overflow .

Reproducible Example:

library(tidymodels)
library(ggplot2)
library(dplyr)

# Simulate missingness in diamonds dataset
set.seed(42)
diamonds_missing <- diamonds %>%
  select(price, carat, x,y, cut) %>% 
  slice_sample(prop = .2) %>% 
  filter(grepl("Good|Fair", cut)) %>% 
  mutate(
    carat = ifelse(runif(n()) < 0.1, NA, carat))  # 10% missing)

# Let's say 'cut' is the stratum — each group was sampled separately
# What I want: for each 'cut', split into training/test, impute 'price' using bagged trees using only training data from that group, then pool all data for model training, tuning with cross validation, model comparison, etc. 

# Split by strata
diamonds_by_cut <- diamonds_missing %>%
  group_split(cut) %>%
  set_names(unique(diamonds_missing$cut))

# Group-wise imputation function
group_impute <- function(df_stratum) {
  set.seed(123)
  
  split <- initial_split(df_stratum)
  train <- training(split)
  test <- testing(split)
  
  # Recipe for bagged tree imputation
  rec <- recipe(price ~ ., data = train) %>%
    update_role(cut, new_role = "id") %>%  
    step_impute_bag(
      carat, 
      impute_with = imp_vars(x,y))
  
  prep_rec <- prep(rec)
  
  train_imputed <- bake(prep_rec, new_data = NULL)
  test_imputed  <- bake(prep_rec, new_data = test)
  
  list(
    train = train_imputed,
    test  = test_imputed)
  
  }

# Apply imputation per group
imputed_list <- map(diamonds_by_cut, group_impute)

# Extract and label test data from all groups
test_data_all <- imputed_list %>%
  imap_dfr(~ mutate(.x$test, cut = .y))

# Extract and label train data from all groups
diamonds_imputed <- imputed_list %>%
  imap_dfr(~ mutate(.x$train, cut = .y))

diamonds_imputed
> diamonds_imputed
# A tibble: 2,839 × 5
   carat     x     y cut       price
   <dbl> <dbl> <dbl> <chr>     <int>
 1  0.85  6.05  5.96 Very Good  1651
 2  0.91  6.12  6.07 Very Good  3180
 3  0.5   5.01  4.97 Very Good  1323
 4  2.01  7.82  7.78 Very Good 13744
 5  1.5   7.52  7.57 Very Good  7912

# Now we can proceed to model on the pooled, imputed dataset:
# e.g., create vfold_cv(diamonds_imputed), run workflows, possibly tune models, etc.

cv_folds <- vfold_cv(diamonds_imputed, v = 5)

cv_folds
> cv_folds
#  5-fold cross-validation 
# A tibble: 5 × 2
  splits             id   
  <list>             <chr>
1 <split [2271/568]> Fold1
2 <split [2271/568]> Fold2
3 <split [2271/568]> Fold3
4 <split [2271/568]> Fold4
5 <split [2272/567]> Fold5


ln <- recipe(price ~ carat + x + y, data = diamonds_imputed) %>%
  step_log(carat, base = 10)

inter <- recipe(price ~ carat + x + y, data = diamonds_imputed) %>%
  step_interact(terms = ~ carat:x)

# 2. Define model spec (linear regression)
gl_spec <- linear_reg() %>%
  set_mode("regression") %>%
  set_engine("lm")

# 3. Create workflow set
wf_set <- workflow_set(
  preproc = list(
    ln = ln,
    inter = inter),
  models = list(
    linear = gl_spec))

# 4. Fit resamples on CV folds
wf_results <- wf_set %>%
  workflow_map(
    "fit_resamples",
    resamples = cv_folds,
    metrics = metric_set(rsq),
    verbose = TRUE,
    seed = 270225)

collect_metrics(wf_results)

That fits the model lm(var ~ group, data = some_data) and uses least squares. That does take all of the data at once but doesn't technically pool across strata.

Here's an example to show that the imputed value are the same as the sample means and, if I chnage one strata, the estimates for the others don't change:

library(tidymodels)

car_fac <- mtcars |> mutate(gear = factor(gear))

lm_fit <- lm(mpg ~ gear + 0, data = car_fac)
tidy(lm_fit)
#> # A tibble: 3 × 5
#>   term  estimate std.error statistic  p.value
#>   <chr>    <dbl>     <dbl>     <dbl>    <dbl>
#> 1 gear3     16.1      1.22      13.2 7.87e-14
#> 2 gear4     24.5      1.36      18.1 2.59e-17
#> 3 gear5     21.4      2.11      10.2 4.66e-11

gear_means <- 
  car_fac |> 
  summarize(estimate = mean(mpg), .by = c(gear)) |> 
  arrange(gear)
gear_means
#>   gear estimate
#> 1    3 16.10667
#> 2    4 24.53333
#> 3    5 21.38000

all.equal(gear_means$estimate, tidy(lm_fit)$estimate)
#> [1] TRUE

# two fewer 4-gear cars
car_fac_sub <- car_fac |> slice(-(1:2))
lm_sub_fit <- lm(mpg ~ gear + 0, data = car_fac_sub)

# compare results
tidy(lm_sub_fit) |> 
  select(term, fewer_4_gear = estimate) |> 
  full_join(
    tidy(lm_fit) |> select(term, original = estimate),
    by = "term"
  ) |> 
  mutate(same_estimate = original == fewer_4_gear)
#> # A tibble: 3 × 4
#>   term  fewer_4_gear original same_estimate
#>   <chr>        <dbl>    <dbl> <lgl>        
#> 1 gear3         16.1     16.1 TRUE         
#> 2 gear4         25.2     24.5 FALSE        
#> 3 gear5         21.4     21.4 TRUE

Created on 2025-07-16 with reprex v2.1.1

It's not a particularly great imputation but it is stratified.

Maybe I missed something, but do you impute the test set at all?

To me, the process above looks like the classic example of information leakage. The cross-validation has no way of evaluated how good or bad the imputation in since that methods is baked-in to the data (pun intended). You'll never see poor performance if you imputation is overfitting or otherwise damaging the overall modeling process. Maybe if you have a very simple (=high bias) imputation method, that risk is low but I would avoid that process altogether.

Also, are you fitting a model after the recipe? If so, won't that borrow across strata?