Перенос обучения с tf.keras и Inception-v3: обучение не происходит

Я пытаюсь обучить модель на основе замороженной модели Inception_v3 с 3 классами в качестве выходных данных. Когда я запускаю обучение, точность обучения повышается, но не точность проверки, которая составляет более или менее точно 33,33%, то есть показывает полностью случайное предсказание. Я не могу понять, где ошибка в моем коде и / или подходе

Я пробовал различные формы вывода после ядра Inception v3 без каких-либо различий.

# Model definition
# InceptionV3 frozen, flatten, dense 1024, dropout 50%, dense 1024, dense 3, lr 0.001 --> does not train
# InceptionV3 frozen, flatten, dense 1024, dense 3, lr 0.001 --> does not train
# InceptionV3 frozen, flatten, dense 1024, dense 3, lr 0.005 --> does not train
# InceptionV3 frozen, GlobalAvgPooling, dense 1024, dense 1024, dense 512, dense 3, lr 0.001 --> does not train
# InceptionV3 frozen, GlobalAvgPooling dropout 0.4 dense 3, lr 0.001, custom pre-process --> does not train
# InceptionV3 frozen, GlobalAvgPooling dropout 0.4 dense 3, lr 0.001, custom pre-process, batch=32 --> does not train
# InceptionV3 frozen, GlobalAvgPooling dropout 0.4 dense 3, lr 0.001, custom pre-process, batch=32, rebalance train/val sets --> does not train

IMAGE_SIZE = 150
BATCH_SIZE = 32

def build_model(image_size):
  input_tensor = tf.keras.layers.Input(shape=(image_size, image_size, 3))

  inception_base = InceptionV3(include_top=False, weights='imagenet', input_tensor=input_tensor)
  for layer in inception_base.layers:
    layer.trainable = False

  x = inception_base.output
  x = tf.keras.layers.GlobalAveragePooling2D()(x)
  x = tf.keras.layers.Dropout(0.2)(x)
  output_tensor = tf.keras.layers.Dense(3, activation="softmax")(x)

  model = tf.keras.Model(inputs=input_tensor, outputs=output_tensor)

  return model

model = build_model(IMAGE_SIZE)
model.compile(optimizer=RMSprop(lr=0.002), loss='categorical_crossentropy', metrics=['acc'])

# Data generators with Image augmentations
train_datagen = ImageDataGenerator(
      rescale=1./255,
      preprocessing_function=tf.keras.applications.inception_v3.preprocess_input,
      rotation_range=40,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest')

# Do not augment validation!
validation_datagen = ImageDataGenerator(
    rescale=1./255,
    preprocessing_function=tf.keras.applications.inception_v3.preprocess_input)

train_generator = train_datagen.flow_from_directory(
      train_dir,
      target_size=(IMAGE_SIZE, IMAGE_SIZE),
      batch_size=BATCH_SIZE,
      class_mode='categorical')

validation_generator = validation_datagen.flow_from_directory(
      valid_dir,
      target_size=(IMAGE_SIZE, IMAGE_SIZE),
      batch_size=BATCH_SIZE,
      class_mode='categorical')

Вывод этой ячейки:

Найдено 1697 изображений, относящихся к 3 классам. Найдено 712 изображений, относящихся к 3 классам.

Результат двух последних эпох обучения:

Эпоха 19/20
23/23 [==============================] - 6 с 257 мс / шаг - потеря: 1.1930 - соотв: 0.3174
54/54 [================================] - 20 с 363 мс / шаг - потеря : 0.7870 - acc: 0.6912 - val_loss: 1.1930 - val_acc: 0.3174
Эпоха 20/20 - 23/23 [====================== ========] - 6 с 255 мс / шаг - потеря: 1.1985 - согласно: 0.3160 ​​
54/54 [===================== =========] - 20 с 362 мс / шаг - потеря: 0,7819 - acc: 0,7018 - val_loss: 1,1985 - val_acc: 0,3160


person whobbes    schedule 11.06.2019    source источник
comment
Подхожу к модели с: history = model.fit_generator(train_generator, epochs=20, verbose=1, validation_data=validation_generator)   -  person whobbes    schedule 11.06.2019


Ответы (1)


Единственное, что меня бросает в глаза, - это отказаться от rescale=1./255 ImageDataGenerators, потому что этим также занимается tf.keras.applications.inception_v3.preprocess_input, который масштабирует от -1 до 1; ожидаемый вход сети.

person TheLoneDeranger    schedule 11.06.2019
comment
Большое спасибо за это, сеть сейчас тренируется. По какой-то причине я сделал предположение, что всем сетям нужен вход в диапазоне [0,1]. - person whobbes; 12.06.2019