Source code for mednet.data.typing

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines most common types used in code."""

import abc
import collections.abc
import typing

import torch
import torch.utils.data

Sample: typing.TypeAlias = typing.Mapping[str, typing.Any]
"""Definition of a sample.

A dictionary containing an arbitrary number of keys and values.  Some of the
keys are reserved, others ignored within the framework, and can be re-used to
hold sample metadata required for further analysis.

Reserved keys:

* ``input``: This is typically a 1, 2 or 3D torch float tensor containing the
  input data to be analysed.
* ``target``: This is typically a torch float tensor containing the target the
  network must try to achieve. In the case of classification, it can be a 1D
  tensor containing a single entry (binary classification) or multiple entries
  (multi-class classification).  In the case of semantic segmentation, this
  entry typically contains a float representation of the target mask the
  network must decode from the ``input`` data.
* ``mask``: A torch float tensor containing a mask over which the input (and
  the output) may be ignored.  Typically used in semantic segmentation tasks.
* ``name``: A name for the sample.  Typically set to the name of the file or
  file-stem holding the ``input`` data.
"""


[docs] class RawDataLoader(abc.ABC): """A loader object can load samples from storage."""
[docs] @abc.abstractmethod def sample(self, sample: typing.Any) -> Sample: """Load whole samples from media. Parameters ---------- sample Information about the sample to load. Implementation dependent. Returns ------- The instantiated sample, which is a dictionary where keys name the sample's data and metadata. """ pass
[docs] @abc.abstractmethod def target(self, sample: typing.Any) -> torch.Tensor: """Load only sample target from its raw representation. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing the sample target. Returns ------- The label corresponding to the specified sample, encapsulated as a torch float tensor. """ pass
Transform: typing.TypeAlias = typing.Callable[[torch.Tensor], torch.Tensor] """A callable that transforms tensors into (other) tensors. Typically used in data-processing pipelines inside pytorch. """ TransformSequence: typing.TypeAlias = typing.Sequence[Transform] """A sequence of transforms.""" DatabaseSplit: typing.TypeAlias = collections.abc.Mapping[ str, typing.Sequence[typing.Any], ] """The definition of a database split. A database split maps dataset (subset) names to sequences of objects that, through a :py:class:`RawDataLoader`, eventually becomes a :py:data:`.Sample` in the processing pipeline. """ ConcatDatabaseSplit: typing.TypeAlias = collections.abc.Mapping[ str, typing.Sequence[tuple[typing.Sequence[typing.Any], RawDataLoader]], ] """The definition of a complex database split composed of several other splits. A database split maps dataset (subset) names to sequences of objects that, through a :py:class:`.RawDataLoader`, eventually becomes a :py:data:`.Sample` in the processing pipeline. Objects of this subtype allow the construction of complex splits composed of cannibalized parts of other splits. Each split may be assigned a different :py:class:`.RawDataLoader`. """
[docs] class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): """Our own definition of a pytorch Dataset. We iterate over Sample objects in this case. Our datasets always provide a dunder len method. """
[docs] def targets(self) -> list[torch.Tensor]: """Return the integer targets for all samples in the dataset.""" raise NotImplementedError("You must implement the `targets()` method")
DataLoader: typing.TypeAlias = torch.utils.data.DataLoader[Sample] """Our own augmentation definition of a pytorch DataLoader. We iterate over Sample objects in this case. """