[docs]classModel(BaseModel):"""Base model type for semantic segmentation tasks. Parameters ---------- name Common name to give to models of this type. loss_type The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. loss_arguments Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. scheduler_type The type of scheduler to use for training. scheduler_arguments Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. num_classes Number of outputs (classes) for this model. """def__init__(self,name:str,loss_type:type[torch.nn.Module]=MultiLayerBCELogitsLossWeightedPerBatch,loss_arguments:dict[str,typing.Any]={},optimizer_type:type[torch.optim.Optimizer]=torch.optim.Adam,optimizer_arguments:dict[str,typing.Any]={},scheduler_type:type[torch.optim.lr_scheduler.LRScheduler]|None=None,scheduler_arguments:dict[str,typing.Any]={},model_transforms:TransformSequence=[],augmentation_transforms:TransformSequence=[],num_classes:int=1,):super().__init__(name,loss_type,loss_arguments,optimizer_type,optimizer_arguments,scheduler_type,scheduler_arguments,model_transforms,augmentation_transforms,num_classes,)
[docs]deftraining_step(self,batch,batch_idx):# debug code to inspect images by eye:# from torchvision.transforms.functional import to_pil_image# for k in batch["image"]:# to_pil_image(k).show()# __import__("pdb").set_trace()returnself.train_loss(self(self.augmentation_transforms(batch["image"])),self.augmentation_transforms(batch["target"]),)