mednet.models.losses

Custom losses for different tasks.

Functions

pos_weight_for_bcewithlogitsloss(datamodule)

Generate the pos_weight argument for losses of type torch.nn.BCEWithLogitsLoss.

Classes

BCEWithLogitsLossWeightedPerBatch()

Calculates the binary cross entropy loss for every batch.

MOONBCEWithLogitsLoss([weights])

Calculates the weighted binary cross entropy loss based on [GGA24].

MultiLayerBCELogitsLossWeightedPerBatch()

Weighted Binary Cross-Entropy Loss for multi-layered inputs.

MultiLayerSoftJaccardAndBCELogitsLoss([alpha])

Implement Equation 3 in [ISBS18] for the multi-output networks.

SoftJaccardAndBCEWithLogitsLoss([alpha])

Implement the generalized loss function of Equation (3) at [ISBS18].

mednet.models.losses.pos_weight_for_bcewithlogitsloss(datamodule)[source]

Generate the pos_weight argument for losses of type torch.nn.BCEWithLogitsLoss.

This function can generate the pos_weight parameters for both train and validation losses given a datamodule.

Parameters:

datamodule (ConcatDataModule) – The datamodule to probe for training and validation datasets.

Return type:

tuple[dict[str, Tensor], dict[str, Tensor]]

Returns:

A tuple containing the training and validation pos_weight arguments, wrapped in a dictionary.

class mednet.models.losses.BCEWithLogitsLossWeightedPerBatch[source]

Bases: Module

Calculates the binary cross entropy loss for every batch.

This loss is similar to torch.nn.BCEWithLogitsLoss, except it updates the pos_weight (ratio between negative and positive target pixels) parameter for the loss term for every batch, based on the accumulated taget pixels for all samples in the batch.

Implements Equation 1 in [MPTAVG16]. The weight depends on the current proportion between negatives and positives in the ground- truth sample being analyzed.

forward(input_, target)[source]

Forward pass.

Parameters:
  • input – Logits produced by the model to be evaluated, with the shape [n, c] (classification), or [n, c, h, w] (segmentation).

  • target (Tensor) – Ground-truth information with the shape [n, c] (classification), or [n, c, h, w] (segmentation), containing zeroes and ones.

Return type:

Tensor

Returns:

The average loss for all input data.

class mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss(alpha=0.7)[source]

Bases: Module

Implement the generalized loss function of Equation (3) at [ISBS18].

At the paper, authors suggest a value of \(\alpha = 0.7\), which we set as default for instances of this type.

\[L = \alpha H + (1-\alpha)(1-J)\]

J is the Jaccard distance, and H, the Binary Cross-Entropy Loss. Our implementation is based on torch.nn.BCEWithLogitsLoss.

Parameters:

alpha (float) – Determines the weighting of J and H. Default: 0.7.

forward(input_, target)[source]

Forward pass.

Parameters:
  • input – Logits produced by the model to be evaluated, with the shape [n, c] (classification), or [n, c, h, w] (segmentation).

  • target (Tensor) – Ground-truth information with the shape [n, c] (classification), or [n, c, h, w] (segmentation), containing zeroes and ones.

Return type:

Tensor

Returns:

Loss, in a single entry.

class mednet.models.losses.MultiLayerBCELogitsLossWeightedPerBatch[source]

Bases: BCEWithLogitsLossWeightedPerBatch

Weighted Binary Cross-Entropy Loss for multi-layered inputs.

This loss can be used in networks that produce more than one output that has to match output targets. For example, architectures such as as hed.HED or lwnet.LittleWNet require this feature.

It follows the inherited super class applying on-the-fly pos_weight updates per batch.

forward(input_, target)[source]

Forward pass.

Parameters:
  • input – Logits produced by the model to be evaluated, with the shape [n, c] (classification), or [n, c, h, w] (segmentation).

  • target (Tensor) – Ground-truth information with the shape [n, c] (classification), or [n, c, h, w] (segmentation), containing zeroes and ones.

Return type:

