Source code for mednet.models.segment.model

# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Definition of base model type for semantic segmentation tasks."""

import logging
import typing

import torch
import torch.nn
import torch.optim.optimizer
import torch.utils.data

from ...data.typing import TransformSequence
from ..losses import BCEWithLogitsLossWeightedPerBatch
from ..model import Model as BaseModel

logger = logging.getLogger(__name__)


[docs] class Model(BaseModel): """Base model type for object detection tasks. Parameters ---------- name Common name to give to models of this type. loss_type The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. loss_arguments Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. scheduler_type The type of scheduler to use for training. scheduler_arguments Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. num_classes Number of outputs (classes) for this model. """ def __init__( self, name: str, loss_type: type[torch.nn.Module] | None = BCEWithLogitsLossWeightedPerBatch, loss_arguments: dict[str, typing.Any] | None = None, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] | None = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None, scheduler_arguments: dict[str, typing.Any] | None = None, model_transforms: TransformSequence | None = None, augmentation_transforms: TransformSequence | None = None, num_classes: int = 1, ): super().__init__( name=name, loss_type=loss_type, loss_arguments=loss_arguments, optimizer_type=optimizer_type, optimizer_arguments=optimizer_arguments, scheduler_type=scheduler_type, scheduler_arguments=scheduler_arguments, model_transforms=model_transforms, augmentation_transforms=augmentation_transforms, num_classes=num_classes, )
[docs] def training_step(self, batch, batch_idx): del batch_idx # debug code to inspect images by eye: # from torchvision.transforms.v2.functional import to_pil_image # for k in batch["image"]: # to_pil_image(k).show() # __import__("pdb").set_trace() return self.train_loss( self(self.augmentation_transforms(batch["image"])), self.augmentation_transforms(batch["target"]), )