Extending dials:: with 'filter' parameter from wavlets::dwt()

I really do promise I'm trying my best to self-teach here, but - moving on to the next step in tidymodels implementation - I tried to make the 'filter' parameter from the wavelets::dwt() function tunable, but for some odd reason:

  • If I enter the functions by hand in the global environment, it works fine - tunable(recipe(...) %>% step_dwt(..., filter = tune()) sees a tunable filter parameter, and uses it.
  • But if I wrap it up into a package, even with a filter = tune() argument, tunable() doesn't see the tunable filter argument, and so everything downstream fails.

Hard to do a minimal example here. I thought maybe the dials have to be in a separate package, but looking at embed:: that isn't the case - however, embed:: does have some fun .onLoad logic.

.onLoad <- function(libname, pkgname) {
  embed_exports <- getNamespaceExports(ns = "embed")
  names <- names(embed_exports)
  tunable_steps <- grep("tunable.step", embed_exports, fixed = TRUE, 
                        value = TRUE)
  for (i in tunable_steps) {
    s3_register("tune::tunable", i)
## a whole lot of other stuff follows - also relevant?

Do I also need to include some, or all the code above to make the tunable parameter visible? Once again, I suspect my limits are showing here.

Code for all of it:

step_dwt <- function(
  role = NA,
  trained = FALSE,
  ref_dist = NULL,
  filter = "haar",
  options = list(),
  skip = FALSE,
  id = recipes::rand_id("dwt")
) {

  ## The variable selectors are not immediately evaluated by using
  ##  the `quos()` function in `rlang`. `ellipse_check()` captures
  ##  the values and also checks to make sure that they are not empty.
  terms <- recipes::ellipse_check(...)

      terms = terms,
      trained = trained,
      role = role,
      ref_dist = ref_dist,
      filter = filter,
      options = options,
      skip = skip,
      id = id

step_dwt_new <-
  function(terms, role, trained, ref_dist, filter, options, skip, id) {
      subclass = "dwt",
      terms = terms,
      role = role,
      trained = trained,
      ref_dist = ref_dist,
      filter = filter,
      options = options,
      skip = skip,
      id = id

prep.step_dwt <- function(x, training, info = NULL, ...) {
  col_names <- recipes::terms_select(terms = x$terms, info = info)
  ## We actually only need the training set here
  ## Since there's nothing about the trained data that's useful
  ## you could probably even just return the variable names?
  ref_dist <- training[,col_names]

  ## Use the constructor function to return the updated object.
  ## Note that `trained` is now set to TRUE

    terms = x$terms,
    trained = TRUE,
    role = x$role,
    ref_dist = ref_dist,
    filter = x$filter,
    options = x$options,
    skip = x$skip,
    id = x$id

bake.step_dwt <- function(object, new_data, ...) {
  ## I use expr(), mod_call_args and eval to evaluate map_dwt
  ## this probably is a little aroundabout?
  vars <- names(object$ref_dist)
  dwt_call <- dplyr::expr(map_dwt_over_df(filter = NULL))
  dwt_call$filter <- dplyr::expr(object$filter)
  dwt_call$df <- dplyr::expr(new_data[,vars])

  new_data_cols <- eval(dwt_call)
  new_data <- dplyr::bind_cols(new_data, tibble::as_tibble(new_data_cols))
  ## get rid of the original columns
  ## -vars will not do this!
  new_data <-
    new_data[, !(colnames(new_data) %in% vars), drop = FALSE]

  ## Always convert to tibbles on the way out

dwt_filter <- function(values = values_dwt_filter) {
    type     = c("character"),
    values   = values,
    default  = "haar",
    label    = c(filter = "DWT Filter"),
    finalize = NULL

values_dwt_filter <- c(
  paste0("d", seq(2,20, by = 2)),
  paste0("la", seq(8,20, by = 2)),
  paste0("bl", c(14,18,20)),
  paste0("c", seq(6,30, by = 6))

tunable.step_dwt <- function(x, ...) {
    name = c("filter"),
    call_info = list(list(pkg = "stepdwt", fun = "dwt_filter")),
    source = "recipe",
    component = "step_dwt",
    component_id = x$id

The problem (I think) is with the S3 registration of the tunable() method. I'll work on a PR for your repo.

On a separate note, I'd rename the tuning parameter to filter_type or something. Using a function name (that is used in multiple packages) is going to bite you eventually.

All done in a PR.

As always Max - thank you so much for taking the time! I really appreciate it, and, again, hopefully this is helpful for anyone else working through implementing these. It's a huge gift to the community that you're around to work through this stuff, I don't take it lightly.

Good point about the filter parameter - that reminds me about another question I have about how rpart_train() is used in parsnip::, and generally how those model-wrapper functions are called, but I might spend some more time staring into the code before/instead of coming here.

If you are wondering why we use rlang::call2() so much, there are a few reasons.

First, when we don't use a wrapper, we can execute the model functions without having to include the corresponding package as a formal dependency.

The other reason is that is allows use to intercept arguments that we might want to modify that are passed in via .... We can save the dots as data, alter them, then splice them into the call object created by call2(). We talk about it in this blog post.

It may seem over-the-top but it is the best way that we could effectively do stuff (based on previous caret experiences).

Ah! That does make sense. I did read that blog post before, but it definitely means a lot more to me now on the other side of this journey.

Still not 100% on the quosure magic that make set_fit use rpart_train even though it references func = c(pkg = "rpart", fun = "rpart"),, but I'm motivated to try and figure it out myself first before I come back here!

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.