Tensorflow: как сохранить / восстановить модель?


553

После того, как вы тренируете модель в Tensorflow:

  1. Как сохранить обученную модель?
  2. Как вы позже восстановите эту сохраненную модель?

Удалось ли вам восстановить переменные, использованные в начальной модели? Я также пытаюсь решить ту же проблему, но не могу написать набор переменных, которые использовались при обучении начальной модели (у меня есть файл ckpt)
exAres

Я не пробовал с начальной моделью. У вас есть структура сети модели с ее названиями? Вы должны реплицировать сеть, а затем загрузить весовые коэффициенты и смещения (файл ckpt), как объясняет Райан. Может быть, что-то изменилось с 15 ноября, и теперь есть более простой подход, я не уверен
mathetes

Ох, хорошо. Ранее я загружал другие предварительно обученные модели тензорного потока, но искал переменные спецификации начальной модели. Спасибо.
exAres

1
Если вы восстановите, чтобы продолжить тренировку, просто используйте контрольные точки Saver. Если вы сохраните модель, чтобы сделать ссылку, просто API-интерфейсы tenorsflow SavedModel.
HY G

Также, если вы используете LSTM, у вас будет карта от строки до списка символов, обязательно сохраните и загрузите этот список в том же порядке! Это не покрывается сохранением весов модели и сети графов моделей и создает впечатление, что ваша модель не была загружена при смене сеансов или данных.
Devssh

Ответы:


119

Документы

Из документов:

Сохранить

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)

Восстановить

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

Tensorflow 2

Это все еще бета, поэтому я бы посоветовал против пока. Если вы все еще хотите пойти по этому пути, вот руководство по tf.saved_modelиспользованию

Тензор потока <2

simple_save

Много хорошего ответа, для полноты я добавлю свои 2 цента: simple_save . Также автономный пример кода с использованием tf.data.DatasetAPI.

Python 3; Тензор потока 1.14

import tensorflow as tf
from tensorflow.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Session() as sess:
        ...

        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )

Восстановление:

graph = tf.Graph()
with restored_graph.as_default():
    with tf.Session() as sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
            'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })

Автономный пример

Оригинальное сообщение в блоге

Следующий код генерирует случайные данные ради демонстрации.

  1. Мы начнем с создания заполнителей. Они будут хранить данные во время выполнения. Из них мы создаем, Datasetа затем ее Iterator. Мы получаем сгенерированный тензор итератора, input_tensorкоторый называется входом для нашей модели.
  2. Сама модель построена из input_tensor: двунаправленного RNN на основе GRU, за которым следует плотный классификатор. Потому что почему бы и нет.
  3. Потеря softmax_cross_entropy_with_logitsоптимизирована с Adam. После 2 эпох (по 2 партии в каждой) мы сохраняем «обученную» модель с помощью tf.saved_model.simple_save. Если вы запустите код как есть, то модель будет сохранена в папке, которая называется simple/в вашем текущем рабочем каталоге.
  4. В новом графике мы затем восстановим сохраненную модель с помощью tf.saved_model.loader.load. Мы берем заполнители и логиты с graph.get_tensor_by_nameи Iteratorинициализирующую операцию с graph.get_operation_by_name.
  5. Наконец, мы запускаем логический вывод для обоих пакетов в наборе данных и проверяем, что обе сохраненные и восстановленные модели дают одинаковые значения. Они делают!

Код:

import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants


def model(graph, input_tensor):
    """Create the model which consists of
    a bidirectional rnn (GRU(10)) followed by a dense classifier

    Args:
        graph (tf.Graph): Tensors' graph
        input_tensor (tf.Tensor): Tensor fed as input to the model

    Returns:
        tf.Tensor: the model's output layer Tensor
    """
    cell = tf.nn.rnn_cell.GRUCell(10)
    with graph.as_default():
        ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell,
            cell_bw=cell,
            inputs=input_tensor,
            sequence_length=[10] * 32,
            dtype=tf.float32,
            swap_memory=True,
            scope=None)
        outputs = tf.concat((fw_outputs, bw_outputs), 2)
        mean = tf.reduce_mean(outputs, axis=1)
        dense = tf.layers.dense(mean, 5, activation=None)

        return dense


