Source code for mednet.models.segment.losses

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Specialized losses for semanatic segmentation."""

import torch


[docs] class BCEWithLogitsLossWeightedPerBatch(torch.nn.Module): """Calculates the binary cross entropy loss for every batch. This loss is similar to :py:class:`torch.nn.BCEWithLogitsLoss`, except it updates the ``pos_weight`` (ratio between negative and positive target pixels) parameter for the loss term for every batch, based on the accumulated taget pixels for all samples in the batch. Implements Equation 1 in [MANINIS-2016]_. The weight depends on the current proportion between negatives and positives in the ground- truth sample being analyzed. """ def __init__(self): super().__init__()
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Logits produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``, containing zeroes and ones. Returns ------- The average loss for all input data. """ # calculates the proportion of negatives to the total number of pixels # available in the masked region num_pos = target.sum() return torch.nn.functional.binary_cross_entropy_with_logits( input_, target, reduction="mean", pos_weight=(input_.numel() - num_pos) / num_pos, )
[docs] class SoftJaccardAndBCEWithLogitsLoss(torch.nn.Module): r"""Implement the generalized loss function of Equation (3) at [IGLOVIKOV-2018]_. At the paper, authors suggest a value of :math:`\alpha = 0.7`, which we set as default for instances of this type. .. math:: L = \alpha H + (1-\alpha)(1-J) J is the Jaccard distance, and H, the Binary Cross-Entropy Loss. Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`. Parameters ---------- alpha Determines the weighting of J and H. Default: ``0.7``. """ def __init__(self, alpha: float = 0.7): super().__init__() self.alpha = alpha
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Logits produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``, containing zeroes and ones. Returns ------- Loss, in a single entry. """ eps = 1e-8 probabilities = torch.sigmoid(input_) intersection = (probabilities * target).sum() sums = probabilities.sum() + target.sum() j = intersection / (sums - intersection + eps) # this implements the support for looking just into the RoI h = torch.nn.functional.binary_cross_entropy_with_logits( input_, target, reduction="mean" ) return (self.alpha * h) + ((1 - self.alpha) * (1 - j))
[docs] class MultiLayerBCELogitsLossWeightedPerBatch(BCEWithLogitsLossWeightedPerBatch): """Weighted Binary Cross-Entropy Loss for multi-layered inputs. This loss can be used in networks that produce more than one output that has to match output targets. For example, architectures such as as :py:class:`.hed.HED` or :py:class:`.lwnet.LittleWNet` require this feature. It follows the inherited super class applying on-the-fly `pos_weight` updates per batch. """ def __init__(self): super().__init__()
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. Returns ------- The average loss for all input data. """ fwd = super().forward return torch.cat([fwd(i, target).unsqueeze(0) for i in input_]).mean()
[docs] class MultiLayerSoftJaccardAndBCELogitsLoss(SoftJaccardAndBCEWithLogitsLoss): """Implement Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks. This loss can be used in networks that produce more than one output that has to match output targets. For example, architectures such as as :py:class:`.hed.HED` or :py:class:`.lwnet.LittleWNet` require this feature. Parameters ---------- alpha : float Determines the weighting of SoftJaccard and BCE. Default: ``0.7``. """ def __init__(self, alpha: float = 0.7): super().__init__(alpha=alpha)
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. Returns ------- The average loss for all input data. """ fwd = super().forward return torch.cat([fwd(i, target).unsqueeze(0) for i in input_]).mean()