mednet.models.segment.driu_bn¶
DRIU network architecture
, with added batch-normalization.
Classes
Takes in four feature maps with 16 channels each, concatenates them and applies a 1x1 convolution with 1 output channel. |
|
|
|
|
DRIU with Batch-Normalization head module. |
- class mednet.models.segment.driu_bn.ConcatFuseBlock[source]¶
Bases:
Module
Takes in four feature maps with 16 channels each, concatenates them and applies a 1x1 convolution with 1 output channel.
- forward(x1, x2, x3, x4)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mednet.models.segment.driu_bn.DRIUBNHead(in_channels_list)[source]¶
Bases:
Module
DRIU with Batch-Normalization head module.
Based on paper by [MPTAVG16].
- Parameters:
in_channels_list (list) – Number of channels for each feature map that is returned from backbone.
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mednet.models.segment.driu_bn.DRIUBN(loss_type=<class 'mednet.models.losses.SoftJaccardAndBCEWithLogitsLoss'>, loss_arguments=None, optimizer_type=<class 'torch.optim.adam.Adam'>, optimizer_arguments=None, scheduler_type=None, scheduler_arguments=None, model_transforms=None, augmentation_transforms=None, pretrained=False)[source]¶
Bases:
Model
DRIU network architecture
, with added batch-normalization.- Parameters:
The loss to be used for training and evaluation.
Warning
The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so.
loss_arguments (
dict
[str
,Any
] |None
) – Arguments to the loss.optimizer_type (
type
[Optimizer
]) – The type of optimizer to use for training.optimizer_arguments (
dict
[str
,Any
] |None
) – Arguments to the optimizer afterparams
.scheduler_type (
type
[LRScheduler
] |None
) – The type of scheduler to use for training.scheduler_arguments (
dict
[str
,Any
] |None
) – Arguments to the scheduler afterparams
.model_transforms (
Optional
[Sequence
[Callable
[[Tensor
],Tensor
]]]) – An optional sequence of torch modules containing transforms to be applied on the input before it is fed into the network.augmentation_transforms (
Optional
[Sequence
[Callable
[[Tensor
],Tensor
]]]) – An optional sequence of torch modules containing augmentation transforms to be applied on the input before it is fed into the network.pretrained (
bool
) – If True, will use VGG16 pretrained weights.
- forward(x)[source]¶
Same as
torch.nn.Module.forward()
.- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output