Dear All,
Please have a look at the snippet below.
I need to train an elastic net (from the glmnet package) on a small dataset.
For reasons we do not discuss here, the training set consists of all the observations apart from the most recent one, whereas the test set is one observation only.
Unfortunately, the code fails and I do not understand why. My choice of the test and training set may look odd, but there is nothing illegal about it.
Any suggestions is appreciated.
library(tidyverse)
library(tidymodels)
df_ini <- structure(list(year = c(1998, 2002, 2004, 2005, 2006, 2007, 2008,
2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018),
capital_n1132g_lag_1 = c(3446.5, 4091.1, 3655.1, 3633.3,
3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6,
3718.6, 3467.9, 4214.2, 4237.4, 4450.2), capital_n117g_lag_1 = c(4920.9,
7810.6, 8560.3, 8679.9, 8938.9, 9823.8, 10467.1, 11047.1,
11554.3, 11849.9, 13465.4, 13927.5, 15510.2, 15754.4, 16584.7,
17647.1, 18273.8), capital_n11mg_lag_1 = c(16846, 19605,
19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6,
20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2,
29790.1), employment_be_lag_1 = c(2834.42, 2839.72, 2765.53,
2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1,
2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9
), employment_c_lag_1 = c(2612.76, 2623.69, 2552.89, 2518.57,
2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1,
2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48), employment_j_lag_1 = c(292.93,
389.2, 389.45, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91,
392.18, 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4
), employment_k_lag_1 = c(505.33, 507.12, 510.25, 504.63,
515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98,
524.13, 518.89, 511.57, 505.32, 496.41), employment_mn_lag_1 = c(945.59,
1217.96, 1289.55, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76,
1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83,
2021.51, 2109.71), employment_oq_lag_1 = c(3065.87, 3191.75,
3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85,
3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59,
4171.72), employment_total_lag_1 = c(14509.58, 15127.99,
15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 16356.53,
16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 17039.6,
17142.13, 17365.32, 17650.21), gdp_b1gq_lag_1 = c(187849.7,
220525, 231862.5, 242348.3, 254075, 267824.4, 283978, 293761.9,
288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 333146.1,
344269.3, 357608, 369341.3), gdp_p3_lag_1 = c(139695.2, 161175.8,
169405.6, 176316.4, 185871.1, 194102, 200944.4, 208857.1,
213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6,
249404.3, 257166.5, 265900.2), gdp_p61_lag_1 = c(50117.6,
71948.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1,
91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3,
129183.6, 131524, 140057.8), gdp_p62_lag_1 = c(19441, 26444.4,
28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9,
39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5,
59584.7), price_index_lag_1 = c(1.2, 2.3, 1.3, 2, 2.1, 1.7,
2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 1, 2.2), value_be_lag_1 = c(40533.1,
48207.1, 48673.2, 50737.6, 52955.2, 56872.4, 60864.9, 61029,
56837.8, 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4,
71152.6, 72698.8), value_c_lag_1 = c(33441.8, 40446.6, 40467.4,
42014.6, 44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3,
51467.7, 53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196
), value_j_lag_1 = c(5483.7, 7326.1, 7934.1, 7756.1, 8134.2,
8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1,
10361.4, 10695.4, 11455.3, 11720.6), value_k_lag_1 = c(9210.6,
9977.3, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2,
12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9,
13236.4, 13744.1), value_mn_lag_1 = c(10444, 14061.4, 15706.6,
16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2,
24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6
), value_oq_lag_1 = c(29902.7, 34179.2, 36126.8, 37329.6,
38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9,
49381.5, 50261.7, 51624.3, 53715, 55926.4, 57637.1), value_total_lag_1 = c(167323.4,
197076.7, 207247.6, 216098.3, 225888.1, 239076, 253604.6,
262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1,
307037.7, 318952.7, 329396.1), capital_n1132g_lag_2 = c(3599.2,
3996.9, 3638.4, 3655.1, 3633.3, 3616.2, 3450.7, 3596.8, 3867.2,
3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 3467.9, 4214.2, 4237.4
), capital_n117g_lag_2 = c(4636.2, 7008.5, 8369.6, 8560.3,
8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9,
13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1), capital_n11mg_lag_2 = c(17181.5,
19677.8, 18749.6, 19381.2, 19433.5, 20051.6, 20569.8, 22646.1,
23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 24057.7, 23832.8,
25019.2, 27608.2), employment_be_lag_2 = c(2870.33, 2840.19,
2775.22, 2765.53, 2731.08, 2709.59, 2708.39, 2774.06, 2795.6,
2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 2725.66, 2735.69,
2750.52), employment_c_lag_2 = c(2626.2, 2621.08, 2562.53,
2552.89, 2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97,
2447.65, 2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97
), employment_j_lag_2 = c(275.08, 374.56, 400.75, 389.45,
387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 410.85,
419.75, 427.59, 438.96, 440.33, 460.84), employment_k_lag_2 = c(500.9,
505.13, 502.42, 510.25, 504.63, 515.39, 523.45, 536.6, 550.14,
546.68, 539.96, 536.58, 534.98, 524.13, 518.89, 511.57, 505.32
), employment_mn_lag_2 = c(904.38, 1143.78, 1248.01, 1289.55,
1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 1762.55,
1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51), employment_oq_lag_2 = c(3028.85,
3162.77, 3241.36, 3280.36, 3317.09, 3401.65, 3476.63, 3508.01,
3577.75, 3683.85, 3759.23, 3798.35, 3850.17, 3877.24, 3924.06,
4002.74, 4095.59), employment_total_lag_2 = c(14404.29, 15019.87,
15113.52, 15212.11, 15307.28, 15491.61, 15762.92, 16050.92,
16356.53, 16269.97, 16392.87, 16647.79, 16820.66, 16879.06,
17039.6, 17142.13, 17365.32), gdp_b1gq_lag_2 = c(186928.7,
213606.4, 226735.3, 231862.5, 242348.3, 254075, 267824.4,
283978, 293761.9, 288044.1, 295896.6, 310128.6, 318653.1,
323910.2, 333146.1, 344269.3, 357608), gdp_p3_lag_2 = c(140335.8,
156117.3, 164107.8, 169405.6, 176316.4, 185871.1, 194102,
200944.4, 208857.1, 213630.1, 218947.2, 227250.8, 233638.1,
238329.3, 243860.6, 249404.3, 257166.5), gdp_p61_lag_2 = c(44541.4,
67701.6, 74691.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2,
113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1,
126109.3, 129183.6, 131524), gdp_p62_lag_2 = c(19504.2, 24888.9,
28063.4, 28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8,
38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8,
55885.5), value_be_lag_2 = c(40076.7, 46109.4, 47967.1, 48673.2,
50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6,
61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6), value_c_lag_2 = c(32955.4,
38908.4, 40192.9, 40467.4, 42014.6, 44229, 47735.5, 51552.4,
51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169,
57458.7, 60962.8), value_j_lag_2 = c(5576.8, 6313.9, 7737.1,
7934.1, 7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9,
9217.1, 9405.1, 9802.1, 10361.4, 10695.4, 11455.3), value_k_lag_2 = c(9191,
10458, 10225.2, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7,
13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4,
13482.9, 13236.4), value_mn_lag_2 = c(10092, 12942.5, 15074,
15706.6, 16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490,
23255.2, 24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7
), value_oq_lag_2 = c(30224.3, 33251.5, 35065.6, 36126.8,
37329.6, 38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6,
47980.9, 49381.5, 50261.7, 51624.3, 53715, 55926.4), value_total_lag_2 = c(167141.8,
190624.9, 202353.5, 207247.6, 216098.3, 225888.1, 239076,
253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3,
297230.1, 307037.7, 318952.7), berd = c(2146.085, 3130.884,
3556.479, 4207.669, 4448.676, 4845.861, 5232.63, 5092.902,
5520.422, 5692.841, 6540.457, 6778.42, 7324.679, 7498.488,
7824.51, 7888.444, 8461.72)), row.names = c(NA, -17L), class = c("tbl_df",
"tbl", "data.frame"))
set.seed(1234) ## to make the results reproducible
## I need a particular custom split of my dataset: the test set consists of only the most recent observation, whereas all the rest is the training set
## see https://github.com/tidymodels/rsample/issues/158
indices <-
list(analysis = seq(nrow(df_ini)-1),
assessment = nrow(df_ini)
)
df_split <- make_splits(indices, df_ini)
## df_split <- initial_split(df_ini) ## with the default splitting,
## ## the code works
df_train <- training(df_split)
df_test <- testing(df_split)
folded_data <- vfold_cv(df_train,3)
glmnet_recipe <-
recipe(formula = berd ~ ., data = df_train) %>%
update_role(year, new_role = "ID") %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors(), -all_nominal())
glmnet_spec <-
linear_reg(penalty = tune(), mixture = tune()) %>%
set_mode("regression") %>%
set_engine("glmnet")
glmnet_workflow <-
workflow() %>%
add_recipe(glmnet_recipe) %>%
add_model(glmnet_spec)
glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05,
0.2, 0.4, 0.6, 0.8, 1))
glmnet_tune <-
tune_grid(glmnet_workflow, resamples = folded_data, grid = glmnet_grid,control = control_grid(save_pred = TRUE) )
print(collect_metrics(glmnet_tune))
#> # A tibble: 240 x 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Model001
#> 2 0.000001 0.05 rsq standard 0.929 3 0.0420 Model001
#> 3 0.00000183 0.05 rmse standard 375. 3 48.9 Model002
#> 4 0.00000183 0.05 rsq standard 0.929 3 0.0420 Model002
#> 5 0.00000336 0.05 rmse standard 375. 3 48.9 Model003
#> 6 0.00000336 0.05 rsq standard 0.929 3 0.0420 Model003
#> 7 0.00000616 0.05 rmse standard 375. 3 48.9 Model004
#> 8 0.00000616 0.05 rsq standard 0.929 3 0.0420 Model004
#> 9 0.0000113 0.05 rmse standard 375. 3 48.9 Model005
#> 10 0.0000113 0.05 rsq standard 0.929 3 0.0420 Model005
#> # … with 230 more rows
print(show_best(glmnet_tune, "rmse"))
#> # A tibble: 5 x 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Model001
#> 2 0.00000183 0.05 rmse standard 375. 3 48.9 Model002
#> 3 0.00000336 0.05 rmse standard 375. 3 48.9 Model003
#> 4 0.00000616 0.05 rmse standard 375. 3 48.9 Model004
#> 5 0.0000113 0.05 rmse standard 375. 3 48.9 Model005
best_net <- select_best(glmnet_tune, "rmse")
final_net <- finalize_workflow(
glmnet_workflow,
best_net
)
final_res_net <- last_fit(final_net, df_split)
#> x : internal: Error in data.frame(..., check.names = FALSE): arguments imply...
#> Warning: All models failed in [fit_resamples()]. See the `.notes` column.
print(final_res_net)
#> Warning: This tuning result has notes. Example notes on model fitting include:
#> internal: Error in data.frame(..., check.names = FALSE): arguments imply differing number of rows: 2, 0
#> # Resampling results
#> # Monte Carlo cross-validation (0.94/0.059) with 1 resamples
#> # A tibble: 1 x 5
#> splits id .metrics .notes .predictions
#> <list> <chr> <list> <list> <list>
#> 1 <split [16/1]> train/test split <NULL> <tibble [1 × 1]> <NULL>
final_fit <- final_res_net %>%
collect_predictions()
Created on 2020-10-15 by the reprex package (v0.3.0.9001)