def get_opt_op(graph, logits, labels_tensor):
    """Create optimization operation from model's logits and labels

    Args:
        graph (tf.Graph): Tensors' graph
        logits (tf.Tensor): The model's output without activation
        labels_tensor (tf.Tensor): Target labels

    Returns:
        tf.Operation: the operation performing a stem of Adam optimizer
    """
    with graph.as_default():
        with tf.variable_scope('loss'):
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                    logits=logits, labels=labels_tensor, name='xent'),
                    name="mean-xent"
                    )
        with tf.variable_scope('optimizer'):
            opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
        return opt_op


if __name__ == '__main__':
    # Set random seed for reproducibility
    # and create synthetic data
    np.random.seed(0)
    features = np.random.randn(64, 10, 30)
    labels = np.eye(5)[np.random.randint(0, 5, (64,))]

    graph1 = tf.Graph()
    with graph1.as_default():
        # Random seed for reproducibility
        tf.set_random_seed(0)
        # Placeholders
        batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
        features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
        labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
        # Dataset
        dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
        dataset = dataset.batch(batch_size_ph)
        iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
        dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
        input_tensor, labels_tensor = iterator.get_next()

        # Model
        logits = model(graph1, input_tensor)
        # Optimization
        opt_op = get_opt_op(graph1, logits, labels_tensor)

        with tf.Session(graph=graph1) as sess:
            # Initialize variables
            tf.global_variables_initializer().run(session=sess)
            for epoch in range(3):
                batch = 0
                # Initialize dataset (could feed epochs in Dataset.repeat(epochs))
                sess.run(
                    dataset_init_op,
                    feed_dict={
                        features_data_ph: features,
                        labels_data_ph: labels,
                        batch_size_ph: 32
                    })
                values = []
                while True:
                    try:
                        if epoch < 2:
                            # Training
                            _, value = sess.run([opt_op, logits])
                            print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
                            batch += 1
                        else:
                            # Final inference
                            values.append(sess.run(logits))
                            print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
                            batch += 1
                    except tf.errors.OutOfRangeError:
                        break
            # Save model state
            print('\nSaving...')
            cwd = os.getcwd()
            path = os.path.join(cwd, 'simple')
            shutil.rmtree(path, ignore_errors=True)
            inputs_dict = {
                "batch_size_ph": batch_size_ph,
                "features_data_ph": features_data_ph,
                "labels_data_ph": labels_data_ph
            }
            outputs_dict = {
                "logits": logits
            }
            tf.saved_model.simple_save(
                sess, path, inputs_dict, outputs_dict
            )
            print('Ok')
    # Restoring
    graph2 = tf.Graph()
    with graph2.as_default():
        with tf.Session(graph=graph2) as sess:
            # Restore saved values
            print('\nRestoring...')
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
                path
            )
            print('Ok')
            # Get restored placeholders
            labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
            features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
            batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
            # Get restored model output
            restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
            # Get dataset initializing operation
            dataset_init_op = graph2.get_operation_by_name('dataset_init')

            # Initialize restored dataset
            sess.run(
                dataset_init_op,
                feed_dict={
                    features_data_ph: features,
                    labels_data_ph: labels,
                    batch_size_ph: 32
                }

            )
            # Compute inference for both batches in dataset
            restored_values = []
            for i in range(2):
                restored_values.append(sess.run(restored_logits))
                print('Restored values: ', restored_values[i][0])

    # Check if original inference and restored inference are equal
    valid = all((v == rv).all() for v, rv in zip(values, restored_values))
    print('\nInferences match: ', valid)

Это напечатает:

$ python3 save_and_restore.py

Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]
Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]
Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]
Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]
Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Saving...
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
Ok

Restoring...
INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
Ok
Restored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Inferences match:  True

1
Я новичок, и мне нужно больше объяснений ...: Если у меня есть модель CNN, я должен хранить только 1. input_placeholder 2. label_placeholder и 3. output_of_cnn? Или все промежуточные tf.contrib.layers?
дождь

