Лучший способ сохранить обученную модель в PyTorch?


195

Я искал альтернативные способы сохранить обученную модель в PyTorch. Пока что я нашел две альтернативы.

  1. torch.save () для сохранения модели и torch.load () для загрузки модели.
  2. model.state_dict () для сохранения обученной модели и model.load_state_dict () для загрузки сохраненной модели.

Я сталкивался с этим обсуждением, где подход 2 рекомендуется по подходу 1

У меня вопрос, почему второй подход предпочтительнее? Только потому, что модули torch.nn имеют эти две функции, и нам рекомендуется их использовать?


2
Я думаю, это потому, что torch.save () также сохраняет все промежуточные переменные, например, промежуточные выходные данные для обратного распространения. Но вам нужно только сохранить параметры модели, такие как вес / смещение и т. Д. Иногда первый может быть намного больше, чем второй.
Давэй Ян

2
Я проверял torch.save(model, f)и torch.save(model.state_dict(), f). Сохраненные файлы имеют одинаковый размер. Теперь я в замешательстве. Кроме того, я обнаружил, что использование pickle для сохранения model.state_dict () очень медленно. Я думаю, что лучший способ - использовать, torch.save(model.state_dict(), f)так как вы управляете созданием модели, а torch обрабатывает загрузку весов модели, таким образом устраняя возможные проблемы. Ссылка: обсуждения.pytorch.org/t/saving-torch-models/838/4
Давэй Ян

Похоже, что PyTorch рассмотрел этот вопрос более подробно в своем разделе учебных пособий - там есть много полезной информации, которой нет в ответах, включая сохранение более одной модели за раз и теплый запуск моделей.
whlteXbread

что не так с использованием pickle?
Чарли Паркер

1
@CharlieParker torch.save основан на маринаде. Следующее из вышеприведенного учебника: «[torch.save] сохранит весь модуль, используя модуль pickle в Python. Недостатком этого подхода является то, что сериализованные данные связаны с конкретными классами и точной структурой каталогов, используемой при моделировании. Причиной этого является то, что pickle не сохраняет сам класс модели. Скорее, он сохраняет путь к файлу, содержащему класс, который используется во время загрузки. Из-за этого ваш код может сломаться различными способами, когда используется в других проектах или после рефакторинга. "
Дэвид Миллер

Ответы:


215

Я нашел эту страницу в их репозитории на github, я просто вставлю сюда содержимое.


Рекомендуемый подход к сохранению модели

Существует два основных подхода к сериализации и восстановлению модели.

Первый (рекомендуется) сохраняет и загружает только параметры модели:

torch.save(the_model.state_dict(), PATH)

Тогда позже:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Вторая сохраняет и загружает всю модель:

torch.save(the_model, PATH)

Тогда позже:

the_model = torch.load(PATH)

Однако в этом случае сериализованные данные привязываются к конкретным классам и конкретной используемой структуре каталогов, поэтому они могут ломаться различными способами при использовании в других проектах или после некоторых серьезных рефакторингов.


8
Согласно @smth обсуждения.pytorch.org/ t / saving- and- loading- a- model- in- pytorch/… модель перезагружается в модель поезда по умолчанию. поэтому нужно вручную вызывать the_model.eval () после загрузки, если вы загружаете его для вывода, а не для возобновления обучения.
WillZ

второй метод дает stackoverflow.com/questions/53798009/… ошибка в windows 10. не смог ее решить
Гульзар

Есть ли возможность сохранить без необходимости доступа к модели класса?
Майкл Д

При таком подходе, как вы отслеживаете * args и ** kwargs, которые нужно передать для случая загрузки?
Мариано Камп

что не так с использованием pickle?
Чарли Паркер

145

Это зависит от того, что вы хотите сделать.

Случай № 1: Сохраните модель, чтобы использовать ее для вывода : вы сохраняете модель, восстанавливаете ее, а затем переводите модель в режим оценки. Это сделано потому, что у вас обычно есть BatchNormи Dropoutслои, которые по умолчанию находятся в режиме поезда на строительстве:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Случай № 2. Сохранение модели для возобновления обучения позже . Если вам нужно продолжить обучение модели, которую вы собираетесь сохранить, вам нужно сохранить больше, чем просто модель. Вам также нужно сохранить состояние оптимизатора, эпох, счета и т. Д. Вы бы сделали это так:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Чтобы возобновить обучение, вы должны сделать что-то вроде:, state = torch.load(filepath)а затем, чтобы восстановить состояние каждого отдельного объекта, что-то вроде этого:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Так как вы возобновляете обучение, НЕ звоните, как model.eval()только вы восстанавливаете состояния при загрузке.

