TensorFlow сохраняет / загружает график из файла


98

Из того, что я собрал до сих пор, существует несколько различных способов сбросить график TensorFlow в файл и затем загрузить его в другую программу, но я не смог найти четких примеров / информации о том, как они работают. Я уже знаю следующее:

  1. Сохраните переменные модели в файл контрольной точки (.ckpt) с помощью tf.train.Saver()и восстановите их позже ( источник )
  2. Сохраните модель в файл .pb и загрузите ее обратно с помощью tf.train.write_graph()и tf.import_graph_def()( источник )
  3. Загрузите модель из файла .pb, переобучите ее и выгружайте в новый файл .pb с помощью Bazel ( источник )
  4. Заморозьте график, чтобы сохранить график и веса вместе ( источник )
  5. Используйте as_graph_def()для сохранения модели, а для весов / переменных сопоставьте их с константами ( источник )

Однако мне не удалось прояснить несколько вопросов, касающихся этих различных методов:

  1. Что касается файлов контрольных точек, они сохраняют только обученные веса модели? Могут ли файлы контрольных точек быть загружены в новую программу и использоваться для запуска модели, или они просто служат в качестве способов сохранения весов в модели в определенное время / этап?
  2. Что касается tf.train.write_graph(), сохраняются ли веса / переменные?
  3. Что касается Базела, может он только для переподготовки сохранять / загружать из .pb файлов? Есть ли простая команда Bazel для выгрузки графика в .pb?
  4. Что касается замораживания, можно ли загрузить замороженный график с помощью tf.import_graph_def()?
  5. Демонстрация Android для TensorFlow загружается в модель Google Inception из файла .pb. Если бы я хотел заменить свой собственный файл .pb, как бы я это сделал? Нужно ли мне менять какой-либо собственный код / ​​методы?
  6. В общем, в чем именно разница между всеми этими методами? Или, в более широком смысле, в чем разница между as_graph_def()/.ckpt/.pb?

Короче говоря, я ищу способ сохранить как график (например, различные операции и т. Д.), Так и его веса / переменные в файл, который затем можно использовать для загрузки графика и весов в другую программу. , для использования (не обязательно для продолжения / переподготовки).

Документация по этой теме не очень проста, поэтому мы будем благодарны за любые ответы / информацию.


2
Самый новый / наиболее полный API - это мета-график, который дает вам возможность сохранить все три сразу - 1) график 2) значения параметров 3) коллекции: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Ярослав Булатов

Ответы:


80

Есть много способов подойти к проблеме сохранения модели в TensorFlow, что может немного сбить с толку. Рассмотрим по очереди каждый из ваших подвопросов:

  1. Файлы контрольных точек (например , производится путем вызова saver.save()на tf.train.Saverобъект) содержат только веса, и любые другие переменные , определенные в одной и той же программе. Чтобы использовать их в другой программе, вы должны воссоздать связанную структуру графа (например, запустив код для ее повторного построения или вызвав tf.import_graph_def()), которая сообщает TensorFlow, что делать с этими весами. Обратите внимание, что при вызове saver.save()также создается файл, содержащий MetaGraphDefграф и подробности того, как связать веса из контрольной точки с этим графом. Смотрите руководство для более подробной информации.

  2. tf.train.write_graph()пишет только структуру графа; не веса.

  3. Bazel не имеет отношения к чтению или написанию графиков TensorFlow. (Возможно, я неправильно понял ваш вопрос: не стесняйтесь прояснить его в комментарии.)

  4. Замороженный график можно загрузить с помощью tf.import_graph_def(). В этом случае веса (обычно) встроены в график, поэтому вам не нужно загружать отдельную контрольную точку.

  5. Основное изменение будет заключаться в обновлении имен тензоров, которые вводятся в модель, и имен тензоров, которые извлекаются из модели. В демке TensorFlow Android, это будет соответствовать inputNameи outputNameстроки, которые передаются TensorFlowClassifier.initializeTensorFlow().

  6. Это GraphDefструктура программы, которая обычно не меняется в процессе обучения. Контрольная точка - это моментальный снимок состояния тренировочного процесса, который обычно изменяется на каждом этапе тренировочного процесса. В результате TensorFlow использует разные форматы хранения для этих типов данных, а низкоуровневый API предоставляет разные способы их сохранения и загрузки. Библиотеки более высокого уровня, такие как MetaGraphDefбиблиотеки, Keras и skflow на основе этих механизмов , чтобы обеспечить более удобные способы сохранения и восстановления целой модели.


Означает ли это, что документация C ++ API лжет, когда говорится, что вы можете загрузить сохраненный граф tf.train.write_graph()и затем выполнить его?
mnicky

2
Документация C ++ API не врет, но в ней отсутствуют некоторые детали. Самая важная деталь заключается в том, что в дополнение к GraphDefсохраненному tf.train.write_graph(), вам также необходимо запомнить имена тензоров, которые вы хотите кормить и получать при выполнении графика (пункт 5 выше).
мрри

@mrry: Я пробовал использовать пример с тензорными потоками DeepDream. но вроде нужны предварительно обученные модели в формате pb! Я запустил пример Cifar10, но он создает только контрольные точки! Я не мог найти ни pb-файлов, ни чего-то еще! как я могу преобразовать свои контрольные точки в формат pb, который используется в примере deepdream?
Rika

2
@ Coderx7 Я действительно думаю, что вы не можете преобразовать .ckpt в .pb, поскольку контрольная точка содержит только веса и переменные и ничего не знает о структуре графа
Давидивад

1
есть ли простой код для загрузки файла .pb, а затем его запуска?
Kong

1

Вы можете попробовать следующий код:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.