Я тренировал условную архитектуру GAN, подобную Pix2Pix, со следующим циклом обучения:
for epoch in range(start_epoch, end_epoch):
for batch_i, (input_batch, target_batch) in enumerate(dataLoader.load_batch(batch_size)):
fake_batch= self.generator.predict(input_batch)
d_loss_real = self.discriminator.train_on_batch(target_batch, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_batch, invalid)
d_loss = np.add(d_loss_fake, d_loss_real) * 0.5
g_loss = self.combined.train_on_batch([target_batch, input_batch], [valid, target_batch])
Теперь это работает хорошо, но не очень эффективно, поскольку загрузчик данных быстро становится узким местом с точки зрения времени. Я изучил функцию .fit_generator (), которую предоставляет keras, которая позволяет генератору работать в рабочем потоке и работает намного быстрее.
self.combined.fit_generator(generator=trainLoader,
validation_data=evalLoader
callbacks=[checkpointCallback, historyCallback],
workers=1,
use_multiprocessing=True)
Мне потребовалось некоторое время, чтобы убедиться, что это было неверно, я больше не тренировал свой генератор и дискриминатор по отдельности, а дискриминатор вообще не обучался, так как он был установлен на trainable = False
в комбинированной модели, что по существу разрушило любые виды состязательности. потеря, и я мог бы также обучить свой генератор сам с помощью MSE
.
Теперь мой вопрос: есть ли какая-то работа, например, обучение моего дискриминатора внутри настраиваемого обратного вызова, который запускается каждой партией метода .fit_generator ()? Можно реализовать создание пользовательских обратных вызовов, например, вот так:
class MyCustomCallback(tf.keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
discriminator.train_on_batch()
Другой возможностью было бы распараллелить исходный цикл обучения, но я боюсь, что у меня сейчас нет времени на это.