Tidyverse k-fold cross validation within fold data manipulation question

Hi all, my goal is to carry out some data cleaning in a pipe within cross validated folds before model fitting occurs. Any help is much appreciated.

The simple reprex example uses the iris dataset with NA values in the Sepal.Length to model Petal.Length as the response variable.

The reason I want to do this in my real application is because I want to avoid any data leakage that might come from using any of the testing fold information to define anything in my training folds.

In the example, I replace some of the values in a feature column with NA values to simulate having NA values in a dataset. Once the training folds have been defined, I want to replace the NA values in the column with the column mean (excluding NA values). The catch is that I want to do this separately for only rows that are in the associated training fold.


## Load data and replace Sepal.Length 5.1 with NA values
library(tidyverse)
#> -- Attaching packages ----------------------------------------------------------- tidyverse 1.2.1 --
#> v ggplot2 2.2.1     v purrr   0.2.4
#> v tibble  1.4.2     v dplyr   0.7.4
#> v tidyr   0.8.0     v stringr 1.2.0
#> v readr   1.1.1     v forcats 0.3.0
#> -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag()    masks stats::lag()
library(modelr)

iris2 <- as.tibble(iris)
iris2 <- iris2 %>%
  mutate(Sepal.Length=replace(Sepal.Length, Sepal.Length==5.1, NA))

## Function to take in data, modify Sepal.Length by replacing NA with the column mean 
## and output the modified data
dataPrep <- function(x){
  if("Sepal.Length" %in% names(x)){
    x <- x %>%
      mutate(
        Sepal.Length=replace(
          Sepal.Length, is.na(Sepal.Length), mean(Sepal.Length, na.rm=TRUE)
        )
      )
  }
  return(x)
}

## One of many failed attempt to apply the dataPrep function to training folds
set.seed(1)
fits.dt <- iris2 %>%
  crossv_kfold(5) %>%
  mutate(train2 = map(train, ~ dataPrep(.))) %>%
  mutate(model = map(train2, ~lm(Petal.Length ~ Sepal.Length, data=.)))

## NA value still in Sepal.Length 
## Therefore function does not appear to be applied as intended
str(fits.dt$train2[[1]])
#> List of 2
#>  $ data:Classes 'tbl_df', 'tbl' and 'data.frame':    150 obs. of  5 variables:
#>   ..$ Sepal.Length: num [1:150] NA 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
#>   ..$ Sepal.Width : num [1:150] 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
#>   ..$ Petal.Length: num [1:150] 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#>   ..$ Petal.Width : num [1:150] 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#>   ..$ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
#>  $ idx : int [1:120] 1 3 4 5 7 8 9 10 11 12 ...
#>  - attr(*, "class")= chr "resample"

R.Version()
#> $platform
#> [1] "x86_64-w64-mingw32"
#> 
#> $arch
#> [1] "x86_64"
#> 
#> $os
#> [1] "mingw32"
#> 
#> $system
#> [1] "x86_64, mingw32"
#> 
#> $status
#> [1] ""
#> 
#> $major
#> [1] "3"
#> 
#> $minor
#> [1] "4.3"
#> 
#> $year
#> [1] "2017"
#> 
#> $month
#> [1] "11"
#> 
#> $day
#> [1] "30"
#> 
#> $`svn rev`
#> [1] "73796"
#> 
#> $language
#> [1] "R"
#> 
#> $version.string
#> [1] "R version 3.4.3 (2017-11-30)"
#> 
#> $nickname
#> [1] "Kite-Eating Tree"
1 Like

You would be better off using rsample and recipes to do the cross-validation and imputation, respectively.

library(tidyverse)
library(rsample)
library(recipes)

iris2 <- as.tibble(iris)
iris2 <- iris2 %>%
  mutate(Sepal.Length=replace(Sepal.Length, Sepal.Length==5.1, NA))

# Use recipes for imputation
mean_impute <- recipe(Petal.Length ~ Sepal.Length, data = iris2) %>%
  step_meanimpute(Sepal.Length)


# Use rsample 
fits.dt <- iris2 %>%
  vfold_cv(v = 5) %>% 
  # Do the imputation within resampling by prepping the recipe
  mutate(recipe = map(splits, prepper, recipe = mean_impute, retain = TRUE)) %>%
  # Get the data used to fit the model, which has been imputed,  
  # using the `juice` function
  mutate(model = 
           map(
             recipe, 
             ~lm(Petal.Length ~ Sepal.Length, data = juice(.))
           )
  )
6 Likes

Thank you very much Max, this is quite cool.

I'm struggling to obtain the out of folds predictions otherwise known as the hold out data associated with each fold, my approach is below.

fits.dt <- fits.dt %>%
  mutate(recipe_oof = 
                 map(
                   as.data.frame(
                     splits,
                     data = "assessment")
                   , prepper, recipe = mean_impute, retain = TRUE
                 )
  ) %>%
  mutate(pred_oof = map2(model, recipe_oof, ~predict( .x, juice(.y), type = "response")))

No problem. You should check out the modeling notes from the conference workshop.

Here's how you would do it:

library(tidyverse)
library(rsample)
library(recipes)

iris2 <- as.tibble(iris)
iris2 <- iris2 %>%
  mutate(Sepal.Length=replace(Sepal.Length, Sepal.Length==5.1, NA))

# Use recipes for imputation
mean_impute <- recipe(Petal.Length ~ Sepal.Length, data = iris2) %>%
  step_meanimpute(Sepal.Length)

# Add a function to get the predictions
get_pred <- function(splits, recipe, model) {
  # Get holdout (aka "assessment") data
  dat <- assessment(splits)
  # Do the imputation (based on the corresponding analysis data)
  dat <- bake(recipe, newdata = dat) %>%
    # Get predictions from the model
    mutate(
      pred = predict(model, newdata = .),
      # Add a fold label for later
      fold = labels(splits)$id
    )
      dat
}

# Use rsample 
fits.dt <- iris2 %>%
  vfold_cv(v = 5) %>% 
  # Do the imputation within resampling by prepping the recipe
  mutate(recipe = map(splits, prepper, recipe = mean_impute, retain = TRUE)) %>%
  # Get the data used to fit the model from the `juice` function
  mutate(
    model = 
      map(
        recipe, 
        ~lm(Petal.Length ~ Sepal.Length, data = juice(.))
      )
  ) %>% 
  mutate(
    pred = pmap(list(splits = splits, recipe = recipe, model = model), get_pred)
  )

# Then you can get all of the predictions together
predictions <- bind_rows(fits.dt$pred)

ggplot(predictions, aes(x = Petal.Length, y = pred, col = fold)) + 
  geom_point()
3 Likes