Сохранить классификатор на диск в Scikit-Learn


192

Как сохранить обученный наивный байесовский классификатор на диск и использовать его для прогнозирования данных?

У меня есть следующий пример программы с сайта scikit-learn:

from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()

Ответы:


201

Классификаторы - это просто объекты, которые можно мариновать и выбрасывать, как и любые другие. Чтобы продолжить ваш пример:

import cPickle
# save the classifier
with open('my_dumped_classifier.pkl', 'wb') as fid:
    cPickle.dump(gnb, fid)    

# load it again
with open('my_dumped_classifier.pkl', 'rb') as fid:
    gnb_loaded = cPickle.load(fid)

1
Работает как шарм! Я пытался использовать np.savez и загружать его обратно, и это никогда не помогало. Большое спасибо.
Kartos

7
в python3 используйте модуль pickle, который работает именно так.
MCSH

213

Вы также можете использовать joblib.dump и joblib.load, которые гораздо более эффективны при обработке числовых массивов, чем стандартный выборщик питона.

Joblib включен в scikit-learn:

>>> import joblib
>>> from sklearn.datasets import load_digits
>>> from sklearn.linear_model import SGDClassifier

>>> digits = load_digits()
>>> clf = SGDClassifier().fit(digits.data, digits.target)
>>> clf.score(digits.data, digits.target)  # evaluate training error
0.9526989426822482

>>> filename = '/tmp/digits_classifier.joblib.pkl'
>>> _ = joblib.dump(clf, filename, compress=9)

>>> clf2 = joblib.load(filename)
>>> clf2
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5,
       n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0,
       shuffle=False, verbose=0, warm_start=False)
>>> clf2.score(digits.data, digits.target)
0.9526989426822482

Редактировать: в Python 3.8+ теперь возможно использовать pickle для эффективного выбора объекта с большими числовыми массивами в качестве атрибутов, если вы используете протокол pickle 5 (который не используется по умолчанию).


1
Но из моего понимания конвейерная работа работает, если она является частью единого рабочего процесса. Если я хочу построить модель, сохраните ее на диске и остановите выполнение. Затем я возвращаюсь через неделю и пытаюсь загрузить модель с диска, он выдает ошибку:
venuktan

2
Нет способа остановить и возобновить выполнение fitметода, если это то, что вы ищете. Это, как говорится, joblib.loadне должно вызывать исключение после успешного завершения, joblib.dumpесли вы вызываете его из Python с той же версией библиотеки scikit-learn.
Огризель

10
Если вы используете IPython, не используйте --pylabфлаг командной строки или %pylabмагию, поскольку неявная перегрузка пространства имен, как известно, нарушает процесс выбора. %matplotlib inlineВместо этого используйте явный импорт и магию.
Огризель

2
см scikit-узнать документацию для справки: scikit-learn.org/stable/tutorial/basic/...
user1448319

1
Можно ли переучить ранее сохраненную модель? Конкретно модели SVC?
Uday Sawant

108

То, что вы ищете, называется сохранением модели в словах sklearn и задокументировано во введении и в разделах сохранения модели .

Итак, вы инициализировали свой классификатор и долгое время обучали его

clf = some.classifier()
clf.fit(X, y)

После этого у вас есть два варианта:

1) Используя Pickle

import pickle
# now you can save it to a file
with open('filename.pkl', 'wb') as f:
    pickle.dump(clf, f)

# and later you can load it
with open('filename.pkl', 'rb') as f:
    clf = pickle.load(f)

2) Использование Joblib

from sklearn.externals import joblib
# now you can save it to a file
joblib.dump(clf, 'filename.pkl') 
# and later you can load it
clf = joblib.load('filename.pkl')

Еще раз полезно прочитать вышеупомянутые ссылки


30

Во многих случаях, особенно с классификацией текста, недостаточно просто сохранить классификатор, но вам также необходимо сохранить векторизатор, чтобы вы могли векторизовать свой ввод в будущем.

import pickle
with open('model.pkl', 'wb') as fout:
  pickle.dump((vectorizer, clf), fout)

будущее использование:

with open('model.pkl', 'rb') as fin:
  vectorizer, clf = pickle.load(fin)

X_new = vectorizer.transform(new_samples)
X_new_preds = clf.predict(X_new)

