Source code for mednet.data.segment.cxr8

# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""ChestX-ray8: Hospital-scale Chest X-ray Database.

The database contains a total of 112'120 images. Image size for each X-ray is
1024 x 1024. One set of automatically generated mask annotations is available
for all images.

* Database references:

  * Original data: :cite:p:`wang_chestx-ray8_2017`
  * Split reference: :cite:p:`gaal_attention_2020`

.. important:: **Raw data organization**

    The CXR8_ base datadir, which you should configure following the
    :ref:`mednet.setup` instructions, must contain at least the following
    directories:

    - ``images/`` (directory containing the CXR images, in PNG format)
    - ``segmentations/`` (must contain masks downloaded from `CXR8-Annotations`_)

    The flag ``idiap_folder_structure`` makes the loader search for files
    named, e.g. ``images/00030621_006.png``, as
    ``images/00030/00030621_006.png`` (this is valid for both images and
    segmentation masks).

* Raw data input (on disk):

  * PNG RGB 8-bit depth images
  * Resolution: 1024 x 1024 pixels
  * Total samples available: 112'120

* Output image:

  * Transforms:

    * Load raw PNG with :py:mod:`PIL`, with auto-conversion to RGB, convert to
      tensor
    * Labels for each of the lungs are read from the provided GIF files and
      merged into a single output image.

The ``default`` split contains 78'484 images for training, 11'212 images for
validation, and 22'424 images for testing.

This module contains the base declaration of common data modules and raw-data
loaders for this database. All configured splits inherit from this definition.
"""

import importlib.resources.abc
import os
import pathlib
import typing

import PIL.Image
import torch
from torchvision import tv_tensors
from torchvision.transforms.v2.functional import to_dtype, to_image

from ...utils.rc import load_rc
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from ..typing import RawDataLoader as BaseDataLoader
from ..typing import Sample

DATABASE_SLUG = __name__.rsplit(".", 1)[-1]
"""Pythonic name to refer to this database."""

CONFIGURATION_KEY_DATADIR = "datadir." + DATABASE_SLUG
"""Key to search for in the configuration file for the root directory of this
database."""

CONFIGURATION_KEY_IDIAP_FILESTRUCTURE = DATABASE_SLUG + ".idiap_folder_structure"
"""Key to search for in the configuration file indicating if the loader should
use standard or idiap-based file organisation structure.

It causes the internal loader to search for files in a slightly different
folder structure, that was adapted to Idiap's requirements (number of files per
folder to be less than 10k).
"""


[docs] class RawDataLoader(BaseDataLoader): """A specialized raw-data-loader for the cxr8 dataset.""" datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): self.datadir = pathlib.Path( load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)) ) self.idiap_file_organisation = load_rc().get( CONFIGURATION_KEY_IDIAP_FILESTRUCTURE, False, )
[docs] def sample(self, sample: typing.Any) -> Sample: """Load a single image sample from the disk. Parameters ---------- sample A tuple containing path suffixes to the sample image, target, and mask to be loaded, within the dataset root folder. Returns ------- The sample representation. """ file_path = pathlib.Path(sample[0]) if self.idiap_file_organisation: sample_parts = sample[0].split("/", 1) file_path = pathlib.Path( sample_parts[0] + "/" + sample_parts[1][:5] + "/" + sample_parts[1] ) image = PIL.Image.open(self.datadir / file_path).convert(mode="RGB") image = to_dtype(to_image(image), torch.float32, scale=True) target = self.target(sample) # use image as a base since target() can be overriden by child class mask = torch.ones((1, image.shape[-2], image.shape[-1]), dtype=torch.float32) image = tv_tensors.Image(image) target = tv_tensors.Mask(target) mask = tv_tensors.Mask(mask) return dict(image=image, target=target, mask=mask, name=sample[0])
[docs] def target(self, sample: typing.Any) -> torch.Tensor: """Load only sample target from its raw representation. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing the sample target. Returns ------- The label corresponding to the specified sample, encapsulated as a torch float tensor. """ target_path = pathlib.Path(sample[1]) if self.idiap_file_organisation: target_parts = sample[1].split("/", 1) target_path = pathlib.Path( pathlib.Path( target_parts[0] + "/" + target_parts[1][:5] + "/" + target_parts[1] ) ) target = PIL.Image.open(self.datadir / target_path).convert( mode="1", dither=None ) return to_dtype(to_image(target), torch.float32, scale=True)
[docs] class DataModule(CachingDataModule): """ChestX-ray8: Hospital-scale Chest X-ray Database. Parameters ---------- split_path Path or traversable (resource) with the JSON split description to load. """ def __init__(self, split_path: pathlib.Path | importlib.resources.abc.Traversable): super().__init__( database_split=JSONDatabaseSplit(split_path), raw_data_loader=RawDataLoader(), database_name=DATABASE_SLUG, split_name=split_path.name.rsplit(".", 2)[0], task="segmentation", num_classes=1, )