Как я могу обучить сверточную генерирующую состязательную сеть в Tensorflow.js?

Я следую руководству по адресу https://www.tensorflow.org/tutorials/generative/dcgan.
Хотя руководство написано на python, я пытаюсь реализовать его с помощью tensorflow.js на node.js.
Мне удалось выяснить, как перевести большую часть используемых методов, за исключением случаев, когда речь идет о фактической настройке следующей процедуры шага обучения.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

Ясно, что не все здесь можно перевести в tensorflow.js.
Пока я не могу понять, как получить градиенты и применить их к оптимизатору.
Я пытался использовать функции tf.grad & tf.grads, но чтобы безрезультатно.
Вот что у меня есть:

function trainStep(images) {
    const noise = tf.randomNormal([BATCH_SIZE, noiseDim]);

    const generated = gen.apply(noise, { training: true });
    const realOut = dis.apply(images, { training: true });
    const genOut = dis.apply(generated, { training: true });

    const genLoss = generator.loss(genOut);
    const disLoss = discriminator.loss(realOut, genOut);

    // now what?
}

Есть ли лучший способ сделать это в tensorflow.js, чем показано в руководстве?
Я был бы признателен, если бы у кого-нибудь были ресурсы, чтобы указать мне в правильном направлении.


person Joris Blanken    schedule 29.04.2020    source источник


Ответы (1)


Попробуйте эту официальную лабораторию кода для TensorFlow.js:

https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html.

Это для MNIST, но как только вы это узнаете, вы можете применить его к своему собственному набору данных.

person Jason Mayes    schedule 30.04.2020