Решение ниже
Если вы просто заинтересованы в решении этой проблемы, вы можете перейти к моему ответу ниже.
Исходный вопрос
Я использую tensorflow для обучения с подкреплением. Рой агентов использует модель параллельно, и один центральный объект обучает ее на собранных данных.
Я нашел здесь: Это потокобезопасный при использовании tf.Session в службе логического вывода? что сеансы tensorflow являются потокобезопасными. Поэтому я просто позволяю прогнозированию и обновлению работать параллельно.
Но теперь я хотел бы изменить установку. Вместо того, чтобы обновлять и тренировать одну модель, теперь мне нужно иметь две модели. Один используется для предсказания, а второй обучается. После нескольких шагов обучения веса из второго копируются в первый. Ниже приведен минимальный пример в keras. Для многопроцессорности рекомендуется доработать график, но тогда я не могу скопировать веса:
# the usual imports
import numpy as np
import tensorflow as tf
from keras.models import *
from keras.layers import *
# set up the first model
i = Input(shape=(10,))
b = Dense(1)(i)
prediction_model = Model(inputs=i, outputs=b)
# set up the second model
i2 = Input(shape=(10,))
b2 = Dense(1)(i2)
training_model = Model(inputs=i2, outputs=b2)
# look at this code, to check if the weights are the same
# here the output is different
prediction_model.predict(np.ones((1, 10)))
training_model.predict(np.ones((1, 10)))
# now to use them in multiprocessing, the following is necessary
prediction_model._make_predict_function()
training_model._make_predict_function()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
default_graph = tf.get_default_graph()
# the following line is the critical part
# if this is uncommented, the two options below both fail
# default_graph.finalize()
# option 1, use keras methods to update the weights
prediction_model.set_weights(training_model.get_weights())
# option 2, use tensorflow to update the weights
update_ops = [tf.assign(to_var, from_var) for to_var, from_var in
zip(prediction_model.trainable_weights, training_model.trainable_weights)]
sess.run(update_ops)
# now the predictions are the same
prediction_model.predict(np.ones((1, 10)))
training_model.predict(np.ones((1, 10)))
По вопросу выше рекомендуется доработать график. Если он не завершен, возможны утечки памяти (!?), так что это кажется настоятельной рекомендацией.
Но если я его доработаю, то уже не смогу обновить веса. Что меня смущает в этом, так это то, что сеть можно обучить, поэтому разрешено изменение весов. Назначение выглядит так, как будто веса просто перезаписываются, почему это отличается от применения шага оптимизатора?