I have a question regarding accuracy metrics for a multi-class model. Unless I'm missing something it seems like there isn't any support in yardstick
for within-class accuracy metrics for a multi-class model.
Here's some data to motivate this:
library(dplyr)
library(tidyr)
counts <- tribble(
~truth, ~pred, ~n,
"A", "A", 20,
"A", "B", 3,
"A", "C", 0,
"B", "A", 4,
"B", "B", 23,
"B", "C", 6,
"C", "A", 8,
"C", "B", 5,
"C", "C", 29
)
pred_df <-
counts %>%
uncount(n) %>%
mutate_all(as.factor)
Here's the confusion matrix for that data:
(yes, I realize that yardstick::conf_mat()
exists )
cm <- with(pred_df, table(Prediction = pred, Truth = truth))
cm
#> Truth
#> Prediction A B C
#> A 20 4 8
#> B 3 23 5
#> C 0 6 29
We can obtain relevant micro/macro/macro weighted accuracy metrics with something like this:
library(yardstick)
my_metrics <- metric_set(precision, recall)
my_metrics(pred_df, truth = truth, estimate = pred, estimator = "macro")
#> # A tibble: 2 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 precision macro 0.732
#> 2 recall macro 0.752
Now, if I want to look at precision and recall at the class level rather than aggregate then it seems like I need to step outside of the tidymodels workflow, which would be a bummer with how smooth and seamless the current workflow is, especially when incorporating resampling and gathering all of the metrics with tune::collect_metrics()
.
# precision
cm %>% prop.table(margin = 2) %>% diag()
#> A B C
#> 0.8695652 0.6969697 0.6904762
# recall
cm %>% prop.table(margin = 1) %>% diag()
#> A B C
#> 0.6250000 0.7419355 0.8285714
So, I guess my question is: does anyone know if within-class metrics are currently/will be supported? If not, does anyone have suggestions on how to incorporate them w/out having to break the tidymodels workflow?
P.S. Perhaps this post is better served as a feature request on GitHub, I'll add something there if so.