Накопление градиента в tensorflow 2.x / keras

Я пытаюсь реализовать накопление градиента на TF2.x. Все найденные мной реализации предназначены либо для TF1.x, либо для старого интерфейса keras. Я не думаю, что есть реализация (хотя я был бы очень рад, если бы ошибся в этом).

Вот с чем я работаю:

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from tqdm import tqdm
import matplotlib.pyplot as plt


class SimpleTrainStepModel(Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None


        # FIRST GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x, training = True)  # Forward pass
            loss = self.compiled_loss(y, y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y, y_pred)

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {m.name: m.result() for m in self.metrics}


class GradAccumModel(Model):
    def fit(self, *args, batch_size = 32, grad_accum = 1, **kwargs):
        self.train_function = None
        if batch_size % grad_accum != 0:
            raise ValueError('Batch size must be divisible by the Gradient accumulation steps, dummy!')
        self.grad_accum = grad_accum
        self.batch_size = batch_size
        return super(GradAccumModel, self).fit(*args,
                                               batch_size = self.batch_size,
                                               #validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,
                                               **kwargs)

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None

        step = self.batch_size // self.grad_accum

        # def _slice_nested(obj, i, j):
        #     if type(obj) is list:
        #         return [o[i:j] for o in obj]
        #     else:
        #         return obj[i:j]

        # FIRST GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x[:step], training = True)  # Forward pass
            loss = self.compiled_loss(y[:step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y[:step], y_pred)

        i = tf.constant(step)
        # tf.print('TF - HERE!')
        def cond(i, *args):
            return i < self.batch_size
        def body(i, grad):
            # tf.print('\tTF - HERE!')
            with tf.GradientTape() as tape:
                y_pred = self(x[i:i + step], training = True) # Forward pass
                loss = self.compiled_loss(y[i:i + step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
            _grad = tape.gradient(loss, self.trainable_variables)

            for g,_g in zip(grad, _grad):
                g += _g

            self.compiled_metrics.update_state(y[i:i + step], y_pred)
            return [i + step, grad]

        i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)


        # for g in gradients:        # I tried with and without division co calculate the mean
        #     g *= 1/self.grad_accum #


        # Update weights
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # Update metrics (includes the metric that tracks the loss)

        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}


if __name__ == '__main__':
    (x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.mnist.load_data()

    for MODEL, ga_kwarg, colour in list(zip([Model, SimpleTrainStepModel, GradAccumModel, GradAccumModel],
                                            [{}, {}, {'grad_accum': 1}, {'grad_accum': 6}],
                                            ['blue', 'green', 'yellow', 'red'])):

        for _ in tqdm(range(10)):
            # tf.random.set_seed(0)
            x = Input((28, 28))
            y = x
            y = Flatten()(y)
            y = Dense(128, activation = 'sigmoid')(y)
            y = Dense(10, activation = 'softmax')(y)

            model = MODEL(x, y)
            model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),
                          optimizer = tf.keras.optimizers.Adam(1e-4),
                          metrics = ['acc'])

            hist = model.fit(x_train, y_train, validation_data = (x_valid, y_valid), verbose = 0, batch_size = 6000, epochs = 100, **ga_kwarg)
            plt.plot(hist.history['val_acc'], color = colour, alpha = .25)

    plt.title('')
    plt.xscale('symlog')
    plt.yscale('logit')
    plt.show()

Я смог убедиться, что он действительно экономит память графического процессора. Однако конечный результат отличается от обычного Model.fit.

Проверка

Крупный план

Как видите, первые три Model.fit хорошо сгруппированы и дают одинаковые результаты. Но когда в игру вступает цикл while, обучение становится совершенно другим.

Кто-нибудь знает, почему это происходит?


person mbtg    schedule 10.03.2021    source источник


Ответы (1)


После множества попыток я нашел решение. Похоже, что основная проблема заключалась в составных назначениях градиентов, которые не работают так, как я ожидал. Вот мое окончательное решение для всех, кто может быть заинтересован. Он включает в себя дополнительные материалы для обучения распределенной, смешанной точности и вложенного ввода / вывода.

from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.util import nest
from tensorflow.keras.models import Model as _Model


class Model(_Model):
    def fit(self, *args, batch_size: int = 32, grad_accum_steps: int = 1, **kwargs):
        """
        Shallow wrapper of Model.fit that captures batch_size and additional kwarg: grad_accum.

        Parameters
        ----------
        batch_size : int
            same as in Model.fit
        grad_accum_steps : int
            Number of steps to split batch_size into. The `batch_size` should be divisible by `grad_accum` (defaults to 1).
        """
        if grad_accum_steps == 1:
            super().fit(*args, batch_size = batch_size, **kwargs)

        self.train_function = None
        num_workers = ds_context.get_strategy().num_replicas_in_sync
        if batch_size % (grad_accum_steps * num_workers) != 0:
            raise ValueError(f'Batch size ({batch_size}) must be divisible by the Gradient accumulation steps ({grad_accum_steps}), and the number of replicas ({num_workers}), dummy!')

        self._grad_accum_ = grad_accum_steps
        self._batch_size_ = batch_size
        self._num_workers_ = num_workers
        train_step_backup = self.train_step
        self.train_step = self._train_step_
        out = super(self).fit(*args,
                              batch_size = self._batch_size_, # TODO maybe consider validation batch size
                              **kwargs)

        del self._grad_accum_
        del self._batch_size_
        del self._num_workers_
        self.train_step = train_step_backup
        return out

    def _train_step_(self, data):
        """
        Custom training step taking into account gradient accumulation for low memory training
        """

        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None


        def slice_map(struct, start, stop): # dealing with nasty nested structures
            if struct is None:
                return None # special case for sample_weight

            return nest.map_structure(lambda x: x[start:stop], struct)



        # ---------- GRAD ACCUM STUFF ----------------------------------------------------------------------------------
        step = self._batch_size_ // self._num_workers_ // self._grad_accum_
        x_ = slice_map(x, 0, step)
        y_ = slice_map(y, 0, step)
        w_ = slice_map(sample_weight, 0, step)

        with tf.GradientTape() as tape:

            y_pred = self(x_, training = True)  # Forward pass
            loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
            if isinstance(self.optimizer, lso.LossScaleOptimizer):
                loss = self.optimizer.get_scaled_loss(loss)

        gradients = tape.gradient(loss, self.trainable_variables)
        gradients = [gradient * (1./self._grad_accum_) for gradient in gradients]
        self.compiled_metrics.update_state(y_, y_pred)

        i = tf.constant(step)
        def cond(i, *args):
            return i < self._batch_size_

        def body(i, grad):
            x_ = slice_map(x, i, i + step)
            y_ = slice_map(y, i, i + step)
            w_ = slice_map(sample_weight, i, i + step)

            with tf.GradientTape() as tape:
                y_pred = self(x_, training = True) # Forward pass
                loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
                if isinstance(self.optimizer, lso.LossScaleOptimizer):
                    loss = self.optimizer.get_scaled_loss(loss)

            _grad = tape.gradient(loss, self.trainable_variables)
            _grad = [_g * (1./self._grad_accum_) for _g in _grad]

            grad = [g + _g for g,_g in zip(grad, _grad)]

            self.compiled_metrics.update_state(y_, y_pred)
            return [i + step, grad]

        i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
        # --------------------------------------------------------------------------------------------------------------



        # ---------- STUFF FROM Model._minimize ------------------------------------------------------------------------
        aggregate_grads_outside_optimizer = (self.optimizer._HAS_AGGREGATE_GRAD and not isinstance(self.distribute_strategy.extended, parameter_server_strategy.ParameterServerStrategyExtended))

        if aggregate_grads_outside_optimizer: # TODO there might be some issues with the scaling, due to the extra accumulation steps
            gradients = self.optimizer._aggregate_gradients(zip(gradients, self.trainable_variables))

        if isinstance(self.optimizer, lso.LossScaleOptimizer):
            gradients = self.optimizer.get_unscaled_gradients(gradients)

        gradients = self.optimizer._clip_gradients(gradients)
        if self.trainable_variables:
            if aggregate_grads_outside_optimizer:
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables), experimental_aggregate_gradients = False)
            else:
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # --------------------------------------------------------------------------------------------------------------


        return {m.name: m.result() for m in self.metrics}
person mbtg    schedule 17.03.2021