# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Simple CNN network model from :cite:p:`pasa_efficient_2019`."""
import logging
import typing
import torch
import torch.nn
import torch.nn.functional as F # noqa: N812
import torch.optim.optimizer
import torch.utils.data
from ...data.typing import TransformSequence
from ..typing import Checkpoint
from .model import Model
logger = logging.getLogger(__name__)
[docs]
class Pasa(Model):
"""Simple CNN network model from :cite:p:`pasa_efficient_2019`.
This network has a linear output. You should use losses with ``WithLogit``
instead of cross-entropy versions when training.
Parameters
----------
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
scheduler_type
The type of scheduler to use for training.
scheduler_arguments
Arguments to the scheduler after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
self,
loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] | None = None,
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] | None = None,
scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] | None = None,
model_transforms: TransformSequence | None = None,
augmentation_transforms: TransformSequence | None = None,
num_classes: int = 1,
):
super().__init__(
name="pasa",
loss_type=loss_type,
loss_arguments=loss_arguments,
optimizer_type=optimizer_type,
optimizer_arguments=optimizer_arguments,
scheduler_type=scheduler_type,
scheduler_arguments=scheduler_arguments,
model_transforms=model_transforms,
augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
)
# First convolution block
self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
self.fc3 = torch.nn.Conv2d(1, 16, (1, 1), (4, 4))
self.batchNorm2d_4 = torch.nn.BatchNorm2d(4)
self.batchNorm2d_16 = torch.nn.BatchNorm2d(16)
self.batchNorm2d_16_2 = torch.nn.BatchNorm2d(16)
# Second convolution block
self.fc4 = torch.nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1))
self.fc5 = torch.nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1))
self.fc6 = torch.nn.Conv2d(
16,
32,
(1, 1),
(1, 1),
) # Original stride (2, 2)
self.batchNorm2d_24 = torch.nn.BatchNorm2d(24)
self.batchNorm2d_32 = torch.nn.BatchNorm2d(32)
self.batchNorm2d_32_2 = torch.nn.BatchNorm2d(32)
# Third convolution block
self.fc7 = torch.nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1))
self.fc8 = torch.nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1))
self.fc9 = torch.nn.Conv2d(
32,
48,
(1, 1),
(1, 1),
) # Original stride (2, 2)
self.batchNorm2d_40 = torch.nn.BatchNorm2d(40)
self.batchNorm2d_48 = torch.nn.BatchNorm2d(48)
self.batchNorm2d_48_2 = torch.nn.BatchNorm2d(48)
# Fourth convolution block
self.fc10 = torch.nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1))
self.fc11 = torch.nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1))
self.fc12 = torch.nn.Conv2d(
48,
64,
(1, 1),
(1, 1),
) # Original stride (2, 2)
self.batchNorm2d_56 = torch.nn.BatchNorm2d(56)
self.batchNorm2d_64 = torch.nn.BatchNorm2d(64)
self.batchNorm2d_64_2 = torch.nn.BatchNorm2d(64)
# Fifth convolution block
self.fc13 = torch.nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1))
self.fc14 = torch.nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1))
self.fc15 = torch.nn.Conv2d(
64,
80,
(1, 1),
(1, 1),
) # Original stride (2, 2)
self.batchNorm2d_72 = torch.nn.BatchNorm2d(72)
self.batchNorm2d_80 = torch.nn.BatchNorm2d(80)
self.batchNorm2d_80_2 = torch.nn.BatchNorm2d(80)
self.pool2d = torch.nn.MaxPool2d(
(3, 3),
(2, 2),
) # Pool after conv. block
self.dense = torch.nn.Linear(80, self.num_classes) # Fully connected layer
@Model.num_classes.setter # type: ignore[attr-defined]
def num_classes(self, v: int) -> None:
if self.num_classes != v:
logger.info(
f"Resetting `{self.name}` output classifier layer weights due "
f"to change in output size ({self.num_classes} -> {v})"
)
self.dense = torch.nn.Linear(80, v) # Fully connected layer
self._num_classes = v
[docs]
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
# reset number of output classes if need be
self.num_classes = checkpoint["state_dict"]["dense.bias"].shape[0]
# perform routine checkpoint loading
super().on_load_checkpoint(checkpoint)
[docs]
def forward(self, x):
x = self.normalizer(x) # type: ignore
# First convolution block
_x = x
x = F.relu(self.batchNorm2d_4(self.fc1(x))) # 1st convolution
x = F.relu(self.batchNorm2d_16(self.fc2(x))) # 2nd convolution
x = (x + F.relu(self.batchNorm2d_16_2(self.fc3(_x)))) / 2 # Parallel
x = self.pool2d(x) # Pooling
# Second convolution block
_x = x
x = F.relu(self.batchNorm2d_24(self.fc4(x))) # 1st convolution
x = F.relu(self.batchNorm2d_32(self.fc5(x))) # 2nd convolution
x = (x + F.relu(self.batchNorm2d_32_2(self.fc6(_x)))) / 2 # Parallel
x = self.pool2d(x) # Pooling
# Third convolution block
_x = x
x = F.relu(self.batchNorm2d_40(self.fc7(x))) # 1st convolution
x = F.relu(self.batchNorm2d_48(self.fc8(x))) # 2nd convolution
x = (x + F.relu(self.batchNorm2d_48_2(self.fc9(_x)))) / 2 # Parallel
x = self.pool2d(x) # Pooling
# Fourth convolution block
_x = x
x = F.relu(self.batchNorm2d_56(self.fc10(x))) # 1st convolution
x = F.relu(self.batchNorm2d_64(self.fc11(x))) # 2nd convolution
x = (x + F.relu(self.batchNorm2d_64_2(self.fc12(_x)))) / 2 # Parallel
x = self.pool2d(x) # Pooling
# Fifth convolution block
_x = x
x = F.relu(self.batchNorm2d_72(self.fc13(x))) # 1st convolution
x = F.relu(self.batchNorm2d_80(self.fc14(x))) # 2nd convolution
x = (x + F.relu(self.batchNorm2d_80_2(self.fc15(_x)))) / 2 # Parallel
# no pooling
# Global average pooling
x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
# Dense layer
return self.dense(x)
# x = F.log_softmax(x, dim=1) # 0 is batch size