I am using library(partykit)
to sample from nodes from an rpart
model created with library(tidymodels)
. I need to use repair_call()
to use as.party()
with output from parsnip::fit()
. This breaks when I move the code inside of a custom function:
library(tidymodels)
library(partykit)
cart_model <- parsnip::decision_tree() %>%
parsnip::set_engine("rpart") %>%
parsnip::set_mode("regression")
parsnip_model <- fit(cart_model, mpg ~ ., data = mtcars)
predict_sample_rpart <- function(object, old_data, new_data) {
repaired_model <- repair_call(object, data = old_data)
node_ecdf <- predict(as.party(repaired_model$fit), newdata = new_data, type = "prob")
sample(environment(node_ecdf[["1"]])[["x"]], 1)
}
predict_sample_rpart(parsnip_model, old_data = mtcars, new_data = mtcars)
#> Error in is.data.frame(data) : object 'old_data' not found
repair_call()
assigns the data in repaired_model
as old_data
instead of mtcars
and then predict()
does not work.
Any help fixing this would be greatly appreciated.
Any suggestions for better ways to sample from conditional distributions created by lm()
, rpart()
, ranger()
would be doubly appreciated.
Max
October 7, 2020, 7:28pm
2
This shows the issue with using the call object in computations. It assumes that the data used to create the model are in the same scope/environment as the one that uses the call object.
Inside of your function, it can't find the right reference for the data to put inside of the call. Outside of the function it works fine (I think you meant type = "node"
):
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom 0.7.0 ✓ recipes 0.1.13
#> ✓ dials 0.0.9 ✓ rsample 0.0.8
#> ✓ dplyr 1.0.2 ✓ tibble 3.0.3
#> ✓ ggplot2 3.3.2 ✓ tidyr 1.1.2
#> ✓ infer 0.5.2 ✓ tune 0.1.1
#> ✓ modeldata 0.0.2 ✓ workflows 0.2.0
#> ✓ parsnip 0.1.3 ✓ yardstick 0.0.7
#> ✓ purrr 0.3.4
#> ── Conflicts ───────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
library(partykit)
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
cart_model <- parsnip::decision_tree() %>%
parsnip::set_engine("rpart") %>%
parsnip::set_mode("regression")
parsnip_model <- fit(cart_model, mpg ~ ., data = mtcars)
repaired_model <- repair_call(parsnip_model, data = mtcars)
node_ecdf <- predict(as.party(repaired_model$fit), newdata = head(mtcars), type = "node")
node_ecdf
#> Mazda RX4 Mazda RX4 Wag Datsun 710 Hornet 4 Drive
#> 4 4 5 4
#> Hornet Sportabout Valiant
#> 4 4
Created on 2020-10-07 by the reprex package (v0.3.0)
I'm not sure what the solution is for using it inside of a function so I'd try to not use it that way. That's probably unsatisfying but it is a problem baked into how they use the call object.
1 Like
system
Closed
October 28, 2020, 7:28pm
3
This topic was automatically closed 21 days after the last reply. New replies are no longer allowed. If you have a query related to it or one of the replies, start a new topic and refer back with a link.