Source code for mednet.data.segment.jsrt

# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Japanese Society of Radiological Technology dataset for lung segmentation.

The database includes 154 nodule and 93 non-nodule images.  It contains a total
of 247 resolution of 2048 x 2048 pixels, issued from original digitized
Radiographies (laser scanner). One set of ground-truth lung annotations is
available.

* Database references:

  * Original data: :cite:p:`shiraishi_development_2000`
  * Split: :cite:p:`gaal_attention_2020`

.. important:: **Raw data organization**

   The JSRT_ base datadir, which you should configure following the
   :ref:`mednet.setup` instructions, must contain at least the following
   directories:

   - ``All247images/`` (directory containing the CXR images, in raw format)
   - ``scratch/`` (must contain masks downloaded from `JSRT-Annotations`_)

Data specifications:

* Raw data input (on disk):

  * Original images encoded in proprietary 12-bit RAW format.  A PNG-converted
    set of images is provided at JSRT-Kaggle_ for your reference.  Input
    resolution is 2048 x 2048 pixels.
  * Masks: encoded as GIF files with separate portions for left and right
    lungs, with a resolution of 1024 x 1024 pixels
  * Total samples: 247

* Output sample:

    * Image: Load raw image from folder ``All247images/`` using
      :py:func:`numpy.fromfile`, then applies a simple histogram equalization
      to the 8-bit representation of the image, to obtain something along the
      lines of the PNG (unofficial) version distributed at JSRT-Kaggle_.
      Output images have a size of 1024 x 1024 pixels, achieved by resizing the
      original input with bilinear interpolation.
    * Labels for each of the lungs are read from the provided GIF files and
      merged into a single output image.

The ``default`` split contains 172 samples for training, 25 for validation and
50 for test.

This module contains the base declaration of common data modules and raw-data
loaders for this database. All configured splits inherit from this definition.
"""

import importlib.resources.abc
import os
import pathlib
import typing

import numpy as np
import PIL.Image
import skimage.exposure
import torch
from torchvision import tv_tensors
from torchvision.transforms.v2.functional import to_dtype, to_image

from ...utils.rc import load_rc
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from ..typing import RawDataLoader as BaseDataLoader
from ..typing import Sample

DATABASE_SLUG = __name__.rsplit(".", 1)[-1]
"""Pythonic name to refer to this database."""

CONFIGURATION_KEY_DATADIR = "datadir." + DATABASE_SLUG
"""Key to search for in the configuration file for the root directory of this
database."""


[docs] class RawDataLoader(BaseDataLoader): """A specialized raw-data-loader for the jsrt dataset.""" datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): self.datadir = pathlib.Path( load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)) )
[docs] def load_pil_raw_12bit_jsrt(self, path: pathlib.Path) -> PIL.Image.Image: """Load a raw 16-bit sample data. This method was designed to handle the raw images from the JSRT dataset. It reads the data file and applies a simple histogram equalization to the 8-bit representation of the image to obtain something along the lines of the PNG (unofficial) version distributed at `JSRT-Kaggle`. Parameters ---------- path The full path leading to the image to be loaded. Returns ------- A PIL image in RGB mode, with `width`x`width` pixels. """ raw_image = np.fromfile(path, np.dtype(">u2")).reshape(2048, 2048) raw_image[raw_image > 4095] = 4095 raw_image = 4095 - raw_image # invert colors raw_image = (raw_image >> 4).astype(np.uint8) # 8-bit uint raw_image = skimage.exposure.equalize_hist(raw_image) return PIL.Image.fromarray((raw_image * 255).astype(np.uint8)).convert("RGB")
[docs] def sample(self, sample: typing.Any) -> Sample: """Load a single image sample from the disk. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing the sample label. Returns ------- The sample representation. """ image = self.load_pil_raw_12bit_jsrt(self.datadir / sample[0]) # N.B.: Original image size is 2048x2048 px, however annotations were done on # 1024x1024 px canvas. # # assert image.size == (2048, 2048) image = image.resize((1024, 1024), PIL.Image.Resampling.BILINEAR) image = to_dtype(to_image(image), torch.float32, scale=True) target = self.target(sample) # use image as a base since target() can be overriden by child class mask = torch.ones((1, image.shape[-2], image.shape[-1]), dtype=torch.float32) image = tv_tensors.Image(image) target = tv_tensors.Mask(target) mask = tv_tensors.Mask(mask) return dict(image=image, target=target, mask=mask, name=sample[0])
[docs] def target(self, sample: typing.Any) -> torch.Tensor: """Load only sample target from its raw representation. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing the sample target. Returns ------- The label corresponding to the specified sample, encapsulated as a torch float tensor. """ # Combine left and right lung masks into a single tensor left = PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None) right = PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None) target = np.ma.mask_or(np.asarray(left), np.asarray(right)) return to_dtype(to_image(target), torch.float32, scale=True)
[docs] class DataModule(CachingDataModule): """Japanese Society of Radiological Technology dataset for lung segmentation. Parameters ---------- split_path Path or traversable (resource) with the JSON split description to load. """ def __init__(self, split_path: pathlib.Path | importlib.resources.abc.Traversable): super().__init__( database_split=JSONDatabaseSplit(split_path), raw_data_loader=RawDataLoader(), database_name=DATABASE_SLUG, split_name=split_path.name.rsplit(".", 2)[0], task="segmentation", num_classes=1, )