[docs]classModel(BaseModel):"""Base model type for classification tasks. Parameters ---------- name Common name to give to models of this type. loss_type The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. loss_arguments Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. scheduler_type The type of scheduler to use for training. scheduler_arguments Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. num_classes Number of outputs (classes) for this model. """def__init__(self,name:str,loss_type:type[torch.nn.Module]=torch.nn.BCEWithLogitsLoss,loss_arguments:dict[str,typing.Any]={},optimizer_type:type[torch.optim.Optimizer]=torch.optim.Adam,optimizer_arguments:dict[str,typing.Any]={},scheduler_type:type[torch.optim.lr_scheduler.LRScheduler]|None=None,scheduler_arguments:dict[str,typing.Any]={},model_transforms:TransformSequence=[],augmentation_transforms:TransformSequence=[],num_classes:int=1,):super().__init__(name,loss_type,loss_arguments,optimizer_type,optimizer_arguments,scheduler_type,scheduler_arguments,model_transforms,augmentation_transforms,num_classes,)
[docs]defset_normalizer(self,dataloader:torch.utils.data.DataLoader)->None:"""Initialize the input normalizer for the current model. Parameters ---------- dataloader A torch Dataloader from which to compute the mean and std. """from.normalizerimportmake_z_normalizerlogger.info(f"Uninitialised {self.name} model - "f"computing z-norm factors from train dataloader.",)self.normalizer=make_z_normalizer(dataloader)
[docs]deftraining_step(self,batch,_):images=batch[0]["image"]labels=batch[1]["target"]# Increase label dimension if too low# Allows single and multiclass usageiflabels.ndim==1:labels=torch.reshape(labels,(labels.shape[0],1))# Forward pass on the networkoutputs=self(self.augmentation_transforms(images))returnself._train_loss(outputs,labels.float())
[docs]defvalidation_step(self,batch,batch_idx,dataloader_idx=0):images=batch[0]["image"]labels=batch[1]["target"]# Increase label dimension if too low# Allows single and multiclass usageiflabels.ndim==1:labels=torch.reshape(labels,(labels.shape[0],1))# debug code to inspect images by eye:# from torchvision.transforms.functional import to_pil_image# for k in images:# to_pil_image(k).show()# __import__("pdb").set_trace()# data forwarding on the existing networkoutputs=self(images)returnself._validation_loss(outputs,labels.float())