[docs]classRawDataLoader(BaseDataLoader):"""A specialized raw-data-loader for the Montgomery dataset. Parameters ---------- config_variable Key to search for in the configuration file for the root directory of this database. multiclass Set to ``True`` if the targets should be output as 2 distinct classes instead of a single (0/1) output. """datadir:pathlib.Path"""This variable contains the base directory where the database raw data is stored."""# config_variable: required so this loader can be used for the small# version of the Montgomery database as well.def__init__(self,config_variable:str=CONFIGURATION_KEY_DATADIR,multiclass:bool=False):self.datadir=pathlib.Path(load_rc().get(config_variable,os.path.realpath(os.curdir)),)self.multiclass=multiclass
[docs]defsample(self,sample:tuple[str,int,typing.Any|None])->Sample:"""Load a single image sample from the disk. Parameters ---------- sample Expects a tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing the sample target. Returns ------- The sample representation as a dictionary. """# N.B.: Montgomery images are encoded as grayscale PNGs, so no need to# convert them again with Image.convert("L").image=PIL.Image.open(self.datadir/sample[0])image,_=remove_black_borders(image)image=tv_tensors.Image(to_tensor(image))# use the code below to view generated images# from torchvision.transforms.functional import to_pil_image# to_pil_image(tensor).show()# __import__("pdb").set_trace()returndict(image=image,target=self.target(sample),name=sample[0])
[docs]deftarget(self,sample:typing.Any)->torch.Tensor:"""Load only sample target from its raw representation. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing the sample target. Returns ------- The label corresponding to the specified sample, encapsulated as a 1D torch float tensor. """ifself.multiclass:ifsample[1]==0:returntorch.FloatTensor([1,0])returntorch.FloatTensor([0,1])returntorch.FloatTensor([sample[1]])
[docs]classDataModule(CachingDataModule):"""Montgomery DataModule for TB detection. Parameters ---------- split_path Path or traversable (resource) with the JSON split description to load. multiclass Set to ``True`` if the targets should be output as 2 distinct classes instead of a single (0/1) output. """def__init__(self,split_path:pathlib.Path|importlib.resources.abc.Traversable,multiclass:bool=False,):super().__init__(database_split=JSONDatabaseSplit(split_path),raw_data_loader=RawDataLoader(multiclass=multiclass),database_name=DATABASE_SLUG,split_name=split_path.name.rsplit(".",2)[0],task="classification",)