I am trying to fit a RuleFit model using the xrf engine in tidymodels with the default values for all hyperparameters. Unfortunately I cannot obtain predicted values from the model as I received the following error message
Error in match(x, table, nomatch = 0L) :
'match' requires vector arguments
Example code using mtcars below:
set.seed(13)
options(scipen=999)
library(tidymodels)
library(rules)
library(xrf)
data.mtcars <- mtcars
data.mtcars.split <- initial_split(data.mtcars,
prop = 0.9)
data.mtcars.train <- training(data.mtcars.split)
data.mtcars.test <- testing(data.mtcars.split)
spec_untuned_rule <- rule_fit() %>%
set_engine("xrf", family = "gaussian") %>%
set_mode("regression")
rec_untuned_rule <- recipe(mpg ~ ., data = data.mtcars.train) %>%
step_center(all_predictors()) %>% # mean zero
step_scale(all_predictors()) # standard deviation one
workflow_untuned_rule <- workflow() %>%
add_recipe(rec_untuned_rule) %>%
add_model(spec_untuned_rule)
system.time(model_rule <- fit(workflow_untuned_rule,
data = data.mtcars.train))
predictions_train_rule <- predict(model_rule,
data.mtcars.train) %>%
bind_cols(data.mtcars.train) # combine w/ original dataset
Fitting the same model using the xrf package and obtaining predictions works fine, so I am wondering whether I am missing anything in my code to get this to work using tidymodels. Any help is much appreciated.
model_rule <- xrf(age ~ . -f.eid, data.mtcars.train,
family = 'gaussian')
test <- predict(test, newdata = data.mtcars.train)
session info below
> sessionInfo()
R version 4.1.1 (2021-08-10)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.5 LTS
Matrix products: default
BLAS: /software/spackages_prod/apps/linux-ubuntu20.04-zen2/gcc-9.4.0/r-4.1.1-dqazhid3su5gd6f2sexu2cfhyba7ke6m/rlib/R/lib/libRblas.so
LAPACK: /software/spackages_prod/apps/linux-ubuntu20.04-zen2/gcc-9.4.0/r-4.1.1-dqazhid3su5gd6f2sexu2cfhyba7ke6m/rlib/R/lib/libRlapack.so
locale:
[1] LC_CTYPE=en_GB.UTF-8 LC_NUMERIC=C LC_TIME=en_GB.UTF-8 LC_COLLATE=en_GB.UTF-8 LC_MONETARY=en_GB.UTF-8
[6] LC_MESSAGES=en_GB.UTF-8 LC_PAPER=en_GB.UTF-8 LC_NAME=C LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] xrf_0.2.2 yardstick_1.1.0 workflowsets_1.0.0 workflows_1.1.0 tune_1.0.1 tidyr_1.2.1 tibble_3.1.8 rsample_1.1.0
[9] recipes_1.0.1 purrr_1.0.0 parsnip_1.0.2 modeldata_1.0.1 infer_1.0.3 ggplot2_3.4.0 dplyr_1.0.10 dials_1.0.0
[17] scales_1.2.1 broom_1.0.1 tidymodels_1.0.0
loaded via a namespace (and not attached):
[1] nlme_3.1-160 matrixStats_0.63.0 lubridate_1.8.0 DiceDesign_1.9 RColorBrewer_1.1-3 tools_4.1.1 backports_1.4.1 utf8_1.2.2
[9] R6_2.5.1 rpart_4.1.16 rules_1.0.0 DBI_1.1.3 mgcv_1.8-40 colorspace_2.0-3 nnet_7.3-18 withr_2.5.0
[17] tidyselect_1.2.0 gridExtra_2.3 compiler_4.1.1 glmnet_4.1-2 cli_3.5.0 labeling_0.4.2 fuzzyjoin_0.1.6 stringr_1.5.0
[25] digest_0.6.29 pkgconfig_2.0.3 parallelly_1.32.1 lhs_1.1.5 rlang_1.0.6 rstudioapi_0.14 farver_2.1.1 shape_1.4.6
[33] generics_0.1.3 jsonlite_1.8.2 BiocParallel_1.28.3 magrittr_2.0.3 Matrix_1.5-1 Rcpp_1.0.9 munsell_0.5.0 fansi_1.0.3
[41] GPfit_1.0-8 lifecycle_1.0.3 furrr_0.3.1 stringi_1.7.8 MASS_7.3-58.1 plyr_1.8.8 grid_4.1.1 parallel_4.1.1
[49] listenv_0.8.0 ggrepel_0.9.2 crayon_1.5.2 lattice_0.20-45 splines_4.1.1 pillar_1.8.1 igraph_1.3.5 xgboost_1.4.1.1
[57] corpcor_1.6.10 future.apply_1.9.1 reshape2_1.4.4 codetools_0.2-18 mixOmics_6.23.3 glue_1.6.2 data.table_1.14.2 vctrs_0.5.1
[65] foreach_1.5.2 gtable_0.3.1 future_1.28.0 assertthat_0.2.1 gower_1.0.0 prodlim_2019.11.13 RSpectra_0.16-1 class_7.3-20
[73] survival_3.4-0 timeDate_4021.106 rARPACK_0.11-0 iterators_1.0.14 plsmod_1.0.0 hardhat_1.2.0 ellipse_0.4.3 lava_1.6.10
[81] globals_0.16.1 ellipsis_0.3.2 ipred_0.9-13