Source code for mednet.models.segment.lwnet

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""`Little W-Net (LWNET) network architecture <lwnet_>`_, from [GALDRAN-2020]_.

It is based on two simple U-Nets with 3 layers concatenated to each other.  The
first U-Net produces a segmentation map that is used by the second to better
guide segmentation.
"""

import logging
import typing

import torch
import torch.nn

from ...data.typing import TransformSequence
from .losses import MultiWeightedBCELogitsLoss
from .model import Model

logger = logging.getLogger(__name__)


def _conv1x1(in_planes, out_planes, stride=1):
    return torch.nn.Conv2d(
        in_planes, out_planes, kernel_size=1, stride=stride, bias=False
    )


[docs] class ConvBlock(torch.nn.Module): """Convolution block. Parameters ---------- in_c Number of input channels. out_c Number of output channels. k_sz Kernel Size. shortcut If True, adds a Conv2d layer. pool If True, adds a MaxPool2d layer. """ def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True): super().__init__() if shortcut is True: self.shortcut = torch.nn.Sequential( _conv1x1(in_c, out_c), torch.nn.BatchNorm2d(out_c) ) else: self.shortcut = False pad = (k_sz - 1) // 2 block = [] if pool: self.pool = torch.nn.MaxPool2d(kernel_size=2) else: self.pool = False block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad)) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(out_c)) block.append(torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad)) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(out_c)) self.block = torch.nn.Sequential(*block)
[docs] def forward(self, x): if self.pool: x = self.pool(x) out = self.block(x) if self.shortcut: return out + self.shortcut(x) return out
[docs] class UpsampleBlock(torch.nn.Module): """Upsample block implementation. Parameters ---------- in_c Number of input channels. out_c Number of output channels. up_mode Upsampling mode. """ def __init__(self, in_c, out_c, up_mode="transp_conv"): super().__init__() block = [] if up_mode == "transp_conv": block.append(torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)) elif up_mode == "up_conv": block.append( torch.nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False) ) block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1)) else: raise Exception("Upsampling mode not supported") self.block = torch.nn.Sequential(*block)
[docs] def forward(self, x): return self.block(x)
[docs] class ConvBridgeBlock(torch.nn.Module): """ConvBridgeBlock implementation. Parameters ---------- channels Number of channels. k_sz Kernel Size. """ def __init__(self, channels, k_sz=3): super().__init__() pad = (k_sz - 1) // 2 block = [] block.append(torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad)) block.append(torch.nn.ReLU()) block.append(torch.nn.BatchNorm2d(channels)) self.block = torch.nn.Sequential(*block)
[docs] def forward(self, x): return self.block(x)
[docs] class UpConvBlock(torch.nn.Module): """UpConvBlock implementation. Parameters ---------- in_c Number of input channels. out_c Number of output channels. k_sz Kernel Size. up_mode Upsampling mode. conv_bridge If True, adds a ConvBridgeBlock layer. shortcut If True, adds a Conv2d layer. """ def __init__( self, in_c, out_c, k_sz=3, up_mode="up_conv", conv_bridge=False, shortcut=False, ): super().__init__() self.conv_bridge = conv_bridge self.up_layer = UpsampleBlock(in_c, out_c, up_mode=up_mode) self.conv_layer = ConvBlock( 2 * out_c, out_c, k_sz=k_sz, shortcut=shortcut, pool=False ) if self.conv_bridge: self.conv_bridge_layer = ConvBridgeBlock(out_c, k_sz=k_sz)
[docs] def forward(self, x, skip): up = self.up_layer(x) if self.conv_bridge: out = torch.cat([up, self.conv_bridge_layer(skip)], dim=1) else: out = torch.cat([up, skip], dim=1) return self.conv_layer(out)
[docs] class LittleUNet(torch.nn.Module): """Base little U-Net (LUNET) network architecture, from [GALDRAN-2020]_. Parameters ---------- in_c Number of input channels. n_classes Number of outputs (classes) for this model. layers Number of layers of the model. k_sz Kernel Size. up_mode Upsampling mode. conv_bridge If True, adds a ConvBridgeBlock layer. shortcut If True, adds a Conv2d layer. """ def __init__( self, in_c, n_classes, layers, k_sz=3, up_mode="transp_conv", conv_bridge=True, shortcut=True, ): super().__init__() self.n_classes = n_classes self.first = ConvBlock( in_c=in_c, out_c=layers[0], k_sz=k_sz, shortcut=shortcut, pool=False ) self.down_path = torch.nn.ModuleList() for i in range(len(layers) - 1): block = ConvBlock( in_c=layers[i], out_c=layers[i + 1], k_sz=k_sz, shortcut=shortcut, pool=True, ) self.down_path.append(block) self.up_path = torch.nn.ModuleList() reversed_layers = list(reversed(layers)) for i in range(len(layers) - 1): block = UpConvBlock( in_c=reversed_layers[i], out_c=reversed_layers[i + 1], k_sz=k_sz, up_mode=up_mode, conv_bridge=conv_bridge, shortcut=shortcut, ) self.up_path.append(block) # init, shamelessly lifted from torchvision/models/resnet.py for m in self.modules(): if isinstance(m, torch.nn.Conv2d): torch.nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity="relu" ) elif isinstance(m, torch.nn.BatchNorm2d | torch.nn.GroupNorm): torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.bias, 0) self.final = torch.nn.Conv2d(layers[0], n_classes, kernel_size=1)
[docs] def forward(self, x): x = self.first(x) down_activations = [] for i, down in enumerate(self.down_path): down_activations.append(x) x = down(x) down_activations.reverse() for i, up in enumerate(self.up_path): x = up(x, down_activations[i]) return self.final(x)
[docs] class LittleWNet(Model): """`Little W-Net (LWNET) network architecture <lwnet_>`_, from [GALDRAN-2020]_. Concatenates two :py:class:`Little U-Net <LittleUNet>` models. Parameters ---------- loss_type The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. loss_arguments Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. scheduler_type The type of scheduler to use for training. scheduler_arguments Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. num_classes Number of outputs (classes) for this model. """ def __init__( self, loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss, loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None, scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, ): super().__init__( name="lwnet", loss_type=loss_type, loss_arguments=loss_arguments, optimizer_type=optimizer_type, optimizer_arguments=optimizer_arguments, scheduler_type=scheduler_type, scheduler_arguments=scheduler_arguments, model_transforms=model_transforms, augmentation_transforms=augmentation_transforms, num_classes=num_classes, ) self.unet1 = LittleUNet( in_c=3, n_classes=self.num_classes, layers=(8, 16, 32), conv_bridge=True, shortcut=True, ) self.unet2 = LittleUNet( in_c=3 + self.num_classes, n_classes=self.num_classes, layers=(8, 16, 32), conv_bridge=True, shortcut=True, )
[docs] def forward(self, x): xn = self.normalizer(x) x1 = self.unet1(xn) x2 = self.unet2(torch.cat([xn, x1], dim=1)) return x1, x2
[docs] def predict_step(self, batch, batch_idx, dataloader_idx=0): # prediction only returns the result of the last unet return torch.sigmoid(self(batch[0]["image"])[1])