Tensor

Returns:

The average loss for all input data.

class mednet.models.losses.MultiLayerSoftJaccardAndBCELogitsLoss(alpha=0.7)[source]

Bases: SoftJaccardAndBCEWithLogitsLoss

Implement Equation 3 in [ISBS18] for the multi-output networks.

This loss can be used in networks that produce more than one output that has to match output targets. For example, architectures such as as hed.HED or lwnet.LittleWNet require this feature.

Parameters:

alpha (float) – Determines the weighting of SoftJaccard and BCE. Default: 0.7.

forward(input_, target)[source]

Forward pass.

Parameters:
  • input – Logits produced by the model to be evaluated, with the shape [n, c] (classification), or [n, c, h, w] (segmentation).

  • target (Tensor) – Ground-truth information with the shape [n, c] (classification), or [n, c, h, w] (segmentation), containing zeroes and ones.

Return type:

Tensor

Returns:

The average loss for all input data.

class mednet.models.losses.MOONBCEWithLogitsLoss(weights=None)[source]

Bases: Module

Calculates the weighted binary cross entropy loss based on [GGA24].

This loss implements the domain-adapted multitask loss function in Equation (2) on [GGA24]. The vector of input weights must be calculated from the input dataset in advance, and set during initialization, or later, before the loss can be fully used.

Parameters:

weights (Tensor | None) – The positive weight of each class in the dataset given as input as a [2, C] tensor, with \(w_i^-\) at position 0, and \(w_i^+\) at position 1, as defined in Equation (1) of [GGA24].

classmethod get_arguments_from_datamodule(datamodule)[source]

Compute the MOON weights for train and validation sets of a datamodule.

This function inputs a data.datamodule.ConcatDataModule, and for both the training and validation sets, and for each class on the respective dataloader targets, computes negative and positive weights as such:

\[\begin{split}\begin{align} w_i^+ &= \begin{cases} 1 & \text{if } S^{-}_{i} > S^{+}_{i} \\ \frac{S^{-}_{i}}{S^{+}_{i}} & \text{otherwise} \end{cases} & w_i^- &= \begin{cases} 1 & \text{if } S^{+}_{i} > S^{-}_{i} \\ \frac{S^{+}_{i}}{S^{-}_{i}} & \text{otherwise} \end{cases} \end{align}\end{split}\]

This weight vector is used during runtime to balance individual batch losses respecting individual class distributions.

Parameters:

datamodule (ConcatDataModule) – The datamodule to probe for training and validation datasets.

Return type:

tuple[dict[str, Tensor], dict[str, Tensor]]

Returns:

A tuple containing the training and validation weight arguments, wrapped in a dictionary. Each weight variable contains the weights of each class in the target dataset as a [2, C] tensor, with \(w_i^-\) at position 0, and \(w_i^+\) at position 1, as defined in Equation (1) of [GGA24].

forward(input_, target)[source]

Forward pass.

This function inputs the output of the model and a set of binary targets (as a float tensor containing zeroes and ones), and implements Equation (2) from [GGA24]:

\[\mathcal J = -\sum_{i=1}^M w_i^{t_i} \bigl[t_i\log f_i(x) + (1-t_i)\log (1-f_i(x)) \bigr]\]
Parameters:
  • input – Logits produced by the model to be evaluated, with the shape [n, c] (classification), or [n, c, h, w] (segmentation).

  • target (Tensor) – Ground-truth information with the shape [n, c] (classification), or [n, c, h, w] (segmentation), containing zeroes and ones.

Return type:

Tensor

Returns:

The result of Equation (2) from [GGA24].

Raises:

AssertionError – In case the weights have not be initialized by calling get_arguments_from_datamodule().

to(*args, **kwargs)[source]

Move loss parameters to specified device.

Refer to the method torch.nn.Module.to() for details.

Parameters:
  • *args (Any) – Parameter forwarded to the underlying implementations.

  • **kwargs (Any) – Parameter forwarded to the underlying implementations.

Return type:

Self

Returns:

Self.