workflowsets parameter grids

Is there something that I am missing to create a nonregular grid? the help for tune_grid says that the grid should be A data frame of tuning combinations or a positive integer. The data frame should have columns for each parameter being tuned and rows for tuning parameter candidates. An integer denotes the number of candidate parameter sets to be created automatically. so in this example I create a data frame of the tuning parameters but now I get an error from workflow_map

library('tidymodels')
library('workflowsets')
tidymodels_prefer()

data(parabolic)
parabolic <- parabolic 
str(parabolic)
#> tibble [500 × 3] (S3: tbl_df/tbl/data.frame)
#>  $ X1   : num [1:500] 3.29 1.47 1.66 1.6 2.17 ...
#>  $ X2   : num [1:500] 1.661 0.414 0.791 0.276 3.166 ...
#>  $ class: Factor w/ 2 levels "Class1","Class2": 1 2 2 2 1 1 2 1 2 1 ...

set.seed(1)
split <- initial_split(parabolic)

train_set <- training(split)
test_set <- testing(split)


library(discrim)

mars_disc_spec <- 
  discrim_flexible(prod_degree = tune()) %>% 
  set_engine("earth")

reg_disc_sepc <- 
  discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>% 
  set_engine("klaR")

cart_spec <- 
  decision_tree(cost_complexity = tune(), min_n = tune()) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")


set.seed(2)
train_resamples <- bootstraps(train_set)

all_workflows <- 
  workflow_set(
    preproc = list("formula" = class ~ .),
    models = list(regularized = reg_disc_sepc, 
                  mars = mars_disc_spec, 
                  cart = cart_spec)
  )


# Function to finalize and generate parameter grid
generate_param_grid <- function(workflow_id, workflow_sets, data, size = 50, type="latin_hypercube") {
  param_set <- workflow_sets %>%
    extract_parameter_set_dials(id = workflow_id)
  
  if (nrow(param_set) > 0) {
    # Finalize the parameter set
    finalized_param_set <- param_set %>%
      finalize(data)
    
    # Generate the grid
    finalized_param_set %>%
      grid_space_filling(size = size, 
                         original = TRUE,
                         type = type)
  } else {
    tibble::tibble()  # Return an empty tibble if no parameters
  }
}

# Apply the function to each workflow
param_grids <- purrr::map(
  set_names(all_workflows$wflow_id),
  \(x) generate_param_grid(x, all_workflows, parabolic)
) %>%
  list_rbind()

# Print the parameter grids
param_grids
#> # A tibble: 102 × 5
#>    frac_common_cov frac_identity prod_degree cost_complexity min_n
#>              <dbl>         <dbl>       <int>           <dbl> <int>
#>  1           0.874        0.0399          NA              NA    NA
#>  2           0.398        0.244           NA              NA    NA
#>  3           0.524        0.523           NA              NA    NA
#>  4           0.786        0.542           NA              NA    NA
#>  5           0.994        0.923           NA              NA    NA
#>  6           0.701        0.993           NA              NA    NA
#>  7           0.148        0.474           NA              NA    NA
#>  8           0.774        0.313           NA              NA    NA
#>  9           0.371        0.747           NA              NA    NA
#> 10           0.554        0.859           NA              NA    NA
#> # ℹ 92 more rows

all_workflows_res <- 
  all_workflows %>% 
  # Specifying arguments here adds to any previously set with `option_add()`:
  workflow_map(resamples = train_resamples, 
               grid = param_grids, 
               verbose = TRUE)
#> i 1 of 3 tuning:     formula_regularized
#> ✖ 1 of 3 tuning:     formula_regularized failed with: Error in check_grid(grid = grid, workflow = workflow, pset = pset) :   The provided `grid` has the following parameter columns that have not been marked for tuning by `tune()`: 'prod_degree', 'cost_complexity', 'min_n'.
#> i 2 of 3 tuning:     formula_mars
#> ✖ 2 of 3 tuning:     formula_mars failed with: Error in check_grid(grid = grid, workflow = workflow, pset = pset) :   The provided `grid` has the following parameter columns that have not been marked for tuning by `tune()`: 'frac_common_cov', 'frac_identity', 'cost_complexity', 'min_n'.
#> i 3 of 3 tuning:     formula_cart
#> ✖ 3 of 3 tuning:     formula_cart failed with: Error in check_grid(grid = grid, workflow = workflow, pset = pset) :   The provided `grid` has the following parameter columns that have not been marked for tuning by `tune()`: 'frac_common_cov', 'frac_identity', 'prod_degree'.


