mednet.models.classify.pasa

Simple CNN network model from [PGP+19].

Classes

Pasa([loss_type, loss_arguments, ...])

Simple CNN network model from [PGP+19].

class mednet.models.classify.pasa.Pasa(loss_type=<class 'torch.nn.modules.loss.BCEWithLogitsLoss'>, loss_arguments=None, optimizer_type=<class 'torch.optim.adam.Adam'>, optimizer_arguments=None, scheduler_type=None, scheduler_arguments=None, model_transforms=None, augmentation_transforms=None, num_classes=1)[source]

Bases: Model

Simple CNN network model from [PGP+19].

This network has a linear output. You should use losses with WithLogit instead of cross-entropy versions when training.

Parameters:
  • loss_type (type[Module]) –

    The loss to be used for training and evaluation.

    Warning

    The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so.

  • loss_arguments (dict[str, Any] | None) – Arguments to the loss.

  • optimizer_type (type[Optimizer]) – The type of optimizer to use for training.

  • optimizer_arguments (dict[str, Any] | None) – Arguments to the optimizer after params.

  • scheduler_type (type[LRScheduler] | None) – The type of scheduler to use for training.

  • scheduler_arguments (dict[str, Any] | None) – Arguments to the scheduler after params.

  • model_transforms (Optional[Sequence[Callable[[Tensor], Tensor]]]) – An optional sequence of torch modules containing transforms to be applied on the input before it is fed into the network.

  • augmentation_transforms (Optional[Sequence[Callable[[Tensor], Tensor]]]) – An optional sequence of torch modules containing transforms to be applied on the input before it is fed into the network.

  • num_classes (int) – Number of outputs (classes) for this model.

property num_classes: int

Number of outputs (classes) for this model.

Returns:

The number of outputs supported by this model.

Return type:

int

on_load_checkpoint(checkpoint)[source]

Perform actions during model loading (called by lightning).

If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters:

checkpoint (MutableMapping[str, Any]) – The loaded checkpoint.

Return type:

None

forward(x)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output