Source code for mednet.engine.segment.dumper

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import logging
import pathlib
import typing

import h5py
import lightning.pytorch
import torch.utils.data
import tqdm

logger = logging.getLogger(__name__)


[docs] def run( datamodule: lightning.pytorch.LightningDataModule, output_folder: pathlib.Path, ) -> ( dict[str, list[tuple[str, str]]] | list[list[tuple[str, str]]] | list[tuple[str, str]] | None ): """Dump annotations from input datamodule. Parameters ---------- datamodule The lightning DataModule to extract annotations from. output_folder Folder where to store HDF5 representations of annotations. Returns ------- A JSON-able representation of sample data stored at ``output_folder``. For every split (dataloader), a list of samples in the form ``[sample-name, hdf5-path]`` is returned. In the cases where the ``predict_dataloader()`` returns a single loader, we then return a list. A dictionary is returned in case ``predict_dataloader()`` also returns a dictionary. Raises ------ TypeError If the DataModule's ``predict_dataloader()`` method does not return any of the types described above. """ def _write_sample( sample: typing.Any, output_folder: pathlib.Path ) -> tuple[str, str]: """Write a single sample target to an HDF5 file. Parameters ---------- sample A segmentation sample as output by a dataloader. output_folder Path leading to a folder where to store dumped annotations. Returns ------- A tuple which contains the sample path and the destination directory where the HDF5 file was saved. """ name = sample["name"][0] stem = pathlib.Path(name).with_suffix(".hdf5") dest = output_folder / stem tqdm.tqdm.write(f"`{name}` -> `{str(dest)}`") dest.parent.mkdir(parents=True, exist_ok=True) with h5py.File(dest, "w") as f: f.create_dataset( "image", data=sample["image"][0].cpu().numpy(), compression="gzip", compression_opts=9, ) f.create_dataset( "target", data=(sample["target"][0].squeeze(0).cpu().numpy() > 0.5), compression="gzip", compression_opts=9, ) f.create_dataset( "mask", data=(sample["mask"][0].squeeze(0).cpu().numpy() > 0.5), compression="gzip", compression_opts=9, ) return (name, str(stem)) dataloaders = datamodule.predict_dataloader() if isinstance(dataloaders, torch.utils.data.DataLoader): logger.info("Dump annotations from a single dataloader...") return [_write_sample(k, output_folder) for k in tqdm.tqdm(dataloaders)] if isinstance(dataloaders, list): retval_list = [] for k, dataloader in enumerate(dataloaders): logger.info(f"Dumping annotations from split `{k}`...") retval_list.append( [_write_sample(k, output_folder) for k in tqdm.tqdm(dataloader)] ) return retval_list if isinstance(dataloaders, dict): retval_dict = {} for name, dataloader in dataloaders.items(): logger.info(f"Dumping annotations from split `{name}`...") retval_dict[name] = [ _write_sample(k, output_folder) for k in tqdm.tqdm(dataloader) ] return retval_dict if dataloaders is None: logger.warning("Datamodule did not return any prediction dataloaders!") return None # if you get to this point, then the user is returning something that is # not supported - complain! raise TypeError( f"Datamodule returned strangely typed prediction " f"dataloaders: `{type(dataloaders)}` - Please write code " f"to support this use-case.", )