# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Prediction engine for classification tasks."""
import logging
import typing
import lightning.pytorch
import torch.utils.data
from ...engine.device import DeviceManager
from ...models.classify.typing import (
BinaryPrediction,
BinaryPredictionSplit,
MultiClassPrediction,
MultiClassPredictionSplit,
)
logger = logging.getLogger(__name__)
class _JSONMetadataCollector(lightning.pytorch.callbacks.BasePredictionWriter):
"""Collects further sample metadata to store with predictions.
This object collects further sample metadata we typically keep with
predictions.
Parameters
----------
write_interval
When will this callback be active.
"""
def __init__(
self,
write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"] = "batch",
):
super().__init__(write_interval=write_interval)
self._data: list[BinaryPrediction] | list[MultiClassPrediction] = []
def write_on_batch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
prediction: typing.Any,
batch_indices: typing.Sequence[int] | None,
batch: typing.Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Write batch predictions to disk.
Parameters
----------
trainer
The trainer being used.
pl_module
The pytorch module.
prediction
The actual predictions to record.
batch_indices
The relative position of samples on the epoch.
batch
The current batch.
batch_idx
Index of the batch overall.
dataloader_idx
Index of the dataloader overall.
"""
for k, sample_pred in enumerate(prediction):
sample_name: str = batch[1]["name"][k]
target_shape = batch[1]["target"][k].shape
self._data.append(
(
sample_name,
batch[1]["target"][k].cpu().numpy().tolist(),
sample_pred.cpu().numpy().reshape(target_shape).tolist(),
)
)
def reset(self) -> list[BinaryPrediction] | list[MultiClassPrediction]:
"""Summary of written objects.
Also resets the internal state.
Returns
-------
A list containing a summary of all samples written.
"""
retval = self._data
self._data = []
return retval
[docs]
def run(
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
) -> (
list[BinaryPrediction]
| list[MultiClassPrediction]
| list[list[BinaryPrediction]]
| list[list[MultiClassPrediction]]
| BinaryPredictionSplit
| MultiClassPredictionSplit
| None
):
"""Run inference on input data, output predictions.
Parameters
----------
model
Neural network model (e.g. pasa).
datamodule
The lightning DataModule to use for training **and** validation.
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a lightning accelerator setup.
Returns
-------
(
list[BinaryPrediction]
| list[MultiClassPrediction]
| list[list[BinaryPrediction]]
| list[list[MultiClassPrediction]]
| BinaryPredictionSplit
| MultiClassPredictionSplit
| None
)
Depending on the return type of the DataModule's
``predict_dataloader()`` method:
* if :py:class:`torch.utils.data.DataLoader`, then returns a
:py:class:`list` of predictions.
* if :py:class:`list` of :py:class:`torch.utils.data.DataLoader`, then
returns a list of lists of predictions, each list corresponding to
the iteration over one of the dataloaders.
* if :py:class:`dict` of :py:class:`str` to
:py:class:`torch.utils.data.DataLoader`, then returns a dictionary
mapping names to lists of predictions.
* if ``None``, then returns ``None``.
Raises
------
TypeError
If the DataModule's ``predict_dataloader()`` method does not return any
of the types described above.
"""
from lightning.pytorch.loggers.logger import DummyLogger
collector = _JSONMetadataCollector()
accelerator, devices = device_manager.lightning_accelerator()
trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
devices=devices,
logger=DummyLogger(),
callbacks=[collector],
)
dataloaders = datamodule.predict_dataloader()
if isinstance(dataloaders, torch.utils.data.DataLoader):
logger.info("Running prediction on a single dataloader...")
trainer.predict(model, dataloaders, return_predictions=False)
return collector.reset()
if isinstance(dataloaders, list):
retval_list = []
for k, dataloader in enumerate(dataloaders):
logger.info(f"Running prediction on split `{k}`...")
trainer.predict(model, dataloader, return_predictions=False)
retval_list.append(collector.reset())
return retval_list # type: ignore
if isinstance(dataloaders, dict):
retval_dict = {}
for name, dataloader in dataloaders.items():
logger.info(f"Running prediction on `{name}` split...")
trainer.predict(model, dataloader, return_predictions=False)
retval_dict[name] = collector.reset()
return retval_dict # type: ignore
if dataloaders is None:
logger.warning("Datamodule did not return any prediction dataloaders!")
return None
# if you get to this point, then the user is returning something that is
# not supported - complain!
raise TypeError(
f"Datamodule returned strangely typed prediction "
f"dataloaders: `{type(dataloaders)}` - Please write code "
f"to support this use-case.",
)