mednet.models.classify.loss_weights

Helpers for computing (sample/label) weights for loss terms.

Functions

make_balanced_bcewithlogitsloss(dataloader)

Return a balanced binary-cross-entropy loss.

mednet.models.classify.loss_weights.make_balanced_bcewithlogitsloss(dataloader)[source]

Return a balanced binary-cross-entropy loss.

The loss is weighted using the ratio between positives and total examples available.

Parameters:

dataloader (DataLoader[tuple[Mapping[str, Any], Mapping[str, Any]]]) – The DataLoader to use to compute the BCE weights.

Returns:

An instance of the weighted loss.

Return type:

torch.nn.BCEWithLogitsLoss