Source code for mednet.data.datamodule

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Extension of ``lightning.LightningDataModule`` with dictionary split loading, mini-batching, parallelisation and caching."""

import functools
import itertools
import logging
import sys
import typing

import lightning
import loky
import torch
import torch.backends
import torch.utils.data
import torchvision.transforms
import torchvision.tv_tensors
import tqdm

from .typing import (
    ConcatDatabaseSplit,
    DatabaseSplit,
    DataLoader,
    Dataset,
    RawDataLoader,
    Sample,
    TransformSequence,
)

logger = logging.getLogger(__name__)


def _sample_size_bytes(dataset: Dataset):
    """Recurse into the first sample of a dataset and figures out its total occupance in bytes.

    Parameters
    ----------
    dataset
        The dataset containing the samples to load.
    """

    def _tensor_size_bytes(t: torch.Tensor, n: str) -> int:
        """Return a tensor size in bytes.

        Parameters
        ----------
        t
            A torch Tensor.
        n
            Name of the object.

        Returns
        -------
        int
            The size of the Tensor in bytes.
        """

        logger.info(f"`{n}`: {list(t.shape)}@{t.dtype}")
        return int(t.element_size() * t.shape.numel())

    def _dict_size_bytes(d):
        """Return a dictionary size in bytes.

        Parameters
        ----------
        d
            A dictionary.

        Returns
        -------
        int
            The size of the dictionary in bytes.
        """

        size = 0
        for k, v in d.items():
            if isinstance(v, torch.Tensor):
                size += _tensor_size_bytes(v, k)

        return size

    first_sample = dataset[0]
    size = sys.getsizeof(first_sample)  # measures size of all pythonic objects
    size += _dict_size_bytes(first_sample)  # adds torch tensor sizes

    sample_size_mb = size / (1024.0 * 1024.0)
    logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")


def _apply_iff_tv_tensor(
    data: typing.Any, transform: typing.Callable[[torch.Tensor], torch.Tensor]
):
    """Apply model_transform iff input object is TVTensor.

    Parameters
    ----------
    data
        Sample data to which apply the provided transform.
    transform
        Callable containing a single or a composition of transforms to
        potentially apply to ``data``.

    Returns
    -------
        The transformed version of ``data`` iff applicable.
    """
    if isinstance(data, torchvision.tv_tensors.TVTensor):
        return transform(data)
    return data


class _DelayedLoadingDataset(Dataset):
    """A list that loads its samples on demand.

    This list mimics a pytorch Dataset, except that raw data loading is done
    on-the-fly, as the samples are requested through the bracket operator.

    Parameters
    ----------
    raw_dataset
        An iterable containing the raw dataset samples representing one of the
        database split datasets.
    loader
        An object instance that can load samples from storage.
    transforms
        A set of transforms that should be applied on-the-fly for this dataset,
        to fit the output of the raw-data-loader to the model of interest.
    disable_pbar
        If set, disables progress bars.
    """

    def __init__(
        self,
        raw_dataset: typing.Sequence[typing.Any],
        loader: RawDataLoader,
        transforms: TransformSequence = [],
        disable_pbar: bool = False,
    ):
        self.raw_dataset = raw_dataset
        self.loader = loader
        self.transform = torchvision.transforms.Compose(transforms)
        self.disable_pbar = disable_pbar

        _sample_size_bytes(self)

    def targets(self) -> list[torch.Tensor]:
        """Return the targets for all samples in the dataset.

        Returns
        -------
            The targets for all samples in the dataset.
        """

        return [
            self.loader.target(k)
            for k in tqdm.tqdm(
                self.raw_dataset, unit="sample", disable=self.disable_pbar
            )
        ]

    def __getitem__(self, key: int) -> Sample:
        sample = self.loader.sample(self.raw_dataset[key])
        return {k: _apply_iff_tv_tensor(v, self.transform) for k, v in sample.items()}

    def __len__(self):
        return len(self.raw_dataset)

    def __iter__(self):
        for x in range(len(self)):
            yield self[x]


def _apply_loader_and_transforms(
    info: typing.Any,
    load: typing.Callable[[typing.Any], Sample],
    model_transform: typing.Callable[[torch.Tensor], torch.Tensor],
) -> Sample:
    """Local wrapper to apply raw-data loading and transformation in a single
    step.

    Parameters
    ----------
    info
        The sample information, as loaded from its raw dataset dictionary.
    load
        The raw-data loader function to use for loading the sample.
    model_transform
        A callable that will transform the loaded tensor into something
        suitable for the model it will train.  Typically, this will be a
        composed transform.

    Returns
    -------
    Sample
        The loaded and transformed sample.
    """
    sample = load(info)
    return {k: _apply_iff_tv_tensor(v, model_transform) for k, v in sample.items()}


[docs] class CachedDataset(Dataset): """Basically, a list of preloaded samples. This dataset will load all samples from the raw dataset during construction instead of delaying that to the indexing. Beyond raw-data-loading, ``transforms`` given upon construction contribute to the cached samples. Parameters ---------- raw_dataset An iterable containing the raw dataset samples representing one of the database split datasets. loader An object instance that can load samples and targets from storage. transforms A set of transforms that should be applied to the cached samples for this dataset, to fit the output of the raw-data-loader to the model of interest. parallel Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. disable_pbar If set, disables progress bars. """ def __init__( self, raw_dataset: typing.Sequence[typing.Any], loader: RawDataLoader, transforms: TransformSequence = [], parallel: int = -1, disable_pbar: bool = False, ): self.loader = functools.partial( _apply_loader_and_transforms, load=loader.sample, model_transform=torchvision.transforms.Compose(transforms), ) if parallel < 0: self.data = [ self.loader(k) for k in tqdm.tqdm(raw_dataset, unit="sample", disable=disable_pbar) ] else: instances = parallel or torch.multiprocessing.cpu_count() logger.info(f"Caching dataset using {instances} processes...") # loky executor replaces torch.multiprocessing.Pool, but uses cloudpickle # for more robust data exchange between main process and workers. with loky.ProcessPoolExecutor(max_workers=instances) as executor: # submit all tasks _ = { executor.submit(self.loader, sample): sample for sample in raw_dataset } mapped = executor.map(self.loader, raw_dataset) self.data = list( tqdm.tqdm( mapped, total=len(raw_dataset), disable=disable_pbar, unit="sample", ) ) _sample_size_bytes(self)
[docs] def targets(self) -> list[torch.Tensor]: """Return the targets for all samples in the dataset. Returns ------- The targets for all samples in the dataset. """ return [k["target"] for k in self.data]
def __getitem__(self, key: int) -> Sample: return self.data[key] def __len__(self): return len(self.data) def __iter__(self): yield from self.data
[docs] class ConcatDataset(Dataset): """A dataset that represents a concatenation of other cached or delayed datasets. Parameters ---------- datasets An iterable over pre-instantiated datasets. """ def __init__(self, datasets: typing.Sequence[Dataset]): self._datasets = datasets self._indices = [ (i, j) # dataset relative position, sample relative position for i in range(len(datasets)) for j in range(len(datasets[i])) ]
[docs] def targets(self) -> list[torch.Tensor]: """Return the targets for all samples in the dataset. Returns ------- The targets for all samples in the dataset. """ return list(itertools.chain(*[k.targets() for k in self._datasets]))
def __getitem__(self, key: int) -> Sample: i, j = self._indices[key] return self._datasets[i][j] def __len__(self): return sum([len(k) for k in self._datasets]) def __iter__(self): for dataset in self._datasets: yield from dataset
[docs] class ConcatDataModule(lightning.LightningDataModule): """A conveninent DataModule with dictionary split loading, mini- batching, parallelisation and caching, all in one. Instances of this class can load and concatenate an arbitrary number of data-split (a.k.a. protocol) definitions for (possibly disjoint) databases, and can manage raw data-loading from disk. An optional caching mechanism stores the data in associated CPU memory, which can improve data serving while training and evaluating models. This DataModule defines basic operations to handle data loading and mini-batch handling within this package's framework. It can return :py:class:`torch.utils.data.DataLoader` objects for training, validation, prediction and testing conditions. Parallelisation is handled by a simple input flag. Parameters ---------- splits A dictionary that contains string keys representing dataset names, and values that are iterables over a 2-tuple containing an iterable over arbitrary, user-configurable sample representations (potentially on disk or permanent storage), and :py:class:`.data.typing.RawDataLoader` (or "sample") loader objects, which concretely implement a mechanism to load such samples in memory, from permanent storage. Sample representations on permanent storage may be of any iterable format (e.g. list, dictionary, etc.), for as long as the assigned :py:class:`.data.typing.RawDataLoader` can properly handle it. .. tip:: To check the split and that the loader function works correctly, you may use :py:func:`.split.check_database_split_loading`. This class expects at least one entry called ``train`` to exist in the input dictionary. Optional entries are ``validation``, and ``test``. Entries named ``monitor-...`` will be considered extra datasets that do not influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. database_name The name of the database, or aggregated database containing the raw-samples served by this data module. split_name The name of the split used to group the samples into the various datasets for training, validation and testing. task The task this datamodule generate samples for (e.g. ``classification``, ``segmentation``, or ``detection``). num_classes The number of target classes samples of this datamodule can have. In a classification task, this will dictate the number of outputs for the classifier (one-hot-encoded), the number of segmentation outputs for a semantic segmentation network, or the types of objects in an object detector. collate_fn A custom function to batch the samples. Uses torch.utils.data.default_collate() by default. cache_samples If set, then issue raw data loading during ``prepare_data()``, and serves samples from CPU memory. Otherwise, loads samples from disk on demand. Running from CPU memory will offer increased speeds in exchange for CPU memory. Sufficient CPU memory must be available before you set this attribute to ``True``. It is typically useful for relatively small datasets. batch_size Number of samples in every **training** batch (this parameter affects memory requirements for the network). If the number of samples in the batch is larger than the total number of samples available for training, this value is truncated. If this number is smaller, then batches of the specified size are created and fed to the network until there are no more new samples to feed (epoch is finished). If the total number of training samples is not a multiple of the batch-size, the last batch will be smaller than the first, unless ``drop_incomplete_batch`` is set to ``true``, in which case this batch is not used. drop_incomplete_batch If set, then may drop the last batch in an epoch in case it is incomplete. If you set this option, you should also consider increasing the total number of training epochs, as the total number of training steps may be reduced. parallel Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. """ DatasetDictionary: typing.TypeAlias = dict[str, Dataset] """A dictionary of datasets mapping names to actual datasets.""" def __init__( self, splits: ConcatDatabaseSplit, database_name: str = "", split_name: str = "", task: str = "", num_classes: int = 1, collate_fn=torch.utils.data.default_collate, cache_samples: bool = False, batch_size: int = 1, drop_incomplete_batch: bool = False, parallel: int = -1, ): super().__init__() self.splits = splits self.database_name = database_name self.split_name = split_name self.task = task self.num_classes = num_classes self.collate_fn = collate_fn for dataset_name, split_loaders in splits.items(): count = sum([len(k) for k, _ in split_loaders]) logger.info( f"Dataset `{dataset_name}` (`{database_name}`/`{split_name}`) " f"contains {count} samples", ) self.cache_samples = cache_samples self._model_transforms: TransformSequence | None = None self.batch_size = batch_size self.drop_incomplete_batch = drop_incomplete_batch self.parallel = parallel # immutable, otherwise would need to call self.pin_memory = ( torch.cuda.is_available() or torch.backends.mps.is_available() # type: ignore ) # should only be true if GPU available and using it # datasets that have been setup() for the current stage self._datasets: ConcatDataModule.DatasetDictionary = {} @property def parallel(self) -> int: """Whether to use multiprocessing for data loading. Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. It sets the parameter ``num_workers`` (from DataLoaders) to match the expected pytorch representation. It also sets the ``multiprocessing_context`` to use ``spawn`` instead of the default (``fork``, on Linux). The mapping between the command-line interface ``parallel`` setting works like this: .. list-table:: Relationship between ``parallel`` and DataLoader parameters :widths: 15 15 70 :header-rows: 1 * - CLI ``parallel`` - :py:class:`torch.utils.data.DataLoader` ``kwargs`` - Comments * - ``<0`` - 0 - Disables multiprocessing entirely, executes everything within the same processing context * - ``0`` - :py:func:`multiprocessing.cpu_count` - Runs mini-batch data loading on as many external processes as CPUs available in the current machine * - ``>=1`` - ``parallel`` - Runs mini-batch data loading on as many external processes as set on ``parallel`` Returns ------- int The value of self._parallel. """ return self._parallel @parallel.setter def parallel(self, value: int) -> None: self._dataloader_multiproc: dict[str, typing.Any] = {} self._parallel = value if value < 0: num_workers = 0 else: num_workers = value or torch.multiprocessing.cpu_count() self._dataloader_multiproc["num_workers"] = num_workers if num_workers > 0: self._dataloader_multiproc["multiprocessing_context"] = "spawn" # keep workers hanging around if we have multiple if value >= 0: self._dataloader_multiproc["persistent_workers"] = True @property def model_transforms(self) -> TransformSequence | None: """Transform required to fit data into the model. A list of transforms (torch modules) that will be applied after raw-data-loading. and just before data is fed into the model or eventual data-augmentation transformations for all data loaders produced by this DataModule. This part of the pipeline receives data as output by the raw-data-loader, or model-related transforms (e.g. resize adaptions), if any is specified. If data is cached, it is cached **after** model-transforms are applied, as that is a potential memory saver (e.g., if it contains a resizing operation to smaller images). Returns ------- list A list containing the model tansforms. """ return self._model_transforms @model_transforms.setter def model_transforms(self, value: TransformSequence | None): old_value = self._model_transforms if value is None: self._model_transforms = value else: self._model_transforms = list(value) # datasets that have been setup() for the current stage are reset if value != old_value and len(self._datasets): logger.warning( f"Resetting {len(self._datasets)} loaded datasets due " "to changes in model-transform properties. If you were caching " "data loading, this will (eventually) trigger a reload.", ) self._datasets = {} def _setup_dataset(self, name: str) -> None: """Set up a single dataset from the input data split. Parameters ---------- name Name of the dataset to setup. """ if self.model_transforms is None: raise RuntimeError( "Parameter `model_transforms` has not yet been " "set. If you do not have model transforms, then " "set it to an empty list.", ) if name in self._datasets: logger.info( f"Dataset `{name}` is already setup. Not re-instantiating it.", ) return datasets: list[CachedDataset | _DelayedLoadingDataset] = [] if self.cache_samples: logger.info( f"Loading dataset:`{name}` into memory (caching)." f" Trade-off: CPU RAM usage: more | Disk I/O: less", ) for split, loader in self.splits[name]: datasets.append( CachedDataset(split, loader, self.model_transforms, self.parallel) ) else: logger.info( f"Loading dataset:`{name}` without caching." f" Trade-off: CPU RAM usage: less | Disk I/O: more", ) for split, loader in self.splits[name]: datasets.append( _DelayedLoadingDataset(split, loader, self.model_transforms) ) if len(datasets) == 1: self._datasets[name] = datasets[0] else: self._datasets[name] = ConcatDataset(datasets)
[docs] def val_dataset_keys(self) -> list[str]: """Return list of validation dataset names. Returns ------- list[str] The list of validation dataset names. """ validation_split_name = "validation" if "validation" not in self.splits.keys(): logger.warning( "No split named 'validation', the training split will be used for validation instead." ) validation_split_name = "train" return [validation_split_name] + [ k for k in self.splits.keys() if k.startswith("monitor-") ]
[docs] def setup(self, stage: str) -> None: """Set up datasets for different tasks on the pipeline. This method should setup (load, pre-process, etc) all datasets required for a particular ``stage`` (fit, validate, test, predict), and keep them ready to be used on one of the `_dataloader()` functions that are pertinent for such stage. If you have set ``cache_samples``, samples are loaded at this stage and cached in memory. Parameters ---------- stage Name of the stage in which the setup is applicable. Can be one of ``fit``, ``validate``, ``test`` or ``predict``. Each stage typically uses the following data loaders: * ``fit``: uses both train and validation datasets * ``validate``: uses only the validation dataset * ``test``: uses only the test dataset * ``predict``: uses only the test dataset """ if stage == "fit": for k in ["train"] + self.val_dataset_keys(): self._setup_dataset(k) elif stage == "validate": for k in self.val_dataset_keys(): self._setup_dataset(k) elif stage == "test": self._setup_dataset("test") elif stage == "predict": for k in self.splits: self._setup_dataset(k)
[docs] def teardown(self, stage: str) -> None: """Unset-up datasets for different tasks on the pipeline. This method unsets (unload, remove from memory, etc) all datasets required for a particular ``stage`` (fit, validate, test, predict). If you have set ``cache_samples``, samples are loaded and this may effectivley release all the associated memory. Parameters ---------- stage Name of the stage in which the teardown is applicable. Can be one of ``fit``, ``validate``, ``test`` or ``predict``. Each stage typically uses the following data loaders: * ``fit``: uses both train and validation datasets * ``validate``: uses only the validation dataset * ``test``: uses only the test dataset * ``predict``: uses only the test dataset """ super().teardown(stage)
[docs] def train_dataloader(self) -> DataLoader: """Return the train data loader. Returns ------- The train data loader(s). """ return torch.utils.data.DataLoader( self._datasets["train"], shuffle=True, batch_size=self.batch_size, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, collate_fn=self.collate_fn, **self._dataloader_multiproc, )
[docs] def unshuffled_train_dataloader(self) -> DataLoader: """Return the train data loader without shuffling. Returns ------- The train data loader without shuffling. """ return torch.utils.data.DataLoader( self._datasets["train"], shuffle=False, batch_size=self.batch_size, drop_last=False, collate_fn=self.collate_fn, **self._dataloader_multiproc, )
[docs] def val_dataloader(self) -> dict[str, DataLoader]: """Return the validation data loader(s). Returns ------- The validation data loader(s). """ validation_loader_opts = { "batch_size": self.batch_size, "shuffle": False, "drop_last": self.drop_incomplete_batch, "pin_memory": self.pin_memory, } validation_loader_opts.update(self._dataloader_multiproc) return { k: torch.utils.data.DataLoader( self._datasets[k], collate_fn=self.collate_fn, **validation_loader_opts, ) for k in self.val_dataset_keys() }
[docs] def test_dataloader(self) -> dict[str, DataLoader]: """Return the test data loader(s). Returns ------- The test data loader(s). """ return dict( test=torch.utils.data.DataLoader( self._datasets["test"], batch_size=self.batch_size, shuffle=False, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, collate_fn=self.collate_fn, **self._dataloader_multiproc, ), )
[docs] def predict_dataloader(self) -> dict[str, DataLoader]: """Return the prediction data loader(s). Returns ------- The prediction data loader(s). """ return { k: torch.utils.data.DataLoader( self._datasets[k], batch_size=self.batch_size, shuffle=False, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, collate_fn=self.collate_fn, **self._dataloader_multiproc, ) for k in self._datasets }
[docs] class CachingDataModule(ConcatDataModule): """A simplified version of our DataModule for a single split. Apart from construction, the behaviour of this DataModule is very similar to its simpler counterpart, serving training, validation and test sets. Parameters ---------- database_split A dictionary that contains string keys representing dataset names, and values that are iterables over sample representations (potentially on disk). These objects are passed to an unique :py:class:`.data.typing.RawDataLoader` for loading the :py:data:`.typing.Sample` data (and metadata) in memory. It therefore assumes the whole split is homogeneous and can be loaded in the same way. .. tip:: To check the split and the loader function works correctly, you may use :py:func:`.split.check_database_split_loading`. This class expects at least one entry called ``train`` to exist in the input dictionary. Optional entries are ``validation``, and ``test``. Entries named ``monitor-...`` will be considered extra datasets that do not influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. raw_data_loader An object instance that can load samples from storage. **kwargs List of named parameters matching those of :py:class:`ConcatDataModule`, other than ``splits``. """ def __init__( self, database_split: DatabaseSplit, raw_data_loader: RawDataLoader, **kwargs, ): splits = {k: [(v, raw_data_loader)] for k, v in database_split.items()} super().__init__( splits=splits, **kwargs, )