I cannot figure out how to get tidymodels to retain other information beyond the prediction target when I save predictions, when I'm fitting a model for a grid of different hyperparameter values and want to work further with the out-of-fold predictions (based on K-fold cross-validation).
myresults = myworkflow %>%
tune_grid(resamples = my_cv_folds,
grid = my_hyperparameter_grid,
metrics = metric_set(rmse),
control = control_grid(save_pred=TRUE,
verbose=F))
I'm really interested to add other information to my out-of-fold predictions.
The first reason is that I have a model, where I am trying to predict a variable Y. From the past (from totally different data), I already have a very basic prediction model, that gives an output \tilde{Y}. Thus, one thing to try is to predict Y' = Y - \tilde{Y}, but I also want to try simply predicting Y. Ideally, I want to be able to get a tibble with the out-of-fold predictions + the record id. That way I can then merge in \tilde{Y} values as needed. Additionally, being able to do this is useful for e.g. looking at performance in subgroups (e.g. by gender, race etc.).
However, I cannot see any option for doing this and I already confirmed that somewhere along the way the order of the records changes, so I cannot just call assessment
on each fold to get the right values (and yes, I already dealt with the random order of the folds). As far as I can see control_grid
and collect_predictions
don't have any relevant options and are not keeping the order of records I get from e.g. doing for the first fold assessment(myresults$splits[[1]])$id_variable
.
Am I overlooking something obvious?
Here's a basic example that illustrates the problem:
library(tidyverse)
library(tidymodels)
library(workflows)
library(tune)
library(doMC)
# Simulate some data
set.seed(55)
mydata = tibble(X1 = rnorm(1000), X2 = rnorm(1000), X3 = rnorm(1000), X4 = sample(c("A","B","C","C"), replace=T, size = 1000),
Y = X1 + X2 + X3 + X1*X2*0.01 + X2*X3*0.1 + X1*X2*X3*0.001 - (X4=="C")*0.2 + (X4!="C")*0.2 + rnorm(1000),
Y_tilde = X1 + 1.05*X2 + 0.99*X3,
Y_dash = Y-Y_tilde) %>%
mutate(id_variable = row_number())
# Make some cross-validation folds
my_folds <- vfold_cv(mydata, v = 10)
# Simple model + recipe = workflow definition
my_model <- linear_reg(penalty = tune(),
mixture = tune()) %>%
set_engine("glmnet")
my_recipe <-
recipe(Y_dash ~ X1 + X2 + X3 + X4, data = mydata) %>%
step_indicate_na(all_numeric_predictors(), role = "predictor") %>%
step_impute_median(all_numeric_predictors()) %>%
step_interact(terms = ~ X1:X2 + X2:X3 + X1:X3 + X1:X2:X3) %>%
step_normalize(all_numeric_predictors(),-starts_with("na_ind")) %>%
step_string2factor(all_nominal_predictors()) %>%
step_impute_mode(all_nominal_predictors()) %>%
step_other(all_nominal_predictors(), threshold=0.025) %>%
step_dummy(all_nominal_predictors()) %>%
step_nzv(all_numeric_predictors())
# bake(my_recipe %>% prep(), mydata) %>% data.frame() %>% head()
my_workflow = workflow() %>%
add_model(my_model) %>%
add_recipe(my_recipe)
#### Hyperameters to use ###################
my_hyperparameter_grid = crossing(penalty = c(0.1, 1),
mixture = c(0, 0.5, 1))
# Run this all with some multi-processing
ptm <- proc.time()
registerDoMC(cores = 4)
my_results = my_workflow %>%
tune_grid(resamples = my_folds,
grid = my_hyperparameter_grid,
metrics = metric_set(rmse),
control = control_grid(save_pred=TRUE,
parallel_over = "resamples"))
registerDoSEQ()
proc.time() - ptm
# Now, let's try to get the out of fold predictions and merge in the other information
oofs = my_results %>%
collect_predictions() %>%
mutate(id = factor(id, levels=my_folds$id)) %>%
arrange(id)
# I thought this should work, but it doesn't
oofs$id_variable = rep(
unlist( map(1:length(my_folds$id), function(x) assessment(my_results$splits[[x]])$id_variable) ),
dim(my_hyperparameter_grid)[1] )
oofs$Y = rep(
unlist( map(1:length(my_folds$id), function(x) assessment(my_results$splits[[x]])$Y) ),
dim(my_hyperparameter_grid)[1] )
oofs$Y_tilde = rep(
unlist( map(1:length(my_folds$id), function(x) assessment(my_results$splits[[x]])$Y_tilde) ),
dim(my_hyperparameter_grid)[1] )
# As we can see, we get plenty of cases, where we blatantly have a mis-match that we should not have:
oofs %>%
filter( (Y_dash - (Y-Y_tilde))^2>0.1 )