Параметр "stratify" из метода "train_test_split" (scikit Learn)


94

Я пытаюсь использовать train_test_splitпакет scikit Learn, но у меня проблемы с параметром stratify. Ниже приведен код:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

Однако у меня все еще возникает следующая проблема:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

Есть ли у кого-нибудь представление о том, что происходит? Ниже представлена ​​документация по функциям.

[...]

stratify : array-like или None (по умолчанию None)

Если не None, данные разделяются стратифицированным образом, используя это как массив меток.

Новое в версии 0.17: расслоение разбиения

[...]


Нет, все решено.
Дэниел Оливо

Ответы:


58

Scikit-Learn просто говорит вам, что не распознает аргумент «расслоение», а не то, что вы его неправильно используете. Это связано с тем, что параметр был добавлен в версии 0.17, как указано в процитированной вами документации.

Так что вам просто нужно обновить Scikit-Learn.


У меня та же ошибка, хотя у меня scikit-learn версии 0.21.2. scikit-learn 0.21.2 py37h2a6a0b8_0 conda-forge
Карим Джеруди

326

Этот stratifyпараметр выполняет разделение, так что пропорция значений в произведенной выборке будет такой же, как пропорция значений, предоставленных параметру stratify.

Например, если переменная yявляется бинарной категориальная переменная со значениями 0и 1и есть 25% нулей и 75% из них, stratify=yубедитесь , что ваш случайный раскол имеет 25% 0«s и 75% 1» с.


117
На самом деле это не отвечает на вопрос, но очень полезно просто для понимания того, как это работает. Благодаря тонну.
Рид Джессен

6
Мне все еще трудно понять, зачем нужна эта стратификация: если в данных есть несбалансированный класс, не будет ли он сохранен в среднем при случайном разделении данных?
Хольгер Брандл

14
@HolgerBrandl будет сохраняться в среднем; со стратификацией он обязательно сохранится.
Йонатан

7
@HolgerBrandl с очень маленькими или очень несбалансированными наборами данных, вполне возможно, что случайное разбиение могло полностью исключить класс из одного из разбиений.
CDDT

1
@HolgerBrandl Хороший вопрос! Может быть, мы могли бы добавить это сначала, вам нужно разделить на тренировочный и тестовый набор, используя stratify. Затем, во-вторых, для исправления дисбаланса вам в конечном итоге потребуется выполнить передискретизацию или недостаточную выборку на обучающем наборе. Многие классификаторы Sklearn имеют параметр, называемый весовым коэффициентом, который вы можете установить как сбалансированный. Наконец, для несбалансированного набора данных вы также можете выбрать более подходящий показатель, чем точность. Попробуйте, F1 или область под ROC.
Claude COULOMBE

62

Для моего будущего себя, пришедшего сюда через Google:

train_test_splitсейчас в model_selectionигре, следовательно:

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)

это способ его использовать. Установка random_stateжелательна для воспроизводимости.


Это должен быть ответ :) Спасибо
SwimBikeRun

15

В этом контексте стратификация означает, что метод train_test_split возвращает обучающие и тестовые подмножества, которые имеют те же пропорции меток классов, что и входной набор данных.


3

Попробуйте запустить этот код, он «просто работает»:

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])

@ user5767535 Как вы могли заметить, он работает на моей машине Ubuntu с sklearnверсией '0.17', дистрибутив Anaconda для Python 3,5. Я могу только предложить проверить еще раз, если вы правильно ввели код и обновили программное обеспечение.
Сергей Бушманов

2
@ user5767535 Кстати, «Новое в версии 0.17: разделение на расслоение» дает мне почти уверенность в том, что вам нужно обновить свой sklearn...
Сергей Бушманов
Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.