set.seed(234)
doParallel::registerDoParallel()
res <- fit_resamples(wf,
folds,
control =
control_resamples(
extract = extract_fit_engine))
res %>%
select(id, .extracts) %>%
unnest(.extacts)
this is working with lm, glm, but it's not working with xgboost, multinom_reg ...
Is there a way to extract the coefficients?
in xgboost there aren't exactly coefficients to extract...
Which computational engine are you using for the multinomial regression? I can't imagine coefficients not being available for that, though you have to do a little extra work to get them if you're using glmnet
.
For example, if you are using glmnet
for multinomial regression you might first want to construct a function to extract the coefficients:
extract_glmnet_coefs <- function(x, penalty) {
mod <- extract_fit_engine(x)
classes <- mod$classnames
out <-
mod %>%
glmnet::coef.glmnet(s = penalty) %>%
lapply(as.matrix) %>%
Reduce(f = "cbind", .) %>%
as.data.frame() %>%
set_names(classes) %>%
as_tibble(rownames = "term")
out
}
Then the rest should be fairly close to what you were already doing:
# Set up workflow
set.seed(1)
folds <- vfold_cv(iris, v = 3)
lambda <- 0.01
wf <- workflow(
preprocessor = recipe(iris, Species ~ .),
spec = multinom_reg(engine = "glmnet", penalty = lambda , mixture = 0)
)
# Fit resamples with custom function
res <- fit_resamples(
object = wf,
resamples = folds,
control = control_resamples(
extract = function(x) extract_glmnet_coefs(x, penalty = lambda)
)
)
# Extract the coefficients
res %>%
unnest(.extracts) %>%
select(id, .extracts) %>%
unnest(.extracts)
# A tibble: 15 x 5
id term setosa versicolor virginica
<chr> <chr> <dbl> <dbl> <dbl>
1 Fold1 (Intercept) 3.12 2.93 -6.05
2 Fold1 Sepal.Length -0.733 0.320 0.412
3 Fold1 Sepal.Width 1.39 -1.19 -0.208
4 Fold1 Petal.Length -0.543 0.00604 0.537
5 Fold1 Petal.Width -1.15 -0.422 1.57
6 Fold2 (Intercept) 2.68 2.65 -5.33
7 Fold2 Sepal.Length -0.759 0.213 0.546
8 Fold2 Sepal.Width 1.64 -0.814 -0.830
9 Fold2 Petal.Length -0.550 0.0151 0.535
10 Fold2 Petal.Width -1.20 -0.392 1.59
11 Fold3 (Intercept) 3.88 3.16 -7.04
12 Fold3 Sepal.Length -0.813 0.221 0.592
13 Fold3 Sepal.Width 1.35 -0.992 -0.358
14 Fold3 Petal.Length -0.569 0.0244 0.545
15 Fold3 Petal.Width -1.24 -0.417 1.66
Every goes fine with glm, lm, poisson_reg, but unnest(.extracts) doesn't work with others model. but when insert extract function in control_resamples or control_grid, and you run res %>%
unnest(.extracts) %>%
select(id, .extracts) %>%
unnest(.extracts)
in the first unnest(.extracts) you see a tibble that needs to get extracted but it is not feasible.
I think the structure of tidy or unnest does not work for every model. I wish it could because then you [quote="ttrodrigz, post:4, topic:132531, full:true"]
For example, if you are using glmnet
for multinomial regression you might first want to construct a function to extract the coefficients:
extract_glmnet_coefs <- function(x, penalty) {
mod <- extract_fit_engine(x)
classes <- mod$classnames
out <-
mod %>%
glmnet::coef.glmnet(s = penalty) %>%
lapply(as.matrix) %>%
Reduce(f = "cbind", .) %>%
as.data.frame() %>%
set_names(classes) %>%
as_tibble(rownames = "term")
out
}
Then the rest should be fairly close to what you were already doing:
# Set up workflow
set.seed(1)
folds <- vfold_cv(iris, v = 3)
lambda <- 0.01
wf <- workflow(
preprocessor = recipe(iris, Species ~ .),
spec = multinom_reg(engine = "glmnet", penalty = lambda , mixture = 0)
)
# Fit resamples with custom function
res <- fit_resamples(
object = wf,
resamples = folds,
control = control_resamples(
extract = function(x) extract_glmnet_coefs(x, penalty = lambda)
)
)
# Extract the coefficients
res %>%
unnest(.extracts) %>%
select(id, .extracts) %>%
unnest(.extracts)
# A tibble: 15 x 5
id term setosa versicolor virginica
<chr> <chr> <dbl> <dbl> <dbl>
1 Fold1 (Intercept) 3.12 2.93 -6.05
2 Fold1 Sepal.Length -0.733 0.320 0.412
3 Fold1 Sepal.Width 1.39 -1.19 -0.208
4 Fold1 Petal.Length -0.543 0.00604 0.537
5 Fold1 Petal.Width -1.15 -0.422 1.57
6 Fold2 (Intercept) 2.68 2.65 -5.33
7 Fold2 Sepal.Length -0.759 0.213 0.546
8 Fold2 Sepal.Width 1.64 -0.814 -0.830
9 Fold2 Petal.Length -0.550 0.0151 0.535
10 Fold2 Petal.Width -1.20 -0.392 1.59
11 Fold3 (Intercept) 3.88 3.16 -7.04
12 Fold3 Sepal.Length -0.813 0.221 0.592
13 Fold3 Sepal.Width 1.35 -0.992 -0.358
14 Fold3 Petal.Length -0.569 0.0244 0.545
15 Fold3 Petal.Width -1.24 -0.417 1.66
[/quote]
This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.
If you have a query related to it or one of the replies, start a new topic and refer back with a link.