Использование tf.cond () в функции модели оценки для обучения WGAN на TPU приводит к удвоению global_step

Я пытаюсь обучить GAN на TPU, поэтому я возился с классом TPUEstimator и соответствующей функцией модели, чтобы попытаться реализовать цикл обучения WGAN. Я пытаюсь использовать tf.cond для объединения двух тренировочных операций для TPUEstimatorSpec следующим образом:

opt = tf.cond(
    tf.equal(tf.mod(tf.train.get_or_create_global_step(), 
    CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1), 
    lambda: gen_opt, 
    lambda: critic_opt
)

gen_opt и critic_opt - это функция минимизации оптимизатора, который я использую, также настроенная на обновление глобального шага. CRITIC_UPDATES_PER_GEN_UPDATE - это константа Python для этого и является частью обучения WGAN. Я пытался найти модель GAN с использованием tf.cond, но все модели используют tf.group, который я не могу использовать, потому что вам нужно оптимизировать критик во много раз больше, чем генератор. Однако каждый раз, когда я запускаю 100 пакетов, глобальный шаг увеличивается на 200 в соответствии с номером контрольной точки. Моя модель все еще обучается правильно, или tf.cond просто нельзя использовать такой способ обучения GAN?


person Justin Zhang    schedule 27.01.2019    source источник
comment
Не могли бы вы показать, как вы используете tf.cond?   -  person nessuno    schedule 27.01.2019
comment
Прости за это! Я выложил это прямо перед сном, но теперь обновил.   -  person Justin Zhang    schedule 27.01.2019
comment
mod (x, N) должен дать результат ‹N кстати   -  person ziyuang    schedule 19.05.2020


Ответы (1)


tf.cond не предполагается использовать таким образом для обучения GAN.

Вы получаете 200, потому что на каждом этапе обучения оцениваются побочные эффекты (например, операции присваивания) обоих true_fn и false_fn. Одним из побочных эффектов является глобальная операция step tf.assign_add, которую определяют оба оптимизатора.

Следовательно, то, что происходит, похоже на

  • Казнь global_step++ (gen_opt) и global_step++ (critic_op)
  • Оценка состояния
  • Выполнение true_fn тела или false_fn тела (в зависимости от условия).

Если вы хотите обучить GAN с помощью tf.cond, вам нужно удалить все побочные операции (например, назначение, следовательно, определение шага оптимизации) снаружи _10 _ / _ 11_ и объявить все внутри них.

В качестве справки вы можете увидеть этот ответ о поведении tf.cond: https://stackoverflow.com/a/37064128/2891324

person nessuno    schedule 27.01.2019