тензорный поток: назначение весов после завершения графа

Решение ниже

Если вы просто заинтересованы в решении этой проблемы, вы можете перейти к моему ответу ниже.

Исходный вопрос

Я использую tensorflow для обучения с подкреплением. Рой агентов использует модель параллельно, и один центральный объект обучает ее на собранных данных.

Я нашел здесь: Это потокобезопасный при использовании tf.Session в службе логического вывода? что сеансы tensorflow являются потокобезопасными. Поэтому я просто позволяю прогнозированию и обновлению работать параллельно.

Но теперь я хотел бы изменить установку. Вместо того, чтобы обновлять и тренировать одну модель, теперь мне нужно иметь две модели. Один используется для предсказания, а второй обучается. После нескольких шагов обучения веса из второго копируются в первый. Ниже приведен минимальный пример в keras. Для многопроцессорности рекомендуется доработать график, но тогда я не могу скопировать веса:

# the usual imports
import numpy as np
import tensorflow as tf

from keras.models import *
from keras.layers import *

# set up the first model
i = Input(shape=(10,))
b = Dense(1)(i)
prediction_model = Model(inputs=i, outputs=b)

# set up the second model
i2 = Input(shape=(10,))
b2 = Dense(1)(i2)
training_model = Model(inputs=i2, outputs=b2)

# look at this code, to check if the weights are the same
# here the output is different
prediction_model.predict(np.ones((1, 10)))
training_model.predict(np.ones((1, 10)))

# now to use them in multiprocessing, the following is necessary
prediction_model._make_predict_function()
training_model._make_predict_function()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
default_graph = tf.get_default_graph()

# the following line is the critical part
# if this is uncommented, the two options below both fail
# default_graph.finalize()

# option 1, use keras methods to update the weights
prediction_model.set_weights(training_model.get_weights())

# option 2, use tensorflow to update the weights
update_ops = [tf.assign(to_var, from_var) for to_var, from_var in
              zip(prediction_model.trainable_weights, training_model.trainable_weights)]
sess.run(update_ops)

# now the predictions are the same
prediction_model.predict(np.ones((1, 10)))
training_model.predict(np.ones((1, 10)))

По вопросу выше рекомендуется доработать график. Если он не завершен, возможны утечки памяти (!?), так что это кажется настоятельной рекомендацией.

Но если я его доработаю, то уже не смогу обновить веса. Что меня смущает в этом, так это то, что сеть можно обучить, поэтому разрешено изменение весов. Назначение выглядит так, как будто веса просто перезаписываются, почему это отличается от применения шага оптимизатора?


person lhk    schedule 17.06.2018    source источник


Ответы (1)


Короче говоря, моя проблема заключалась в том, чтобы присвоить значения весам окончательного графа. Если это присваивание сделать после финализации, tensorflow жалуется, что граф больше не может быть изменен.

Я был озадачен, почему это запрещено. Ведь разрешено изменение весов обратным распространением.

Но проблема не связана с изменением весов. Keras set_weights() сбивает с толку, потому что кажется, что веса просто перезаписаны (как в backprop). Собственно, за кулисами добавляются и выполняются операции присваивания. Эти новые операции представляют изменение в графе, и это изменение запрещено.

Таким образом, решение состоит в том, чтобы настроить операции присваивания до окончательной обработки графа. Вы должны изменить порядок кода:

# the usual imports
import numpy as np
import tensorflow as tf

from keras.models import *
from keras.layers import *

# set up the first model
i = Input(shape=(10,))
b = Dense(1)(i)
prediction_model = Model(inputs=i, outputs=b)

# set up the second model
i2 = Input(shape=(10,))
b2 = Dense(1)(i2)
training_model = Model(inputs=i2, outputs=b2)

# set up operations to move weights from training to prediction
update_ops = [tf.assign(to_var, from_var) for to_var, from_var in
              zip(prediction_model.trainable_weights, training_model.trainable_weights)]

# now to use them in multiprocessing, the following is necessary
prediction_model._make_predict_function()
training_model._make_predict_function()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
default_graph = tf.get_default_graph()

default_graph.finalize()

# this can be executed now
sess.run(update_ops)

# now the predictions are the same
prediction_model.predict(np.ones((1, 10)))
training_model.predict(np.ones((1, 10)))
person lhk    schedule 17.06.2018