2
График полностью восстановлен. Вы можете проверить это работает [n.name for n in graph2.as_graph_def().node]. Как сказано в документации, простое сохранение направлено на упрощение взаимодействия с обслуживанием тензорного потока, в этом суть аргументов; другие переменные, тем не менее, все еще восстанавливаются, иначе вывод не произойдет. Просто возьмите интересующие вас переменные, как я сделал в примере. Проверьте документацию
Тед

@ted когда я буду использовать tf.saved_model.simple_save против tf.train.Saver ()? По своей интуиции я использовал tf.train.Saver () во время тренировок и для сохранения разных моментов времени. Я бы использовал tf.saved_model.simple_save, когда обучение сделано для использования в производстве. (Я спросил то же самое и в комментарии здесь )
loco.loop

1
Хорошо, я думаю, но это также работает с моделями в режиме Eager и tfe.Saver?
Джеффри Андерсон

1
без global_stepаргумента, если вы остановитесь, а затем попытаетесь возобновить тренировку, он будет думать, что вы на шаг впереди. Это как минимум
испортит

252

Я улучшаю свой ответ, чтобы добавить больше деталей для сохранения и восстановления моделей.

В (и после) версии 0.11 Tensorflow :

Сохранить модель:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

Восстановить модель:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

Этот и некоторые более продвинутые варианты использования были очень хорошо объяснены здесь.

Краткое полное руководство по сохранению и восстановлению моделей Tensorflow


3
+1 для этого # Доступ сохранен Переменные напрямую печатаются (sess.run ('bias: 0')) # Это выведет 2, которое является значением смещения, которое мы сохранили. Это очень помогает в целях отладки, чтобы увидеть, правильно ли загружена модель. переменные могут быть получены с помощью «All_varaibles = tf.get_collection (tf.GraphKeys.GLOBAL_VARIABLES». Кроме того, «sess.run (tf.global_variables_initializer ())» должен быть перед восстановлением.
LGG

1
Вы уверены, что нам нужно снова запустить global_variables_initializer? Я восстановил свой график с помощью global_variable_initialization, и он каждый раз дает разные результаты для одних и тех же данных. Поэтому я закомментировал инициализацию и просто восстановил график, входную переменную и операции, и теперь он работает нормально.
Адитья Шинде,

@AdityaShinde Я не понимаю, почему я всегда получаю разные значения каждый раз. И я не включил шаг инициализации переменной для восстановления. Я использую свой собственный код между прочим.
Чейн

@AdityaShinde: вам не нужна инициализация, так как значения уже инициализированы функцией восстановления, поэтому удалите ее. Однако я не уверен, почему вы получили другой вывод, используя init op.
Санкит

5
@sankit Когда вы восстанавливаете тензоры, почему вы добавляете :0к именам?
Сахар Рабиновиз

177

В (и после) TensorFlow версии 0.11.0RC1, вы можете сохранить и восстановить вашу модель непосредственно по телефону tf.train.export_meta_graphи в tf.train.import_meta_graphсоответствии с https://www.tensorflow.org/programmers_guide/meta_graph .

Сохранить модель

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

Восстановить модель

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

4
Как загрузить переменные из сохраненной модели? Как скопировать значения в другую переменную?
neel

9
Я не могу заставить этот код работать. Модель действительно сохраняется, но я не могу ее восстановить. Это дает мне эту ошибку. <built-in function TF_Run> returned a result with an error set
Саад Куреши

2
Когда после восстановления я получаю доступ к переменным, как показано выше, это работает. Но я не могу получить переменные более напрямую, используя tf.get_variable_scope().reuse_variables()затем var = tf.get_variable("varname"). Это дает мне ошибку: «ValueError: Переменная varname не существует или не была создана с помощью tf.get_variable ()». Почему? Разве это не возможно?
Иоганн Петрак

4
Это хорошо работает только для переменных, но как вы можете получить доступ к заполнителю и передать ему значения после восстановления графика?
kbrose

11
Это только показывает, как восстановить переменные. Как вы можете восстановить всю модель и протестировать ее на новых данных, не переопределяя сеть?
Chaine

127

Для версии TensorFlow <0.11.0RC1:

