Я пытаюсь реализовать функцию Multiclass Hybrid loss в Python из следующей статьи https://arxiv.org/pdf/1808.05238.pdf для моей проблемы семантической сегментации с использованием несбалансированного набора данных. Мне удалось сделать свою реализацию достаточно правильной, чтобы начать обучение модели, но результаты очень плохие. Архитектура модели - U-net, скорость обучения в оптимизаторе Adam составляет 1e-5. Форма маски (Нет, 512, 512, 3) с 3 классами (в моем случае лес, вырубка леса, другое). Формула, которую я использовал для реализации своей потери:
Код, который я создал:
def build_hybrid_loss(_lambda_=1, _alpha_=0.5, _beta_=0.5, smooth=1e-6):
def hybrid_loss(y_true, y_pred):
C = 3
tversky = 0
# Calculate Tversky Loss
for index in range(C):
inputs_fl = tf.nest.flatten(y_pred[..., index])
targets_fl = tf.nest.flatten(y_true[..., index])
#True Positives, False Positives & False Negatives
TP = tf.reduce_sum(tf.math.multiply(inputs_fl, targets_fl))
FP = tf.reduce_sum(tf.math.multiply(inputs_fl, 1-targets_fl[0]))
FN = tf.reduce_sum(tf.math.multiply(1-inputs_fl[0], targets_fl))
tversky_i = (TP + smooth) / (TP + _alpha_ * FP + _beta_ * FN + smooth)
tversky += tversky_i
tversky += C
# Calculate Focal loss
loss_focal = 0
for index in range(C):
f_loss = - (y_true[..., index] * (1 - y_pred[..., index])**2 * tf.math.log(y_pred[..., index]))
# Average over each data point/image in batch
axis_to_reduce = range(1, 3)
f_loss = tf.math.reduce_mean(f_loss, axis=axis_to_reduce)
loss_focal += f_loss
result = tversky + _lambda_ * loss_focal
return result
return hybrid_loss
Прогноз модели после окончания эпохи (у меня проблема с перестановкой цветов, поэтому красный цвет в прогнозе на самом деле зеленый, что означает лес, поэтому прогноз в основном лес, а не обезлесение) ):
Вопрос в том, что не так с моей реализацией гибридных потерь, что нужно изменить, чтобы она заработала?
I have a problem with swapped colors
: если вы используете opencv, обратите внимание, что цветовое пространство по умолчанию —BGR
, в то время как другие библиотеки обработки изображений обычно работают сRGB
. - person Lescurel   schedule 03.02.2021prediction[..., 0]
- это предсказанная маска для леса. Но когда я используюplt.imshow(prediction)
, прогноз имеет три канала, как и изображение RGB, поэтому первая маска (маска леса) красная. Теперь это побочный вопрос. - person Петр Воротинцев   schedule 03.02.2021