I am trying to use the "caret" library to train a multiclass classification algorithm.
I found the following website which supposedly contains a function that can be used for multiclass classification problem:Error metrics for multi-class problems in R: beyond Accuracy and Kappa | R-bloggers
require(compiler)
multiClassSummary <- cmpfun(function (data, lev = NULL, model = NULL){
#Load Libraries
require(Metrics)
require(caret)
#Check data
if (!all(levels(data[, "pred"]) == levels(data[, "obs"])))
stop("levels of observed and predicted data do not match")
#Calculate custom one-vs-all stats for each class
prob_stats <- lapply(levels(data[, "pred"]), function(class){
#Grab one-vs-all data for the class
pred <- ifelse(data[, "pred"] == class, 1, 0)
obs <- ifelse(data[, "obs"] == class, 1, 0)
prob <- data[,class]
#Calculate one-vs-all AUC and logLoss and return
cap_prob <- pmin(pmax(prob, .000001), .999999)
prob_stats <- c(auc(obs, prob), logLoss(obs, cap_prob))
names(prob_stats) <- c('ROC', 'logLoss')
return(prob_stats)
})
prob_stats <- do.call(rbind, prob_stats)
rownames(prob_stats) <- paste('Class:', levels(data[, "pred"]))
#Calculate confusion matrix-based statistics
CM <- confusionMatrix(data[, "pred"], data[, "obs"])
#Aggregate and average class-wise stats
#Todo: add weights
class_stats <- cbind(CM$byClass, prob_stats)
class_stats <- colMeans(class_stats)
#Aggregate overall stats
overall_stats <- c(CM$overall)
#Combine overall with class-wise stats and remove some stats we don't want
stats <- c(overall_stats, class_stats)
stats <- stats[! names(stats) %in% c('AccuracyNull',
'Prevalence', 'Detection Prevalence')]
#Clean names and return
names(stats) <- gsub('[[:blank:]]+', '_', names(stats))
return(stats)
})
I would like to use this function to train a "Decision Tree" model using the "F1 Score" (or any metric suitable for a multiclass problem). For example:
library(caret)
library(plyr)
library(C50)
library(dplyr)
library(compiler)
train.control <- trainControl(method = "repeatedcv", number = 10, repeats = 3,
summaryFunction = multiClassSummary, classProbs = TRUE)
train_model <- train(my_data$response ~., data = my_data, method = "C5.0",
trControl=train.control ,
preProcess = c("center", "scale"),
tuneLength = 15,
metric = "F1")
But this produces the following error:
Error in ctrl$summaryFunction(testOutput, lev, method):
Your outcome has 4 levels. The prSummary() function isn't appropriate.
Can someone please show me how to fix this error? Thanks!