Короткий ответ:
- Вам нужно передать глобальный шаг оптимизатору, который вы передаете в mon_sess.run. Это позволяет как сохранять, так и извлекать сохраненные контрольные точки.
- Можно запустить сеанс обучения + перекрестной проверки одновременно с помощью одного сеанса MonitoredTrainingSession. Во-первых, вам необходимо пройти через пакеты обучения и пакеты перекрестной проверки через отдельные потоки одного и того же графика (я рекомендую вам поискать это руководство, чтобы узнать, как это сделать). Во-вторых, вы должны - в mon_sess.run () - передать оптимизатор для обучающего потока, а также параметр для потери (/ параметр, который вы хотите отслеживать) потока перекрестной проверки. Если вы хотите запустить тестовый сеанс отдельно от обучения, просто запустите только набор тестов через график и запустите только test_loss (/ другие параметры, которые вы хотите отслеживать) через график. Подробнее о том, как это делается, читайте ниже.
Длинный ответ:
Я обновлю свой ответ, поскольку я сам лучше понимаю, что можно сделать с tf.train.MonitoredSession (tf.train.MonitoredTrainingSession просто создает специализированную версию tf.train.MonitoredSession, как можно увидеть в исходный код).
Ниже приведен пример кода, показывающий, как вы можете сохранять контрольные точки каждые 5 секунд в './ckpt_dir'. При прерывании он перезапустится с последней сохраненной контрольной точки:
def train(inputs, labels_onehot, global_step):
out = tf.contrib.layers.fully_connected(
inputs,
num_outputs=10,
activation_fn=tf.nn.sigmoid)
loss = tf.reduce_mean(
tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=out,
labels=labels_onehot), axis=1))
train_op = opt.minimize(loss, global_step=global_step)
return train_op
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
inputs = ...
labels_onehot = ...
train_op = train(inputs, labels_onehot, global_step)
with tf.train.MonitoredTrainingSession(
checkpoint_dir='./ckpt_dir',
save_checkpoint_secs=5,
hooks=[ ... ] # Choose your hooks
) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
То, что происходит в MonitoredTrainingSession для достижения этой цели, на самом деле состоит из трех вещей:
- Tf.train.MonitoredTrainingSession создает tf.train. Объект Scaffold, который работает как паук в сети; он собирает элементы, необходимые для обучения, сохранения и загрузки модели.
- Он создает объект tf.train.ChiefSessionCreator. Мои знания об этом ограничены, но, насколько я понимаю, он используется, когда ваш алгоритм tf распространяется на несколько серверов. Я считаю, что он сообщает компьютеру, на котором запущен файл, что это главный компьютер, и что именно здесь должен быть сохранен каталог контрольных точек, и что регистраторы должны регистрировать свои данные здесь и т. Д.
- Он создает tf.train.CheckpointSaverHook, который используется для сохранения КПП.
Чтобы заставить его работать, tf.train.CheckpointSaverHook и tf.train.ChiefSessionCreator должны иметь одинаковые ссылки на каталог контрольных точек и шаблон. Если бы tf.train.MonitoredTrainingSession с его параметрами из приведенного выше примера был реализован с помощью трех компонентов, указанных выше, это выглядело бы примерно так:
checkpoint_dir = './ckpt_dir'
scaffold = tf.train.Scaffold()
saverhook = tf.train.CheckpointSaverHook(
checkpoint_dir=checkpoint_dir,
save_secs=5
scaffold=scaffold
)
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir
)
with tf.train.MonitoredSession(
session_creator=session_creator,
hooks=[saverhook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
Чтобы выполнить сеанс проверки поезд + перекрестная проверка, вы можете использовать tf.train.MonitoredSession.run_step_fn () вместе с partial, которое запускает вызов сеанса без вызова каких-либо перехватчиков. Это выглядит так: вы тренируете свою модель для n итераций, а затем запускаете свой набор тестов, повторно инициализируете итераторы и возвращаетесь к обучению своей модели и т. Д. Конечно, вы должны установить свой переменные для повторного использования = tf.AUTO_REUSE при этом. Способ сделать это в коде показан ниже:
from functools import partial
# Build model
...
with tf.variable_scope(..., reuse=tf.AUTO_REUSE):
...
...
def step_fn(fetches, feed_dict, step_context):
return step_context.session.run(fetches=fetches, feed_dict=feed_dict)
with tf.train.MonitoredTrainingSession(
checkpoint_dir=...,
save_checkpoint_steps=...,
hooks=[...],
...
) as mon_sess:
# Initialize iterators (assuming tf.Databases are used)
mon_sess.run_step_fn(
partial(
step_fn,
[train_it.initializer,
test_it.initializer,
...
],
{}
)
)
while not mon_sess.should_stop():
# Train session
for i in range(n):
try:
train_results = mon_sess.run(<train_fetches>)
except Exception as e:
break
# Test session
while True:
try:
test_results = mon_sess.run(<test_fetches>)
except Exception as e:
break
# Reinitialize parameters
mon_sess.run_step_fn(
partial(
step_fn,
[train_it.initializer,
test_it.initializer,
...
],
{}
)
)
Частичная функция просто выполняет каррирование (классическая функция в функциональном программировании) на step_fn, который используется в mon_sess.run_step_fn (). Весь приведенный выше код не был протестирован, и вам, возможно, придется повторно инициализировать train_it перед запуском сеанса тестирования, но, надеюсь, теперь ясно, как можно запустить как обучающий набор, так и набор проверки в одном и том же запуске. Кроме того, это можно использовать вместе с инструментом custom_scalar tenorboard. если вы хотите построить кривую обучения и кривую теста на одном графике.
Наконец, это лучшая реализация этой функциональности, которую мне удалось сделать, и я лично надеюсь, что tenorflow значительно упростит реализацию этой функциональности в будущем, поскольку это довольно утомительно и, вероятно, не так эффективно. Я знаю, что существуют такие инструменты, как оценщик, который может запускать train_and_evaluate, но при этом перестраивается график между каждым поездом и перекрестная проверка, это очень неэффективно, если вы работаете только на одном компьютере. Я где-то читал, что у Keras + tf есть эта функция, но, поскольку я не использую Keras + tf, это не вариант. В любом случае, я надеюсь, что это может помочь кому-то другому, борющемуся с такими же проблемами!
person
Andreas Forslöw
schedule
08.02.2018