Как MonitoredTrainingSession () работает в режиме восстановления и тестирования?

В Tensorflow мы могли создавать и создавать несколько сеансов Tensorflow, используя Between-graph Replication для распределенного обучения. MonitoredTrainingSession() координирует несколько сеансов Tensorflow, и есть аргумент checkpoint_dir для MonitoredTrainingSession() для восстановления сеанса / графика Tensorflow. Теперь у меня следующие вопросы:

  1. Обычно мы используем объект tf.train.Saver() для восстановления графиков Tensorflow с помощью saver.restore(...). Но как их восстановить с помощью MonitoredTrainingSession()?
  2. Поскольку мы запускаем несколько процессов, и каждый процесс строит и создает сеанс Tensorflow для обучения, мне интересно, должны ли мы также запускать несколько процессов для тестирования (или прогнозирования) после обучения. Другими словами, как MonitoredTrainingSession() работает с режимом тестирования (или прогнозирования)?

Я прочитал Tensorflow Doc, но не нашел ответов на эти 2 вопроса. Я очень ценю, если у кого-то есть решения. Спасибо!


person Ruofan Kong    schedule 29.03.2017    source источник


Ответы (3)


Короткий ответ:

  1. Вам нужно передать глобальный шаг оптимизатору, который вы передаете в mon_sess.run. Это позволяет как сохранять, так и извлекать сохраненные контрольные точки.
  2. Можно запустить сеанс обучения + перекрестной проверки одновременно с помощью одного сеанса MonitoredTrainingSession. Во-первых, вам необходимо пройти через пакеты обучения и пакеты перекрестной проверки через отдельные потоки одного и того же графика (я рекомендую вам поискать это руководство, чтобы узнать, как это сделать). Во-вторых, вы должны - в mon_sess.run () - передать оптимизатор для обучающего потока, а также параметр для потери (/ параметр, который вы хотите отслеживать) потока перекрестной проверки. Если вы хотите запустить тестовый сеанс отдельно от обучения, просто запустите только набор тестов через график и запустите только test_loss (/ другие параметры, которые вы хотите отслеживать) через график. Подробнее о том, как это делается, читайте ниже.

Длинный ответ:

Я обновлю свой ответ, поскольку я сам лучше понимаю, что можно сделать с tf.train.MonitoredSession (tf.train.MonitoredTrainingSession просто создает специализированную версию tf.train.MonitoredSession, как можно увидеть в исходный код).

Ниже приведен пример кода, показывающий, как вы можете сохранять контрольные точки каждые 5 секунд в './ckpt_dir'. При прерывании он перезапустится с последней сохраненной контрольной точки:

def train(inputs, labels_onehot, global_step):
    out = tf.contrib.layers.fully_connected(
                            inputs,
                            num_outputs=10,
                            activation_fn=tf.nn.sigmoid)
    loss = tf.reduce_mean(
             tf.reduce_sum(
                tf.nn.sigmoid_cross_entropy_with_logits(
                            logits=out,
                            labels=labels_onehot), axis=1))
    train_op = opt.minimize(loss, global_step=global_step)
    return train_op

with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()
    inputs = ...
    labels_onehot = ...
    train_op = train(inputs, labels_onehot, global_step)

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir='./ckpt_dir',
        save_checkpoint_secs=5,
        hooks=[ ... ] # Choose your hooks
    ) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(train_op)

То, что происходит в MonitoredTrainingSession для достижения этой цели, на самом деле состоит из трех вещей:

  1. Tf.train.MonitoredTrainingSession создает tf.train. Объект Scaffold, который работает как паук в сети; он собирает элементы, необходимые для обучения, сохранения и загрузки модели.
  2. Он создает объект tf.train.ChiefSessionCreator. Мои знания об этом ограничены, но, насколько я понимаю, он используется, когда ваш алгоритм tf распространяется на несколько серверов. Я считаю, что он сообщает компьютеру, на котором запущен файл, что это главный компьютер, и что именно здесь должен быть сохранен каталог контрольных точек, и что регистраторы должны регистрировать свои данные здесь и т. Д.
  3. Он создает tf.train.CheckpointSaverHook, который используется для сохранения КПП.

Чтобы заставить его работать, tf.train.CheckpointSaverHook и tf.train.ChiefSessionCreator должны иметь одинаковые ссылки на каталог контрольных точек и шаблон. Если бы tf.train.MonitoredTrainingSession с его параметрами из приведенного выше примера был реализован с помощью трех компонентов, указанных выше, это выглядело бы примерно так:

