Source code for mednet.models.segment.unet

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""`UNet network architecture <unet_>`_, from [RONNEBERGER-2015]_."""

import logging
import typing

import torch.nn
import torch.utils.data

from ...data.typing import TransformSequence
from ..losses import SoftJaccardAndBCEWithLogitsLoss
from .backbones.vgg import vgg16_for_segmentation
from .make_layers import UnetBlock, conv_with_kaiming_uniform
from .model import Model

logger = logging.getLogger(__name__)


[docs] class UNetHead(torch.nn.Module): """UNet head module. Parameters ---------- in_channels_list Number of channels for each feature map that is returned from backbone. pixel_shuffle If True, upsample using PixelShuffleICNR. """ def __init__(self, in_channels_list: list[int], pixel_shuffle=False): super().__init__() # number of channels c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list # build layers self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle, middle_block=True) self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle) self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle) self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle) self.final = conv_with_kaiming_uniform(c_decode1, 1, 1)
[docs] def forward(self, x: list[torch.Tensor]): """Forward pass. Parameters ---------- x List of tensors as returned from the backbone network. First element: height and width of input image. Remaining elements: feature maps for each feature level. Returns ------- OUtput of the forward pass. """ # NOTE: x[0]: height and width of input image not needed in U-Net architecture decode4 = self.decode4(x[5], x[4]) decode3 = self.decode3(decode4, x[3]) decode2 = self.decode2(decode3, x[2]) decode1 = self.decode1(decode2, x[1]) return self.final(decode1)
[docs] class Unet(Model): """`UNet network architecture <unet_>`_, from [RONNEBERGER-2015]_. Parameters ---------- loss_type The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. loss_arguments Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. scheduler_type The type of scheduler to use for training. scheduler_arguments Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. num_classes Number of outputs (classes) for this model. pretrained If True, will use VGG16 pretrained weights. """ def __init__( self, loss_type: type[torch.nn.Module] = SoftJaccardAndBCEWithLogitsLoss, loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None, scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, pretrained: bool = False, ): super().__init__( name="unet", loss_type=loss_type, loss_arguments=loss_arguments, optimizer_type=optimizer_type, optimizer_arguments=optimizer_arguments, scheduler_type=scheduler_type, scheduler_arguments=scheduler_arguments, model_transforms=model_transforms, augmentation_transforms=augmentation_transforms, num_classes=num_classes, ) self.pretrained = pretrained self.backbone = vgg16_for_segmentation( pretrained=self.pretrained, return_features=[3, 8, 14, 22, 29], ) self.head = UNetHead([64, 128, 256, 512, 512], pixel_shuffle=False)
[docs] def forward(self, x): x = self.normalizer(x) x = self.backbone(x) return self.head(x)
[docs] def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: """Initialize the normalizer for the current model. This function is NOOP if ``pretrained = True`` (normalizer set to imagenet weights, during contruction). Parameters ---------- dataloader A torch Dataloader from which to compute the mean and std. Will not be used if the model is pretrained. """ if self.pretrained: from ..normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " f"computing z-norm factors from train dataloader. " f"Using preset factors from torchvision.", ) self.normalizer = make_imagenet_normalizer() else: super().set_normalizer(dataloader)