[docs]classLogisticRegression(Model):"""`Logistic regression model <logistic-regression_>`_ for multi-class classification. Parameters ---------- loss_type The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. loss_arguments Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. num_classes Number of outputs (classes) for this model. input_size The number of inputs this classifer shall process. """def__init__(self,loss_type:type[torch.nn.Module]=torch.nn.BCEWithLogitsLoss,loss_arguments:dict[str,typing.Any]={},optimizer_type:type[torch.optim.Optimizer]=torch.optim.Adam,optimizer_arguments:dict[str,typing.Any]={"lr":1e-2},num_classes:int=1,input_size:int=14,):super().__init__(name="logistic-regression",loss_type=loss_type,loss_arguments=loss_arguments,optimizer_type=optimizer_type,optimizer_arguments=optimizer_arguments,scheduler_type=None,scheduler_arguments={},model_transforms=[],augmentation_transforms=[],num_classes=num_classes,)self.input_size=input_sizeself.num_classes=num_classes@Model.num_classes.setter# type: ignore[attr-defined]defnum_classes(self,v:int)->None:self.linear=torch.nn.Linear(self.input_size,v)self._num_classes=v
[docs]defon_load_checkpoint(self,checkpoint:Checkpoint)->None:num_classes=checkpoint["state_dict"]["linear.bias"].shape[0]ifnum_classes!=self.num_classes:logger.debug(f"Resetting number-of-output-classes at `{self.name}` model from "f"{self.num_classes} to {num_classes} while loading checkpoint.")self.num_classes=num_classessuper().on_load_checkpoint(checkpoint)