mednet.models.segment.m2unet¶
Mobile2 UNet network architecture, from [LAIBACHER-2018].
Classes
|
Decoder block: upsample and concatenate with features maps from the encoder part. |
|
Last decoder block. |
|
M2U-Net head module. |
|
Mobile2 UNet network architecture, from [LAIBACHER-2018]. |
- class mednet.models.segment.m2unet.DecoderBlock(up_in_c, x_in_c, upsamplemode='bilinear', expand_ratio=0.15)[source]¶
Bases:
ModuleDecoder block: upsample and concatenate with features maps from the encoder part.
- Parameters:
up_in_c – Number of input channels.
x_in_c – Number of cat channels.
upsamplemode – Mode to use for upsampling.
expand_ratio – The expand ratio.
- forward(up_in, x_in)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mednet.models.segment.m2unet.LastDecoderBlock(x_in_c, upsamplemode='bilinear', expand_ratio=0.15)[source]¶
Bases:
ModuleLast decoder block.
- Parameters:
x_in_c – Number of cat channels.
upsamplemode – Mode to use for upsampling.
expand_ratio – The expand ratio.
- forward(up_in, x_in)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mednet.models.segment.m2unet.M2UNetHead(in_channels_list=None, upsamplemode='bilinear', expand_ratio=0.15)[source]¶
Bases:
ModuleM2U-Net head module.
- Parameters:
in_channels_list – Number of channels for each feature map that is returned from backbone.
upsamplemode – Mode to use for upsampling.
expand_ratio – The expand ratio.
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mednet.models.segment.m2unet.M2Unet(loss_type=<class 'mednet.models.segment.losses.SoftJaccardAndBCEWithLogitsLoss'>, loss_arguments={}, optimizer_type=<class 'torch.optim.adam.Adam'>, optimizer_arguments={}, scheduler_type=None, scheduler_arguments={}, model_transforms=[], augmentation_transforms=[], num_classes=1, pretrained=False)[source]¶
Bases:
ModelMobile2 UNet network architecture, from [LAIBACHER-2018].
- Parameters:
The loss to be used for training and evaluation.
Warning
The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so.
optimizer_type (
type[Optimizer]) – The type of optimizer to use for training.optimizer_arguments (
dict[str,Any]) – Arguments to the optimizer afterparams.scheduler_type (
type[LRScheduler] |None) – The type of scheduler to use for training.scheduler_arguments (
dict[str,Any]) – Arguments to the scheduler afterparams.model_transforms (
Sequence[Callable[[Tensor],Tensor]]) – An optional sequence of torch modules containing transforms to be applied on the input before it is fed into the network.augmentation_transforms (
Sequence[Callable[[Tensor],Tensor]]) – An optional sequence of torch modules containing transforms to be applied on the input before it is fed into the network.num_classes (
int) – Number of outputs (classes) for this model.pretrained (
bool) – If True, will use VGG16 pretrained weights.
- forward(x)[source]¶
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- set_normalizer(dataloader)[source]¶
Initialize the normalizer for the current model.
This function is NOOP if
pretrained = True(normalizer set to imagenet weights, during contruction).- Parameters:
dataloader (
DataLoader) – A torch Dataloader from which to compute the mean and std. Will not be used if the model is pretrained.- Return type: