In my application, the data-generating process requires stratified handling, as the data was sampled within known strata (e.g., by country), and each stratum is assumed to follow a structurally different process. Therefore, each group should have its own imputation model, trained only on its own training data to avoid leakage across strata or between train/test splits. Is this approach correct for imputing by subgroup and then modelling? If not, is there a better way to handle this in tidymodels?
My approach
Using a subset of the diamonds dataset (where cut is treated as a stratum), I:
- Split the data within each cut group (strata) into training and test sets.
- Fit a bagged trees imputer (
step_impute_bag
) on the training set of each stratum. - Bake the train data, then pool the strata back together.
- On the pooled, imputed dataset, I run cross-validated workflows to compare model specifications:
- linear
- linear interaction
I came across this example, which suggests:
step_impute_linear(var, impute_with = imp_vars(group))
This includes the group
variable as a predictor in a single pooled imputation model. However, it does not isolate the imputation by group — it shares information across groups, which violates my data-generating assumptions.
Note: I also posted this question on Stack Overflow .
Reproducible Example:
library(tidymodels)
library(ggplot2)
library(dplyr)
# Simulate missingness in diamonds dataset
set.seed(42)
diamonds_missing <- diamonds %>%
select(price, carat, x,y, cut) %>%
slice_sample(prop = .2) %>%
filter(grepl("Good|Fair", cut)) %>%
mutate(
carat = ifelse(runif(n()) < 0.1, NA, carat)) # 10% missing)
# Let's say 'cut' is the stratum — each group was sampled separately
# What I want: for each 'cut', split into training/test, impute 'price' using bagged trees using only training data from that group, then pool all data for model training, tuning with cross validation, model comparison, etc.
# Split by strata
diamonds_by_cut <- diamonds_missing %>%
group_split(cut) %>%
set_names(unique(diamonds_missing$cut))
# Group-wise imputation function
group_impute <- function(df_stratum) {
set.seed(123)
split <- initial_split(df_stratum)
train <- training(split)
test <- testing(split)
# Recipe for bagged tree imputation
rec <- recipe(price ~ ., data = train) %>%
update_role(cut, new_role = "id") %>%
step_impute_bag(
carat,
impute_with = imp_vars(x,y))
prep_rec <- prep(rec)
train_imputed <- bake(prep_rec, new_data = NULL)
test_imputed <- bake(prep_rec, new_data = test)
list(
train = train_imputed,
test = test_imputed)
}
# Apply imputation per group
imputed_list <- map(diamonds_by_cut, group_impute)
# Extract and label test data from all groups
test_data_all <- imputed_list %>%
imap_dfr(~ mutate(.x$test, cut = .y))
# Extract and label train data from all groups
diamonds_imputed <- imputed_list %>%
imap_dfr(~ mutate(.x$train, cut = .y))
diamonds_imputed
> diamonds_imputed
# A tibble: 2,839 × 5
carat x y cut price
<dbl> <dbl> <dbl> <chr> <int>
1 0.85 6.05 5.96 Very Good 1651
2 0.91 6.12 6.07 Very Good 3180
3 0.5 5.01 4.97 Very Good 1323
4 2.01 7.82 7.78 Very Good 13744
5 1.5 7.52 7.57 Very Good 7912
# Now we can proceed to model on the pooled, imputed dataset:
# e.g., create vfold_cv(diamonds_imputed), run workflows, possibly tune models, etc.
cv_folds <- vfold_cv(diamonds_imputed, v = 5)
cv_folds
> cv_folds
# 5-fold cross-validation
# A tibble: 5 × 2
splits id
<list> <chr>
1 <split [2271/568]> Fold1
2 <split [2271/568]> Fold2
3 <split [2271/568]> Fold3
4 <split [2271/568]> Fold4
5 <split [2272/567]> Fold5
ln <- recipe(price ~ carat + x + y, data = diamonds_imputed) %>%
step_log(carat, base = 10)
inter <- recipe(price ~ carat + x + y, data = diamonds_imputed) %>%
step_interact(terms = ~ carat:x)
# 2. Define model spec (linear regression)
gl_spec <- linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")
# 3. Create workflow set
wf_set <- workflow_set(
preproc = list(
ln = ln,
inter = inter),
models = list(
linear = gl_spec))
# 4. Fit resamples on CV folds
wf_results <- wf_set %>%
workflow_map(
"fit_resamples",
resamples = cv_folds,
metrics = metric_set(rsq),
verbose = TRUE,
seed = 270225)
collect_metrics(wf_results)