mednet.models.classify.alexnet¶
AlexNet network architecture, from [KSH17].
Classes
|
AlexNet network architecture model, from [KSH17]. |
- class mednet.models.classify.alexnet.Alexnet(loss_type=<class 'torch.nn.modules.loss.BCEWithLogitsLoss'>, loss_arguments=None, optimizer_type=<class 'torch.optim.adam.Adam'>, optimizer_arguments=None, scheduler_type=None, scheduler_arguments=None, model_transforms=None, augmentation_transforms=None, pretrained=False, num_classes=1)[source]¶
Bases:
Model
AlexNet network architecture model, from [KSH17].
Note: only usable with a normalized dataset
- Parameters:
The loss to be used for training and evaluation.
Warning
The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so.
loss_arguments (
dict
[str
,Any
] |None
) – Arguments to the loss.optimizer_type (
type
[Optimizer
]) – The type of optimizer to use for training.optimizer_arguments (
dict
[str
,Any
] |None
) – Arguments to the optimizer afterparams
.scheduler_type (
type
[LRScheduler
] |None
) – The type of scheduler to use for training.scheduler_arguments (
dict
[str
,Any
] |None
) – Arguments to the scheduler afterparams
.model_transforms (
Optional
[Sequence
[Callable
[[Tensor
],Tensor
]]]) – An optional sequence of torch modules containing transforms to be applied on the input before it is fed into the network.augmentation_transforms (
Optional
[Sequence
[Callable
[[Tensor
],Tensor
]]]) – An optional sequence of torch modules containing transforms to be applied on the input before it is fed into the network.pretrained (
bool
) – If set to True, loads pretrained model weights during initialization, else trains a new model.num_classes (
int
) – Number of outputs (classes) for this model.
- property num_classes: int¶
Number of outputs (classes) for this model.
- Returns:
The number of outputs supported by this model.
- Return type:
- on_load_checkpoint(checkpoint)[source]¶
Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to restore this.
- Parameters:
checkpoint (
MutableMapping
[str
,Any
]) – The loaded checkpoint.- Return type: