Загрузка обученной модели Keras и продолжение обучения


102

Мне было интересно, можно ли сохранить частично обученную модель Keras и продолжить обучение после повторной загрузки модели.

Причина этого в том, что в будущем у меня будет больше обучающих данных, и я не хочу снова переобучать всю модель.

Я использую следующие функции:

#Partly train model
model.fit(first_training, first_classes, batch_size=32, nb_epoch=20)

#Save partly trained model
model.save('partly_trained.h5')

#Load partly trained model
from keras.models import load_model
model = load_model('partly_trained.h5')

#Continue training
model.fit(second_training, second_classes, batch_size=32, nb_epoch=20)

Изменить 1: добавлен полностью рабочий пример

С первым набором данных после 10 эпох потеря последней эпохи будет 0,0748, а точность 0,9863.

После сохранения, удаления и перезагрузки модели потеря и точность модели, обученной на втором наборе данных, составят 0,1711 и 0,9504 соответственно.

Это вызвано новыми данными обучения или полностью переобученной моделью?

"""
Model by: http://machinelearningmastery.com/
"""
# load (downloaded if needed) the MNIST dataset
import numpy
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils
from keras.models import load_model
numpy.random.seed(7)

def baseline_model():
    model = Sequential()
    model.add(Dense(num_pixels, input_dim=num_pixels, init='normal', activation='relu'))
    model.add(Dense(num_classes, init='normal', activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

if __name__ == '__main__':
    # load data
    (X_train, y_train), (X_test, y_test) = mnist.load_data()

    # flatten 28*28 images to a 784 vector for each image
    num_pixels = X_train.shape[1] * X_train.shape[2]
    X_train = X_train.reshape(X_train.shape[0], num_pixels).astype('float32')
    X_test = X_test.reshape(X_test.shape[0], num_pixels).astype('float32')
    # normalize inputs from 0-255 to 0-1
    X_train = X_train / 255
    X_test = X_test / 255
    # one hot encode outputs
    y_train = np_utils.to_categorical(y_train)
    y_test = np_utils.to_categorical(y_test)
    num_classes = y_test.shape[1]

    # build the model
    model = baseline_model()

    #Partly train model
    dataset1_x = X_train[:3000]
    dataset1_y = y_train[:3000]
    model.fit(dataset1_x, dataset1_y, nb_epoch=10, batch_size=200, verbose=2)

    # Final evaluation of the model
    scores = model.evaluate(X_test, y_test, verbose=0)
    print("Baseline Error: %.2f%%" % (100-scores[1]*100))

    #Save partly trained model
    model.save('partly_trained.h5')
    del model

    #Reload model
    model = load_model('partly_trained.h5')

    #Continue training
    dataset2_x = X_train[3000:]
    dataset2_y = y_train[3000:]
    model.fit(dataset2_x, dataset2_y, nb_epoch=10, batch_size=200, verbose=2)
    scores = model.evaluate(X_test, y_test, verbose=0)
    print("Baseline Error: %.2f%%" % (100-scores[1]*100))

3
Вы это проверяли? Я не вижу причин, чтобы это не работало.
маз

Сейчас я вижу, что моя точность падает примерно на 10 процентов после загрузки модели (только в первые эпохи). Если перезагрузка работает, это, конечно, вызвано новыми данными обучения. Но я просто хочу убедиться, что это действительно так.
Wilmar van Ommeren 08

7
Вы сохраняете свою модель напрямую с помощью model.save или используете контрольную точку модели ( keras.io/callbacks/#example-model-checkpoints )? Если вы используете model.save, есть ли вероятность, что вы сохраняете последнюю модель (т.е. последнюю эпоху) вместо лучшей (наименьшая ошибка)? Вы можете предоставить актуальный код?
маз

Сохраняю свою последнюю модель, не самую лучшую (до этого момента я не знал, что это возможно). Я подготовлю код
Вильмар ван Оммерен

3
Так не могли бы вы перезагрузить это и продолжить обучение на тех же данных поезда? Это должно гарантировать вам, что перезагрузка в порядке, если результаты будут сопоставимы.
Marcin Moejko 08

Ответы:


36

Фактически - model.saveсохраняет всю информацию, необходимую для перезапуска тренировки в вашем случае. Единственное, что может испортить перезагрузка модели, - это состояние вашего оптимизатора. Чтобы проверить это - попробуйте saveперезагрузить модель и обучить ее на обучающих данных.


1
@Marcin: при использовании keras save()сохраняет ли он лучший результат (наименьшие потери) модели или последний результат (последнее обновление) модели? спасибо
Lion Lai

5
последнее обновление. Обратный вызов контрольной точки модели предназначен для сохранения лучшего.
Холи

2
@Khaj Вы имеете в виду этот keras.io/callbacks/#modelcheckpoint ? Вроде по умолчанию сохраняет последнее обновление (не самое лучшее); лучший сохраняется, только если save_best_only=Trueон установлен явно.
flow2k

9

Большинство приведенных выше ответов касались важных моментов. Если вы используете последнюю версию Tensorflow ( TF2.1или выше), следующий пример вам поможет. Модельная часть кода взята с сайта Tensorflow.

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])

  model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',metrics=['accuracy'])
  return model

# Create a basic model instance
model=create_model()
model.fit(x_train, y_train, epochs = 10, validation_data = (x_test,y_test),verbose=1)

Сохраните модель в формате * .tf. По моему опыту, если у вас определен какой-либо custom_loss, формат * .h5 не сохранит статус оптимизатора и, следовательно, не будет служить вашей цели, если вы захотите переобучить модель с того места, где мы оставили.

# saving the model in tensorflow format
model.save('./MyModel_tf',save_format='tf')


# loading the saved model
loaded_model = tf.keras.models.load_model('./MyModel_tf')

# retraining the model
loaded_model.fit(x_train, y_train, epochs = 10, validation_data = (x_test,y_test),verbose=1)

Этот подход перезапустит обучение с того места, где мы остановились перед сохранением модели. Как уже отмечалось другими, если вы хотите сохранить вес лучшей модели или вы хотите сохранить веса модели каждую эпоху вам нужно использовать keras обратные вызовы функций (ModelCheckpoint) с опциями , такими как save_weights_only=True, save_freq='epoch'и save_best_only.

Для получения дополнительной информации, пожалуйста, проверьте здесь и еще один пример здесь .


1
хорошо, это выглядит очень многообещающе - спасибо за информацию. в этом примере мне кажется, что вы переобучаете модель на тех же данных, которые использовались для обучения. Если да, то я бы подумал, что правильным подходом было бы загрузить новое подмножество обучающих данных для повторного обучения (чтобы отразить новую информацию, вводимую в процесс)?
bibzzzz

1
@bibzzzz Согласен с вами. Очень хороший комментарий. Я хотел продемонстрировать переобучение на тех же данных, чтобы повысить производительность. Суть ясно показывает улучшение производительности там, где оно было остановлено перед сохранением модели. Полностью согласен с тем, что переучивайтесь на других данных и попробую позже. Благодарность!
Вишнувардхана Джанапати

1
отлично - вы очень хорошо это продемонстрировали, спасибо.
bibzzzz

8

Проблема может заключаться в том, что вы используете другой оптимизатор - или другие аргументы оптимизатора. У меня была такая же проблема с пользовательской предварительно обученной моделью, используя

reduce_lr = ReduceLROnPlateau(monitor='loss', factor=lr_reduction_factor,
                              patience=patience, min_lr=min_lr, verbose=1)

для предварительно обученной модели, при этом исходная скорость обучения начинается с 0,0003, а во время предварительного обучения она снижается до min_learning rate, которая составляет 0,000003

Я просто скопировал эту строку в сценарий, который использует предварительно обученную модель, и получил очень плохую точность. Пока я не заметил, что последней скоростью обучения предварительно обученной модели была минимальная скорость обучения, то есть 0,000003. И если я начну с этой скорости обучения, я получу точно такую ​​же точность, что и результат предварительно обученной модели, - что имеет смысл, если начать со скорости обучения, которая в 100 раз больше, чем последняя скорость обучения, использованная в предварительно обученной модели. Модель приведет к значительному превышению GD и, следовательно, к значительному снижению точности.


2

Обратите внимание, что у Keras иногда возникают проблемы с загруженными моделями, как здесь . Это может объяснить случаи, когда вы не начинаете с той же обученной точности.


1

Все вышеперечисленное помогает, вы должны продолжить с той же скорости обучения (), что и LR, когда модель и веса были сохранены. Установите его прямо в оптимизаторе.

Обратите внимание, что улучшение оттуда не гарантируется, потому что модель могла достичь локального минимума, который может быть глобальным. Нет смысла возобновлять модель для поиска другого локального минимума, если только вы не намерены увеличивать скорость обучения контролируемым образом и подталкивать модель к возможно лучшему минимуму неподалеку.


Почему это? Разве я не могу использовать меньший LR, чем раньше?
lte__

На самом деле, продолжение обучения МОЖЕТ привести вас к лучшей модели, если вы получите больше данных. Так что есть смысл возобновить модель, чтобы найти другой локальный минимум.
Кори Левинсон

0

Вы также можете нажать Concept Drift, см. Следует ли переобучать модель при появлении новых наблюдений . Также существует концепция катастрофического забывания, о которой говорится в ряде научных статей. Вот один с MNIST Эмпирическое исследование катастрофического забывания

Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.