# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Functions supporting training of pytorch models."""
import logging
import os
import pathlib
import lightning.pytorch
import lightning.pytorch.callbacks
import lightning.pytorch.loggers
import torch
import torch.utils.data
from ..data.datamodule import ConcatDataModule
from ..data.typing import Dataset
from ..models.model import Model
from ..utils.checkpointer import CHECKPOINT_ALIASES
from ..utils.resources import ResourceMonitor
from .callbacks import LoggingCallback
from .device import DeviceManager
logger = logging.getLogger(__name__)
[docs]
def get_checkpoint_file(results_dir: pathlib.Path) -> pathlib.Path | None:
"""Return the path of the latest checkpoint if it exists.
Parameters
----------
results_dir
Directory in which results are saved.
Returns
-------
Path to the latest checkpoint
"""
from ..utils.checkpointer import get_checkpoint_to_resume_training
checkpoint_file = None
if results_dir.is_dir():
try:
checkpoint_file = get_checkpoint_to_resume_training(results_dir)
except FileNotFoundError:
logger.info(
f"Folder {results_dir} already exists, but I did not"
f" find any usable checkpoint file to resume training"
f" from. Starting from scratch...",
)
return checkpoint_file
[docs]
def load_checkpoint(
checkpoint_file: pathlib.Path | None,
datamodule: ConcatDataModule,
model: torch.nn.Module,
):
"""Load the checkpoint.
Parameters
----------
checkpoint_file
Path to the checkpoint.
datamodule
Instance of a Datamodule, used to set the model's normalizer.
model
The model corresponding to the checkpoint.
"""
if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
# Sets the model normalizer with the unaugmented-train-subset if we are
# starting from scratch and/or the model does not contain its own
# checkpoint loading strategy (e.g. a pytorch stock checkpoint). This
# call may be a NOOP, if the model comes from outside this framework,
# and expects different weights for the normalisation layer.
if isinstance(model, Model):
if model.normalizer is None:
model.set_normalizer(datamodule.unshuffled_train_dataloader())
else:
logger.warning(
f"Model {model.name} has no `set_normalizer` method. "
"Skipping normalization setup (unsupported external model).",
)
else:
# Normalizer will be loaded during model.on_load_checkpoint
checkpoint = torch.load(checkpoint_file)
start_epoch = checkpoint["epoch"]
logger.info(
f"Resuming from epoch {start_epoch} "
f"(checkpoint file: `{str(checkpoint_file)}`)...",
)
[docs]
def setup_datamodule(
datamodule: ConcatDataModule,
model: Model,
batch_size: int,
drop_incomplete_batch: bool,
cache_samples: bool,
parallel: int,
) -> None: # numpydoc ignore=PR01
"""Configure and set up the datamodule."""
datamodule.batch_size = batch_size
datamodule.drop_incomplete_batch = drop_incomplete_batch
datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="fit")
[docs]
def reset_num_outputs_on_model(
dataloader: torch.utils.data.DataLoader,
model: Model,
) -> None:
"""Reset the number of outputs in a model if need be.
Parameters
----------
dataloader
The dataloader to use for verifying the number of targets, the the
function correctly ponders the number of outputs required by the model.
model
The model that will have its number of outputs potentially modified.
"""
if isinstance(dataloader.dataset, Dataset):
sample = dataloader.dataset[0]
else:
sample = next(iter(dataloader))
num_classes = sample["target"].shape[0]
if num_classes != model.num_classes:
logger.warning(
f"Resetting number-of-output-classes at `{model.name}` model from "
f"{model.num_classes} to {num_classes}."
)
model.num_classes = num_classes
[docs]
def validate_model_datamodule(model: Model, datamodule: ConcatDataModule):
"""Validate the use of a model and datamodule together.
Parameters
----------
model
The model to be validated.
datamodule
The datamodule to be validated.
Raises
------
TypeError
In case the types of both objects is not compatible.
"""
from ..models.classify.model import Model as ClassificationModel
from ..models.segment.model import Model as SegmentationModel
# asserts data module and model are compatible
match model:
case ClassificationModel():
if datamodule.task != "classification":
raise TypeError(
f"Classification model `{model.name}` is incompatible with "
f"`{datamodule.task}` task from datamodule "
f"`{datamodule.database_name}`."
)
case SegmentationModel():
if datamodule.task != "segmentation":
raise TypeError(
f"Segmentation model `{model.name}` is incompatible with "
f"`{datamodule.task}` task from datamodule "
f"`{datamodule.database_name}`."
)
case _:
raise TypeError(f"Do not know how to handle model of type `{type(model)}`")
[docs]
def run(
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
validation_period: int,
device_manager: DeviceManager,
max_epochs: int,
output_folder: pathlib.Path,
monitoring_interval: int | float,
accumulate_grad_batches: int,
checkpoint: pathlib.Path | None,
):
"""Fit a CNN model using supervised learning and save it to disk.
This method supports periodic checkpointing and the output of a
tensorboard-formatted log with the evolution of some figures during training.
Parameters
----------
model
Neural network model (e.g. pasa).
datamodule
The lightning DataModule to use for training **and** validation.
validation_period
Number of epochs after which validation happens. By default, we run
validation after every training epoch (period=1). You can change this
to make validation more sparse, by increasing the validation period.
Notice that this affects checkpoint saving. While checkpoints are
created after every training step (the last training step always
triggers the overriding of latest checkpoint), and that this process is
independent of validation runs, evaluation of the 'best' model obtained
so far based on those will be influenced by this setting.
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a lightning accelerator setup.
max_epochs
The maximum number of epochs to train for.
output_folder
Folder in which the results will be saved.
monitoring_interval
Interval, in seconds (or fractions), through which we should monitor
resources during training.
accumulate_grad_batches
Number of accumulations for backward propagation to accumulate gradients
over k batches before stepping the optimizer. The default of 1 forces
the whole batch to be processed at once. Otherwise the batch is multiplied
by accumulate-grad-batches pieces, and gradients are accumulated to complete
each step. This is especially interesting when one is training on GPUs with
a limited amount of onboard RAM.
checkpoint
Path to an optional checkpoint file to load.
"""
output_folder.mkdir(parents=True, exist_ok=True)
from .loggers import CustomTensorboardLogger
log_dir = "logs"
tensorboard_logger = CustomTensorboardLogger(
output_folder,
log_dir,
)
logger.info(
f"Monitor training with `tensorboard serve "
f"--logdir={output_folder}/{log_dir}/`. "
f"Then, open a browser on the printed address.",
)
resource_monitor = ResourceMonitor(
interval=monitoring_interval,
device_type=device_manager.device_type,
main_pid=os.getpid(),
)
# This checkpointer will operate at the end of every validation epoch
# (which happens at each checkpoint period), it will then save the lowest
# validation loss model observed. It will also save the last trained model
checkpoint_minvalloss_callback = lightning.pytorch.callbacks.ModelCheckpoint(
dirpath=output_folder,
filename=CHECKPOINT_ALIASES["best"],
save_last=True, # will (re)create the last trained model, at every iteration
monitor="loss/validation",
mode="min",
save_on_train_epoch_end=True,
every_n_epochs=validation_period, # frequency at which it checks the "monitor"
enable_version_counter=False, # no versioning of aliased checkpoints
)
checkpoint_minvalloss_callback.CHECKPOINT_NAME_LAST = CHECKPOINT_ALIASES[ # type: ignore
"periodic"
]
# This context guarantees that the start/stop of the monitoring thread will work
# irrespectively of exceptions thrown by the code
with resource_monitor:
accelerator, devices = device_manager.lightning_accelerator()
trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=max_epochs,
accumulate_grad_batches=accumulate_grad_batches,
logger=tensorboard_logger,
check_val_every_n_epoch=validation_period,
log_every_n_steps=len(datamodule.train_dataloader()),
callbacks=[
LoggingCallback(resource_monitor),
checkpoint_minvalloss_callback,
],
)
checkpoint_str = checkpoint if checkpoint is None else str(checkpoint)
_ = trainer.fit(model, datamodule, ckpt_path=checkpoint_str)