TensorFlow, почему после сохранения модели остается 3 файла?


114

Прочитав документацию , я сохранил модель TensorFlow, вот мой демонстрационный код:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

но после этого я обнаружил, что есть 3 файла

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

И я не могу восстановить модель путем восстановления model.ckptфайла, так как такого файла нет. Вот мой код

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Итак, почему здесь 3 файла?


2
Вы придумали, как с этим справиться? Как мне снова загрузить модель (используя Keras)?
rajkiran

Ответы:


117

Попробуй это:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

Метод сохранения TensorFlow сохраняет файлы трех типов, поскольку он хранит структуру графика отдельно от значений переменных . .metaФайл описывает сохраненную структуру графа, так что вам нужно импортировать его перед восстановлением контрольной точки ( в противном случае он не знает , какие переменные сохраненные значения контрольных точек соответствуют).

В качестве альтернативы вы можете сделать это:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Несмотря на то, что имени файла нет model.ckpt, вы все равно будете ссылаться на сохраненную контрольную точку с этим именем при ее восстановлении. Из saver.pyисходного кода :

Пользователям нужно только взаимодействовать с указанным пользователем префиксом ... вместо любого физического пути.


1
Значит, .index и .data не используются? Когда же тогда используются эти 2 файла?
ajfbiw.s

26
@ ajfbiw.s .meta хранит структуру графика, .data хранит значения каждой переменной в графе, .index идентифицирует контрольную точку. Итак, в приведенном выше примере: import_meta_graph использует .meta, а saver.restore использует .data и .index
TK Bartel

О, я вижу. Спасибо.
ajfbiw.s

1
Есть ли шанс, что вы сохранили модель с другой версией TensorFlow, чем вы использовали для ее загрузки? ( github.com/tensorflow/tensorflow/issues/5639 )
TK Bartel

5
Кто-нибудь знает, что это значит 00000и 00001цифры? в variables.data-?????-of-?????файле
Иван Талалаев

55
  • метафайл : описывает сохраненную структуру графа, включает GraphDef, SaverDef и так далее; потом применим tf.train.import_meta_graph('/tmp/model.ckpt.meta'), восстановлю Saverи Graph.

  • индексный файл : это неизменяемая таблица типа строка-строка (tensorflow :: table :: Table). Каждый ключ - это имя тензора, а его значение - это сериализованный BundleEntryProto. Каждый BundleEntryProto описывает метаданные тензора: какой из файлов «данных» содержит содержимое тензора, смещение в этом файле, контрольную сумму, некоторые вспомогательные данные и т. Д.

  • файл данных : это коллекция TensorBundle, сохраните значения всех переменных.


У меня есть pb-файл для классификации изображений. Могу ли я использовать его для классификации видео в реальном времени?

Не могли бы вы сообщить мне, используя Keras 2, как мне загрузить модель, если она сохранена как 3 файла?
rajkiran

5

Я восстанавливаю обученные вложения слов из учебника по тензорному потоку Word2Vec.

Если вы создали несколько контрольных точек:

например, созданные файлы выглядят так

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

попробуй это

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

при вызове restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")

Что означает "00000-of-00001" в "model.ckpt-55695.data-00000-of-00001"?
hafiz031,

0

Если вы, например, тренировали CNN с отсеиванием, вы могли бы сделать это:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.