mednet.data.datamodule

Extension of lightning.LightningDataModule with dictionary split loading, mini-batching, parallelisation and caching.

Classes

CachedDataset(raw_dataset, loader[, ...])

Basically, a list of preloaded samples.

CachingDataModule(database_split, ...)

A simplified version of our DataModule for a single split.

ConcatDataModule(splits[, database_name, ...])

A conveninent DataModule with dictionary split loading, mini- batching, parallelisation and caching, all in one.

ConcatDataset(datasets)

A dataset that represents a concatenation of other cached or delayed datasets.

class mednet.data.datamodule.CachedDataset(raw_dataset, loader, transforms=[], parallel=-1, multiprocessing_context=None, disable_pbar=False)[source]

Bases: Dataset

Basically, a list of preloaded samples.

This dataset will load all samples from the raw dataset during construction instead of delaying that to the indexing. Beyond raw-data-loading, transforms given upon construction contribute to the cached samples.

Parameters:
  • raw_dataset (Sequence[Any]) – An iterable containing the raw dataset samples representing one of the database split datasets.

  • loader (RawDataLoader) – An object instance that can load samples and targets from storage.

  • transforms (Sequence[Callable[[Tensor], Tensor]]) – A set of transforms that should be applied to the cached samples for this dataset, to fit the output of the raw-data-loader to the model of interest.

  • parallel (int) – Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading.

  • multiprocessing_context (str | None) – Which implementation of the multiprocessing context to use. Options are defined at multiprocessing. If set to None, use the default for the current platform.

  • disable_pbar (bool) – If set, disables progress bars.

targets()[source]

Return the targets for all samples in the dataset.

Return type:

list[Tensor]

Returns:

The targets for all samples in the dataset.

class mednet.data.datamodule.ConcatDataset(datasets)[source]

Bases: Dataset

A dataset that represents a concatenation of other cached or delayed datasets.

Parameters:

datasets (Sequence[Dataset]) – An iterable over pre-instantiated datasets.

targets()[source]

Return the targets for all samples in the dataset.

Return type:

list[Tensor]

Returns:

The targets for all samples in the dataset.

class mednet.data.datamodule.ConcatDataModule(splits, database_name='', split_name='', task='', num_classes=1, collate_fn=<function default_collate>, cache_samples=False, batch_size=1, drop_incomplete_batch=False, parallel=-1)[source]

Bases: LightningDataModule

A conveninent DataModule with dictionary split loading, mini- batching, parallelisation and caching, all in one.

Instances of this class can load and concatenate an arbitrary number of data-split (a.k.a. protocol) definitions for (possibly disjoint) databases, and can manage raw data-loading from disk. An optional caching mechanism stores the data in associated CPU memory, which can improve data serving while training and evaluating models.

This DataModule defines basic operations to handle data loading and mini-batch handling within this package’s framework. It can return torch.utils.data.DataLoader objects for training, validation, prediction and testing conditions. Parallelisation is handled by a simple input flag.

Parameters:
  • splits (Mapping[str, Sequence[tuple[Sequence[Any], RawDataLoader]]]) –

    A dictionary that contains string keys representing dataset names, and values that are iterables over a 2-tuple containing an iterable over arbitrary, user-configurable sample representations (potentially on disk or permanent storage), and data.typing.RawDataLoader (or “sample”) loader objects, which concretely implement a mechanism to load such samples in memory, from permanent storage.

    Sample representations on permanent storage may be of any iterable format (e.g. list, dictionary, etc.), for as long as the assigned data.typing.RawDataLoader can properly handle it.

    Tip

    To check the split and that the loader function works correctly, you may use split.check_database_split_loading().

    This class expects at least one entry called train to exist in the input dictionary. Optional entries are validation, and test. Entries named monitor-... will be considered extra datasets that do not influence any early stop criteria during training, and are just monitored beyond the validation dataset.

  • database_name (str) – The name of the database, or aggregated database containing the raw-samples served by this data module.

  • split_name (str) – The name of the split used to group the samples into the various datasets for training, validation and testing.

  • task (str) – The task this datamodule generate samples for (e.g. classification, segmentation, or detection).

  • num_classes (int) – The number of target classes samples of this datamodule can have. In a classification task, this will dictate the number of outputs for the classifier (one-hot-encoded), the number of segmentation outputs for a semantic segmentation network, or the types of objects in an object detector.

  • collate_fn – A custom function to batch the samples. Uses torch.utils.data.default_collate() by default.

  • cache_samples (bool) – If set, then issue raw data loading during prepare_data(), and serves samples from CPU memory. Otherwise, loads samples from disk on demand. Running from CPU memory will offer increased speeds in exchange for CPU memory. Sufficient CPU memory must be available before you set this attribute to True. It is typically useful for relatively small datasets.

  • batch_size (int) – Number of samples in every training batch (this parameter affects memory requirements for the network). If the number of samples in the batch is larger than the total number of samples available for training, this value is truncated. If this number is smaller, then batches of the specified size are created and fed to the network until there are no more new samples to feed (epoch is finished). If the total number of training samples is not a multiple of the batch-size, the last batch will be smaller than the first, unless drop_incomplete_batch is set to true, in which case this batch is not used.

  • drop_incomplete_batch (bool) – If set, then may drop the last batch in an epoch in case it is incomplete. If you set this option, you should also consider increasing the total number of training epochs, as the total number of training steps may be reduced.

  • parallel (int) – Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading.

