Восстановите обученную модель тензорного потока, отредактируйте значение, связанное с узлом, и сохраните его.

Я обучил модель с помощью tensorflow и использовал пакетную нормализацию во время обучения. Для пакетной нормализации требуется, чтобы пользователь передал логическое значение is_training, чтобы указать, находится ли модель в фазе обучения или тестирования.

Когда модель обучалась, is_training был установлен как константа, как показано ниже.

is_training = tf.constant(True, dtype=tf.bool, name='is_training')

Я сохранил обученную модель, файлы включают контрольную точку, файл .meta, файл .index и .data. Я хочу восстановить модель и запустить логический вывод с ее помощью. Модель не может быть переобучена. Итак, я хотел бы восстановить существующую модель, установить значение is_training на False, а затем снова сохранить модель. Как я могу отредактировать логическое значение, связанное с этим узлом, и снова сохранить модель?


person Effective_cellist    schedule 17.08.2017    source источник
comment
было бы проще, если бы вы использовали is_training=tf.Variable.., а не константу   -  person Ishant Mrinal    schedule 17.08.2017
comment
Есть ли причина, по которой is_training должна быть постоянной тензорного потока? Разве это не может быть Python bool? Обратите внимание, что изменение is_training на python bool не должно приводить к ошибкам при восстановлении модели.   -  person GeertH    schedule 17.08.2017
comment
@GeertH Может быть, вопрос в том, как мне установить is_training в False после загрузки модели, а затем сохранить ее обратно. Чтобы при повторном восстановлении узел имел значение False.   -  person Effective_cellist    schedule 17.08.2017


Ответы (1)


Вы можете использовать аргумент input_map в tf.train.import_meta_graph чтобы переназначить тензор графика на обновленное значение.

config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
    # define the new is_training tensor
    is_training = tf.constant(False, dtype=tf.bool, name='is_training')

    # now import the graph using the .meta file of the checkpoint
    saver = tf.train.import_meta_graph(
    '/path/to/model.meta', input_map={'is_training:0':is_training})

    # restore all weights using the model checkpoint 
    saver.restore(sess, '/path/to/model')

    # save updated graph and variables values
    saver.save(sess, '/path/to/new-model-name')
person Ishant Mrinal    schedule 17.08.2017
comment
Приведенный выше код вызывает ошибку ValueError: tf.import_graph_def() requires a non-empty name if input_map is used - person Effective_cellist; 17.08.2017
comment
Я протестировал этот код, используя tensorflow==1.2.0, надеюсь, это поможет; ТАКЖЕ это не tf.import_graph_def. см. мой код. - person Ishant Mrinal; 17.08.2017
comment
Я пробовал ваш код таким, какой он есть, ошибка выдается этой строкой, saver = tf.train.import_meta_graph(r'D:\code\iprings\k-fold-model\VanillaCNN_24.0000.meta', input_map={'is_training':is_training}) - person Effective_cellist; 17.08.2017
comment
я бы посоветовал вам использовать tf == 1.2.0 - person Ishant Mrinal; 17.08.2017