Случай № 3: Модель для использования кем-то другим, не имеющим доступа к вашему коду : В Tensorflow вы можете создать .pbфайл, который определяет как архитектуру, так и вес модели. Это очень удобно, особенно при использовании Tensorflow serve. Эквивалентный способ сделать это в Pytorch будет:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Этот способ по-прежнему не является пуленепробиваемым, и поскольку в pytorch все еще происходит множество изменений, я бы не рекомендовал это делать.


1
Есть ли рекомендуемое окончание файла для 3 случаев? Или это всегда .pth?
Верена

1
В случае № 3 torch.loadвозвращается только OrderedDict. Как вы получаете модель для того, чтобы делать прогнозы?
Alber8295

Привет, могу ли я узнать, как выполнить упомянутое «Дело № 2: Сохранить модель, чтобы продолжить обучение позже»? Мне удалось загрузить контрольную точку в модель, а затем я не смог запустить или возобновить тренировку модели, например "model.to (device) model = train_model_epoch (model, критерий, оптимизатор, sched, эпохи)"
dnez

1
Привет, для случая, который предназначен для вывода, в официальном документе Pytorch говорят, что должны сохранить оптимизатор state_dict для вывода или завершения обучения. «При сохранении общей контрольной точки, которая будет использоваться либо для вывода, либо для возобновления обучения, вы должны сохранить больше, чем просто state_dict модели. Важно также сохранить state_dict оптимизатора, поскольку он содержит буферы и параметры, которые обновляются по мере обучения модели. . "
Мохаммед Авни

1
В случае № 3 класс модели должен быть где-то определен.
Майкл Д

12

Библиотека Python pickle реализует двоичные протоколы для сериализации и десериализации объекта Python.

Когда вы import torch(или когда вы используете PyTorch) это будет import pickleдля вас, и вам не нужно вызывать pickle.dump()и pickle.load()напрямую, которые являются методами для сохранения и загрузки объекта.

На самом деле torch.save()и torch.load()заверну pickle.dump()и pickle.load()для вас.

А state_dictдругой упомянутый ответ заслуживает еще несколько заметок.

Что state_dictу нас внутри PyTorch? На самом деле есть два state_dictс.

Модель PyTorch torch.nn.Moduleимеет model.parameters()вызов для получения обучаемых параметров (w и b). Эти обучаемые параметры, однажды установленные случайным образом, будут обновляться с течением времени по мере нашего обучения. Изучаемые параметры являются первыми state_dict.

Второе state_dict- это диктатор состояния оптимизатора. Вы помните, что оптимизатор используется для улучшения наших усваиваемых параметров. Но оптимизатор state_dictисправлен. Там нечему учиться.

Поскольку state_dictобъекты являются словарями Python, их можно легко сохранять, обновлять, изменять и восстанавливать, добавляя большую модульность моделям и оптимизаторам PyTorch.

Давайте создадим супер простую модель, чтобы объяснить это:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Этот код выведет следующее:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Обратите внимание, что это минимальная модель. Вы можете попробовать добавить стек последовательных

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Обратите внимание, что только слои с усваиваемыми параметрами (сверточные слои, линейные слои и т. Д.) И зарегистрированные буферы (слои группового набора) имеют записи в модели state_dict.

Неизучаемые вещи принадлежат объекту оптимизатора state_dict, который содержит информацию о состоянии оптимизатора, а также используемые гиперпараметры.

Остальная часть истории такая же; на этапе вывода (это этап, когда мы используем модель после обучения) для прогнозирования; мы делаем прогноз на основе параметров, которые мы узнали. Поэтому для вывода нам просто нужно сохранить параметры model.state_dict().

torch.save(model.state_dict(), filepath)

И использовать позже model.load_state_dict (torch.load (filepath)) model.eval ()

Примечание: не забывайте последнюю строку, model.eval()это важно после загрузки модели.

Также не пытайтесь сохранить torch.save(model.parameters(), filepath). Это model.parameters()просто объект генератора.

С другой стороны, torch.save(model, filepath)сохраняет сам объект модели, но имейте в виду, что модель не имеет оптимизатора state_dict. Посмотрите другой отличный ответ @Jadiel de Armas, чтобы сохранить информацию о состоянии оптимизатора.


Хотя это не простое решение, суть проблемы глубоко проанализирована! Upvote.
Джейсон Янг

7

Общепринятым соглашением PyTorch является сохранение моделей с использованием расширения файлов .pt или .pth.

Сохранить / загрузить всю модель Сохранить:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Загрузить:

Класс модели должен быть определен где-то

model = torch.load(PATH)
model.eval()

4

Если вы хотите сохранить модель и хотите продолжить обучение позже:

Одиночный графический процессор: Сохранить:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Загрузить:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Несколько GPU: Сохранить

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Загрузить:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.