DatasetDictionary

A dictionary of datasets mapping names to actual datasets.

alias of dict[str, Dataset]

property parallel: int

Whether to use multiprocessing for data loading.

Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading.

It sets the parameter num_workers (from DataLoaders) to match the expected pytorch representation. It also sets the multiprocessing_context to use spawn instead of the default (fork, on Linux).

The mapping between the command-line interface parallel setting works like this:

Table 4 Relationship between parallel and DataLoader parameters

CLI parallel

torch.utils.data.DataLoader kwargs

Comments

<0

0

Disables multiprocessing entirely, executes everything within the same processing context

0

multiprocessing.cpu_count()

Runs mini-batch data loading on as many external processes as CPUs available in the current machine

>=1

parallel

Runs mini-batch data loading on as many external processes as set on parallel

Returns:

The value of self._parallel.

Return type:

int

property model_transforms: Sequence[Callable[[Tensor], Tensor]] | None

Transform required to fit data into the model.

A list of transforms (torch modules) that will be applied after raw-data-loading. and just before data is fed into the model or eventual data-augmentation transformations for all data loaders produced by this DataModule. This part of the pipeline receives data as output by the raw-data-loader, or model-related transforms (e.g. resize adaptions), if any is specified. If data is cached, it is cached after model-transforms are applied, as that is a potential memory saver (e.g., if it contains a resizing operation to smaller images).

Returns:

A list containing the model tansforms.

Return type:

list

val_dataset_keys()[source]

Return list of validation dataset names.

Returns:

The list of validation dataset names.

Return type:

list[str]

setup(stage)[source]

Set up datasets for different tasks on the pipeline.

This method should setup (load, pre-process, etc) all datasets required for a particular stage (fit, validate, test, predict), and keep them ready to be used on one of the _dataloader() functions that are pertinent for such stage.

If you have set cache_samples, samples are loaded at this stage and cached in memory.

Parameters:

stage (str) –

Name of the stage in which the setup is applicable. Can be one of fit, validate, test or predict. Each stage typically uses the following data loaders:

  • fit: uses both train and validation datasets

  • validate: uses only the validation dataset

  • test: uses only the test dataset

  • predict: uses only the test dataset

Return type:

None

teardown(stage)[source]

Unset-up datasets for different tasks on the pipeline.

This method unsets (unload, remove from memory, etc) all datasets required for a particular stage (fit, validate, test, predict).

If you have set cache_samples, samples are loaded and this may effectivley release all the associated memory.

Parameters:

stage (str) –

Name of the stage in which the teardown is applicable. Can be one of fit, validate, test or predict. Each stage typically uses the following data loaders:

  • fit: uses both train and validation datasets

  • validate: uses only the validation dataset

  • test: uses only the test dataset

  • predict: uses only the test dataset

Return type:

None

train_dataloader()[source]

Return the train data loader.

Return type:

DataLoader[Mapping[str, Any]]

Returns:

The train data loader(s).

unshuffled_train_dataloader()[source]

Return the train data loader without shuffling.

Return type:

DataLoader[Mapping[str, Any]]

Returns:

The train data loader without shuffling.

val_dataloader()[source]

Return the validation data loader(s).

Return type:

dict[str, DataLoader[Mapping[str, Any]]]

Returns:

The validation data loader(s).

test_dataloader()[source]

Return the test data loader(s).

Return type:

dict[str, DataLoader[Mapping[str, Any]]]

Returns:

The test data loader(s).

predict_dataloader()[source]

Return the prediction data loader(s).

Return type:

dict[str, DataLoader[Mapping[str, Any]]]

Returns:

The prediction data loader(s).

class mednet.data.datamodule.CachingDataModule(database_split, raw_data_loader, **kwargs)[source]

Bases: ConcatDataModule

A simplified version of our DataModule for a single split.

Apart from construction, the behaviour of this DataModule is very similar to its simpler counterpart, serving training, validation and test sets.

Parameters:
  • database_split (Mapping[str, Sequence[Any]]) –

    A dictionary that contains string keys representing dataset names, and values that are iterables over sample representations (potentially on disk). These objects are passed to an unique data.typing.RawDataLoader for loading the typing.Sample data (and metadata) in memory. It therefore assumes the whole split is homogeneous and can be loaded in the same way.

    Tip

    To check the split and the loader function works correctly, you may use split.check_database_split_loading().

    This class expects at least one entry called train to exist in the input dictionary. Optional entries are validation, and test. Entries named monitor-... will be considered extra datasets that do not influence any early stop criteria during training, and are just monitored beyond the validation dataset.

  • raw_data_loader (RawDataLoader) – An object instance that can load samples from storage.

  • **kwargs – List of named parameters matching those of ConcatDataModule, other than splits.