[docs]classRawDataLoader(BaseDataLoader):"""A specialized raw-data-loader for the jsrt dataset."""datadir:pathlib.Path"""This variable contains the base directory where the database raw data is stored."""def__init__(self):self.datadir=pathlib.Path(load_rc().get(CONFIGURATION_KEY_DATADIR,os.path.realpath(os.curdir)))
[docs]defload_pil_raw_12bit_jsrt(self,path:pathlib.Path)->PIL.Image.Image:"""Load a raw 16-bit sample data. This method was designed to handle the raw images from the JSRT dataset. It reads the data file and applies a simple histogram equalization to the 8-bit representation of the image to obtain something along the lines of the PNG (unofficial) version distributed at `JSRT-Kaggle`. Parameters ---------- path The full path leading to the image to be loaded. Returns ------- A PIL image in RGB mode, with `width`x`width` pixels. """raw_image=np.fromfile(path,np.dtype(">u2")).reshape(2048,2048)raw_image[raw_image>4095]=4095raw_image=4095-raw_image# invert colorsraw_image=(raw_image>>4).astype(np.uint8)# 8-bit uintraw_image=skimage.exposure.equalize_hist(raw_image)returnPIL.Image.fromarray((raw_image*255).astype(np.uint8)).convert("RGB")
[docs]defsample(self,sample:typing.Any)->Sample:"""Load a single image sample from the disk. Parameters ---------- sample A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing the sample label. Returns ------- The sample representation. """image=self.load_pil_raw_12bit_jsrt(self.datadir/sample[0])assertimage.size==(2048,2048)image=to_tensor(image.resize((1024,1024),PIL.Image.Resampling.BILINEAR))# Combine left and right lung masks into a single tensorassertsample[2]isnotNonetarget=tv_tensors.Image(to_tensor(np.ma.mask_or(np.asarray(PIL.Image.open(self.datadir/sample[1]).convert(mode="1",dither=None)),np.asarray(PIL.Image.open(self.datadir/sample[2]).convert(mode="1",dither=None)),)).float())image=tv_tensors.Image(image)target=tv_tensors.Mask(target)mask=tv_tensors.Mask(torch.ones_like(target))returndict(image=image,target=target,mask=mask,name=sample[0])
[docs]classDataModule(CachingDataModule):"""Japanese Society of Radiological Technology dataset for lung segmentation. Parameters ---------- split_path Path or traversable (resource) with the JSON split description to load. """def__init__(self,split_path:pathlib.Path|importlib.resources.abc.Traversable):super().__init__(database_split=JSONDatabaseSplit(split_path),raw_data_loader=RawDataLoader(),database_name=DATABASE_SLUG,split_name=split_path.name.rsplit(".",2)[0],task="segmentation",)