Source code for mednet.models.segment.losses

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Specialized losses for semanatic segmentation."""

import torch


[docs] class WeightedBCELogitsLoss(torch.nn.Module): """Calculates sum of weighted cross entropy loss. Implements Equation 1 in [MANINIS-2016]_. The weight depends on the current proportion between negatives and positives in the ground- truth sample being analyzed. """ def __init__(self): super().__init__()
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. Returns ------- The average loss for all input data. """ # calculates the proportion of negatives to the total number of pixels # available in the masked region num_pos = target.sum() return torch.nn.functional.binary_cross_entropy_with_logits( input_, target, reduction="mean", pos_weight=(input_.shape.numel() - num_pos) / num_pos, )
[docs] class SoftJaccardBCELogitsLoss(torch.nn.Module): r"""Implement the generalized loss function of Equation (3) in. [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary Cross-Entropy Loss: .. math:: L = \alpha H + (1-\alpha)(1-J) Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`. Parameters ---------- alpha Determines the weighting of J and H. Default: ``0.7``. """ def __init__(self, alpha: float = 0.7): super().__init__() self.alpha = alpha
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Value produced by the model to be evaluated, with the shape ``[n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. Returns ------- Loss, in a single entry. """ eps = 1e-8 probabilities = torch.sigmoid(input_) intersection = (probabilities * target).sum() sums = probabilities.sum() + target.sum() j = intersection / (sums - intersection + eps) # this implements the support for looking just into the RoI h = torch.nn.functional.binary_cross_entropy_with_logits( input_, target, reduction="mean" ) return (self.alpha * h) + ((1 - self.alpha) * (1 - j))
[docs] class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): """Weighted Binary Cross-Entropy Loss for multi-layered inputs (e.g. for Holistically-Nested Edge Detection in [XIE-2015]_). """ def __init__(self): super().__init__()
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. Returns ------- The average loss for all input data. """ return torch.cat( [ super(MultiWeightedBCELogitsLoss, self).forward(i, target).unsqueeze(0) for i in input_ ] ).mean()
[docs] class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): """Implement Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks such as HED or Little W-Net. Parameters ---------- alpha : float Determines the weighting of SoftJaccard and BCE. Default: ``0.3``. """ def __init__(self, alpha: float = 0.7): super().__init__(alpha=alpha)
[docs] def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- input_ Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]``. target Ground-truth information with the shape ``[n, c, h, w]``. Returns ------- The average loss for all input data. """ return torch.cat( [ super(MultiSoftJaccardBCELogitsLoss, self) .forward(i, target) .unsqueeze(0) for i in input_ ] ).mean()