Source code for mednet.engine.callbacks

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Lightning callbacks to log custom measurements."""

import logging
import time
import typing
import warnings

import lightning.pytorch
import lightning.pytorch.callbacks
import torch

from ..utils.resources import ResourceMonitor, aggregate

logger = logging.getLogger(__name__)


[docs] class LoggingCallback(lightning.pytorch.Callback): """Callback to log various training metrics and device information. Rationale: 1. Losses are logged at the end of every batch, accumulated and handled by the lightning framework. 2. Everything else is done at the end of a training or validation epoch and mostly concerns runtime metrics such as memory and cpu/gpu utilisation. Parameters ---------- resource_monitor A monitor that watches resource usage (CPU/GPU) in a separate process and totally asynchronously with the code execution. """ def __init__( self, resource_monitor: ResourceMonitor, ): super().__init__() # timers self._start_training_time = time.time() self._start_training_epoch_time = time.time() self._start_validation_epoch_time = time.time() # log accumulators for a single flush at each training cycle self._to_log: dict[str, float] = {} # helpers for CPU and GPU utilisation self._resource_monitor = resource_monitor self._max_queue_retries = 2
[docs] def on_train_start( self, trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ): """Execute actions when training starts (lightning callback). This method is executed whenever you *start* training a module. Parameters ---------- trainer The Lightning trainer object. pl_module The lightning module that is being trained. """ self._start_training_time = time.time()
[docs] def on_train_end( self, trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ): """Execute actions when training ends (lightning callback). This method is executed whenever you *end* training a module. Parameters ---------- trainer The Lightning trainer object. pl_module The lightning module that is being trained. """ pass
[docs] def on_train_epoch_start( self, trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ) -> None: """Execute actions when a training epoch starts (lightning callback). This method is executed whenever a training epoch starts. Presumably, batches happen as often as possible. You want to make this code very fast. Do not log things to the terminal or the such, or do complicated (lengthy) calculations. .. warning:: This is executed **while** you are training. Be very succint or face the consequences of slow training! Parameters ---------- trainer The Lightning trainer object. pl_module The lightning module that is being trained. """ # summarizes resource usage since the last checkpoint # clears internal buffers and starts accumulating again. self._resource_monitor.clear() self._start_training_epoch_time = time.time()
[docs] def on_train_epoch_end( self, trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ): """Execute actions after a training epoch ends (lightning callback). This method is executed whenever a training epoch ends. Presumably, epochs happen as often as possible. You want to make this code relatively fast to avoid significative runtime slow-downs. Parameters ---------- trainer The Lightning trainer object. pl_module The lightning module that is being trained. """ # evaluates this training epoch total time, and log it epoch_time = time.time() - self._start_training_epoch_time self._to_log["epoch-duration-seconds/train"] = epoch_time if len(pl_module.trainer.lr_scheduler_configs) == 0: self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore else: self._to_log["learning-rate"] = pl_module.trainer.lr_scheduler_configs[ 0 ].scheduler.get_last_lr()[0] # type: ignore overall_cycle_time = time.time() - self._start_training_epoch_time self._to_log["cycle-time-seconds/train"] = overall_cycle_time self._to_log["total-execution-time-seconds"] = ( time.time() - self._start_training_time ) self._to_log["eta-seconds"] = overall_cycle_time * ( trainer.max_epochs - trainer.current_epoch # type: ignore ) # the "step" is the tensorboard jargon for "epoch" or "batch", # depending on how we are logging - in a more general way, it simply # means the relative time step. self._to_log["step"] = float(trainer.current_epoch) # Do not log during sanity check as results are not relevant if not trainer.sanity_checking: pl_module.log_dict(self._to_log) self._to_log = {}
[docs] def on_train_batch_end( self, trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, outputs: typing.Mapping[str, torch.Tensor], batch: typing.Mapping[str, typing.Any], batch_idx: int, ) -> None: """Execute actions after a training batch ends (lightning callback). This method is executed whenever a training batch ends. Presumably, batches happen as often as possible. You want to make this code very fast. Do not log things to the terminal or the such, or do complicated (lengthy) calculations. .. warning:: This is executed **while** you are training. Be very succint or face the consequences of slow training! Parameters ---------- trainer The Lightning trainer object. pl_module The lightning module that is being trained. outputs The outputs of the module's ``training_step``. batch The data that the training step received. batch_idx The relative number of the batch. """ batch_size = batch["image"].shape[0] pl_module.log( "loss/train", outputs["loss"].item(), prog_bar=True, on_step=False, on_epoch=True, batch_size=batch_size, )
[docs] def on_validation_epoch_start( self, trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ) -> None: """Execute actions before a validation batch starts (lightning callback). This method is executed whenever a validation batch starts. Presumably, batches happen as often as possible. You want to make this code very fast. Do not log things to the terminal or the such, or do complicated (lengthy) calculations. .. warning:: This is executed **while** you are training. Be very succint or face the consequences of slow training! Parameters ---------- trainer The Lightning trainer object. pl_module The lightning module that is being trained. """ current_timestamp = time.time() # required because the validation epoch is started **within** the # training epoch START/END. # summarizes resource usage since the last checkpoint metrics = self._resource_monitor.checkpoint() # filter out samples not acquired during the training epoch metrics = [ k for k in metrics if k["timestamp"] > self._start_training_epoch_time and k["timestamp"] < current_timestamp ] if not metrics: warnings.warn( f"CPU/GPU monitor worker did not provide any measures. " f"This can be due to the length of the monitoring interval, " f"currently set to {self._resource_monitor.interval}s, " f"compared to the time I've been accumulating measurements " f"({(time.time() - self._start_training_epoch_time):.1f}s). " f"A way this can be fixed is by reducing the monitoring " f"interval to a suitable value, so it allows some measures " f"to be performed. Note this is only possible if the time " f"to log a single measurement point is smaller than the " f"time it takes to **train** a single epoch." ) else: for metric_name, metric_value in aggregate(metrics).items(): self._to_log[f"{metric_name}/train"] = float(metric_value) self._start_validation_epoch_time = time.time() self._resource_monitor.clear()
[docs] def on_validation_epoch_end( self, trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, ) -> None: """Execute actions after a validation batch ends (lightning callback). This method is executed whenever a validation epoch ends. Presumably, epochs happen as often as possible. You want to make this code relatively fast to avoid significative runtime slow-downs. Parameters ---------- trainer The Lightning trainer object. pl_module The lightning module that is being trained. """ current_timestamp = time.time() # summarizes resource usage since the last checkpoint # clears internal buffers and starts accumulating again. metrics = self._resource_monitor.checkpoint() # filter out samples not acquired during the validation epoch metrics = [ k for k in metrics if k["timestamp"] > self._start_training_epoch_time and k["timestamp"] < current_timestamp ] if not metrics: warnings.warn( f"CPU/GPU monitor worker did not provide any measures. " f"This can be due to the length of the monitoring interval, " f"currently set to {self._resource_monitor.interval}s, " f"compared to the time I've been accumulating measurements " f"({(time.time() - self._start_validation_epoch_time):.1f}s). " f"A way this can be fixed is by reducing the monitoring " f"interval to a suitable value, so it allows some measures " f"to be performed. Note this is only possible if the time " f"to log a single measurement point is smaller than the " f"time it takes to **validate** a single epoch." ) else: for metric_name, metric_value in aggregate(metrics).items(): self._to_log[f"{metric_name}/validation"] = float(metric_value) epoch_time = time.time() - self._start_validation_epoch_time self._to_log["epoch-duration-seconds/validation"] = epoch_time self._to_log["step"] = float(trainer.current_epoch) # Do not log during sanity check as results are not relevant if not trainer.sanity_checking: pl_module.log_dict(self._to_log) self._to_log = {}
[docs] def on_validation_batch_end( self, trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, outputs: torch.Tensor, batch: typing.Mapping[str, typing.Any], batch_idx: int, dataloader_idx: int = 0, ) -> None: """Execute actions after a validation after ends (lightning callback). This method is executed whenever a validation batch ends. Presumably, batches happen as often as possible. You want to make this code very fast. Do not log things to the terminal or the such, or do complicated (lengthy) calculations. .. warning:: This is executed **while** you are training. Be very succint or face the consequences of slow training! Parameters ---------- trainer The Lightning trainer object. pl_module The lightning module that is being trained. outputs The outputs of the module's ``training_step``. batch The data that the training step received. batch_idx The relative number of the batch. dataloader_idx Index of the dataloader used during validation. Use this to figure out which dataset was used for this validation epoch. """ if dataloader_idx == 0: key = "loss/validation" else: key = f"loss/validation-{dataloader_idx}" batch_size = batch["image"].shape[0] pl_module.log( key, outputs.item(), prog_bar=False, on_step=False, on_epoch=True, batch_size=batch_size, add_dataloader_idx=False, )