Сохраненные контрольные точки содержат значения для Variables в вашей модели, а не саму модель / график, что означает, что график должен быть таким же, когда вы восстанавливаете контрольную точку.

Вот пример для линейной регрессии, где есть обучающий цикл, который сохраняет переменные контрольные точки, и раздел оценки, который будет восстанавливать переменные, сохраненные в предыдущем прогоне, и вычислять прогнозы. Конечно, вы также можете восстановить переменные и продолжить обучение, если хотите.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

Вот документы для Variables, которые охватывают сохранение и восстановление. А вот документы для Saver.


1
ФЛАГИ определяются пользователем. Вот пример их определения: github.com/tensorflow/tensorflow/blob/master/tensorflow/…
Райан Сепасси

в каком формате batch_xнужно быть? Binary? Numpy массив?
Пепе

@pepe Numpy Arrary должно быть хорошо. И тип элемента должен соответствовать типу заполнителя. [link] tenorflow.org/versions/r0.9/api_docs/python/…
Донни

ФЛАГИ дают ошибку undefined. Можете ли вы сказать мне, что является определением FLAGS для этого кода. @RyanSepassi
Мухаммед Ханнан,

Для того, чтобы сделать его явным: Последние версии Tensorflow действительно позволяют сохранить модель / график. [Мне было неясно, какие аспекты ответа относятся к ограничению <0.11. Учитывая большое количество голосов, я испытал искушение полагать, что это общее утверждение все еще верно для последних версий.]
bluenote10

78

Моя среда: Python 3.6, Tensorflow 1.3.0

Хотя было много решений, большинство из них основано на tf.train.Saver. Когда мы загружаем .ckptспасены Saver, мы должны либо пересмотреть сеть tensorflow или использовать какое - то странное и с трудом вспомнил имя, например 'placehold_0:0', 'dense/Adam/Weight:0'. Здесь я рекомендую использовать tf.saved_modelодин из простейших примеров, приведенных ниже, вы можете узнать больше об обслуживании модели TensorFlow :

Сохранить модель:

import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

Загрузите модель:

import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})

4
+1 за отличный пример API SavedModel. Тем не менее, я бы хотел, чтобы в разделе « Сохранить модель » была показана тренировочная петля, подобная ответу Райана Сепасси! Я понимаю, что это старый вопрос, но этот ответ является одним из немногих (и ценных) примеров SavedModel, которые я нашел в Google.
Дилан Ф

@ Это отличный ответ - только один, нацеленный на новую SavedModel. Не могли бы вы взглянуть на этот вопрос SavedModel? stackoverflow.com/questions/48540744/…
bluesummers

Теперь заставьте все это работать корректно с моделями TF Eager. В своей презентации 2018 года Google посоветовал всем уйти от графического кода TF.
Джеффри Андерсон

55

Модель состоит из двух частей: определение модели, сохраненное Supervisorкак graph.pbtxtв каталоге модели, и числовые значения тензоров, сохраненные в файлах контрольных точек, например model.ckpt-1003418.

Определение модели может быть восстановлено с помощью tf.import_graph_def, а веса восстановлены с помощью Saver.

Тем не менее, Saverиспользуется специальная коллекция, содержащая список переменных, которые прикреплены к модели Graph, и эта коллекция не инициализируется с помощью import_graph_def, поэтому вы не можете использовать их вместе в данный момент (это исправлено в нашей дорожной карте). На данный момент вы должны использовать подход Райана Сепасси - вручную построить график с одинаковыми именами узлов и использовать Saverдля загрузки в него весов.

(В качестве альтернативы вы можете взломать его, используя, используя import_graph_def, создавая переменные вручную и используя tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)для каждой переменной, затем используя Saver)


В примере classify_image.py, который использует inceptionv3, загружается только graphdef. Означает ли это, что теперь GraphDef также содержит переменную?
jrabary

1
@jrabary Модель, вероятно, была заморожена .
Эрик Платон

1
Эй, я новичок в tenorflow и у меня проблемы с сохранением моей модели. Я был бы очень признателен, если бы вы могли помочь мне stackoverflow.com/questions/48083474/…
Ruchir Baronia

39

Вы также можете пойти по этому более простому пути.

