mednet.models.segment.unet

UNet network architecture, from [RFB15].

Classes

UNetHead(in_channels_list[, pixel_shuffle])

UNet head module.

Unet([loss_type, loss_arguments, ...])

UNet network architecture, from [RFB15].

class mednet.models.segment.unet.UNetHead(in_channels_list, pixel_shuffle=False)[source]

Bases: Module

UNet head module.

Parameters:
  • in_channels_list (list[int]) – Number of channels for each feature map that is returned from backbone.

  • pixel_shuffle – If True, upsample using PixelShuffleICNR.

forward(x)[source]

Forward pass.

Parameters:

x (list[Tensor]) – List of tensors as returned from the backbone network. First element: height and width of input image. Remaining elements: feature maps for each feature level.

Returns:

OUtput of the forward pass.

class mednet.models.segment.unet.Unet(loss_type=<class 'mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss'>, loss_arguments=None, optimizer_type=<class 'torch.optim.adam.Adam'>, optimizer_arguments=None, scheduler_type=None, scheduler_arguments=None, model_transforms=None, augmentation_transforms=None, pretrained=False)[source]

Bases: Model

UNet network architecture, from [RFB15].

Parameters:
  • loss_type (type[Module]) –

    The loss to be used for training and evaluation.

    Warning

    The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so.

  • loss_arguments (dict[str, Any] | None) – Arguments to the loss.

  • optimizer_type (type[Optimizer]) – The type of optimizer to use for training.

  • optimizer_arguments (dict[str, Any] | None) – Arguments to the optimizer after params.

  • scheduler_type (type[LRScheduler] | None) – The type of scheduler to use for training.

  • scheduler_arguments (dict[str, Any] | None) – Arguments to the scheduler after params.

  • model_transforms (Optional[Sequence[Callable[[Tensor], Tensor]]]) – An optional sequence of torch modules containing transforms to be applied on the input before it is fed into the network.

  • augmentation_transforms (Optional[Sequence[Callable[[Tensor], Tensor]]]) – An optional sequence of torch modules containing transforms to be applied on the input before it is fed into the network.

  • pretrained (bool) – If True, will use VGG16 pretrained weights.

forward(x)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output