Source code for mednet.data.detect.montgomery
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Montgomery database transformed for lung detection.
Check :py:mod:`.segment.montgomery` for details. This module only uses the
segmentation utilities to provide an "object detection" interface.
* Output sample:
* Image: As per :py:mod:`.segment.montgomery`.
* Bounding-box: A single bounding-box accounting for the observed lung region.
This module contains the base declaration of common data modules and raw-data loaders
for this database. All configured splits inherit from this definition.
"""
import importlib.resources.abc
import pathlib
import typing
import torch
from torchvision import tv_tensors
from torchvision.ops import masks_to_boxes
from ..datamodule import CachingDataModule
from ..segment.montgomery import CONFIGURATION_KEY_DATADIR, DATABASE_SLUG
from ..segment.montgomery import RawDataLoader as BaseDataLoader
from ..split import JSONDatabaseSplit
from ..typing import Sample
[docs]
class RawDataLoader(BaseDataLoader):
"""A specialized raw-data-loader for the montgomery dataset."""
def __init__(self):
super().__init__()
# keep this on this module for correct database script support!
assert CONFIGURATION_KEY_DATADIR
[docs]
def sample(self, sample: typing.Any) -> Sample:
"""Load a single image sample from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
The sample representation.
"""
retval = super().sample(sample)
target = tv_tensors.BoundingBoxes(
data=self.target(sample),
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=retval["image"].shape[-2:],
)
return dict(
image=retval["image"],
target=target,
labels=torch.FloatTensor([1]), # background is 0
mask=retval["mask"],
name=sample[0],
)
[docs]
def target(self, sample: typing.Any) -> torch.Tensor:
"""Load only sample target from its raw representation.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing
the sample target.
Returns
-------
The label corresponding to the specified sample, encapsulated as a
torch float tensor.
"""
# converts target into bounding box
return masks_to_boxes(super().target(sample))
[docs]
class DataModule(CachingDataModule):
"""Montgomery database transformed for lung detection.
Parameters
----------
split_path
Path or traversable (resource) with the JSON split description to load.
"""
def __init__(self, split_path: pathlib.Path | importlib.resources.abc.Traversable):
super().__init__(
database_split=JSONDatabaseSplit(split_path),
raw_data_loader=RawDataLoader(),
database_name=DATABASE_SLUG,
split_name=split_path.name.rsplit(".", 2)[0],
task="detection",
num_classes=1,
)