Перед сбросом векторизатора можно удалить свойство векторизатора stop_words_:

vectorizer.stop_words_ = None

сделать сброс более эффективным. Также, если параметры вашего классификатора редки (как в большинстве примеров текстовой классификации), вы можете преобразовать параметры из плотного в разреженный, что будет иметь огромное значение с точки зрения потребления памяти, загрузки и выгрузки. Разрежьте модель по:

clf.sparsify()

Это будет автоматически работать для SGDClassifier, но если вы знаете, что ваша модель редкая (много нулей в clf.coef_), то вы можете вручную преобразовать clf.coef_ в скудную матрицу csr scipy :

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_)

и тогда вы можете хранить его более эффективно.


Проницательный ответ! Просто хотел добавить в случае SVC, он возвращает разреженный параметр модели.
Шаян Амани

5

sklearnОценщики реализуют методы, чтобы вам было легко сохранять соответствующие обученные свойства оценщика. Некоторые оценщики __getstate__сами реализуют методы, а другие, например, GMMпросто используют базовую реализацию, которая просто сохраняет внутренний словарь объектов:

def __getstate__(self):
    try:
        state = super(BaseEstimator, self).__getstate__()
    except AttributeError:
        state = self.__dict__.copy()

    if type(self).__module__.startswith('sklearn.'):
        return dict(state.items(), _sklearn_version=__version__)
    else:
        return state

Рекомендуемый метод сохранения вашей модели на диск - использовать pickleмодуль:

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
    pickle.dump(model,f)

Тем не менее, вам следует сохранить дополнительные данные, чтобы вы могли в будущем переучить свою модель или столкнуться с тяжелыми последствиями (такими как блокировка в старой версии sklearn) .

Из документации :

Чтобы перестроить аналогичную модель с будущими версиями scikit-learn, необходимо сохранить дополнительные метаданные вдоль протравленной модели:

Обучающие данные, например, ссылка на неизменный снимок

Исходный код Python, используемый для генерации модели

Версии scikit-learn и его зависимостей

Оценка перекрестной проверки, полученная по данным обучения

Это особенно верно для оценщиков Ensemble, которые полагаются на tree.pyxмодуль, написанный на Cython (например, IsolationForest), поскольку он создает связь с реализацией, которая не гарантируется стабильной между версиями sklearn. Он видел назад несовместимые изменения в прошлом.

Если ваши модели становятся очень большими и загрузка становится неприятной, вы также можете использовать более эффективные joblib. Из документации:

В конкретном случае scikit может быть более интересно использовать joblib замену pickle( joblib.dump& joblib.load), которая более эффективна для объектов, которые несут большие массивы внутри, как это часто бывает для встроенных оценок scikit-learn, но может только мариновать на диск, а не на строку:


1
but can only pickle to the disk and not to a stringНо вы можете засолить это в StringIO из JobLib. Это то, что я делаю все время.
Матфея

Мой текущий проект делает нечто подобное, вы знаете, что The training data, e.g. a reference to a immutable snapshotздесь? ТИА!
Дейзи Цинь

1

sklearn.externals.joblibустарели , так 0.21и будет удален в v0.23:

/usr/local/lib/python3.7/site-packages/sklearn/externals/joblib/ init .py: 15: FutureWarning: sklearn.externals.joblib устарела в 0.21 и будет удалена в 0.23. Пожалуйста, импортируйте эту функцию напрямую из joblib, которую можно установить с помощью: pip install joblib. Если это предупреждение появляется при загрузке протравленных моделей, вам может потребоваться повторная сериализация этих моделей с помощью scikit-learn 0.21+.
warnings.warn (msg, category = FutureWarning)


Поэтому вам необходимо установить joblib:

pip install joblib

и, наконец, запишите модель на диск:

import joblib
from sklearn.datasets import load_digits
from sklearn.linear_model import SGDClassifier


digits = load_digits()
clf = SGDClassifier().fit(digits.data, digits.target)

with open('myClassifier.joblib.pkl', 'wb') as f:
    joblib.dump(clf, f, compress=9)

Теперь, чтобы прочитать дамп-файл, все, что вам нужно, это запустить:

with open('myClassifier.joblib.pkl', 'rb') as f:
    my_clf = joblib.load(f)
Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.