mednet.models.losses¶
Custom losses for different tasks.
Functions
|
Generate the |
Classes
Calculates the binary cross entropy loss for every batch. |
|
|
Calculates the weighted binary cross entropy loss based on [GGA24]. |
Weighted Binary Cross-Entropy Loss for multi-layered inputs. |
|
Implement Equation 3 in [ISBS18] for the multi-output networks. |
|
|
Implement the generalized loss function of Equation (3) at [ISBS18]. |
- mednet.models.losses.pos_weight_for_bcewithlogitsloss(datamodule)[source]¶
Generate the
pos_weight
argument for losses of typetorch.nn.BCEWithLogitsLoss
.This function can generate the
pos_weight
parameters for both train and validation losses given a datamodule.
- class mednet.models.losses.BCEWithLogitsLossWeightedPerBatch[source]¶
Bases:
Module
Calculates the binary cross entropy loss for every batch.
This loss is similar to
torch.nn.BCEWithLogitsLoss
, except it updates thepos_weight
(ratio between negative and positive target pixels) parameter for the loss term for every batch, based on the accumulated taget pixels for all samples in the batch.Implements Equation 1 in [MPTAVG16]. The weight depends on the current proportion between negatives and positives in the ground- truth sample being analyzed.
- forward(input_, target)[source]¶
Forward pass.
- Parameters:
input – Logits produced by the model to be evaluated, with the shape
[n, c]
(classification), or[n, c, h, w]
(segmentation).target (
Tensor
) – Ground-truth information with the shape[n, c]
(classification), or[n, c, h, w]
(segmentation), containing zeroes and ones.
- Return type:
- Returns:
The average loss for all input data.
- class mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss(alpha=0.7)[source]¶
Bases:
Module
Implement the generalized loss function of Equation (3) at [ISBS18].
At the paper, authors suggest a value of \(\alpha = 0.7\), which we set as default for instances of this type.
\[L = \alpha H + (1-\alpha)(1-J)\]J is the Jaccard distance, and H, the Binary Cross-Entropy Loss. Our implementation is based on
torch.nn.BCEWithLogitsLoss
.- Parameters:
alpha (
float
) – Determines the weighting of J and H. Default:0.7
.
- forward(input_, target)[source]¶
Forward pass.
- Parameters:
input – Logits produced by the model to be evaluated, with the shape
[n, c]
(classification), or[n, c, h, w]
(segmentation).target (
Tensor
) – Ground-truth information with the shape[n, c]
(classification), or[n, c, h, w]
(segmentation), containing zeroes and ones.
- Return type:
- Returns:
Loss, in a single entry.
- class mednet.models.losses.MultiLayerBCELogitsLossWeightedPerBatch[source]¶
Bases:
BCEWithLogitsLossWeightedPerBatch
Weighted Binary Cross-Entropy Loss for multi-layered inputs.
This loss can be used in networks that produce more than one output that has to match output targets. For example, architectures such as as
hed.HED
orlwnet.LittleWNet
require this feature.It follows the inherited super class applying on-the-fly pos_weight updates per batch.
- forward(input_, target)[source]¶
Forward pass.
- Parameters:
input – Logits produced by the model to be evaluated, with the shape
[n, c]
(classification), or[n, c, h, w]
(segmentation).target (
Tensor
) – Ground-truth information with the shape[n, c]
(classification), or[n, c, h, w]
(segmentation), containing zeroes and ones.
- Return type:
- Returns:
The average loss for all input data.
- class mednet.models.losses.MultiLayerSoftJaccardAndBCELogitsLoss(alpha=0.7)[source]¶
Bases:
SoftJaccardAndBCEWithLogitsLoss
Implement Equation 3 in [ISBS18] for the multi-output networks.
This loss can be used in networks that produce more than one output that has to match output targets. For example, architectures such as as
hed.HED
orlwnet.LittleWNet
require this feature.- Parameters:
alpha (float) – Determines the weighting of SoftJaccard and BCE. Default:
0.7
.
- forward(input_, target)[source]¶
Forward pass.
- Parameters:
input – Logits produced by the model to be evaluated, with the shape
[n, c]
(classification), or[n, c, h, w]
(segmentation).target (
Tensor
) – Ground-truth information with the shape[n, c]
(classification), or[n, c, h, w]
(segmentation), containing zeroes and ones.
- Return type:
- Returns:
The average loss for all input data.
- class mednet.models.losses.MOONBCEWithLogitsLoss(weights=None)[source]¶
Bases:
Module
Calculates the weighted binary cross entropy loss based on [GGA24].
This loss implements the domain-adapted multitask loss function in Equation (2) on [GGA24]. The vector of input weights must be calculated from the input dataset in advance, and set during initialization, or later, before the loss can be fully used.
- Parameters:
weights (
Tensor
|None
) – The positive weight of each class in the dataset given as input as a[2, C]
tensor, with \(w_i^-\) at position 0, and \(w_i^+\) at position 1, as defined in Equation (1) of [GGA24].
- classmethod get_arguments_from_datamodule(datamodule)[source]¶
Compute the MOON weights for train and validation sets of a datamodule.
This function inputs a
data.datamodule.ConcatDataModule
, and for both the training and validation sets, and for each class on the respective dataloader targets, computes negative and positive weights as such:\[\begin{split}\begin{align} w_i^+ &= \begin{cases} 1 & \text{if } S^{-}_{i} > S^{+}_{i} \\ \frac{S^{-}_{i}}{S^{+}_{i}} & \text{otherwise} \end{cases} & w_i^- &= \begin{cases} 1 & \text{if } S^{+}_{i} > S^{-}_{i} \\ \frac{S^{+}_{i}}{S^{-}_{i}} & \text{otherwise} \end{cases} \end{align}\end{split}\]This weight vector is used during runtime to balance individual batch losses respecting individual class distributions.
- Parameters:
datamodule (
ConcatDataModule
) – The datamodule to probe for training and validation datasets.- Return type:
- Returns:
A tuple containing the training and validation
weight
arguments, wrapped in a dictionary. Eachweight
variable contains the weights of each class in the target dataset as a[2, C]
tensor, with \(w_i^-\) at position 0, and \(w_i^+\) at position 1, as defined in Equation (1) of [GGA24].
- forward(input_, target)[source]¶
Forward pass.
This function inputs the output of the model and a set of binary targets (as a float tensor containing zeroes and ones), and implements Equation (2) from [GGA24]:
\[\mathcal J = -\sum_{i=1}^M w_i^{t_i} \bigl[t_i\log f_i(x) + (1-t_i)\log (1-f_i(x)) \bigr]\]- Parameters:
input – Logits produced by the model to be evaluated, with the shape
[n, c]
(classification), or[n, c, h, w]
(segmentation).target (
Tensor
) – Ground-truth information with the shape[n, c]
(classification), or[n, c, h, w]
(segmentation), containing zeroes and ones.
- Return type:
- Returns:
The result of Equation (2) from [GGA24].
- Raises:
AssertionError – In case the weights have not be initialized by calling
get_arguments_from_datamodule()
.
- to(*args, **kwargs)[source]¶
Move loss parameters to specified device.
Refer to the method
torch.nn.Module.to()
for details.