Я работаю над реализацией генерирующей состязательной сети (GAN) в PyTorch 1.5.0.
Для вычисления потерь генератора я вычисляю как отрицательные вероятности того, что дискриминатор неправильно классифицирует полностью реальную мини-серию, так и фальшивую мини-серию полностью (генерируемую генератором). Затем я последовательно повторяю обе части в обратном направлении и, наконец, применяю пошаговую функцию.
Вычисление и обратное распространение части потерь, которая является функцией неправильной классификации сгенерированных поддельных данных, кажется прямым, поскольку во время обратного распространения этого члена потерь обратный путь проходит через генератор, который произвел поддельные данные. данные в первую очередь.
Однако классификация мини-пакетов, состоящих только из реальных данных, не включает передачу данных через генератор. Поэтому мне было интересно, будет ли следующий фрагмент кода по-прежнему вычислять градиенты для генератора или он вообще не будет вычислять какие-либо градиенты (поскольку обратный путь не ведет через генератор, а дискриминатор находится в режиме оценки при обновлении генератора )?
# Update generator #
net.generator.train()
net.discriminator.eval()
net.generator.zero_grad()
# All-real minibatch
x_real = get_all_real_minibatch()
y_true = torch.full((batch_size,), label_fake).long() # Pretend true targets were fake
y_pred = net.discriminator(x_real) # Produces softmax probability distribution over (0=label_fake,1=label_real)
loss_real = NLLLoss(torch.log(y_pred), y_true)
loss_real.backward()
optimizer_generator.step()
Если это не работает должным образом, как я могу заставить это работать? Заранее спасибо!