Среднее значение ROC для повторной 10-кратной перекрестной проверки с оценками вероятности


15

Я планирую использовать повторную (10 раз) стратифицированную 10-кратную перекрестную проверку примерно в 10 000 случаев с использованием алгоритма машинного обучения. Каждый раз повторение будет сделано с разными случайными семенами.

В этом процессе я создаю 10 экземпляров вероятностных оценок для каждого случая. 1 случай оценки вероятности для каждого из 10 повторений 10-кратной перекрестной проверки

Могу ли я усреднить 10 вероятностей для каждого случая, а затем создать новую среднюю кривую ROC (представляющую результаты повторного 10-кратного CV), которую можно сравнить с другими кривыми ROC путем парных сравнений?

Ответы:


13

Из вашего описания, кажется, есть смысл: не только вы можете рассчитать среднюю ROC-кривую, но и дисперсию вокруг нее, чтобы построить доверительные интервалы. Это должно дать вам представление о том, насколько стабильна ваша модель.

Например, вот так:

введите описание изображения здесь

Здесь я поместил отдельные кривые ROC, а также среднюю кривую и доверительные интервалы. Есть области, где кривые совпадают, поэтому у нас меньше различий, и есть области, где они не согласны.

Для повторного резюме вы можете просто повторить его несколько раз и получить общее среднее значение по всем отдельным сгибам:

введите описание изображения здесь

Это очень похоже на предыдущую картину, но дает более стабильные (то есть надежные) оценки среднего значения и дисперсии.

Вот код, чтобы получить сюжет:

import matplotlib.pyplot as plt
import numpy as np
from scipy import interp

from sklearn.datasets import make_classification
from sklearn.cross_validation import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve

X, y = make_classification(n_samples=500, random_state=100, flip_y=0.3)

kf = KFold(n=len(y), n_folds=10)

tprs = []
base_fpr = np.linspace(0, 1, 101)

plt.figure(figsize=(5, 5))

for i, (train, test) in enumerate(kf):
    model = LogisticRegression().fit(X[train], y[train])
    y_score = model.predict_proba(X[test])
    fpr, tpr, _ = roc_curve(y[test], y_score[:, 1])

    plt.plot(fpr, tpr, 'b', alpha=0.15)
    tpr = interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tprs.append(tpr)

tprs = np.array(tprs)
mean_tprs = tprs.mean(axis=0)
std = tprs.std(axis=0)

tprs_upper = np.minimum(mean_tprs + std, 1)
tprs_lower = mean_tprs - std


plt.plot(base_fpr, mean_tprs, 'b')
plt.fill_between(base_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.3)

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.axes().set_aspect('equal', 'datalim')
plt.show()

Для повторного резюме:

idx = np.arange(0, len(y))

for j in np.random.randint(0, high=10000, size=10):
    np.random.shuffle(idx)
    kf = KFold(n=len(y), n_folds=10, random_state=j)

    for i, (train, test) in enumerate(kf):
        model = LogisticRegression().fit(X[idx][train], y[idx][train])
        y_score = model.predict_proba(X[idx][test])
        fpr, tpr, _ = roc_curve(y[idx][test], y_score[:, 1])

        plt.plot(fpr, tpr, 'b', alpha=0.05)
        tpr = interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

Источник вдохновения: http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html


3

Неправильно усреднять вероятности, потому что это не будет представлять прогнозы, которые вы пытаетесь проверить, и включает загрязнение между проверочными образцами.

Обратите внимание, что для достижения достаточной точности может потребоваться 100 повторов 10-кратной перекрестной проверки. Или используйте загрузчик оптимизма Эфрона-Гонга, который требует меньше итераций для той же точности (см., Например rms, validateфункции пакета R ).

с


Не могли бы вы подробнее рассказать, почему усреднение не является правильным?
DataD'oh

Уже заявлено. Вам необходимо подтвердить меру, которую вы будете использовать в поле.
Фрэнк Харрелл
Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.