Source code for mednet.data.segment.stare

# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""STARE dataset for vessel segmentation.

A subset of the original STARE dataset contains 20 annotated eye fundus images
with a resolution of 700 x 605 (width x height). Two sets of ground-truth
vessel annotations are available. The first set by Adam Hoover ("ah") is
commonly used for training and testing. The second set by Valentina Kouznetsova
("vk") is typically used as a “human” baseline.

* Database references:

  * Original data: :cite:p:`hoover_locating_2000`
  * Split reference: :cite:p:`maninis_deep_2016`

Data specifications:

* Raw data input (on disk):

  * RGB images encoded in PPM format with resolution (HxW) = 605 x 700
  * Total samples: 397 (out of which only 20 are annotated for vessel
    segmentation)

* Output sample:

    * Image: Load raw PPM images with :py:mod:`PIL`, with auto-conversion to RGB.
    * Vessel annotations: Load annotations with :py:mod:`PIL`, with
      auto-conversion to model ``1`` with no dithering.
    * Eye fundus mask: Load mask with :py:mod:`PIL`, with
      auto-conversion to model ``1`` with no dithering.

Protocol ``ah`` (default baseline, with first, more detailed annotator)
includes 10 training samples and 10 test samples.  Protocol ``vk`` (second
annotator) includes the same samples but annotated by a second expert.

This module contains the base declaration of common data modules and raw-data
loaders for this database. All configured splits inherit from this definition.
"""

import importlib.resources
import importlib.resources.abc
import os
import pathlib
import typing

import PIL.Image
import torch
from torchvision import tv_tensors
from torchvision.transforms.v2.functional import to_dtype, to_image

from ...models.transforms import crop_image_to_mask
from ...utils.rc import load_rc
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from ..typing import RawDataLoader as BaseDataLoader
from ..typing import Sample

DATABASE_SLUG = __name__.rsplit(".", 1)[-1]
"""Pythonic name to refer to this database."""

CONFIGURATION_KEY_DATADIR = "datadir." + DATABASE_SLUG
"""Key to search for in the configuration file for the root directory of this
database."""


[docs] class RawDataLoader(BaseDataLoader): """A specialized raw-data-loader for the Stare database.""" datadir: pathlib.Path """This variable contains the base directory where the database raw data is stored.""" def __init__(self): self.datadir = pathlib.Path( load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)) )
[docs] def sample(self, sample: typing.Any) -> Sample: """Load a single image sample from the disk. Parameters ---------- sample A tuple containing the path suffix, within the database root folder, where to find the image to be loaded, and an integer, representing the sample label. Returns ------- The sample representation. """ image = PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB") image = to_dtype(to_image(image), torch.float32, scale=True) target = self.target(sample) mask_path = ( importlib.resources.files(__package__) / "masks" / DATABASE_SLUG / sample[2] ) with importlib.resources.as_file(mask_path) as path: mask = PIL.Image.open(path).convert(mode="1", dither=None) mask = to_dtype(to_image(mask), torch.float32, scale=True) image = tv_tensors.Image(crop_image_to_mask(image, mask)) target = tv_tensors.Mask(crop_image_to_mask(target, mask)) mask = tv_tensors.Mask(mask) return dict(image=image, target=target, mask=mask, name=sample[0])
[docs] def target(self, sample: typing.Any) -> torch.Tensor: """Load only sample target from its raw representation. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing the sample target. Returns ------- The label corresponding to the specified sample, encapsulated as a torch float tensor. """ target = PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None) return to_dtype(to_image(target), torch.float32, scale=True)
[docs] class DataModule(CachingDataModule): """STARE database for Vessel Segmentation. Parameters ---------- split_path Path or traversable (resource) with the JSON split description to load. """ def __init__(self, split_path: pathlib.Path | importlib.resources.abc.Traversable): super().__init__( database_split=JSONDatabaseSplit(split_path), raw_data_loader=RawDataLoader(), database_name=DATABASE_SLUG, split_name=split_path.name.rsplit(".", 2)[0], task="segmentation", num_classes=1, )