Повторное использование переменной внедрения для логического вывода в Tf.Estimator API

В NMT, использующем архитектуру seq2seq, во время вывода нам нужна переменная внедрения, обученная на этапе обучения, в качестве входных данных для GreedyEmbeddingHelper или BeamSearchDecoder.

Вопрос в том, в контексте обучения и вывода с использованием Estimator API, как мы можем извлечь эту обученную переменную внедрения, чтобы использовать ее для прогнозирования?


person cad86    schedule 25.03.2018    source источник
comment
stackoverflow.com/questions/ 37660685/ вам помочь?   -  person bantmen    schedule 26.03.2018
comment
Не совсем. В реализации seq2seq в Estimator API выходные встраивания обычно обучаются в предложении IF, к которому можно получить доступ только во время обучения и оценки, поскольку на этих двух этапах вы уже знаете выходные данные. Для предсказания у вас нет, поэтому вы не можете получить доступ к этому биту. Хотя спасибо за ссылку.   -  person cad86    schedule 26.03.2018


Ответы (1)


Я нашел решение, основанное на следующем stackoverflow ">ответить. На этапе прогнозирования вы можете использовать tf.contrib.framework.load_variable для извлечения переменной внедрения из обученной и сохраненной модели Tensorflow следующим образом:

if mode == tf.estimator.ModeKeys.PREDICT:
    embeddings = tf.constant(tf.contrib.framework.load_variable('.','embed/embeddings'))
    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embeddings,
    start_tokens=tf.fill([batch_size], 1),end_token=0)

Итак, в моем случае я запускал код из той же папки, содержащей сохраненную модель, и имя моей переменной было «внедрение/встраивание». Обратите внимание, что это работает только с вложениями, обученными с помощью модели тензорного потока. В противном случае обратитесь к ответу, указанному выше.

Чтобы найти имя переменной с помощью API оценки, вы можете использовать метод get_variable_names() для получения списка всех имен переменных, сохраненных на графике.

person cad86    schedule 26.03.2018