sessionInfo()
#> R version 4.4.1 (2024-06-14 ucrt)
#> Platform: x86_64-w64-mingw32/x64
#> Running under: Windows 11 x64 (build 22631)
#> 
#> Matrix products: default
#> 
#> 
#> locale:
#> [1] LC_COLLATE=English_United States.utf8 
#> [2] LC_CTYPE=English_United States.utf8   
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.utf8    
#> 
#> time zone: America/New_York
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] discrim_1.0.1      yardstick_1.3.1    workflowsets_1.1.0 workflows_1.1.4   
#>  [5] tune_1.2.1         tidyr_1.3.1        tibble_3.2.1       rsample_1.2.1     
#>  [9] recipes_1.1.0      purrr_1.0.2        parsnip_1.2.1      modeldata_1.4.0   
#> [13] infer_1.0.7        ggplot2_3.5.1      dplyr_1.1.4        dials_1.3.0       
#> [17] scales_1.3.0       broom_1.0.6        tidymodels_1.2.0  
#> 
#> loaded via a namespace (and not attached):
#>  [1] conflicted_1.2.0    rlang_1.1.4         magrittr_2.0.3     
#>  [4] furrr_0.3.1         compiler_4.4.1      vctrs_0.6.5        
#>  [7] combinat_0.0-8      lhs_1.2.0           pkgconfig_2.0.3    
#> [10] fastmap_1.2.0       backports_1.5.0     utf8_1.2.4         
#> [13] promises_1.3.0      rmarkdown_2.28      prodlim_2024.06.25 
#> [16] haven_2.5.4         klaR_1.7-3          xfun_0.47          
#> [19] reprex_2.1.1        cachem_1.1.0        labelled_2.13.0    
#> [22] highr_0.11          later_1.3.2         parallel_4.4.1     
#> [25] R6_2.5.1            parallelly_1.38.0   rpart_4.1.23       
#> [28] lubridate_1.9.3     Rcpp_1.0.13         iterators_1.0.14   
#> [31] knitr_1.48          future.apply_1.11.2 httpuv_1.6.15      
#> [34] Matrix_1.7-0        splines_4.4.1       nnet_7.3-19        
#> [37] timechange_0.3.0    tidyselect_1.2.1    rstudioapi_0.16.0  
#> [40] yaml_2.3.10         timeDate_4032.109   codetools_0.2-20   
#> [43] miniUI_0.1.1.1      listenv_0.9.1       lattice_0.22-6     
#> [46] shiny_1.9.1         withr_3.0.1         evaluate_0.24.0    
#> [49] future_1.34.0       survival_3.7-0      pillar_1.9.0       
#> [52] foreach_1.5.2       generics_0.1.3      hms_1.1.3          
#> [55] munsell_0.5.1       plotmo_3.6.4        globals_0.16.3     
#> [58] xtable_1.8-4        class_7.3-22        glue_1.8.0         
#> [61] mda_0.5-4           tools_4.4.1         data.table_1.16.0  
#> [64] gower_1.0.1         forcats_1.0.0       fs_1.6.4           
#> [67] grid_4.4.1          plotrix_3.8-4       ipred_0.9-15       
#> [70] colorspace_2.1-1    earth_5.3.3         Formula_1.2-5      
#> [73] cli_3.6.3           DiceDesign_1.10     fansi_1.0.6        
#> [76] lava_1.8.0          gtable_0.3.5        GPfit_1.0-8        
#> [79] digest_0.6.37       memoise_2.0.1       htmltools_0.5.8.1  
#> [82] questionr_0.7.8     lifecycle_1.0.4     hardhat_1.4.0      
#> [85] mime_0.12           MASS_7.3-61

Okay, I think I figured out what I was doing wrong with setting up grid sets for parameter grid the key is the function option_add which can be used to add options to specific workflow ids. So this means that you can manually create grids for each preprocessing-model combinations.

I did create a helper functions to generalize the process as shown below in generate_param_grid then I just apply the function to the workflow set with purrr::map and bind the results with list_rbind, then you can proceed with the fit per the normal workflowset process:

library('tidymodels')
library('workflowsets')
tidymodels_prefer()

data(parabolic)
parabolic <- parabolic 
str(parabolic)
#> tibble [500 × 3] (S3: tbl_df/tbl/data.frame)
#>  $ X1   : num [1:500] 3.29 1.47 1.66 1.6 2.17 ...
#>  $ X2   : num [1:500] 1.661 0.414 0.791 0.276 3.166 ...
#>  $ class: Factor w/ 2 levels "Class1","Class2": 1 2 2 2 1 1 2 1 2 1 ...

set.seed(1)
split <- initial_split(parabolic)

train_set <- training(split)
test_set <- testing(split)


