Source code for mednet.models.classify.model
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Definition of base model type for classification tasks."""
import logging
import typing
import torch
import torch.nn
import torch.optim
import torch.optim.optimizer
import torch.utils.data
from ...data.typing import TransformSequence
from ..model import Model as BaseModel
logger = logging.getLogger(__name__)
[docs]
class Model(BaseModel):
"""Base model type for classification tasks.
Parameters
----------
name
Common name to give to models of this type.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
scheduler_type
The type of scheduler to use for training.
scheduler_arguments
Arguments to the scheduler after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
self,
name: str,
loss_type: type[torch.nn.Module] | None = None,
loss_arguments: dict[str, typing.Any] | None = None,
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] | None = None,
scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] | None = None,
model_transforms: TransformSequence | None = None,
augmentation_transforms: TransformSequence | None = None,
num_classes: int = 1,
):
super().__init__(
name=name,
loss_type=loss_type,
loss_arguments=loss_arguments,
optimizer_type=optimizer_type,
optimizer_arguments=optimizer_arguments,
scheduler_type=scheduler_type,
scheduler_arguments=scheduler_arguments,
model_transforms=model_transforms,
augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
)
[docs]
def training_step(self, batch, batch_idx):
del batch_idx
return self.train_loss(
self(self.augmentation_transforms(batch["image"])), batch["target"]
)