Есть ли способ создать слой предварительной обработки keras, который случайным образом вращается под заданными углами?

Я работаю над проектом классификации астрономических изображений, и в настоящее время я использую keras для создания CNN.

Я пытаюсь создать конвейер предварительной обработки, чтобы расширить свой набор данных слоями keras / tensorflow. Для простоты я хотел бы реализовать случайные преобразования двугранной группы (т. Е. Для квадратной изображения, поворот и переворот на 90 градусов), но кажется, что tf.keras.preprocessing.image.random_rotation допускает только случайную степень в непрерывном диапазоне выбора после равномерного распределения.

Мне было интересно, есть ли способ вместо этого выбрать из списка указанных градусов, в моем случае [0, 90, 180, 270].


person eymerich92    schedule 25.02.2021    source источник


Ответы (1)


К счастью, есть функция тензорного потока, которая делает то, что вы хотите: tf.image.rot90 . Следующий шаг - обернуть эту функцию в пользовательский PreprocessingLayer, чтобы она выполнялась случайным образом.

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers.experimental.preprocessing import PreprocessingLayer

class RandomRot90(PreprocessingLayer):
    def __init__(self, name=None, **kwargs) -> None:
        super(RandomRot90, self).__init__(name=name, **kwargs)
        self.input_spec = tf.keras.layers.InputSpec(ndim=4)
    
    def call(self, inputs, training=True):
        if training is None:
            training = K.learning_phase()
        
        def random_rot90():
            # random int between 0 and 3
            rot = tf.random.uniform((),0,4, dtype=tf.int32)
            return tf.image.rot90(inputs, k=rot)
        
        # if not training, do nothing
        outputs = tf.cond(training, random_rot90, lambda:inputs)
        outputs.set_shape(inputs.shape)
        return outputs
    
    def compute_output_shape(self, input_shape):
        return input_shape
  • Обратите внимание, что вы можете захотеть реализовать get_config, если хотите сохранить и загрузить модель с этим слоем. (См. документацию)
  • Также обратите внимание, что этот слой может выйти из строя, если ваши входные данные не квадратные (высота! = Ширина).
person Lescurel    schedule 25.02.2021