predict from rule_fit model; Error in match(x, table, nomatch = 0L)

I am trying to fit a RuleFit model using the xrf engine in tidymodels with the default values for all hyperparameters. Unfortunately I cannot obtain predicted values from the model as I received the following error message

Error in match(x, table, nomatch = 0L) :
'match' requires vector arguments

Example code using mtcars below:



data.mtcars <- mtcars

data.mtcars.split <- initial_split(data.mtcars,
                                   prop = 0.9)

data.mtcars.train <- training(data.mtcars.split)
data.mtcars.test <- testing(data.mtcars.split)

spec_untuned_rule <- rule_fit() %>%
  set_engine("xrf", family = "gaussian") %>%

rec_untuned_rule <- recipe(mpg ~ ., data = data.mtcars.train) %>%
  step_center(all_predictors()) %>% # mean zero
  step_scale(all_predictors()) # standard deviation one

workflow_untuned_rule <- workflow() %>%
  add_recipe(rec_untuned_rule) %>% 

system.time(model_rule <- fit(workflow_untuned_rule,
                              data = data.mtcars.train))

predictions_train_rule <- predict(model_rule,
                                  data.mtcars.train) %>%
  bind_cols(data.mtcars.train) # combine w/ original dataset

Fitting the same model using the xrf package and obtaining predictions works fine, so I am wondering whether I am missing anything in my code to get this to work using tidymodels. Any help is much appreciated.

model_rule <- xrf(age ~ . -f.eid, data.mtcars.train,
                  family = 'gaussian')

test <- predict(test, newdata = data.mtcars.train)

session info below

> sessionInfo()
R version 4.1.1 (2021-08-10)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.5 LTS

Matrix products: default
BLAS:   /software/spackages_prod/apps/linux-ubuntu20.04-zen2/gcc-9.4.0/r-4.1.1-dqazhid3su5gd6f2sexu2cfhyba7ke6m/rlib/R/lib/
LAPACK: /software/spackages_prod/apps/linux-ubuntu20.04-zen2/gcc-9.4.0/r-4.1.1-dqazhid3su5gd6f2sexu2cfhyba7ke6m/rlib/R/lib/

 [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C               LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8     LC_MONETARY=en_GB.UTF-8   
 [6] LC_MESSAGES=en_GB.UTF-8    LC_PAPER=en_GB.UTF-8       LC_NAME=C                  LC_ADDRESS=C               LC_TELEPHONE=C            

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] xrf_0.2.2          yardstick_1.1.0    workflowsets_1.0.0 workflows_1.1.0    tune_1.0.1         tidyr_1.2.1        tibble_3.1.8       rsample_1.1.0     
 [9] recipes_1.0.1      purrr_1.0.0        parsnip_1.0.2      modeldata_1.0.1    infer_1.0.3        ggplot2_3.4.0      dplyr_1.0.10       dials_1.0.0       
[17] scales_1.2.1       broom_1.0.1        tidymodels_1.0.0  

loaded via a namespace (and not attached):
 [1] nlme_3.1-160        matrixStats_0.63.0  lubridate_1.8.0     DiceDesign_1.9      RColorBrewer_1.1-3  tools_4.1.1         backports_1.4.1     utf8_1.2.2         
 [9] R6_2.5.1            rpart_4.1.16        rules_1.0.0         DBI_1.1.3           mgcv_1.8-40         colorspace_2.0-3    nnet_7.3-18         withr_2.5.0        
[17] tidyselect_1.2.0    gridExtra_2.3       compiler_4.1.1      glmnet_4.1-2        cli_3.5.0           labeling_0.4.2      fuzzyjoin_0.1.6     stringr_1.5.0      
[25] digest_0.6.29       pkgconfig_2.0.3     parallelly_1.32.1   lhs_1.1.5           rlang_1.0.6         rstudioapi_0.14     farver_2.1.1        shape_1.4.6        
[33] generics_0.1.3      jsonlite_1.8.2      BiocParallel_1.28.3 magrittr_2.0.3      Matrix_1.5-1        Rcpp_1.0.9          munsell_0.5.0       fansi_1.0.3        
[41] GPfit_1.0-8         lifecycle_1.0.3     furrr_0.3.1         stringi_1.7.8       MASS_7.3-58.1       plyr_1.8.8          grid_4.1.1          parallel_4.1.1     
[49] listenv_0.8.0       ggrepel_0.9.2       crayon_1.5.2        lattice_0.20-45     splines_4.1.1       pillar_1.8.1        igraph_1.3.5        xgboost_1.4.1.1    
[57] corpcor_1.6.10      future.apply_1.9.1  reshape2_1.4.4      codetools_0.2-18    mixOmics_6.23.3     glue_1.6.2          data.table_1.14.2   vctrs_0.5.1        
[65] foreach_1.5.2       gtable_0.3.1        future_1.28.0       assertthat_0.2.1    gower_1.0.0         prodlim_2019.11.13  RSpectra_0.16-1     class_7.3-20       
[73] survival_3.4-0      timeDate_4021.106   rARPACK_0.11-0      iterators_1.0.14    plsmod_1.0.0        hardhat_1.2.0       ellipse_0.4.3       lava_1.6.10        
[81] globals_0.16.1      ellipsis_0.3.2      ipred_0.9-13

This appears to have been a bug in {rules}. Thanks for finding it! It is being fixed in Bug family quosure by EmilHvitfeldt · Pull Request #71 · tidymodels/rules · GitHub and you should be able to use it calling

# install.packages("devtools")

I will try to get an update of this package on CRAN shortly.

Secondly, you can use augment() instead of predict() + bind_cols() so you last couple of lines will be

predictions_train_rule <- augment(model_rule, data.mtcars.train)
1 Like

brilliant, thanks very much!

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.