Шаг 1: инициализируйте все ваши переменные

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

Шаг 2: сохранить сессию внутри модели Saverи сохранить ее

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

Шаг 3: восстановить модель

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

Шаг 4: проверьте вашу переменную

W1 = session.run(W1)
print(W1)

Работая в другом экземпляре Python, используйте

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

Привет, Как я могу сохранить модель после 3000 итераций, как в Caffe. Я обнаружил, что тензорный поток сохраняет только последние модели, несмотря на то, что я объединяю номер итерации с моделью, чтобы дифференцировать его среди всех итераций. Я имею в виду model_3000.ckpt, model_6000.ckpt, --- model_100000.ckpt. Можете ли вы объяснить, почему он не сохраняет все, а сохраняет только последние 3 итерации.
хан


3
Есть ли способ получить все имена переменных / операций, сохраненные в графе?
Мундра

21

В большинстве случаев сохранение и восстановление с диска - tf.train.Saverэто ваш лучший вариант:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

Вы также можете сохранить / восстановить саму структуру графика (подробности см. В документации MetaGraph ). По умолчанию Saverструктура графика сохраняется в .metaфайл. Вы можете позвонить, import_meta_graph()чтобы восстановить его. Он восстанавливает структуру графа и возвращает значение, Saverкоторое вы можете использовать для восстановления состояния модели:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

Однако есть случаи, когда вам нужно что-то гораздо быстрее. Например, если вы реализуете раннюю остановку, вы хотите сохранять контрольные точки каждый раз, когда модель улучшается во время обучения (как измерено в наборе проверки), а затем, если в течение некоторого времени нет прогресса, вы захотите вернуться к лучшей модели. Если вы будете сохранять модель на диск каждый раз, когда она улучшается, это значительно замедлит процесс обучения. Хитрость заключается в том, чтобы сохранить состояния переменных в памяти , а затем просто восстановить их позже:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

Краткое объяснение: когда вы создаете переменную X, TensorFlow автоматически создает операцию присваивания, X/Assignчтобы установить начальное значение переменной. Вместо того, чтобы создавать заполнители и дополнительные операции присваивания (которые просто испортили бы график), мы просто используем эти существующие операции присваивания. Первый вход каждого присваивания op является ссылкой на переменную, которую он должен инициализировать, а второй input ( assign_op.inputs[1]) является начальным значением. Таким образом, чтобы установить любое значение, которое мы хотим (вместо начального значения), нам нужно использовать a feed_dictи заменить начальное значение. Да, TensorFlow позволяет указывать значение для любой операции, а не только для заполнителей, так что это прекрасно работает.


Спасибо за ответ. У меня похожий вопрос о том, как преобразовать один файл .ckpt в два .index и .data (скажем, для предварительно обученных начальных моделей, доступных на tf.slim). Мой вопрос здесь: stackoverflow.com/questions/47762114/…
Амир

Эй, я новичок в tenorflow и у меня проблемы с сохранением моей модели. Я был бы очень признателен, если бы вы могли помочь мне stackoverflow.com/questions/48083474/…
Ruchir Baronia

17

Как сказал Ярослав, вы можете взломать восстановление из graph_def и контрольной точки, импортировав график, создав переменные вручную, а затем используя Saver.

Я реализовал это для личного использования, поэтому я решил поделиться этим кодом здесь.

Ссылка: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(Это, конечно, взлом, и нет никакой гарантии, что модели, сохраненные таким образом, останутся читаемыми в будущих версиях TensorFlow.)


14

Если это внутренне сохраненная модель, вы просто указываете восстановитель для всех переменных как

restorer = tf.train.Saver(tf.all_variables())

и использовать его для восстановления переменных в текущем сеансе:

restorer.restore(self._sess, model_file)

Для внешней модели вам необходимо указать соответствие имен ее переменных именам ваших переменных. Вы можете просмотреть имена переменных модели, используя команду

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

Сценарий inspect_checkpoint.py можно найти в папке «./tensorflow/python/tools» источника Tensorflow.

Чтобы указать отображение, вы можете использовать мой Tensorflow-Worklab , который содержит набор классов и сценариев для обучения и переподготовки различных моделей. Включает пример переподготовки моделей ResNet, расположенный здесь


