I'm unable to sort out how to get use tune, in the person of tune::tune_bayes() or tune::tune_grid() to help select a parameter in custom recipe step. In looking into why that might be, I have chosen to try to tune_grid() a workflow that includes my custom recipe step (step_adstock) with a fixed parameter (ie retention) value. When I do that, tune_grid() reports that there is no prep method for the step_adstock class.
I have prep.step_adstock defined in the environment I'm calling tune_grid from, and I can see prep listed as an available method when I try methods(class='step_adstock').
Any advice is appreciated -- a reproducible (but lengthy) example is pasted in below:
library(tidymodels)
library(tidyverse)
#modified from Julia Silge's blog: https://juliasilge.com/blog/sf-trees-random-tuning/.
sf_trees <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-01-28/sf_trees.csv")
trees_df <- sf_trees %>%
mutate(
legal_status = case_when(
legal_status == "DPW Maintained" ~ legal_status,
TRUE ~ "Other"
),
plot_size = parse_number(plot_size)
) %>%
select(-address) %>%
na.omit() %>%
mutate_if(is.character, factor)
set.seed(123)
trees_split <- initial_split(trees_df, strata = legal_status)
trees_train <- training(trees_split)
trees_test <- testing(trees_split)
#model spec
tune_spec <- rand_forest(
mtry = tune(),
trees = 20,
min_n = 3 #originally tune()
) %>%
set_mode("classification") %>%
set_engine("ranger")
#cv folds
trees_folds <- vfold_cv(trees_train,v=2)
### several functions to serve methods for the adstock step class, essentially copied out of this directly -- https://www.tidymodels.org/learn/develop/recipes/
step_adstock <- function(
recipe,
...,
role = NA,
trained = FALSE,
options = list( names = TRUE), #change to be range of retention
skip = FALSE,
retention=.5,
adstocks=NULL,
id = rand_id("adstock")
) {
add_step(
recipe,
step_adstock_new(
terms = enquos(...),
trained = trained,
role = role,
options = options,
skip = skip,
id = id,
retention=retention,
adstocks=adstocks
)
)
}
step_adstock_new <-
function(terms, role, trained, retention, adstocks, options, skip, id) {
step(
subclass = "adstock",
terms = terms,
role = role,
trained = trained,
adstocks=adstocks,
retention=retention,
options = options,
skip = skip,
id = id
)
}
prep.step_adstock <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
if (x$options$names == FALSE) {
rlang::abort("`names` should be set to TRUE")
}
step_adstock_new(terms=x$terms,
role=x$role,
trained=TRUE,
retention=x$retention,
adstocks=col_names,
options=x$options,
skip=x$skip,
id=x$id
)
}
bake.step_adstock<-function(object,new_data,...){
vars<-names(object$adstocks)
groupings<-groups(new_data)
list_stocks<-new_data[,vars] %>% reframe(across(everything(), function(x){ stats::filter(x,object$retention,'recursive')}))
new_data[,vars]<-lapply(list_stocks,as.numeric,na.rm=T)
if(length(groupings)>0){
as_tibble(new_data) %>% group_by(lapply(groupings,sym))}
else{as_tibble(new_data)}
}
print.step_adstock <-
function(x, width = max(20, options()$width - 35), ...) {
title<-"Adstock Transformation on "
print_step(
# Names before prep (could be selectors)
untr_obj = x$terms,
# Names after prep:
tr_obj = names(x$adstocks),
# Has it been prepped?
trained = x$trained,
# An estimate of how many characters to print on a line:
width = width,
title=paste("Adstock Transformation with retention",x$retention,"on"),
case_weights=x$case_weights
)
invisible(x)
}
tunable.step_adstock <- function (x, ...) {
tibble::tibble(
name = c("retention"),
call_info = list(list( fun = "retention")),
source = "recipe",
component = "step_adstock",
component_id = x$id
)
}
### function to create a parameter for dials, following https://www.tidymodels.org/learn/develop/parameters/
retention<-function(range=c(0,.8)){new_quant_param(type='double',range=range,inclusive=c(TRUE,TRUE),
label=c(retention='retention'),finalize = NULL)}
#######################################################################
##Now the example of the prep method not being found in tune_grid can actually begin
######################################################################
#recipe no custom
no_custom_rec <- recipe(legal_status ~ ., data = trees_train) %>%
update_role(tree_id, new_role = "ID") %>%
step_pca(all_numeric(),num_comp =2)
#check that it will prep
prep(no_custom_rec)
#custom step_adstock added
custom_rec<-recipe(legal_status ~ ., data = trees_train) %>%
update_role(tree_id, new_role = "ID") %>%
step_pca(all_numeric(),num_comp =2) %>% step_adstock(all_numeric(),retention=.5)
#check it will prep
prep(custom_rec)
#put tune() calls in
no_custom_tune_rec<-recipe(legal_status ~ ., data = trees_train) %>%
update_role(tree_id, new_role = "ID") %>%
step_pca(all_numeric(),num_comp = tune())
#no_custom_tune_rec won't prep because tune() call leaves num_comp undecided -- this is expected
custom_tune_rec<-recipe(legal_status ~ ., data = trees_train) %>%
update_role(tree_id, new_role = "ID") %>%
step_pca(all_numeric(),num_comp = tune()) %>% step_adstock(all_numeric(),retention = .5)
#custom_tune_rec won't prep because tune() call leaves num_comp undecided -- this is expected
#two workflows and parameter sets to tune, one with custom and one without:
tune_no_custom_wf <- workflow() %>%
add_recipe(no_custom_tune_rec) %>%
add_model(tune_spec)
tune_custom_wf<- workflow() %>%
add_recipe(custom_tune_rec) %>%
add_model(tune_spec)
tune_these_parms<-extract_parameter_set_dials(tune_custom_wf) %>% finalize(trees_train)
#two tune_grid calls -- no_custom works, custom fails because there is no prep method for step_adstock?
#no custom -- works
doParallel::registerDoParallel()
no_custom_results <- tune_grid(
tune_no_custom_wf,
resamples = trees_folds,
grid = 1,
param_info = tune_these_parms
)
doParallel::stopImplicitCluster()
# custom -- fails
doParallel::registerDoParallel()
custom_results <- tune_grid(
tune_custom_wf,
resamples = trees_folds,
grid = 1,
param_info = tune_these_parms
)
doParallel::stopImplicitCluster()
#show_notes(.Last.tune.result)
# unique notes:
# ───────────────────────────────────────────────────────────────────────────────────────────
# Error in `step_adstock()`:
# Caused by error in `UseMethod()`:
# ! no applicable method for 'prep' applied to an object of class "c('step_adstock', 'step')"
And a sessionInfo dump:
sessionInfo()
R version 4.2.2 (2022-10-31 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19044)
Matrix products: default
locale:
[1] LC_COLLATE=English_United States.utf8 LC_CTYPE=English_United States.utf8 LC_MONETARY=English_United States.utf8
[4] LC_NUMERIC=C LC_TIME=English_United States.utf8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] lubridate_1.9.2 forcats_1.0.0 stringr_1.5.0 readr_2.1.4 tidyverse_2.0.0 yardstick_1.1.0 workflowsets_1.0.0
[8] workflows_1.1.2 tune_1.0.1 tidyr_1.3.0 tibble_3.2.1 rsample_1.1.1 recipes_1.0.7 purrr_1.0.1
[15] parsnip_1.0.3 modeldata_1.0.1 infer_1.0.4 ggplot2_3.4.2 dplyr_1.1.2 dials_1.1.0 scales_1.2.1
[22] broom_1.0.5 tidymodels_1.0.0
loaded via a namespace (and not attached):
[1] nlme_3.1-160 fs_1.5.2 blastula_0.3.3 bigrquery_1.4.1.9000 bit64_4.0.5
[6] doParallel_1.0.17 DiceDesign_1.9 httr_1.4.4 tools_4.2.2 backports_1.4.1
[11] utf8_1.2.2 R6_2.5.1 rpart_4.1.19 DBI_1.1.3 mgcv_1.8-41
[16] colorspace_2.0-3 nnet_7.3-18 withr_2.5.0 tidyselect_1.2.0 curl_4.3.3
[21] bit_4.0.5 compiler_4.2.2 cli_3.6.1 arrow_10.0.1 odbc_1.3.3
[26] tidytable_0.9.1 mvnfast_0.2.8 digest_0.6.30 pkgconfig_2.0.3 htmltools_0.5.4
[31] parallelly_1.32.1 lhs_1.1.6 dbplyr_2.3.3 fastmap_1.1.0 rlang_1.1.0
[36] rstudioapi_0.14 generics_0.1.3 jsonlite_1.8.7 vroom_1.6.0 googlesheets4_1.0.1
[41] magrittr_2.0.3 scam_1.2-13 patchwork_1.1.2 Matrix_1.5-1 GPfit_1.0-8
[46] Rcpp_1.0.9 munsell_0.5.0 fansi_1.0.3 furrr_0.3.1 gratia_0.8.1
[51] lifecycle_1.0.3 stringi_1.7.8 MASS_7.3-58.1 grid_4.2.2 blob_1.2.3
[56] listenv_0.8.0 parallel_4.2.2 crayon_1.5.2 librarian_1.8.1 lattice_0.20-45
[61] haven_2.5.1 splines_4.2.2 hms_1.1.2 pillar_1.9.0 ranger_0.14.1
[66] future.apply_1.10.0 codetools_0.2-18 glue_1.6.2 data.table_1.14.8 vctrs_0.6.1
[71] tzdb_0.3.0 foreach_1.5.2 cellranger_1.1.0 gtable_0.3.3 future_1.29.0
[76] assertthat_0.2.1 gower_1.0.0 prodlim_2019.11.13 class_7.3-20 survival_3.4-0
[81] googledrive_2.0.0 ChannelAttribution_2.0.6 gargle_1.2.1 timeDate_4021.106 iterators_1.0.14
[86] hardhat_1.3.0 lava_1.7.0 globals_0.16.2 timechange_0.1.1 ellipsis_0.3.2
[91] ipred_0.9-13