library('discrim')

mars_disc_spec <- 
  discrim_flexible(prod_degree = tune()) %>% 
  set_engine("earth")

reg_disc_sepc <- 
  discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>% 
  set_engine("klaR")

cart_spec <- 
  decision_tree(cost_complexity = tune(), min_n = tune()) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")


set.seed(2)
train_resamples <- bootstraps(train_set)

all_workflows <- 
  workflow_set(
    preproc = list("formula" = class ~ .),
    models = list(regularized = reg_disc_sepc, 
                  mars = mars_disc_spec, 
                  cart = cart_spec)
  )


# Function to finalize and generate parameter grid
generate_param_grid <- function(workflow_id, workflow_sets, data, size = 50, type="latin_hypercube") {
  param_set <- workflow_sets %>%
    extract_parameter_set_dials(id = workflow_id)
  
  if (nrow(param_set) > 0) {
    # Finalize the parameter set
    finalized_param_set <- param_set %>%
      finalize(data)
    
    # Generate the grid
    finalized_param_set <- finalized_param_set %>%
      grid_space_filling(size = size, 
                         original = TRUE,
                         type = type)
    
    workflow_sets <- workflow_sets %>%
      option_add(grid = finalized_param_set, id = workflow_id) %>%
      filter(wflow_id == workflow_id)
    
    workflow_sets
  } else {
    tibble::tibble()  # Return an empty tibble if no parameters
  }
}


# test function
generate_param_grid('formula_mars', all_workflows, parabolic)
#> # A workflow set/tibble: 1 × 4
#>   wflow_id     info             option    result    
#>   <chr>        <list>           <list>    <list>    
#> 1 formula_mars <tibble [1 × 4]> <opts[1]> <list [0]>

# Apply the function to each workflow
all_workflows_grid <- purrr::map(
  set_names(all_workflows$wflow_id),
  \(x) generate_param_grid(x, all_workflows, data= parabolic)
) %>%
  list_rbind()

all_workflows_res <- 
  all_workflows_grid %>% 
  workflow_map(resamples = train_resamples, 
               verbose = TRUE,
               control = control_grid(
                 save_pred = TRUE,
                 parallel_over = "everything",
                 save_workflow = TRUE)
               )
#> i 1 of 3 tuning:     formula_regularized
#> ✔ 1 of 3 tuning:     formula_regularized (2m 44.7s)
#> i 2 of 3 tuning:     formula_mars
#> ✔ 2 of 3 tuning:     formula_mars (2.5s)
#> i 3 of 3 tuning:     formula_cart
#> ✔ 3 of 3 tuning:     formula_cart (52.5s)

all_workflows_res %>%
  autoplot(rank_metric = "accuracy") +
  theme(legend.position = "bottom") +
  guides(color = guide_legend(nrow = 3))

image


sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 20.04.6 LTS
#> 
#> Matrix products: default
#> BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.8.so;  LAPACK version 3.9.0
#> 
#> locale:
#>  [1] LC_CTYPE=C.UTF-8       LC_NUMERIC=C           LC_TIME=C.UTF-8       
#>  [4] LC_COLLATE=C.UTF-8     LC_MONETARY=C.UTF-8    LC_MESSAGES=C.UTF-8   
#>  [7] LC_PAPER=C.UTF-8       LC_NAME=C              LC_ADDRESS=C          
#> [10] LC_TELEPHONE=C         LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C   
#> 
#> time zone: UTC
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] rpart_4.1.23       earth_5.3.3        plotmo_3.6.4       plotrix_3.8-4     
#>  [5] Formula_1.2-5      mda_0.5-4          class_7.3-22       klaR_1.7-3        
#>  [9] MASS_7.3-60.2      discrim_1.0.1      yardstick_1.3.1    workflowsets_1.1.0
#> [13] workflows_1.1.4    tune_1.2.1         tidyr_1.3.1        tibble_3.2.1      
#> [17] rsample_1.2.1      recipes_1.1.0      purrr_1.0.2        parsnip_1.2.1     
#> [21] modeldata_1.4.0    infer_1.0.7        ggplot2_3.5.1      dplyr_1.1.4       
#> [25] dials_1.3.0        scales_1.3.0       broom_1.0.7        tidymodels_1.2.0  
#> 
#> loaded via a namespace (and not attached):
#>  [1] conflicted_1.2.0    rlang_1.1.4         magrittr_2.0.3     
#>  [4] furrr_0.3.1         compiler_4.4.1      vctrs_0.6.5        
#>  [7] combinat_0.0-8      lhs_1.2.0           pkgconfig_2.0.3    
#> [10] fastmap_1.2.0       backports_1.5.0     labeling_0.4.3     
#> [13] utf8_1.2.4          promises_1.3.0      rmarkdown_2.28     
#> [16] prodlim_2024.06.25  haven_2.5.4         xfun_0.47          
#> [19] reprex_2.1.1        cachem_1.1.0        labelled_2.13.0    
#> [22] highr_0.11          later_1.3.2         parallel_4.4.1     
#> [25] prettyunits_1.2.0   R6_2.5.1            parallelly_1.38.0  
#> [28] lubridate_1.9.3     Rcpp_1.0.13         iterators_1.0.14   
#> [31] knitr_1.48          future.apply_1.11.2 httpuv_1.6.15      
#> [34] Matrix_1.7-0        splines_4.4.1       nnet_7.3-19        
#> [37] timechange_0.3.0    tidyselect_1.2.1    rstudioapi_0.16.0  
#> [40] yaml_2.3.10         timeDate_4041.110   codetools_0.2-20   
#> [43] miniUI_0.1.1.1      listenv_0.9.1       lattice_0.22-6     
#> [46] shiny_1.9.1         withr_3.0.1         evaluate_1.0.0     
#> [49] future_1.34.0       survival_3.6-4      pillar_1.9.0       
#> [52] foreach_1.5.2       generics_0.1.3      hms_1.1.3          
#> [55] munsell_0.5.1       globals_0.16.3      xtable_1.8-4       
#> [58] glue_1.8.0          tools_4.4.1         data.table_1.16.0  
#> [61] modelenv_0.1.1      gower_1.0.1         forcats_1.0.0      
#> [64] fs_1.6.4            grid_4.4.1          ipred_0.9-15       
#> [67] colorspace_2.1-1    cli_3.6.3           DiceDesign_1.10    
#> [70] fansi_1.0.6         lava_1.8.0          gtable_0.3.5       
#> [73] GPfit_1.0-8         digest_0.6.37       farver_2.1.2       
#> [76] memoise_2.0.1       htmltools_0.5.8.1   questionr_0.7.8    
#> [79] lifecycle_1.0.4     hardhat_1.4.0       mime_0.12

Created on 2024-10-03 with reprex v2.1.1

Here the internal single level workflow update function

#' Generate Parameter Grid for a Single Workflow
#'
#' @param workflow_id workflow id
#' @param workflow_sets workflow sets
#' @param data data frame to use for \code{finalize}
#' @param size size of the grid
#' @param type type of grid
#'
#' @return updated workflow set with parameter grid added to the workflow

.generate_param_grid <- function(workflow_id, workflow_sets, data, size = 50, type="latin_hypercube") {
  #browser()
  param_set <- workflow_sets %>%
    workflowsets::extract_parameter_set_dials(id = workflow_id)

  #print("param_set: \n")
  #print(param_set)

  if (nrow(param_set) > 0) {
    # Finalize the parameter set
    finalized_param_set <- param_set %>%
      finalize(data)

    #print("finalized_param_set: \n")
    #print(finalized_param_set)

    # Generate the grid
    finalized_param_set_space_filled <- finalized_param_set %>%
      dials::grid_space_filling(size = size,
                                original = TRUE,
                                type = type)

    #print("finalized_param_set_space_filled: \n")
    #print(finalized_param_set_space_filled)

    workflow_sets_with_grid <- workflow_sets %>%
      workflowsets::option_add(grid = finalized_param_set_space_filled, id = workflow_id) %>%
      dplyr::filter(wflow_id == workflow_id)

    #print("workflow_sets_with_grid: \n")
    #print(workflow_sets_with_grid$option[['grid']])

    workflow_sets_with_grid
  } else {
    tibble::tibble()  # Return an empty tibble if no parameters
  }
}

And this is the function:

#' Generate Parameter Grid for all Workflows
#'
#' @param x workflow set
#' @param df data frame to use for finalization
#' @param ... additional arguments passed to \code{.generate_param_grid}
#'
#' @return updated workflow set with parameter grid added to each workflow

generate_param_grid <- function(x, df, ...){
  #browser()
  purrr::map(
    purrr::set_names(x$wflow_id),
    \(z) .generate_param_grid(workflow_id = z, workflow_sets = x, data= df, ...)
  ) %>%
    purrr::list_rbind()
}

I am noticing in testing that if both the PreProcessing & Models have unfinalized parameters then the function needs to be applied twice:


all_workflows_grid  <- 
   generate_param_grid( # the outer function finalizes the PreProcessing parameters 
      # the inner function finalizes the model parameters 
      generate_param_grid(all_workflows, parabolic), 
      parabolic)