all_variables()сейчас устарела
MiniQuark

Эй, я новичок в tenorflow и у меня проблемы с сохранением моей модели. Я был бы очень признателен, если бы вы могли помочь мне stackoverflow.com/questions/48083474/…
Ruchir Baronia

12

Вот мое простое решение для двух основных случаев, отличающихся тем, хотите ли вы загрузить график из файла или построить его во время выполнения.

Этот ответ верен для Tensorflow 0.12+ (включая 1.0).

Восстановление графика в коде

Сохранение

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

загрузка

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

Загрузка также графика из файла

При использовании этого метода убедитесь, что все ваши слои / переменные явно задают уникальные имена.В противном случае Tensorflow сам сделает имена уникальными, и поэтому они будут отличаться от имен, хранящихся в файле. В предыдущей технике это не проблема, поскольку имена «искажаются» одинаково как при загрузке, так и при сохранении.

Сохранение

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

загрузка

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

-1 Начинать свой ответ, отбрасывая «все остальные ответы здесь», довольно резко. Тем не менее, я проголосовал по другим причинам: вы должны обязательно сохранить все глобальные переменные, а не только обучаемые переменные. Например, global_stepпеременная и скользящие средние значения нормализации партии являются необучаемыми переменными, но обе они, безусловно, заслуживают сохранения. Кроме того, вы должны более четко отличать построение графика от запуска сеанса, например Saver(...).save(), каждый раз при его запуске будут создаваться новые узлы. Наверное, не то, что вы хотите. И это еще не все ...: /
MiniQuark

@MiniQuark ок, спасибо за ваш отзыв, я отредактирую ответ в соответствии с вашими предложениями;)
Мартин

10

Вы также можете проверить примеры в TensorFlow / skflow , который предлагает saveи restoreметоды, которые могут помочь вам легко управлять вашими моделями. Он имеет параметры, которые вы также можете контролировать, как часто вы хотите создавать резервные копии вашей модели.


9

Если вы используете tf.train.MonitoredTrainingSession в качестве сеанса по умолчанию, вам не нужно добавлять дополнительный код для сохранения / восстановления. Просто передайте имя dir контрольной точки в конструктор MonitoredTrainingSession, он будет использовать сессионные перехватчики для их обработки.


Использование tf.train.Supervisor поможет вам создать такой сеанс и предоставит более полное решение.
Марк

1
@Mark tf.train.Supervisor устарела
Changming Sun

Есть ли у вас какая-либо ссылка, подтверждающая утверждение, что Supervisor устарел? Я не видел ничего, что указывает на то, что это так.
Марк


Спасибо за URL - я проверил с оригинальным источником информации, и мне сказали, что, вероятно, он будет примерно до конца серии TF 1.x, но никаких гарантий после этого.
Марк

8

Все ответы здесь великолепны, но я хочу добавить две вещи.

Во-первых, чтобы уточнить ответ @ user7505159, важно добавить «./» в начало восстанавливаемого имени файла.

Например, вы можете сохранить график без "./" в имени файла следующим образом:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

Но чтобы восстановить график, вам может понадобиться добавить «./» к имени файла:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

Вам не всегда нужен «./», но это может вызвать проблемы в зависимости от вашей среды и версии TensorFlow.

Также необходимо упомянуть, что это sess.run(tf.global_variables_initializer())может быть важно перед восстановлением сеанса.

Если при попытке восстановить сохраненный сеанс вы получаете сообщение об ошибке в отношении неинициализированных переменных, убедитесь, что вы указали sess.run(tf.global_variables_initializer())перед saver.restore(sess, save_file)строкой. Это может спасти вас от головной боли.


7

Как описано в выпуске 6255 :

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')

вместо

saver.restore('my_model_final.ckpt')

7

Согласно новой версии Tensorflow, tf.train.Checkpointэто предпочтительный способ сохранения и восстановления модели:

