Как сообщить Керасу о прекращении тренировок на основе величины потерь?


82

В настоящее время я использую следующий код:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

Он говорит Керасу прекратить тренировку, если потери не улучшились в течение 2 эпох. Но я хочу прекратить тренировку после того, как потеря стала меньше некоторого постоянного «THR»:

if val_loss < THR:
    break

Я видел в документации, что есть возможность сделать свой обратный вызов: http://keras.io/callbacks/ Но ничего не нашел, как остановить процесс обучения. Мне нужен совет.

Ответы:


85

Я нашел ответ. Я заглянул в исходники Keras и нашел код для EarlyStopping. Я сделал свой обратный вызов, основываясь на нем:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

И использование:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

1
Просто если кому-то это будет полезно - в моем случае я использовал monitor = 'loss', это сработало.
QtRoS

15
Кажется, Керас обновился. В функцию обратного вызова EarlyStopping теперь встроена min_delta. Нет необходимости больше взламывать исходный код, ура! stackoverflow.com/a/41459368/3345375
jkdev

3
Перечитав вопрос и ответы, мне нужно исправить себя: min_delta означает «Остановитесь раньше, если не хватает улучшений за эпоху (или за несколько эпох)». Однако ОП спросил, как «остановить досрочно, когда убыток становится ниже определенного уровня».
jkdev

NameError: имя «Обратный вызов» не определено ... Как это исправить?
alyssaeliyah

2
Элия, попробуй это: from keras.callbacks import Callback
ZFTurbo

26

Обратный вызов keras.callbacks.EarlyStopping имеет аргумент min_delta. Из документации Keras:

min_delta: минимальное изменение отслеживаемого количества, которое квалифицируется как улучшение, т. е. абсолютное изменение менее min_delta не будет считаться улучшением.


3
Для справки, вот документы для более ранней версии Keras (1.1.0), в которой аргумент min_delta еще не был включен: faroit.github.io/keras-docs/1.1.0/callbacks/#earlystopping
jkdev

как я мог сделать так, чтобы это не прекратилось, пока оно не min_deltaсохранялось в течение нескольких эпох?
zyxue

есть еще один параметр EarlyStopping, называемый терпением: количество эпох без улучшений, после которых обучение будет остановлено.
devin

13

Одним из решений является вызов model.fit(nb_epoch=1, ...)внутри цикла for, затем вы можете поместить оператор break внутри цикла for и выполнить любой другой настраиваемый поток управления, который хотите.


Было бы неплохо, если бы они сделали обратный вызов, который принимает единственную функцию, которая может это сделать.
Честность,

8

Я решил ту же проблему, используя собственный обратный вызов.

В следующем пользовательском коде обратного вызова присвойте THR значение, при котором вы хотите остановить обучение, и добавьте обратный вызов в свою модель.

from keras.callbacks import Callback

class stopAtLossValue(Callback):

        def on_batch_end(self, batch, logs={}):
            THR = 0.03 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') <= THR:
                 self.model.stop_training = True

2

Пока я практиковал специализацию TensorFlow , я изучил очень элегантную технику. Просто немного изменен из принятого ответа.

Давайте рассмотрим пример с нашими любимыми данными MNIST.

import tensorflow as tf

class new_callback(tf.keras.callbacks.Callback):
    def epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy')> 0.90): # select the accuracy
            print("\n !!! 90% accuracy, no further training !!!")
            self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize

callbacks = new_callback()

# model = tf.keras.models.Sequential([# define your model here])

model.compile(optimizer=tf.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Итак, здесь я установил metrics=['accuracy'], и, таким образом, в классе обратного вызова условие установлено на'accuracy'> 0.90 .

Вы можете выбрать любую метрику и следить за тренировкой, как в этом примере. Самое главное, что вы можете установить разные условия для разных метрик и использовать их одновременно.

Надеюсь, это поможет!


имя функции должно быть on_epoch_end
xarion

0

Для меня модель остановит обучение только в том случае, если я добавлю оператор return после установки для параметра stop_training значения True, потому что я звонил после self.model.evaluate. Поэтому либо убедитесь, что в конце функции поставили stop_training = True, либо добавьте оператор возврата.

def on_epoch_end(self, batch, logs):
        self.epoch += 1
        self.stoppingCounter += 1
        print('\nstopping counter \n',self.stoppingCounter)

        #Stop training if there hasn't been any improvement in 'Patience' epochs
        if self.stoppingCounter >= self.patience:
            self.model.stop_training = True
            return

        # Test on additional set if there is one
        if self.testingOnAdditionalSet:
            evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0)
            self.validationLoss2.append(evaluation[0])
            self.validationAcc2.append(evaluation[1])enter code here

0

Если вы используете настраиваемый цикл обучения, вы можете использовать a collections.deque, который представляет собой «скользящий» список, который можно добавлять, и левые элементы выскакивают, когда список длиннее, чем maxlen. Вот строчка:

loss_history = deque(maxlen=early_stopping + 1)

for epoch in range(epochs):
    fit(epoch)
    loss_history.append(test_loss.result().numpy())
    if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
            break

Вот полный пример:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque

data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)

data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))

train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)

model = tf.keras.models.Sequential([
    Dense(8, activation='relu'),
    Dense(16, activation='relu'),
    Dense(info.features['label'].num_classes)])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


@tf.function
def test_step(inputs, labels):
    logits = model(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(labels, logits)


def fit(epoch):
    template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
               'Train Acc {:.2f} Test Acc {:.2f}'

    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for X_train, y_train in train_dataset:
        train_step(X_train, y_train)

    for X_test, y_test in test_dataset:
        test_step(X_test, y_test)

    print(template.format(
        epoch + 1,
        train_loss.result(),
        test_loss.result(),
        train_acc.result(),
        test_acc.result()
    ))


def main(epochs=50, early_stopping=10):
    loss_history = deque(maxlen=early_stopping + 1)

    for epoch in range(epochs):
        fit(epoch)
        loss_history.append(test_loss.result().numpy())
        if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
            print(f'\nEarly stopping. No validation loss '
                  f'improvement in {early_stopping} epochs.')
            break

if __name__ == '__main__':
    main(epochs=250, early_stopping=10)
Epoch  1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch  2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch  3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch  4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch  5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87

Early stopping. No validation loss improvement in 10 epochs.
Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.