Я пытаюсь реализовать накопление градиента на TF2.x. Все найденные мной реализации предназначены либо для TF1.x, либо для старого интерфейса keras. Я не думаю, что есть реализация (хотя я был бы очень рад, если бы ошибся в этом).
Вот с чем я работаю:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from tqdm import tqdm
import matplotlib.pyplot as plt
class SimpleTrainStepModel(Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
(x, y), sample_weight = data, None
# FIRST GRADIENT
with tf.GradientTape() as tape:
y_pred = self(x, training = True) # Forward pass
loss = self.compiled_loss(y, y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
gradients = tape.gradient(loss, self.trainable_variables)
self.compiled_metrics.update_state(y, y_pred)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
return {m.name: m.result() for m in self.metrics}
class GradAccumModel(Model):
def fit(self, *args, batch_size = 32, grad_accum = 1, **kwargs):
self.train_function = None
if batch_size % grad_accum != 0:
raise ValueError('Batch size must be divisible by the Gradient accumulation steps, dummy!')
self.grad_accum = grad_accum
self.batch_size = batch_size
return super(GradAccumModel, self).fit(*args,
batch_size = self.batch_size,
#validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,
**kwargs)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
(x, y), sample_weight = data, None
step = self.batch_size // self.grad_accum
# def _slice_nested(obj, i, j):
# if type(obj) is list:
# return [o[i:j] for o in obj]
# else:
# return obj[i:j]
# FIRST GRADIENT
with tf.GradientTape() as tape:
y_pred = self(x[:step], training = True) # Forward pass
loss = self.compiled_loss(y[:step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
gradients = tape.gradient(loss, self.trainable_variables)
self.compiled_metrics.update_state(y[:step], y_pred)
i = tf.constant(step)
# tf.print('TF - HERE!')
def cond(i, *args):
return i < self.batch_size
def body(i, grad):
# tf.print('\tTF - HERE!')
with tf.GradientTape() as tape:
y_pred = self(x[i:i + step], training = True) # Forward pass
loss = self.compiled_loss(y[i:i + step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
_grad = tape.gradient(loss, self.trainable_variables)
for g,_g in zip(grad, _grad):
g += _g
self.compiled_metrics.update_state(y[i:i + step], y_pred)
return [i + step, grad]
i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
# for g in gradients: # I tried with and without division co calculate the mean
# g *= 1/self.grad_accum #
# Update weights
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update metrics (includes the metric that tracks the loss)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
if __name__ == '__main__':
(x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.mnist.load_data()
for MODEL, ga_kwarg, colour in list(zip([Model, SimpleTrainStepModel, GradAccumModel, GradAccumModel],
[{}, {}, {'grad_accum': 1}, {'grad_accum': 6}],
['blue', 'green', 'yellow', 'red'])):
for _ in tqdm(range(10)):
# tf.random.set_seed(0)
x = Input((28, 28))
y = x
y = Flatten()(y)
y = Dense(128, activation = 'sigmoid')(y)
y = Dense(10, activation = 'softmax')(y)
model = MODEL(x, y)
model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer = tf.keras.optimizers.Adam(1e-4),
metrics = ['acc'])
hist = model.fit(x_train, y_train, validation_data = (x_valid, y_valid), verbose = 0, batch_size = 6000, epochs = 100, **ga_kwarg)
plt.plot(hist.history['val_acc'], color = colour, alpha = .25)
plt.title('')
plt.xscale('symlog')
plt.yscale('logit')
plt.show()
Я смог убедиться, что он действительно экономит память графического процессора. Однако конечный результат отличается от обычного Model.fit
.
Как видите, первые три Model.fit
хорошо сгруппированы и дают одинаковые результаты. Но когда в игру вступает цикл while
, обучение становится совершенно другим.
Кто-нибудь знает, почему это происходит?