checkpoint_dir = './ckpt_dir'

scaffold = tf.train.Scaffold()
saverhook = tf.train.CheckpointSaverHook(
    checkpoint_dir=checkpoint_dir,
    save_secs=5
    scaffold=scaffold
)
session_creator = tf.train.ChiefSessionCreator(
    scaffold=scaffold,
    checkpoint_dir=checkpoint_dir
)

with tf.train.MonitoredSession(
    session_creator=session_creator,
    hooks=[saverhook]) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(train_op)

Чтобы выполнить сеанс проверки поезд + перекрестная проверка, вы можете использовать tf.train.MonitoredSession.run_step_fn () вместе с partial, которое запускает вызов сеанса без вызова каких-либо перехватчиков. Это выглядит так: вы тренируете свою модель для n итераций, а затем запускаете свой набор тестов, повторно инициализируете итераторы и возвращаетесь к обучению своей модели и т. Д. Конечно, вы должны установить свой переменные для повторного использования = tf.AUTO_REUSE при этом. Способ сделать это в коде показан ниже:

from functools import partial

# Build model
...

with tf.variable_scope(..., reuse=tf.AUTO_REUSE):
    ...

...

def step_fn(fetches, feed_dict, step_context):
    return step_context.session.run(fetches=fetches, feed_dict=feed_dict)

with tf.train.MonitoredTrainingSession(
                checkpoint_dir=...,
                save_checkpoint_steps=...,
                hooks=[...],
                ...
                ) as mon_sess:

                # Initialize iterators (assuming tf.Databases are used)
                mon_sess.run_step_fn(
                           partial(
                               step_fn, 
                               [train_it.initializer, 
                                test_it.initializer, 
                                ...
                               ], 
                               {}
                            )
                )

                while not mon_sess.should_stop():
                    # Train session
                    for i in range(n):
                        try:
                            train_results = mon_sess.run(<train_fetches>)
                        except Exception as e:
                            break

                    # Test session
                    while True:
                        try:
                            test_results = mon_sess.run(<test_fetches>)
                        except Exception as e:
                            break

                    # Reinitialize parameters
                    mon_sess.run_step_fn(
                               partial(
                                  step_fn, 
                                  [train_it.initializer, 
                                   test_it.initializer, 
                                   ...
                                  ], 
                                  {}
                               )
                    )

Частичная функция просто выполняет каррирование (классическая функция в функциональном программировании) на step_fn, который используется в mon_sess.run_step_fn (). Весь приведенный выше код не был протестирован, и вам, возможно, придется повторно инициализировать train_it перед запуском сеанса тестирования, но, надеюсь, теперь ясно, как можно запустить как обучающий набор, так и набор проверки в одном и том же запуске. Кроме того, это можно использовать вместе с инструментом custom_scalar tenorboard. если вы хотите построить кривую обучения и кривую теста на одном графике.

Наконец, это лучшая реализация этой функциональности, которую мне удалось сделать, и я лично надеюсь, что tenorflow значительно упростит реализацию этой функциональности в будущем, поскольку это довольно утомительно и, вероятно, не так эффективно. Я знаю, что существуют такие инструменты, как оценщик, который может запускать train_and_evaluate, но при этом перестраивается график между каждым поездом и перекрестная проверка, это очень неэффективно, если вы работаете только на одном компьютере. Я где-то читал, что у Keras + tf есть эта функция, но, поскольку я не использую Keras + tf, это не вариант. В любом случае, я надеюсь, что это может помочь кому-то другому, борющемуся с такими же проблемами!

person Andreas Forslöw    schedule 08.02.2018

Вы должны импортировать мета-график, а затем восстановить модель. Вдохновляйтесь приведенным ниже фрагментом, он должен работать на вас.

    self.sess = tf.Session()
    ckpt = tf.train.latest_checkpoint("location-of/model")
    saver = tf.train.import_meta_graph(ckpt + '.meta', clear_devices=True)
    saver.restore(self.sess, ckpt)
person BJC    schedule 20.05.2018

  1. Кажется, что восстановление сделано за вас. В документации API говорится, что вызов MonitoredTrainingSession возвращает экземпляр MonitoredSession, который при создании "... восстанавливает переменные, если существует контрольная точка ..."

  2. Ознакомьтесь с tf.contrib.learn.Estimator(..).predict(..) и более конкретно tf.contrib.learn.Estimator(..)._infer_model(..) методами здесь и здесь. Они также создают там MonitoredSession.

person Misha E    schedule 17.04.2017