[docs]classAlexnet(Model):"""`AlexNet network architecture <alexnet-pytorch_>`_ model, from [ALEXNET-2012]_. Note: only usable with a normalized dataset Parameters ---------- loss_type The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. loss_arguments Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. scheduler_type The type of scheduler to use for training. scheduler_arguments Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. num_classes Number of outputs (classes) for this model. """def__init__(self,loss_type:type[torch.nn.Module]=torch.nn.BCEWithLogitsLoss,loss_arguments:dict[str,typing.Any]={},optimizer_type:type[torch.optim.Optimizer]=torch.optim.Adam,optimizer_arguments:dict[str,typing.Any]={},scheduler_type:type[torch.optim.lr_scheduler.LRScheduler]|None=None,scheduler_arguments:dict[str,typing.Any]={},model_transforms:TransformSequence=[],augmentation_transforms:TransformSequence=[],pretrained:bool=False,num_classes:int=1,):super().__init__(name="alexnet",loss_type=loss_type,loss_arguments=loss_arguments,optimizer_type=optimizer_type,optimizer_arguments=optimizer_arguments,scheduler_type=scheduler_type,scheduler_arguments=scheduler_arguments,model_transforms=model_transforms,augmentation_transforms=augmentation_transforms,num_classes=num_classes,)self.pretrained=pretrained# Load pretrained modelifself.pretrained:weights=Noneelse:logger.info(f"Loading pretrained `{self.name}` model weights")weights=models.AlexNet_Weights.DEFAULTself.model_ft=models.alexnet(weights=weights)# Instantiates model and adapts output featuresself.num_classes=num_classes@Model.num_classes.setter# type: ignore[attr-defined]defnum_classes(self,v:int)->None:ifself.model_ft.classifier[6].out_features!=v:ifself.pretrained:logger.info(f"Resetting `{self.name}` pretrained classifier layer weights due "f"to change in output size "f"({self.model_ft.classifier[6].out_features} -> {v})")self.model_ft.classifier[4]=torch.nn.Linear(in_features=self.model_ft.classifier[1].out_features,out_features=512)self.model_ft.classifier[6]=torch.nn.Linear(in_features=self.model_ft.classifier[4].out_features,out_features=v)self._num_classes=v
[docs]defon_load_checkpoint(self,checkpoint:Checkpoint)->None:num_classes=checkpoint["state_dict"]["model_ft.classifier.6.bias"].shape[0]ifnum_classes!=self.num_classes:logger.debug(f"Resetting number-of-output-classes at `{self.name}` model from "f"{self.num_classes} to {num_classes} while loading checkpoint.")self.num_classes=num_classessuper().on_load_checkpoint(checkpoint)
[docs]defforward(self,x):"""Forward the input tensor through the network, producing a prediction. Parameters ---------- x The tensor input to be forwarded. Returns ------- The prediction, as a tensor. """x=self.normalizer(x)returnself.model_ft(x)
[docs]defset_normalizer(self,dataloader:torch.utils.data.DataLoader)->None:"""Initialize the normalizer for the current model. This function is NOOP if ``pretrained = True`` (normalizer set to imagenet weights, during contruction). Parameters ---------- dataloader A torch Dataloader from which to compute the mean and std. Will not be used if the model is pretrained. """ifself.pretrained:from..normalizerimportmake_imagenet_normalizerlogger.info(f"ImageNet pre-trained {self.name} model - NOT "f"computing z-norm factors from train dataloader. "f"Using preset factors from torchvision.",)self.normalizer=make_imagenet_normalizer()else:super().set_normalizer(dataloader)