Целью биомедицинской сегментации является обнаружение различных анатомических структур на изображениях. Это ключевая вспомогательная технология для медицинских приложений, таких как диагностика, планирование и руководство. Поэтому за последние несколько лет было опубликовано несколько методов, посвященных сегментации биомедицинских изображений.

В этом проекте я использовал архитектуру U-Net [1] для сегментации поражений на КТ легких. U-Net — это полностью сверточная нейронная сеть, напоминающая автоэнкодер. Модель состоит из 2-х частей — энкодера и декодера — которые связаны между собой скиповыми соединениями и узким местом, как показано на рисунке ниже.

Даже при наличии нескольких обучающих изображений сетевая архитектура позволяет модели получать информацию о структуре изображения. Эта способность делает модель идеальной для медицинской среды, где получение огромного количества изображений обходится дорого.

В этом проекте я реализовал U-Net в PyTorch, используя оригинальную архитектуру, предложенную Олафом Роннебергером и др. [1]. В предлагаемой архитектуре и декодер, и кодер состоят из 4 двойных сверточных блоков. Каждая свертка сопровождается нормализацией партии и применением функции активации ReLU.

class ConvolutionBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()

        self.conv1 = nn.Conv2d(
            input_channels, output_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(output_channels)

        self.conv2 = nn.Conv2d(
            output_channels, output_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(output_channels)

        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

В блоках кодировщика выходные данные блока двойной свертки затем подвергаются субдискретизации с использованием слоя максимального объединения 2x2. Кроме того, выход блоков свертки служит вторым входом для декодера через пропускные соединения. Блоки декодера выглядят одинаково, но он использует деконволюционные слои (сверточное транспонирование), а не слои с максимальным объединением.

class EncoderBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()

        self.conv_block = ConvolutionBlock(input_channels, output_channels)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv_block(inputs)
        p = self.pool(x)
        return x, p


class DecoderBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()

        self.up = nn.ConvTranspose2d(
            input_channels, output_channels, kernel_size=2, stride=2, padding=0)
        self.conv = ConvolutionBlock(
            output_channels + output_channels, output_channels)

    def forward(self, inputs, skip_connection):
        x = self.up(inputs)
        x = torch.cat([x, skip_connection], axis=1)
        x = self.conv(x)
        return x

Вся U-Net состоит из 4 блоков энкодера, за которыми следует один блок двойной свертки, который служит узким местом. В части декодера есть 4 блока декодера, которые увеличивают размер изображения до его исходного размера. После декодирования остается последний сверточный слой, служащий выходом сегментации.

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.e1 = EncoderBlock(3, 64)
        self.e2 = EncoderBlock(64, 128)
        self.e3 = EncoderBlock(128, 256)
        self.e4 = EncoderBlock(256, 512)

        """ Bottleneck/Bridge connection """
        self.b = ConvolutionBlock(512, 1024)

        """" Decoder """
        self.d1 = DecoderBlock(1024, 512)
        self.d2 = DecoderBlock(512, 256)
        self.d3 = DecoderBlock(256, 128)
        self.d4 = DecoderBlock(128, 64)

        """ Segmentation output """
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoding """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        """ Bottleneck/Bridge connection """
        b = self.b(p4)

        """ Decoding """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        return self.outputs(d4)

Для обучения модели я использовал набор данных КТ-сканирования COVID-19 от Kaggle [2]. Он состоит из нескольких тысяч сканов и масок сегментации. Из них я выбрал 100 изображений с соответствующими масками, так как хотел смоделировать отсутствие входных данных.
Позже я разбил эти изображения 50–50 на обучающую и тестовую подмножества. Я также использовал аугментацию данных, такую ​​как вертикальное и горизонтальное переворачивание или вращение, чтобы получить 200 обучающих выборок из исходных 50.

Во время обучения я использовал оптимизатор Adam, а также средство масштабирования градиента с автоприведением, чтобы получить более точные результаты. Для функции потерь я использовал бинарную кросс-энтропию и потери в костях во время экспериментов. Я обнаружил, что сумма обоих привела к лучшим результатам.

После обучения модели в течение 500 эпох я смог воспроизвести маски изображений практически без отличий от исходных данных, как показано в сравнении изображений ниже.

В целом, U-Net — это модель, которая отлично подходит для сегментации изображений. Он показал выдающуюся производительность даже при обучении всего на нескольких обучающих образцах. Модель также легко расширяется для многоклассовой сегментации.
Единственным его недостатком является временная сложность обучения модели, когда у вас нет доступа к графическому процессору. Несмотря на временную сложность, я бы рекомендовал U-Net для современных проектов медицинской сегментации.

Ссылки

[1] Олаф Роннебергер, Филипп Фишер и Томас Брокс. U-net: сверточные сети для сегментации биомедицинских изображений. 2015.

[2] Маэде Мафтуни. Набор данных сегментации поражений при компьютерной томографии Covid-19, 2021 г. URL: https://www.kaggle.com/datasets/maedemaftouni/covid19-ct-scan-lesion-segmentation-dataset