[docs]classMobileNetV24Segmentation(torchvision.models.mobilenetv2.MobileNetV2):"""Adaptation of `MobileNetV2 pytorch model <mobilenetv2-pytorch_>`_ to U-Net style network for segmentation. This version of MobileNetV2 is slightly modified so it can be used through torchvision's API. It outputs intermediate features which are normally not output by the base MobileNetV2 implementation, but are required for segmentation operations. Parameters ---------- *args Arguments to be passed to the parent MobileNetV2 model. **kwargs Keyword arguments to be passed to the parent MobileNetV2 model. * ``return_features`` (:py:class:`list`): An optional list of integers indicating the feature layers to be returned from the original module. """def__init__(self,*args,**kwargs):self._return_features=kwargs.pop("return_features")super().__init__(*args,**kwargs)
[docs]defforward(self,x):outputs=[]# hw of input, needed for DRIU and HEDoutputs.append(x.shape[2:4])outputs.append(x)forindex,minenumerate(self.features):x=m(x)# extract layersifindexinself._return_features:outputs.append(x)returnoutputs
[docs]defmobilenet_v2_for_segmentation(pretrained=False,progress=True,**kwargs):"""Create MobileNetV2 model for segmentation task. Parameters ---------- pretrained If True, uses MobileNetV2 pretrained weights. progress If True, shows a progress bar when downloading the pretrained weights. **kwargs Keyword arguments to be passed to the parent MobileNetV2 model. * ``return_features`` (:py:class:`list`): An optional list of integers indicating the feature layers to be returned from the original module. Returns ------- Instance of the MobileNetV2 model for segmentation. """model=MobileNetV24Segmentation(**kwargs)ifpretrained:state_dict=load_state_dict_from_url(torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT.url,progress=progress,)model.load_state_dict(state_dict)# erase MobileNetV2 head (for classification), not used for segmentationdelattr(model,"classifier")return_features=kwargs.get("return_features")ifreturn_featuresisnotNone:model.features=model.features[:(max(return_features)+1)]returnmodel