Source code for mednet.data.classify.montgomery_shenzhen

# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of :py:mod:`montgomery's <.data.classify.montgomery>` and :py:mod:`shenzhen's <.data.classify.shenzhen>` splits.

This module contains the base declaration of common data modules and raw-data
loaders for this database. All configured splits inherit from this definition.
"""

import importlib.resources.abc
import pathlib

from ..datamodule import ConcatDataModule
from ..split import JSONDatabaseSplit
from .montgomery import RawDataLoader as MontgomeryLoader
from .shenzhen import RawDataLoader as ShenzhenLoader

DATABASE_SLUG = __name__.rsplit(".", 1)[-1]
"""Pythonic name to refer to this database."""


[docs] class DataModule(ConcatDataModule): """Aggregated DataModule composed of :py:mod:`montgomery's <.data.classify.montgomery>` and :py:mod:`shenzhen's <.data.classify.shenzhen>` splits. Parameters ---------- split_name The name of the split to assign to this data module. split_path Path or traversable (resource) with the JSON split description to load for montgomery and shenzhen databases (in this order). """ def __init__( self, split_name: str, split_path: tuple[ pathlib.Path | importlib.resources.abc.Traversable, pathlib.Path | importlib.resources.abc.Traversable, ], ): montgomery_loader = MontgomeryLoader() montgomery_split = JSONDatabaseSplit(split_path[0]) shenzhen_loader = ShenzhenLoader() shenzhen_split = JSONDatabaseSplit(split_path[1]) super().__init__( splits={ "train": [ (montgomery_split["train"], montgomery_loader), (shenzhen_split["train"], shenzhen_loader), ], "validation": [ (montgomery_split["validation"], montgomery_loader), (shenzhen_split["validation"], shenzhen_loader), ], "test": [ (montgomery_split["test"], montgomery_loader), (shenzhen_split["test"], shenzhen_loader), ], }, database_name=DATABASE_SLUG, split_name=split_name, task="classification", num_classes=1, )