Уменьшение памяти Tensorflow TPU v2 / v3 bfloat16

Моя модель слишком велика, чтобы получить партию> 64 с обычными устройствами TPU v2. На сайте устранения неполадок упоминается, что в следующих версиях tenorflow будет bfloat16 служба поддержки. Могут ли недавно поддерживаемые версии tf 1.9–1.12 использовать bfloat16, и если да, то есть ли ограниченный набор оптимизаторов, которые я могу использовать? Я не нашел дополнительной документации по этому поводу, но видел использование bfloat16 в модели tensor2tensor, поэтому я думаю, что должен быть способ.

Кроме того, я читал, что TPU v3 также поддерживает более крупные модели, но модель потребует минимальных изменений, но я не нахожу документации, что нужно изменить.

Я уже использую Adafactor и пытался уменьшить мои слои, если у вас есть дополнительные советы по уменьшению, это тоже было бы здорово. Я использую матрицы изображений и векторы слов (на данный момент float32) в качестве входных данных.


person user2368505    schedule 24.11.2018    source источник


Ответы (1)


Вы можете использовать bfloat16 с TPU. Есть две основные вещи, которые нужно сделать:

  1. Передайте ввод в bfloat16 в своем входном конвейере
  2. Окружите свою сеть областью видимости bfloat16 и преобразуйте выходные данные в F32 для дальнейших вычислений.

Вот фрагмент кода, иллюстрирующий необходимые изменения:

def input_fn():

  def dataset_parser(self, value):
    """Parse an ImageNet record from a serialized string Tensor."""
    image = self.image_preprocessing_fn(
        image_bytes=image_bytes,
        is_training=self.is_training,
    )

    if self.use_bfloat16:
      image = tf.cast(image, tf.bfloat16)

    return image, label


def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator."""

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    network = resnet_model.resnet_v1(
        resnet_depth=FLAGS.resnet_depth,
        num_classes=LABEL_CLASSES,
        data_format=FLAGS.data_format)
    return network(
        inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  if FLAGS.precision == 'bfloat16':
    with bfloat16.bfloat16_scope():
      logits = build_network()
    logits = tf.cast(logits, tf.float32)
  elif FLAGS.precision == 'float32':
    logits = build_network()

Вы также можете увидеть второе условие, показанное в эта модель TPU.

person Alex Ilchenko    schedule 18.12.2018