Checkpoint.saveи Checkpoint.restoreзаписывать и читать объектные контрольные точки, в отличие от tf.train.Saver, который записывает и читает контрольные точки на основе variable.name. Объектная контрольная точка сохраняет график зависимостей между объектами Python (Layers, Optimizer, Variables и т. Д.) С именованными ребрами, и этот график используется для сопоставления переменных при восстановлении контрольной точки. Он может быть более устойчивым к изменениям в программе Python и помогает поддерживать восстановление при создании переменных при их активном выполнении. Предпочитаю tf.train.Checkpointболее tf.train.Saverдля нового кода .

Вот пример:

import tensorflow as tf
import os

tf.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Больше информации и пример здесь.


7

Для tenorflow 2.0 это так же просто, как

# Save the model
model.save('path_to_my_model.h5')

Для восстановления:

new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')

Как насчет всех пользовательских операций и переменных, которые не являются частью объекта модели? Будут ли они каким-то образом сохранены при вызове save () для модели? У меня есть различные пользовательские выражения потерь и тензорных вероятностей, которые используются в сети логического вывода и генерации, но они не являются частью моей модели. Мой объект модели keras содержит только плотные и конверсионные слои. В TF 1 я просто вызвал метод save, и я мог быть уверен, что все операции и тензоры, используемые в моем графике, будут сохранены. В TF2 я не вижу, как будут сохраняться операции, которые каким-то образом не добавляются в модель keras.
Кристоф

Есть ли еще информация о восстановлении моделей в TF 2.0? Я не могу восстановить веса из контрольных файлов , генерируемых с помощью C API, см: stackoverflow.com/questions/57944786/...
jregalad


5

tf.keras Сохранение модели с помощью TF2.0

Я вижу отличные ответы для сохранения моделей с использованием TF1.x. Я хочу предоставить еще несколько советов по сохранению tensorflow.kerasмоделей, что немного сложно, так как существует множество способов сохранить модель.

Здесь я приведу пример сохранения tensorflow.kerasмодели в model_pathпапку в текущем каталоге. Это хорошо работает с самым последним тензорным потоком (TF2.0). Я обновлю это описание, если будут какие-либо изменения в ближайшем будущем.

Сохранение и загрузка всей модели

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

#import data
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# create a model
def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
# compile the model
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  return model

# Create a basic model instance
model=create_model()

model.fit(x_train, y_train, epochs=1)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save entire model to a HDF5 file
model.save('./model_path/my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('./model_path/my_model.h5')
loss, acc = new_model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Сохранение и загрузка модели Только веса

Если вас интересует только сохранение весов моделей, а затем загрузка весов для восстановления модели, тогда

model.fit(x_train, y_train, epochs=5)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Сохранение и восстановление с помощью обратного вызова контрольной точки keras

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)

latest = tf.train.latest_checkpoint(checkpoint_dir)

new_model = create_model()
new_model.load_weights(latest)
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

сохранение модели с пользовательскими метриками

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Custom Loss1 (for example) 
@tf.function() 
def customLoss1(yTrue,yPred):
  return tf.reduce_mean(yTrue-yPred) 

# Custom Loss2 (for example) 
@tf.function() 
def customLoss2(yTrue, yPred):
  return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred))) 

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', customLoss1, customLoss2])
  return model

# Create a basic model instance
model=create_model()

# Fit and evaluate model 
model.fit(x_train, y_train, epochs=1)
loss, acc,loss1, loss2 = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

model.save("./model.h5")

new_model=tf.keras.models.load_model("./model.h5",custom_objects={'customLoss1':customLoss1,'customLoss2':customLoss2})

Сохранение модели keras с пользовательскими операциями

Когда у нас есть пользовательские операции, как в следующем случае ( tf.tile), нам нужно создать функцию и обернуть ее слоем Lambda. В противном случае модель не может быть сохранена.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import Model

def my_fun(a):
  out = tf.tile(a, (1, tf.shape(a)[0]))
  return out

a = Input(shape=(10,))
#out = tf.tile(a, (1, tf.shape(a)[0]))
out = Lambda(lambda x : my_fun(x))(a)
model = Model(a, out)

x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())

model.save('my_model.h5')

#load the model
new_model=tf.keras.models.load_model("my_model.h5")

Я думаю, что я рассмотрел несколько способов сохранения модели tf.keras. Однако есть много других способов. Пожалуйста, прокомментируйте ниже, если вы видите, что ваш случай использования не описан выше. Спасибо!


