С помощью пакета каретки можно ли получить матрицы путаницы для конкретных пороговых значений?


13

Я получил модель логистической регрессии (через train) для бинарного ответа, и я получил логистическую матрицу спутанности через confusionMatrixв caret. Это дает мне путаницу в логистической модели, хотя я не уверен, какой порог используется для ее получения. Как получить матрицу путаницы для определенных пороговых значений, используя confusionMatrixin caret?


У меня нет ответа, но часто такие вопросы даются в файле справки. Если это не помогло, вы можете посмотреть на исходный код. Вы можете распечатать исходный текст в консоли, набрав confusionmatrixбез скобок.
борец с тенью

Не совсем понятно, что именно вы сделали. Вы вызывали glmфункцию из statsпакета и передавали ее результат confusionMatrix? Я не знал, что это можно сделать, и, читая руководство, совсем не понятно. Или ты predictчто-то сделал ? Краткий пример поможет.
Calimo

1
@Calimo Я использовал trainфункцию, caretчтобы соответствовать модели, которая позволяет мне указать ее как glm с биномиальным семейством. Затем я использовал predictфункцию для объекта, созданного с помощью train.
Черное молоко

Ответы:


11

Большинство классификационных моделей в R дают как прогнозирование класса, так и вероятности для каждого класса. Для двоичных данных почти в каждом случае прогнозирование класса основано на 50% -ной вероятности отсечения.

glmта же. С caretпомощью использование predict(object, newdata)дает вам прогнозируемый класс и predict(object, new data, type = "prob")даст вам специфичные для класса вероятности (когда objectгенерируется train).

Вы можете делать вещи по-другому, определяя свою собственную модель и применяя любую обрезку, которую вы хотите. На caret сайте также есть пример, который использует повторную выборку для оптимизации вероятности отсечки.

ТЛ; др

confusionMatrix использует предсказанные классы и, таким образом, 50% вероятности отсечки

Максимум


14

Есть довольно простой способ, предполагая tune <- train(...):

probsTest <- predict(tune, test, type = "prob")
threshold <- 0.5
pred      <- factor( ifelse(probsTest[, "yes"] > threshold, "yes", "no") )
pred      <- relevel(pred, "yes")   # you may or may not need this; I did
confusionMatrix(pred, test$response)

Очевидно, что вы можете установить пороговое значение для того, что хотите попробовать, или выбрать «лучший», где лучший означает максимальную комбинированную специфичность и чувствительность:

library(pROC)
probsTrain <- predict(tune, train, type = "prob")
rocCurve   <- roc(response = train$response,
                      predictor = probsTrain[, "yes"],
                      levels = rev(levels(train$response)))
plot(rocCurve, print.thres = "best")

Посмотрев на пример, который выложил Макс, я не уверен, есть ли какие-то статистические нюансы, которые делают мой подход менее желательным.


Что означают эти три значения в выводимом графике rocCurve? например, по моим данным это говорит 0,289 (0,853, 0,831). Означает ли 0,289 лучший порог, который следует использовать при разграничении двоичного результата? то есть каждый случай с прогнозируемой вероятностью> 0,289 будет кодироваться «1», а каждый случай с прогнозируемой вероятностью <0,289 будет кодироваться «0», а не порогом 0,5 по умолчанию для caretпакета?
COIP

2
да, это совершенно верно, а два других значения в скобках - это чувствительность и специфичность (правда, я забыл, что есть что)
efh0888

2
Кроме того, с тех пор я понял, что вы можете извлечь его из кривой roc, используя, rocCurve$thresholds[which(rocCurve$sensitivities + rocCurve$specificities == max(rocCurve$sensitivities + rocCurve$specificities))]что также дает вам возможность по-разному взвешивать их, если вы хотите ... Еще одна вещь, на которую стоит обратить внимание, это то, что реалистично, вы, вероятно, хотите настроить порог (например, вы бы с любой моделью гиперпараметр), как Макс описывает здесь .
efh0888
Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.