Source code for mednet.models.losses

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Custom losses for different tasks."""

import logging
import typing

import numpy.typing
import torch
import torch.utils.data

from ..data.datamodule import ConcatDataModule
from ..data.typing import Dataset
from .typing import TaskType

logger = logging.getLogger(__name__)


def _task_type(
    targets: torch.Tensor
    | numpy.typing.NDArray
    | typing.Iterable[typing.Iterable[int]]
    | typing.Iterable[typing.Iterable[typing.Iterable[typing.Iterable[int]]]],
) -> TaskType:
    """Determine the type of task from combined targets available.

    This function will look into the provided targets of a dataset and will
    determine the type of task.

    Parameters
    ----------
    targets
        The complete target set, for the whole dataset being analyzed. This
        matrix should be ``[n, C]`` where ``n`` is the number of samples, and
        ``C`` the number of classes.  All values should be either 0 or 1.

    Returns
    -------
        The type of tas
    """

    int_targets = torch.Tensor(targets).int()

    task_type: TaskType = "classification"
    if len(int_targets.shape) > 2:
        task_type = "segmentation"

    return task_type


def _get_positive_weights_from_targets(targets: torch.Tensor) -> torch.Tensor:
    """Compute the weights of each class of a set of targets.

    This function inputs a set of targets and computes the ratio between number
    of negative and positive samples (scalar).  The weight can be used to
    adjust minimisation criteria to in cases there is a huge data imbalance.

    It returns a vector with weights (inverse counts) for each target.

    Parameters
    ----------
    targets
        A :py:class:`torch.Tensor` containing the targets, in the format
        ``[n, C]`` where ``n`` is the number of samples and ``C`` the
        number of classes.

    Returns
    -------
        The positive weight of each class in the dataset given as input.
    """

    task_type = _task_type(targets)

    if task_type == "segmentation":
        # rearranges ``targets`` vector so the problem looks like a simpler
        # classification problem where each pixel is a "separate sample"
        targets = targets.transpose(0, 2).transpose(1, 3).reshape(-1, targets.shape[1])

    positive_count = targets.sum(dim=0)
    negative_count = targets.shape[0] - positive_count
    return negative_count / positive_count


def _get_positive_weights_from_dataloader(
    dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
    """Compute the weights of each class of a DataLoader.

    This function inputs a :py:class:`torch.utils.data.DataLoader` and computes
    the ratio between number of negative and positive samples (scalar).  The
    weight can be used to adjust minimisation criteria to in cases there is a
    huge data imbalance.

    It returns a vector with weights (inverse counts) for each target.

    Parameters
    ----------
    dataloader
        A DataLoader from which to compute the positive weights.  Entries must
        be a dictionary which must contain a ``target`` key.

    Returns
    -------
        The positive weight of each class in the dataset given as input.

    Raises
    ------
    NotImplementedError
        In the case of "multilabel" datasets, which are currently not
        supported.
    """

    if isinstance(dataloader.dataset, Dataset):
        # there is a faster way to access the targets!
        targets = dataloader.dataset.targets()
    else:
        targets = [batch["target"] for batch in dataloader]

    return _get_positive_weights_from_targets(torch.vstack(targets))


[docs] def pos_weight_for_bcewithlogitsloss( datamodule: ConcatDataModule, ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: """Generate the ``pos_weight`` argument for losses of type :py:class:`torch.nn.BCEWithLogitsLoss`. This function can generate the ``pos_weight`` parameters for both train and validation losses given a datamodule. Parameters ---------- datamodule The datamodule to probe for training and validation datasets. Returns ------- A tuple containing the training and validation ``pos_weight`` arguments, wrapped in a dictionary. """ train_weights = _get_positive_weights_from_dataloader( datamodule.unshuffled_train_dataloader() ) logger.info(f"train: BCEWithLogitsLoss(pos_weight={train_weights})") if "validation" in datamodule.val_dataloader().keys(): validation_weights = _get_positive_weights_from_dataloader( datamodule.val_dataloader()["validation"] ) else: logger.warning( "Datamodule does not contain a validation dataloader. " "The training dataloader will be used instead." ) validation_weights = train_weights logger.info(f"validation: BCEWithLogitsLoss(weight={validation_weights})") return (dict(pos_weight=train_weights), dict(pos_weight=validation_weights))
[docs] class BCEWithLogitsLossWeightedPerBatch(torch.nn.Module): """Calculates the binary cross entropy loss for every batch. This loss is similar to :py:class:`torch.nn.BCEWithLogitsLoss`, except it updates the ``pos_weight`` (ratio between negative and positive target pixels) parameter for the loss term for every batch, based on the accumulated taget pixels for all samples in the batch. Implements Equation 1 in :cite:p:`maninis_deep_2016`. The weight depends on the current proportion between negatives and positives in the ground- truth sample being analyzed. """ def __init__(self): super().__init__()
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Logits produced by the model to be evaluated, with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation). target Ground-truth information with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation), containing zeroes and ones. Returns ------- The average loss for all input data. """ # calculates the proportion of negatives to the total number of pixels # available in the masked region num_pos = target.sum() return torch.nn.functional.binary_cross_entropy_with_logits( input_, target, reduction="mean", pos_weight=(input_.numel() - num_pos) / num_pos, )
[docs] class SoftJaccardAndBCEWithLogitsLoss(torch.nn.Module): r"""Implement the generalized loss function of Equation (3) at :cite:p:`iglovikov_ternausnetv2_2018`. At the paper, authors suggest a value of :math:`\alpha = 0.7`, which we set as default for instances of this type. .. math:: L = \alpha H + (1-\alpha)(1-J) J is the Jaccard distance, and H, the Binary Cross-Entropy Loss. Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`. Parameters ---------- alpha Determines the weighting of J and H. Default: ``0.7``. """ def __init__(self, alpha: float = 0.7): super().__init__() self.alpha = alpha
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Logits produced by the model to be evaluated, with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation). target Ground-truth information with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation), containing zeroes and ones. Returns ------- Loss, in a single entry. """ eps = 1e-8 probabilities = torch.sigmoid(input_) intersection = (probabilities * target).sum() sums = probabilities.sum() + target.sum() j = intersection / (sums - intersection + eps) # this implements the support for looking just into the RoI h = torch.nn.functional.binary_cross_entropy_with_logits( input_, target, reduction="mean" ) return (self.alpha * h) + ((1 - self.alpha) * (1 - j))
[docs] class MultiLayerBCELogitsLossWeightedPerBatch(BCEWithLogitsLossWeightedPerBatch): """Weighted Binary Cross-Entropy Loss for multi-layered inputs. This loss can be used in networks that produce more than one output that has to match output targets. For example, architectures such as as :py:class:`.hed.HED` or :py:class:`.lwnet.LittleWNet` require this feature. It follows the inherited super class applying on-the-fly `pos_weight` updates per batch. """ def __init__(self): super().__init__()
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Logits produced by the model to be evaluated, with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation). target Ground-truth information with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation), containing zeroes and ones. Returns ------- The average loss for all input data. """ fwd = super().forward return torch.cat([fwd(i, target).unsqueeze(0) for i in input_]).mean()
[docs] class MultiLayerSoftJaccardAndBCELogitsLoss(SoftJaccardAndBCEWithLogitsLoss): """Implement Equation 3 in :cite:p:`iglovikov_ternausnetv2_2018` for the multi-output networks. This loss can be used in networks that produce more than one output that has to match output targets. For example, architectures such as as :py:class:`.hed.HED` or :py:class:`.lwnet.LittleWNet` require this feature. Parameters ---------- alpha : float Determines the weighting of SoftJaccard and BCE. Default: ``0.7``. """ def __init__(self, alpha: float = 0.7): super().__init__(alpha=alpha)
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Logits produced by the model to be evaluated, with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation). target Ground-truth information with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation), containing zeroes and ones. Returns ------- The average loss for all input data. """ fwd = super().forward return torch.cat([fwd(i, target).unsqueeze(0) for i in input_]).mean()
[docs] class MOONBCEWithLogitsLoss(torch.nn.Module): """Calculates the weighted binary cross entropy loss based on :cite:p:`guler_refining_2024`. This loss implements the domain-adapted multitask loss function in Equation (2) on :cite:p:`guler_refining_2024`. The vector of input weights must be calculated from the input dataset in advance, and set during initialization, or later, before the loss can be fully used. Parameters ---------- weights The positive weight of each class in the dataset given as input as a ``[2, C]`` tensor, with :math:`w_i^-` at position 0, and :math:`w_i^+` at position 1, as defined in Equation (1) of :cite:p:`guler_refining_2024`. """ def __init__(self, weights: torch.Tensor | None = None): super().__init__() self.weights = weights @classmethod def _get_weight_from_targets(cls, targets: torch.Tensor) -> torch.Tensor: r"""Compute the MOON weights from a set of targets as per Equation (1) in :cite:p:`guler_refining_2024`. Parameters ---------- targets A :py:class:`torch.Tensor` containing the targets, in the format ``[n, C]`` where ``n`` is the number of samples and ``C`` the number of classes. Returns ------- The weight of each class in the dataset given as input as a ``[2, C]`` tensor, with :math:`w_i^-` at position 0, and :math:`w_i^+` at position 1, as defined in Equation (1) of :cite:p:`guler_refining_2024`. """ task_type = _task_type(targets) if task_type == "segmentation": # rearranges ``targets`` vector so the problem looks like a simpler # classification problem where each pixel is a "separate sample" targets = ( targets.transpose(0, 2).transpose(1, 3).reshape(-1, targets.shape[1]) ) # at this point targets should be [n x C] s_plus = targets.sum(dim=0) s_minus = targets.shape[0] - s_plus w_minus = (s_plus / s_minus).clamp(min=0, max=1) w_plus = (s_minus / s_plus).clamp(min=0, max=1) return torch.vstack((w_minus, w_plus)) @classmethod def _get_weight_from_dataloader( cls, dataloader: torch.utils.data.DataLoader ) -> torch.Tensor: r"""Compute the MOON weights of each class of a DataLoader as per Equation (1) in :cite:p:`guler_refining_2024`. Parameters ---------- dataloader A DataLoader from which to compute the positive weights. Entries must be a dictionary which must contain a ``target`` key. Returns ------- The weight of each class in the dataset given as input as a ``[2, C]`` tensor, with :math:`w_i^-` at position 0, and :math:`w_i^+` at position 1, as defined in Equation (1) of :cite:p:`guler_refining_2024`. """ if isinstance(dataloader.dataset, Dataset): # there is a faster way to access the targets! targets = dataloader.dataset.targets() else: targets = [batch["target"] for batch in dataloader] return cls._get_weight_from_targets(torch.vstack(targets))
[docs] @classmethod def get_arguments_from_datamodule( cls, datamodule: ConcatDataModule ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: r"""Compute the MOON weights for train and validation sets of a datamodule. This function inputs a :py:class:`.data.datamodule.ConcatDataModule`, and for both the training and validation sets, and for each class on the respective dataloader targets, computes negative and positive weights as such: .. math:: \begin{align} w_i^+ &= \begin{cases} 1 & \text{if } S^{-}_{i} > S^{+}_{i} \\ \frac{S^{-}_{i}}{S^{+}_{i}} & \text{otherwise} \end{cases} & w_i^- &= \begin{cases} 1 & \text{if } S^{+}_{i} > S^{-}_{i} \\ \frac{S^{+}_{i}}{S^{-}_{i}} & \text{otherwise} \end{cases} \end{align} This weight vector is used during runtime to balance individual batch losses respecting individual class distributions. Parameters ---------- datamodule The datamodule to probe for training and validation datasets. Returns ------- A tuple containing the training and validation ``weight`` arguments, wrapped in a dictionary. Each ``weight`` variable contains the weights of each class in the target dataset as a ``[2, C]`` tensor, with :math:`w_i^-` at position 0, and :math:`w_i^+` at position 1, as defined in Equation (1) of :cite:p:`guler_refining_2024`. """ train_weights = cls._get_weight_from_dataloader( datamodule.unshuffled_train_dataloader() ) logger.info(f"train: MOONBCEWithLogitsLoss(weight={train_weights})") if "validation" in datamodule.val_dataloader().keys(): validation_weights = cls._get_weight_from_dataloader( datamodule.val_dataloader()["validation"] ) else: logger.warning( "Datamodule does not contain a validation dataloader. " "The training dataloader will be used instead." ) validation_weights = train_weights logger.info(f"validation: MOONBCEWithLogitsLoss(weight={validation_weights})") return (dict(weights=train_weights), dict(weights=validation_weights))
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r"""Forward pass. This function inputs the output of the model and a set of binary targets (as a float tensor containing zeroes and ones), and implements Equation (2) from :cite:p:`guler_refining_2024`: .. math:: \mathcal J = -\sum_{i=1}^M w_i^{t_i} \bigl[t_i\log f_i(x) + (1-t_i)\log (1-f_i(x)) \bigr] Parameters ---------- input_ Logits produced by the model to be evaluated, with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation). target Ground-truth information with the shape ``[n, c]`` (classification), or ``[n, c, h, w]`` (segmentation), containing zeroes and ones. Returns ------- The result of Equation (2) from :cite:p:`guler_refining_2024`. Raises ------ AssertionError In case the weights have not be initialized by calling :py:meth:`get_arguments_from_datamodule`. """ assert self.weights is not None, ( f"Weights are not initialized. Call " f"{self.__class__.__name__}.get_arguments_from_datamodule() to sort this." ) if len(input_.shape) > 2: # segmentation input_ = input_.transpose(0, 2).transpose(1, 3).reshape(-1, input_.shape[1]) target = target.transpose(0, 2).transpose(1, 3).reshape(-1, target.shape[1]) weights = ((1.0 - target) * self.weights[0]) + (target * self.weights[1]) individual_losses = torch.nn.functional.binary_cross_entropy_with_logits( input_, target, reduction="none" ) return (individual_losses * weights).mean()
[docs] def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self: """Move loss parameters to specified device. Refer to the method :py:meth:`torch.nn.Module.to` for details. Parameters ---------- *args Parameter forwarded to the underlying implementations. **kwargs Parameter forwarded to the underlying implementations. Returns ------- Self. """ if self.weights is None: logger.warning( f"Weights are not initialized. Call {self.__class__.__name__}." f"get_arguments_from_datamodule() to sort this." ) return self self.weights = self.weights.to(*args, **kwargs) return self