Я хотел бы запустить данную модель как в наборе поездов (is_training=True
), так и в наборе проверки (is_training=False
), в частности, с тем, как применяется dropout
. Прямо сейчас готовые модели предоставить параметр is_training
, который передается на уровень dropout
при построении сети. Проблема в том, что если я вызову метод дважды с разными значениями is_training
, я получу две разные сети, которые не имеют общих весов (я так думаю?). Как мне заставить две сети использовать одинаковые веса, чтобы я мог запустить сеть, которую я обучил на проверочном наборе?
Модель Tensorflow (tf-slim) с is_training True и False
Ответы (3)
Я написал решение с вашим комментарием, чтобы использовать Overfeat в режиме обучения и тестирования. (Я не смог проверить это, чтобы вы могли проверить, работает ли оно?)
Сначала немного импорта и параметров:
import tensorflow as tf
slim = tf.contrib.slim
overfeat = tf.contrib.slim.nets.overfeat
batch_size = 32
inputs = tf.placeholder(tf.float32, [batch_size, 231, 231, 3])
dropout_keep_prob = 0.5
num_classes = 1000
В режиме обучения мы передаем нормальную область в функцию overfeat
:
scope = 'overfeat'
is_training = True
output = overfeat.overfeat(inputs, num_classes, is_training,
dropout_keep_prob, scope=scope)
Затем в тестовом режиме мы создаем тот же объем, но с reuse=True
.
scope = tf.VariableScope(reuse=True, name='overfeat')
is_training = False
output = overfeat.overfeat(inputs, num_classes, is_training,
dropout_keep_prob, scope=scope)
вы можете просто использовать заполнитель для is_training:
isTraining = tf.placeholder(tf.bool)
# create nn
net = ...
net = slim.dropout(net,
keep_prob=0.5,
is_training=isTraining)
net = ...
# training
sess.run([net], feed_dict={isTraining: True})
# testing
sess.run([net], feed_dict={isTraining: False})
Это зависит от случая, решения разные.
Мой первый вариант — использовать другой процесс для проведения оценки. Вам нужно только проверить наличие новой контрольной точки и загрузить ее веса в оценочную сеть (с помощью is_training=False
):
checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
# wait until a new check point is available
while self.lastest_checkpoint == checkpoint:
time.sleep(30) # sleep 30 seconds waiting for a new checkpoint
checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
logging.info('Restoring model from {}'.format(checkpoint))
self.saver.restore(session, checkpoint)
self.lastest_checkpoint = checkpoint
Второй вариант: после каждой эпохи вы выгружаете график и создаете новый график оценки. Это решение тратит много времени на загрузку и выгрузку графиков.
Третий вариант — разделить веса. Но заполнение этих сетей очередями или наборами данных может привести к проблемам, поэтому вы должны быть очень осторожны. Я использую это только для сиамских сетей.
with tf.variable_scope('the_scope') as scope:
your_model(is_training=True)
scope.reuse_variables()
your_model(is_training=False)
tf-slim
используетtf.get_variable()
, который повторно использует переменные между вызовами. - person Olivier Moindrot   schedule 06.09.2016scope
, а затем в целях безопасности лучше также установить значениеreuse=True
. - person Alex Rothberg   schedule 06.09.2016