step_window() within grouping variable

Hi,

I'm wondering if there is any way to use the step_window() function in a recipe to do a rolling mean by group? The example shows how to do rolling calculations, but it doesn't appear that it's possible to specifify a grouping variable and do those rolling means within that grouping variable only?

I've got an example below that shows how it's currently done in step_window() and then how I could achieve this using the zoo package (there are obviously other ways to calculate rolling means).

  library(recipes)
  library(dplyr)
  library(rlang)
  library(zoo)

  
  set.seed(5522)
  sim_dat <- data.frame(x1 = (20:100) / 10)
  n <- nrow(sim_dat)
  sim_dat$group <- sort(rep(c("a", "b", "c"), times = n/3))
  sim_dat$y1 <- sin(sim_dat$x1) + rnorm(n, sd = 0.1)
  sim_dat$y2 <- cos(sim_dat$x1) + rnorm(n, sd = 0.1)
  sim_dat$x2 <- runif(n)
  sim_dat$x3 <- rnorm(n)
  
  # Using Recipes Example from help docs (can't do grouping)
  rec <- recipe(y1 + y2 ~ x1 + x2 + x3, data = sim_dat) %>%
    step_window(starts_with("y"),
                size = 7, statistic = "median",
                names = paste0("med_7pt_", 1:2),
                role = "outcome")
  
  rec <- prep(rec, training = sim_dat)
  
  smoothed_dat <- bake(rec, sim_dat, everything())
  
  # Using Zoo Package
  grouped_smoothed_dat <- sim_dat %>%
    group_by(group) %>%
    mutate(y1_rollgroup = zoo::rollmean(y1, k = 3, fill = NA))
  
  print(smoothed_dat)
#> # A tibble: 81 × 7
#>       x1     x2      x3    y1     y2 med_7pt_1 med_7pt_2
#>    <dbl>  <dbl>   <dbl> <dbl>  <dbl>     <dbl>     <dbl>
#>  1   2   0.686   0.885  1.03  -0.304     0.741    -0.612
#>  2   2.1 0.210   0.801  0.741 -0.455     0.741    -0.612
#>  3   2.2 0.385  -1.31   0.813 -0.582     0.741    -0.612
#>  4   2.3 0.271  -0.826  0.752 -0.612     0.741    -0.612
#>  5   2.4 0.140   0.0792 0.566 -0.646     0.566    -0.615
#>  6   2.5 0.582  -0.253  0.410 -0.615     0.495    -0.646
#>  7   2.6 0.0130  0.0261 0.479 -0.680     0.479    -0.680
#>  8   2.7 0.679   0.351  0.495 -0.828     0.466    -0.796
#>  9   2.8 0.701  -0.679  0.466 -0.983     0.410    -0.828
#> 10   2.9 0.905   1.45   0.281 -0.899     0.285    -0.899
#> # ℹ 71 more rows
  print(grouped_smoothed_dat, n = 30)
#> # A tibble: 81 × 7
#> # Groups:   group [3]
#>       x1 group      y1       y2     x2      x3 y1_rollgroup
#>    <dbl> <chr>   <dbl>    <dbl>  <dbl>   <dbl>        <dbl>
#>  1   2   a      1.03   -0.304   0.686   0.885      NA      
#>  2   2.1 a      0.741  -0.455   0.210   0.801       0.861  
#>  3   2.2 a      0.813  -0.582   0.385  -1.31        0.769  
#>  4   2.3 a      0.752  -0.612   0.271  -0.826       0.710  
#>  5   2.4 a      0.566  -0.646   0.140   0.0792      0.576  
#>  6   2.5 a      0.410  -0.615   0.582  -0.253       0.485  
#>  7   2.6 a      0.479  -0.680   0.0130  0.0261      0.461  
#>  8   2.7 a      0.495  -0.828   0.679   0.351       0.480  
#>  9   2.8 a      0.466  -0.983   0.701  -0.679       0.414  
#> 10   2.9 a      0.281  -0.899   0.905   1.45        0.344  
#> 11   3   a      0.285  -0.796   0.198   0.128       0.249  
#> 12   3.1 a      0.180  -1.14    0.443   0.535       0.145  
#> 13   3.2 a     -0.0309 -1.08    0.277   0.661       0.00704
#> 14   3.3 a     -0.128  -1.07    0.724  -1.30       -0.120  
#> 15   3.4 a     -0.202  -0.871   0.225  -0.175      -0.228  
#> 16   3.5 a     -0.354  -0.876   0.830  -0.0944     -0.353  
#> 17   3.6 a     -0.504  -0.867   0.600  -0.742      -0.461  
#> 18   3.7 a     -0.527  -0.895   0.276   0.453      -0.567  
#> 19   3.8 a     -0.669  -0.648   0.676   0.284      -0.669  
#> 20   3.9 a     -0.810  -0.828   0.374  -1.83       -0.753  
#> 21   4   a     -0.779  -0.715   0.160   0.589      -0.823  
#> 22   4.1 a     -0.881  -0.673   0.677  -1.10       -0.845  
#> 23   4.2 a     -0.875  -0.419   0.960   1.57       -0.889  
#> 24   4.3 a     -0.911  -0.301   0.329   0.766      -0.894  
#> 25   4.4 a     -0.895  -0.283   0.381  -0.484      -0.927  
#> 26   4.5 a     -0.973  -0.0302  0.999   0.401      -0.929  
#> 27   4.6 a     -0.919  -0.110   0.107  -1.17       NA      
#> 28   4.7 b     -1.10    0.173   0.527   0.402      NA      
#> 29   4.8 b     -1.05    0.00226 0.178   1.07       -1.08   
#> 30   4.9 b     -1.09    0.140   0.0762 -0.214      -1.07   
#> # ℹ 51 more rows

Sorry. We currently don't support that.

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.