Сохранение кодировщика тензорного потока, декодера и внимания

Начните обучение простого NMT (нейронного машинного переводчика) с вниманием, используя кодировщик и декодер, Обучение проводилось на Colab,

encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)

Затем используйте контрольные точки, чтобы сохранить модель,

# On loacl machine dir changed to 'training_checkpoints/' to fit the loaction
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

И сэкономить при обучении с помощью

checkpoint.save(file_prefix = checkpoint_prefix)

После тренировки восстановление контрольных точек отлично работает на Colab, и даже при сохранении всей папки контрольных точек на диске Google и их повторном восстановлении, но при попытке восстановить их на моем локальном компьютере он возвращает разные и ненужные результаты, Запустите контрольную точку перед обучением с использованием

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

Вывод записной книжки Colab:

Input: <start> يلعبون الكرة <end>
Predicted translation: he played soccer . <end> 

Выход локального компьютера:

Input: <start> يلعبون الكرة <end>
Predicted translation: take either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either

Версия Colab tensorflow: 1.13.0-rc1

Версия tenorflow для локальной машины: 1.12.0

Как сохранить модель, не столкнувшись с этой проблемой, зная, что эта проблема связана с разными версиями tenorflow?

Дополнительная ссылка для записной книжки NMT nmt_with_attention_nattention"> Нейронный машинный перевод с вниманием


person Samir    schedule 13.02.2019    source источник


Ответы (1)


TF дает только гарантии прямой совместимости: https://www.tensorflow.org/guide/version_compat#compatibility_of_graphs_and_checkpoints Неудивительно, что 1.13 сохраняет файл, который 1.12 не может восстановить. Обновите тензорный поток вашей локальной машины?

person Ami F    schedule 16.02.2019
comment
Хорошо, я понял, но как я могу сохранить модель с ее весами кодировщика и декодера без контрольных точек и с помощью tf.train.Saver? - person Samir; 18.02.2019