Source code for mednet.models.detect.faster_rcnn

# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Faster R-CNN object detection (and classification) network architecture, from :cite:p:`ren_faster_2017`."""

import logging
import typing

import torch
import torch.nn
import torch.optim
import torch.optim.lr_scheduler
import torch.optim.optimizer
import torch.utils.data
import torchvision.models as models

from ...data.typing import TransformSequence
from ..typing import Checkpoint
from .model import Model

logger = logging.getLogger(__name__)


[docs] class FasterRCNN(Model): """Faster R-CNN object detection (and classification) network architecture, from :cite:p:`ren_faster_2017`. Parameters ---------- optimizer_type The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. scheduler_type The type of scheduler to use for training. scheduler_arguments Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. num_classes Number of outputs (classes) for this model. Do not account for the background (we compensate internally). variant One of the torchvision supported variants. """ def __init__( self, optimizer_type: type[torch.optim.Optimizer] = torch.optim.SGD, optimizer_arguments: dict[str, typing.Any] = dict( lr=0.005, momentum=0.9, weight_decay=0.0005 ), scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = torch.optim.lr_scheduler.StepLR, scheduler_arguments: dict[str, typing.Any] = dict(step_size=3, gamma=0.1), model_transforms: TransformSequence | None = None, augmentation_transforms: TransformSequence | None = None, pretrained: bool = False, num_classes: int = 1, variant: typing.Literal[ "resnet50-v1", "resnet50-v2", "mobilenetv3-large", "mobilenetv3-small" ] = "mobilenetv3-small", ): super().__init__( name=f"faster-rcnn[{variant}]", loss_type=None, loss_arguments=None, optimizer_type=optimizer_type, optimizer_arguments=optimizer_arguments, scheduler_type=scheduler_type, scheduler_arguments=scheduler_arguments, model_transforms=model_transforms, augmentation_transforms=augmentation_transforms, num_classes=num_classes, ) self.pretrained = pretrained self.variant = variant # Load pretrained model weights = None if pretrained: from ..normalizer import make_cocov1_normalizer Model.normalizer.fset(self, make_cocov1_normalizer()) # type: ignore[attr-defined] logger.info(f"Loading pretrained `{self.name}` model weights") match self.variant: case "resnet50-v1": weights = "FasterRCNN_ResNet50_FPN_Weights.COCO_V1" case "resnet50-v2": weights = "FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1" case "mobilenetv3-large": weights = "FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1" case "mobilenetv3-small": weights = "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1" match self.variant: case "resnet50-v1": self.model = models.detection.fasterrcnn_resnet50_fpn(weights=weights) case "resnet50-v2": self.model = models.detection.fasterrcnn_resnet50_fpn_v2( weights=weights ) case "mobilenetv3-large": self.model = models.detection.fasterrcnn_mobilenet_v3_large_fpn( weights=weights ) case "mobilenetv3-small": self.model = models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( weights=weights ) # Instantiates model and adapts output features self.num_classes = num_classes @Model.num_classes.setter # type: ignore[attr-defined] def num_classes(self, v: int) -> None: # Faster R-CNN models will have num_classes + 1 outputs accounting for the # background v += 1 if self.model.roi_heads.box_predictor.cls_score.out_features != v: if self.pretrained: logger.info( f"Resetting `{self.name}` pretrained classifier " f"layer weights due to change in output size " f"({self.model.roi_heads.box_predictor.cls_score.out_features} -> {v})" ) self.model.roi_heads.box_predictor = ( models.detection.faster_rcnn.FastRCNNPredictor( self.model.roi_heads.box_predictor.cls_score.in_features, v ) ) self._num_classes = v
[docs] def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: checkpoint["variant"] = self.variant super().on_save_checkpoint(checkpoint)
[docs] def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: # setup model type self.variant = checkpoint.get("variant", "mobilenetv3-small") match self.variant: case "resnet50-v1": self.model = models.detection.fasterrcnn_resnet50_fpn() case "resnet50-v2": self.model = models.detection.fasterrcnn_resnet50_fpn_v2() case "mobilenetv3-large": self.model = models.detection.fasterrcnn_mobilenet_v3_large_fpn() case "mobilenetv3-small": self.model = models.detection.fasterrcnn_mobilenet_v3_large_320_fpn() # reset number of classes if need be num_classes = checkpoint["state_dict"][ "model.roi_heads.box_predictor.cls_score.bias" ].shape[0] if num_classes != self.num_classes: logger.debug( f"Resetting number-of-output-classes at `{self.name}` model from " f"{self.num_classes} to {num_classes} while loading checkpoint." ) self.num_classes = num_classes - 1 super().on_load_checkpoint(checkpoint)
[docs] def forward( self, images: list[torch.Tensor], targets: list[dict[str, torch.Tensor]] | None = None, ): """Forward the input tensor through the network, producing a prediction. Parameters ---------- images Input images, to be analyzed or trained on. targets Targets for the current input images. Targets should be passed only if *training*. It should be alist of dictionaries, each corresponding to one of the input images in ``images``. Each dictionary should have two keys: ``boxes``, that contains the bounding boxes associated with the image, and ``labels``, that contains the labels associated with each bounding box. Returns ------- A list with various dictionaries, each referring to one of the """ images = [self.normalizer(k) for k in images] if targets is None: return self.model(images) return self.model(images, targets)
[docs] def training_step(self, batch, batch_idx): del batch_idx # debug code to inspect images by eye: # from torchvision.transforms.v2.functional import to_pil_image # for k in batch["image"]: # to_pil_image(k).show() # __import__("pdb").set_trace() # during training, detection models __call__() take 2 lists as inputs: # * list of images # * list of dictionaries with "boxes" and "labels" # # -> model returns a Dict[Tensor] during training, containing the # classification and regression losses for both the RPN and the R-CNN. result = self.forward( [k for k in self._augmentation_transforms(batch["image"])], [ {"boxes": k1, "labels": k2.long()} for (k1, k2) in zip( self._augmentation_transforms(batch["target"]), batch["labels"] ) ], ) return sum(result.values())
[docs] def validation_step(self, batch, batch_idx, dataloader_idx=0): del batch_idx, dataloader_idx # satisfies linter # need to be in "train" mode to calculate losses self.model.train() with torch.no_grad(): result = self( [k for k in batch["image"]], [ {"boxes": k1, "labels": k2.long()} for (k1, k2) in zip(batch["target"], batch["labels"]) ], ) return sum(result.values())
[docs] def predict_step(self, batch, batch_idx, dataloader_idx=0): del batch_idx, dataloader_idx # satisfies linter # debug code to inspect images by eye: # from torchvision.transforms.v2.functional import to_pil_image # for k in batch["image"]: # to_pil_image(k).show() # __import__("pdb").set_trace() # during inference, detection models __call__() take 1 list as input: # * list of images # # -> model returns the post-processed predictions as a # list[dict[torch.Tensor]], one for each input image return self([k for k in batch["image"]])