# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import copy
import logging
import typing
import lightning.pytorch
import torch
import torch.nn
import torch.optim.lr_scheduler
import torch.optim.optimizer
import torch.utils.data
import torchvision.transforms
from ..data.typing import TransformSequence
from .typing import Checkpoint
logger = logging.getLogger(__name__)
[docs]
class Model(lightning.pytorch.LightningModule):
"""Base class for models.
Parameters
----------
name
Common name to give to models of this type.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
scheduler_type
The type of scheduler to use for training.
scheduler_arguments
Arguments to the scheduler after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
self,
name: str,
loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__()
self.name = name
self._num_classes = num_classes
self.model_transforms = model_transforms
self.loss_type = loss_type
self.train_loss_arguments = copy.deepcopy(loss_arguments)
self.validation_loss_arguments = copy.deepcopy(loss_arguments)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._scheduler_type = scheduler_type
self._scheduler_arguments = scheduler_arguments
self.augmentation_transforms = augmentation_transforms
# initializes losses from input arguments
self.configure_losses()
@property
def augmentation_transforms(self) -> torchvision.transforms.Compose:
return self._augmentation_transforms
@augmentation_transforms.setter
def augmentation_transforms(self, v: TransformSequence) -> None:
self._augmentation_transforms = torchvision.transforms.Compose(v)
if len(v) != 0:
transforms_str = ", ".join(
[
f"{type(k).__module__}.{str(k)}"
for k in self._augmentation_transforms.transforms
]
)
logger.info(f"Data augmentations: {transforms_str}")
else:
logger.info("Data augmentations: None")
@property
def num_classes(self) -> int:
"""Number of outputs (classes) for this model.
Returns
-------
int
The number of outputs supported by this model.
"""
return self._num_classes
@num_classes.setter
def num_classes(self, v: int) -> None:
raise RuntimeError(f"Cannot reset number of classes to `{v}` for model.")
[docs]
def forward(self, x):
raise NotImplementedError
[docs]
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint["normalizer"] = self.normalizer
[docs]
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
super().on_load_checkpoint(typing.cast(typing.Any, checkpoint))
[docs]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
[docs]
def balance_losses(self, datamodule) -> None:
"""Balance the loss based on the distribution of positives.
This function will balance the loss with considering the targets in the
datamodule. Only works if the loss supports it (i.e. contains a
``pos_weight`` attribute).
Parameters
----------
datamodule
Instance of a datamodule from where targets will be loaded.
"""
from .loss_weights import get_positive_weights
if hasattr(self.loss_type(), "pos_weight"):
train_weights = get_positive_weights(datamodule.train_dataloader())
self.train_loss_arguments["pos_weight"] = train_weights
logger.info(
f"Balanced training loss `{self.loss_type}`: "
f"`pos_weight={train_weights}`."
)
if "validation" in datamodule.val_dataloader().keys():
validation_weights = get_positive_weights(
datamodule.val_dataloader()["validation"]
)
else:
logger.warning(
"Datamodule does not contain a validation dataloader. "
"The training dataloader will be used instead."
)
validation_weights = get_positive_weights(datamodule.train_dataloader())
self.validation_loss_arguments["pos_weight"] = validation_weights
logger.info(
f"Balanced validation loss `{self.loss_type}`: "
f"`pos_weight={validation_weights}`."
)
else:
logger.warning(
f"Loss `{self.loss_type}` does not posess a `pos_weight` "
f"attribute and will not be balanced."
)
# calls super class to continue the loss configuration
self.configure_losses()
[docs]
def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
"""Move model, augmentations and losses to specified device.
Refer to the method :py:meth:`torch.nn.Module.to` for details.
Parameters
----------
*args
Parameter forwarded to the underlying implementations.
**kwargs
Parameter forwarded to the underlying implementations.
Returns
-------
Self.
"""
super().to(*args, **kwargs)
self._augmentation_transforms = torchvision.transforms.Compose(
[
k.to(*args, **kwargs)
for k in self._augmentation_transforms.transforms
if hasattr(k, "to")
]
)
self.train_loss.to(*args, **kwargs)
self.validation_loss.to(*args, **kwargs)
return self
[docs]
def training_step(self, batch, batch_idx):
raise NotImplementedError(
"You cannot use the base model without implementing the training_step()"
)
[docs]
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# debug code to inspect images by eye:
# from torchvision.transforms.functional import to_pil_image
# for k in batch["image"]:
# to_pil_image(k).show()
# __import__("pdb").set_trace()
return self.validation_loss(self(batch["image"]), batch["target"])
[docs]
def predict_step(self, batch, batch_idx, dataloader_idx=0):
# debug code to inspect images by eye:
# from torchvision.transforms.functional import to_pil_image
# for k in batch["image"]:
# to_pil_image(k).show()
# __import__("pdb").set_trace()
return torch.sigmoid(self(batch["image"]))