Hi,
I'm wondering if there is any way to use the step_window()
function in a recipe to do a rolling mean by group? The example shows how to do rolling calculations, but it doesn't appear that it's possible to specifify a grouping variable and do those rolling means within that grouping variable only?
I've got an example below that shows how it's currently done in step_window()
and then how I could achieve this using the zoo
package (there are obviously other ways to calculate rolling means).
library(recipes)
library(dplyr)
library(rlang)
library(zoo)
set.seed(5522)
sim_dat <- data.frame(x1 = (20:100) / 10)
n <- nrow(sim_dat)
sim_dat$group <- sort(rep(c("a", "b", "c"), times = n/3))
sim_dat$y1 <- sin(sim_dat$x1) + rnorm(n, sd = 0.1)
sim_dat$y2 <- cos(sim_dat$x1) + rnorm(n, sd = 0.1)
sim_dat$x2 <- runif(n)
sim_dat$x3 <- rnorm(n)
# Using Recipes Example from help docs (can't do grouping)
rec <- recipe(y1 + y2 ~ x1 + x2 + x3, data = sim_dat) %>%
step_window(starts_with("y"),
size = 7, statistic = "median",
names = paste0("med_7pt_", 1:2),
role = "outcome")
rec <- prep(rec, training = sim_dat)
smoothed_dat <- bake(rec, sim_dat, everything())
# Using Zoo Package
grouped_smoothed_dat <- sim_dat %>%
group_by(group) %>%
mutate(y1_rollgroup = zoo::rollmean(y1, k = 3, fill = NA))
print(smoothed_dat)
#> # A tibble: 81 × 7
#> x1 x2 x3 y1 y2 med_7pt_1 med_7pt_2
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 2 0.686 0.885 1.03 -0.304 0.741 -0.612
#> 2 2.1 0.210 0.801 0.741 -0.455 0.741 -0.612
#> 3 2.2 0.385 -1.31 0.813 -0.582 0.741 -0.612
#> 4 2.3 0.271 -0.826 0.752 -0.612 0.741 -0.612
#> 5 2.4 0.140 0.0792 0.566 -0.646 0.566 -0.615
#> 6 2.5 0.582 -0.253 0.410 -0.615 0.495 -0.646
#> 7 2.6 0.0130 0.0261 0.479 -0.680 0.479 -0.680
#> 8 2.7 0.679 0.351 0.495 -0.828 0.466 -0.796
#> 9 2.8 0.701 -0.679 0.466 -0.983 0.410 -0.828
#> 10 2.9 0.905 1.45 0.281 -0.899 0.285 -0.899
#> # ℹ 71 more rows
print(grouped_smoothed_dat, n = 30)
#> # A tibble: 81 × 7
#> # Groups: group [3]
#> x1 group y1 y2 x2 x3 y1_rollgroup
#> <dbl> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 2 a 1.03 -0.304 0.686 0.885 NA
#> 2 2.1 a 0.741 -0.455 0.210 0.801 0.861
#> 3 2.2 a 0.813 -0.582 0.385 -1.31 0.769
#> 4 2.3 a 0.752 -0.612 0.271 -0.826 0.710
#> 5 2.4 a 0.566 -0.646 0.140 0.0792 0.576
#> 6 2.5 a 0.410 -0.615 0.582 -0.253 0.485
#> 7 2.6 a 0.479 -0.680 0.0130 0.0261 0.461
#> 8 2.7 a 0.495 -0.828 0.679 0.351 0.480
#> 9 2.8 a 0.466 -0.983 0.701 -0.679 0.414
#> 10 2.9 a 0.281 -0.899 0.905 1.45 0.344
#> 11 3 a 0.285 -0.796 0.198 0.128 0.249
#> 12 3.1 a 0.180 -1.14 0.443 0.535 0.145
#> 13 3.2 a -0.0309 -1.08 0.277 0.661 0.00704
#> 14 3.3 a -0.128 -1.07 0.724 -1.30 -0.120
#> 15 3.4 a -0.202 -0.871 0.225 -0.175 -0.228
#> 16 3.5 a -0.354 -0.876 0.830 -0.0944 -0.353
#> 17 3.6 a -0.504 -0.867 0.600 -0.742 -0.461
#> 18 3.7 a -0.527 -0.895 0.276 0.453 -0.567
#> 19 3.8 a -0.669 -0.648 0.676 0.284 -0.669
#> 20 3.9 a -0.810 -0.828 0.374 -1.83 -0.753
#> 21 4 a -0.779 -0.715 0.160 0.589 -0.823
#> 22 4.1 a -0.881 -0.673 0.677 -1.10 -0.845
#> 23 4.2 a -0.875 -0.419 0.960 1.57 -0.889
#> 24 4.3 a -0.911 -0.301 0.329 0.766 -0.894
#> 25 4.4 a -0.895 -0.283 0.381 -0.484 -0.927
#> 26 4.5 a -0.973 -0.0302 0.999 0.401 -0.929
#> 27 4.6 a -0.919 -0.110 0.107 -1.17 NA
#> 28 4.7 b -1.10 0.173 0.527 0.402 NA
#> 29 4.8 b -1.05 0.00226 0.178 1.07 -1.08
#> 30 4.9 b -1.09 0.140 0.0762 -0.214 -1.07
#> # ℹ 51 more rows