Я пытаюсь обучить GAN на TPU, поэтому я возился с классом TPUEstimator и соответствующей функцией модели, чтобы попытаться реализовать цикл обучения WGAN. Я пытаюсь использовать tf.cond
для объединения двух тренировочных операций для TPUEstimatorSpec следующим образом:
opt = tf.cond(
tf.equal(tf.mod(tf.train.get_or_create_global_step(),
CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1),
lambda: gen_opt,
lambda: critic_opt
)
gen_opt
и critic_opt
- это функция минимизации оптимизатора, который я использую, также настроенная на обновление глобального шага. CRITIC_UPDATES_PER_GEN_UPDATE
- это константа Python для этого и является частью обучения WGAN. Я пытался найти модель GAN с использованием tf.cond
, но все модели используют tf.group
, который я не могу использовать, потому что вам нужно оптимизировать критик во много раз больше, чем генератор. Однако каждый раз, когда я запускаю 100 пакетов, глобальный шаг увеличивается на 200 в соответствии с номером контрольной точки. Моя модель все еще обучается правильно, или tf.cond
просто нельзя использовать такой способ обучения GAN?