Source code for mednet.data.classify.tbpoc
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""TB-POC dataset for computer-aided diagnosis.
This databases contain only the tuberculosis final diagnosis (0 or 1) and come
from HIV infected patients.
* Database reference: :cite:p:`griesel_optimizing_2018`
.. important:: **Raw data organization**
The TB-POC base datadir, which you should configure following the
:ref:`mednet.setup` instructions, must contain at least the directory
``TBPOC_CXR`` with all JPEG images.
Data specifications:
* Raw data input (on disk):
* JPEG 8-bit Grayscale images
* Original resolution (height x width or width x height): 2048 x 2500 pixels
or 2500 x 2048 pixels
* Total samples: 407
* Output image:
* Transforms:
* Load raw grayscale jpeg with :py:mod:`PIL`
* Remove black borders
* Convert to torch tensor
* Torch center cropping to get square image
* Final specifications:
* Grayscale, encoded as a single plane tensor, 32-bit floats,
square with varying resolutions (2048 x 2048 being the maximum),
but also depending on black borders' sizes on the input image.
* Labels: 0 (healthy), 1 (active tuberculosis)
This module contains the base declaration of common data modules and raw-data
loaders for this database. All configured splits inherit from this definition.
"""
import importlib.resources.abc
import os
import pathlib
import typing
import PIL.Image
import torch
from torchvision import tv_tensors
from torchvision.transforms.v2.functional import to_dtype, to_image
from ...utils.rc import load_rc
from ..datamodule import CachingDataModule
from ..image_utils import remove_black_borders
from ..split import JSONDatabaseSplit
from ..typing import RawDataLoader as BaseDataLoader
from ..typing import Sample
DATABASE_SLUG = __name__.rsplit(".", 1)[-1]
"""Pythonic name of this database."""
CONFIGURATION_KEY_DATADIR = "datadir." + DATABASE_SLUG
"""Key to search for in the configuration file for the root directory of this
database."""
[docs]
class RawDataLoader(BaseDataLoader):
"""A specialized raw-data-loader for the Shenzen dataset."""
datadir: pathlib.Path
"""This variable contains the base directory where the database raw data is
stored."""
def __init__(self):
self.datadir = pathlib.Path(
load_rc().get(
CONFIGURATION_KEY_DATADIR,
os.path.realpath(os.curdir),
),
)
[docs]
def sample(self, sample: typing.Any) -> Sample:
"""Load a single image sample from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample target.
Returns
-------
The sample representation.
"""
# images from TBPOC are encoded as grayscale JPEGs, no need to
# call convert("L") here.
image = PIL.Image.open(self.datadir / sample[0])
image, _ = remove_black_borders(image)
image = to_dtype(to_image(image), torch.float32, scale=True)
image = tv_tensors.Image(image)
# use the code below to view generated images
# from torchvision.transforms.v2.functional import to_pil_image
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return dict(image=image, target=self.target(sample), name=sample[0])
[docs]
def target(self, sample: typing.Any) -> torch.Tensor:
"""Load only sample target from its raw representation.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample target.
Returns
-------
The label corresponding to the specified sample, encapsulated as a
1D torch float tensor.
"""
return torch.FloatTensor([sample[1]])
[docs]
class DataModule(CachingDataModule):
"""TB-POC dataset for computer-aided diagnosis.
Parameters
----------
split_path
Path or traversable (resource) with the JSON split description to load.
"""
def __init__(self, split_path: pathlib.Path | importlib.resources.abc.Traversable):
super().__init__(
database_split=JSONDatabaseSplit(split_path),
raw_data_loader=RawDataLoader(),
database_name=DATABASE_SLUG,
split_name=split_path.name.rsplit(".", 2)[0],
task="classification",
num_classes=1,
)