Is there something that I am missing to create a nonregular grid? the help for tune_grid
says that the grid
should be A data frame of tuning combinations or a positive integer. The data frame should have columns for each parameter being tuned and rows for tuning parameter candidates. An integer denotes the number of candidate parameter sets to be created automatically.
so in this example I create a data frame of the tuning parameters but now I get an error from workflow_map
library('tidymodels')
library('workflowsets')
tidymodels_prefer()
data(parabolic)
parabolic <- parabolic
str(parabolic)
#> tibble [500 × 3] (S3: tbl_df/tbl/data.frame)
#> $ X1 : num [1:500] 3.29 1.47 1.66 1.6 2.17 ...
#> $ X2 : num [1:500] 1.661 0.414 0.791 0.276 3.166 ...
#> $ class: Factor w/ 2 levels "Class1","Class2": 1 2 2 2 1 1 2 1 2 1 ...
set.seed(1)
split <- initial_split(parabolic)
train_set <- training(split)
test_set <- testing(split)
library(discrim)
mars_disc_spec <-
discrim_flexible(prod_degree = tune()) %>%
set_engine("earth")
reg_disc_sepc <-
discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>%
set_engine("klaR")
cart_spec <-
decision_tree(cost_complexity = tune(), min_n = tune()) %>%
set_engine("rpart") %>%
set_mode("classification")
set.seed(2)
train_resamples <- bootstraps(train_set)
all_workflows <-
workflow_set(
preproc = list("formula" = class ~ .),
models = list(regularized = reg_disc_sepc,
mars = mars_disc_spec,
cart = cart_spec)
)
# Function to finalize and generate parameter grid
generate_param_grid <- function(workflow_id, workflow_sets, data, size = 50, type="latin_hypercube") {
param_set <- workflow_sets %>%
extract_parameter_set_dials(id = workflow_id)
if (nrow(param_set) > 0) {
# Finalize the parameter set
finalized_param_set <- param_set %>%
finalize(data)
# Generate the grid
finalized_param_set %>%
grid_space_filling(size = size,
original = TRUE,
type = type)
} else {
tibble::tibble() # Return an empty tibble if no parameters
}
}
# Apply the function to each workflow
param_grids <- purrr::map(
set_names(all_workflows$wflow_id),
\(x) generate_param_grid(x, all_workflows, parabolic)
) %>%
list_rbind()
# Print the parameter grids
param_grids
#> # A tibble: 102 × 5
#> frac_common_cov frac_identity prod_degree cost_complexity min_n
#> <dbl> <dbl> <int> <dbl> <int>
#> 1 0.874 0.0399 NA NA NA
#> 2 0.398 0.244 NA NA NA
#> 3 0.524 0.523 NA NA NA
#> 4 0.786 0.542 NA NA NA
#> 5 0.994 0.923 NA NA NA
#> 6 0.701 0.993 NA NA NA
#> 7 0.148 0.474 NA NA NA
#> 8 0.774 0.313 NA NA NA
#> 9 0.371 0.747 NA NA NA
#> 10 0.554 0.859 NA NA NA
#> # ℹ 92 more rows
all_workflows_res <-
all_workflows %>%
# Specifying arguments here adds to any previously set with `option_add()`:
workflow_map(resamples = train_resamples,
grid = param_grids,
verbose = TRUE)
#> i 1 of 3 tuning: formula_regularized
#> ✖ 1 of 3 tuning: formula_regularized failed with: Error in check_grid(grid = grid, workflow = workflow, pset = pset) : The provided `grid` has the following parameter columns that have not been marked for tuning by `tune()`: 'prod_degree', 'cost_complexity', 'min_n'.
#> i 2 of 3 tuning: formula_mars
#> ✖ 2 of 3 tuning: formula_mars failed with: Error in check_grid(grid = grid, workflow = workflow, pset = pset) : The provided `grid` has the following parameter columns that have not been marked for tuning by `tune()`: 'frac_common_cov', 'frac_identity', 'cost_complexity', 'min_n'.
#> i 3 of 3 tuning: formula_cart
#> ✖ 3 of 3 tuning: formula_cart failed with: Error in check_grid(grid = grid, workflow = workflow, pset = pset) : The provided `grid` has the following parameter columns that have not been marked for tuning by `tune()`: 'frac_common_cov', 'frac_identity', 'prod_degree'.
sessionInfo()
#> R version 4.4.1 (2024-06-14 ucrt)
#> Platform: x86_64-w64-mingw32/x64
#> Running under: Windows 11 x64 (build 22631)
#>
#> Matrix products: default
#>
#>
#> locale:
#> [1] LC_COLLATE=English_United States.utf8
#> [2] LC_CTYPE=English_United States.utf8
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C
#> [5] LC_TIME=English_United States.utf8
#>
#> time zone: America/New_York
#> tzcode source: internal
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] discrim_1.0.1 yardstick_1.3.1 workflowsets_1.1.0 workflows_1.1.4
#> [5] tune_1.2.1 tidyr_1.3.1 tibble_3.2.1 rsample_1.2.1
#> [9] recipes_1.1.0 purrr_1.0.2 parsnip_1.2.1 modeldata_1.4.0
#> [13] infer_1.0.7 ggplot2_3.5.1 dplyr_1.1.4 dials_1.3.0
#> [17] scales_1.3.0 broom_1.0.6 tidymodels_1.2.0
#>
#> loaded via a namespace (and not attached):
#> [1] conflicted_1.2.0 rlang_1.1.4 magrittr_2.0.3
#> [4] furrr_0.3.1 compiler_4.4.1 vctrs_0.6.5
#> [7] combinat_0.0-8 lhs_1.2.0 pkgconfig_2.0.3
#> [10] fastmap_1.2.0 backports_1.5.0 utf8_1.2.4
#> [13] promises_1.3.0 rmarkdown_2.28 prodlim_2024.06.25
#> [16] haven_2.5.4 klaR_1.7-3 xfun_0.47
#> [19] reprex_2.1.1 cachem_1.1.0 labelled_2.13.0
#> [22] highr_0.11 later_1.3.2 parallel_4.4.1
#> [25] R6_2.5.1 parallelly_1.38.0 rpart_4.1.23
#> [28] lubridate_1.9.3 Rcpp_1.0.13 iterators_1.0.14
#> [31] knitr_1.48 future.apply_1.11.2 httpuv_1.6.15
#> [34] Matrix_1.7-0 splines_4.4.1 nnet_7.3-19
#> [37] timechange_0.3.0 tidyselect_1.2.1 rstudioapi_0.16.0
#> [40] yaml_2.3.10 timeDate_4032.109 codetools_0.2-20
#> [43] miniUI_0.1.1.1 listenv_0.9.1 lattice_0.22-6
#> [46] shiny_1.9.1 withr_3.0.1 evaluate_0.24.0
#> [49] future_1.34.0 survival_3.7-0 pillar_1.9.0
#> [52] foreach_1.5.2 generics_0.1.3 hms_1.1.3
#> [55] munsell_0.5.1 plotmo_3.6.4 globals_0.16.3
#> [58] xtable_1.8-4 class_7.3-22 glue_1.8.0
#> [61] mda_0.5-4 tools_4.4.1 data.table_1.16.0
#> [64] gower_1.0.1 forcats_1.0.0 fs_1.6.4
#> [67] grid_4.4.1 plotrix_3.8-4 ipred_0.9-15
#> [70] colorspace_2.1-1 earth_5.3.3 Formula_1.2-5
#> [73] cli_3.6.3 DiceDesign_1.10 fansi_1.0.6
#> [76] lava_1.8.0 gtable_0.3.5 GPfit_1.0-8
#> [79] digest_0.6.37 memoise_2.0.1 htmltools_0.5.8.1
#> [82] questionr_0.7.8 lifecycle_1.0.4 hardhat_1.4.0
#> [85] mime_0.12 MASS_7.3-61