Предположим, у нас есть простая модель Keras, использующая BatchNormalization:
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.BatchNormalization()
])
Как на самом деле использовать его с GradientTape? Следующее, похоже, не работает, поскольку оно не обновляет скользящие средние?
# model training... we want the output values to be close to 150
for i in range(1000):
x = np.random.randint(100, 110, 10).astype(np.float32)
with tf.GradientTape() as tape:
y = model(np.expand_dims(x, axis=1))
loss = tf.reduce_mean(tf.square(y - 150))
grads = tape.gradient(loss, model.variables)
opt.apply_gradients(zip(grads, model.variables))
В частности, если вы просматриваете скользящие средние, они остаются прежними (проверьте model.variables, средние значения всегда равны 0 и 1). Я знаю, что можно использовать .fit() и .predict(), но я хотел бы использовать GradientTape и не знаю, как это сделать. В некоторых версиях документации предлагается обновить update_ops, но, похоже, это не работает в активном режиме.
В частности, следующий код не выведет ничего близкого к 150 после приведенного выше обучения.
x = np.random.randint(200, 210, 100).astype(np.float32)
print(model(np.expand_dims(x, axis=1)))