Source code for mednet.models.classify.densenet

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""`DenseNet-121 network architecture <densenet-pytorch_>`_, from :cite:p:`huang_densely_2017`."""

import logging
import typing

import torch
import torch.nn
import torch.optim.optimizer
import torch.utils.data
import torchvision.models as models

from ...data.typing import TransformSequence
from ..typing import Checkpoint
from .model import Model

logger = logging.getLogger(__name__)


[docs] class Densenet(Model): """`DenseNet-121 network architecture <densenet-pytorch_>`_, from :cite:p:`huang_densely_2017`. Parameters ---------- loss_type The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. loss_arguments Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. scheduler_type The type of scheduler to use for training. scheduler_arguments Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. dropout Dropout rate after each dense layer. num_classes Number of outputs (classes) for this model. """ def __init__( self, loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss, loss_arguments: dict[str, typing.Any] | None = None, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] | None = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None, scheduler_arguments: dict[str, typing.Any] | None = None, model_transforms: TransformSequence | None = None, augmentation_transforms: TransformSequence | None = None, pretrained: bool = False, dropout: float = 0.1, num_classes: int = 1, ): super().__init__( name="densenet", loss_type=loss_type, loss_arguments=loss_arguments, optimizer_type=optimizer_type, optimizer_arguments=optimizer_arguments, scheduler_type=scheduler_type, scheduler_arguments=scheduler_arguments, model_transforms=model_transforms, augmentation_transforms=augmentation_transforms, num_classes=num_classes, ) self.pretrained = pretrained self.dropout = dropout # Load pretrained model weights = None if self.pretrained: from ..normalizer import make_imagenet_normalizer Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined] logger.info(f"Loading pretrained `{self.name}` model weights") weights = models.DenseNet121_Weights.DEFAULT self.model = models.densenet121(weights=weights, drop_rate=self.dropout) # output layer self.model.classifier = torch.nn.Linear( self.model.classifier.in_features, self.num_classes ) @Model.num_classes.setter # type: ignore[attr-defined] def num_classes(self, v: int) -> None: if self.num_classes != v: logger.info( f"Resetting `{self.name}` output classifier layer weights due " f"to a change in output size ({self.num_classes} -> {v})" ) self.model.classifier = torch.nn.Linear( self.model.classifier.in_features, v ) self._num_classes = v
[docs] def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: # support previous version of densenet (model_ft -> model) if any([k.startswith("model_ft") for k in checkpoint["state_dict"].keys()]): # convert all "model_ft" entries to "model" checkpoint["state_dict"] = { k.replace("model_ft", "model"): v for k, v in checkpoint["state_dict"].items() } # reset number of output classes if need be self.num_classes = checkpoint["state_dict"]["model.classifier.bias"].shape[0] # perform routine checkpoint loading super().on_load_checkpoint(checkpoint)
[docs] def forward(self, x): x = self.normalizer(x) return self.model(x)