3

Используйте tf.train.Saver для сохранения модели, напомните, вам нужно указать var_list, если вы хотите уменьшить размер модели. Val_list может быть tf.trainable_variables или tf.global_variables.


3

Вы можете сохранить переменные в сети, используя

saver = tf.train.Saver() 
saver.save(sess, 'path of save/fileName.ckpt')

Чтобы восстановить сеть для повторного использования позже или в другом сценарии, используйте:

saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('path of save/')
sess.run(....) 

Важные точки:

  1. sess должно быть одинаковым между первым и последующим прогонами (связная структура).
  2. saver.restore нужен путь к папке с сохраненными файлами, а не отдельный путь к файлу.

2

Везде, где вы хотите сохранить модель,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

Убедитесь, что у всех tf.Variableесть имена, потому что вы можете восстановить их позже, используя их имена. И где вы хотите предсказать,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

Убедитесь, что заставка работает внутри соответствующего сеанса. Помните, что если вы используете tf.train.latest_checkpoint('./'), то будет использоваться только последняя контрольная точка.


2

Я на версии:

tensorflow (1.13.1)
tensorflow-gpu (1.13.1)

Простой способ

Сохранить:

model.save("model.h5")

Восстановить:

model = tf.keras.models.load_model("model.h5")

2

Для tensflow-2.0

это очень просто

import tensorflow as tf

СПАСТИ

model.save("model_name")

ВОССТАНОВИТЬ

model = tf.keras.models.load_model('model_name')

1

После ответа @Vishnuvardhan Janapati, вот еще один способ сохранить и перезагрузить модель с пользовательским слоем / метрикой / потерей в TensorFlow 2.0.0.

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects

# custom loss (for example)  
def custom_loss(y_true,y_pred):
  return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss}) 

# custom loss (for example) 
class CustomLayer(Layer):
  def __init__(self, ...):
      ...
  # define custom layer and all necessary custom operations inside custom layer

get_custom_objects().update({'CustomLayer': CustomLayer})  

Таким образом, после того, как вы выполнили такие коды и сохранили свою модель с помощью tf.keras.models.save_modelили model.saveили с помощью функции ModelCheckpointобратного вызова, вы можете перезагрузить вашу модель без необходимости точных пользовательских объектов, таких простых, как

new_model = tf.keras.models.load_model("./model.h5"})

0

В новой версии tenorflow 2.0 процесс сохранения / загрузки модели стал намного проще. Из-за реализации API Keras, высокоуровневого API для TensorFlow.

Чтобы сохранить модель: проверьте документацию для справки: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model

tf.keras.models.save_model(model_name, filepath, save_format)

Чтобы загрузить модель:

https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model

model = tf.keras.models.load_model(filepath)

0

Вот простой пример использования Tensorflow 2.0 SavedModel формата (который является рекомендуемым форматом, в соответствии с Документами ) для простого MNIST набора данных классификатора, используя Keras функционального API не слишком много фантазии происходит:

# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt

# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)

# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

# Train
model.fit(x_train, y_train, epochs=3)

# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)

# ... possibly another python program 

# Reload model
loaded_model = tf.keras.models.load_model(export_path) 

# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step

# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))

# Show results
print(np.argmax(prediction['graph_output']))  # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary)  # prints the image

Что такое serving_default?

Это имя определения подписи выбранного вами тега (в данном случае serveбыл выбран тег по умолчанию ). Также здесь объясняется, как найти теги и подписи модели с помощью saved_model_cli.

Отказ от ответственности

Это просто базовый пример, если вы просто хотите запустить его, но он ни в коем случае не является полным ответом - возможно, я смогу обновить его в будущем. Я просто хотел привести простой пример, используяSavedModel TF 2.0, потому что я нигде не видел, даже такого простого.

@ Ответ Тома - пример SavedModel, но он не будет работать на Tensorflow 2.0, потому что, к сожалению, есть некоторые серьезные изменения.

Ответ @ Vishnuvardhan Janapati говорит о TF 2.0, но это не для формата SavedModel.

Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.