[docs]classVGG4Segmentation(torchvision.models.vgg.VGG):"""Adaptation of `VGG pytorch model <vgg-pytorch_>`_ functionality to U-Net style network for segmentation. This version of VGG is slightly modified so it can be used through torchvision's API. It outputs intermediate features which are normally not output by the base VGG implementation, but are required for segmentation operations. Parameters ---------- *args Arguments to be passed to the parent VGG model. **kwargs Keyword arguments to be passed to the parent VGG model. * ``return_features`` (:py:class:`list`): An optional list of integers indicating the feature layers to be returned from the original module. """def__init__(self,*args,**kwargs):self._return_features=kwargs.pop("return_features")super().__init__(*args,**kwargs)
[docs]defforward(self,x):outputs=[]# hardwiring of inputoutputs.append(x.shape[2:4])forindex,minenumerate(self.features):x=m(x)# extract layersifindexinself._return_features:outputs.append(x)returnoutputs
def_make_vgg16_type_d_for_segmentation(pretrained,batch_norm,progress,**kwargs):ifpretrained:kwargs["init_weights"]=Falsemodel=VGG4Segmentation(torchvision.models.vgg.make_layers(torchvision.models.vgg.cfgs["D"],batch_norm=batch_norm,),**kwargs,)ifpretrained:weights=(torchvision.models.vgg.VGG16_Weights.DEFAULT.urlifnotbatch_normelsetorchvision.models.vgg.VGG16_BN_Weights.DEFAULT.url)state_dict=load_state_dict_from_url(weights,progress=progress)model.load_state_dict(state_dict)# erase VGG head (for classification), not used for segmentationdelattr(model,"classifier")delattr(model,"avgpool")returnmodel
[docs]defvgg16_for_segmentation(pretrained=False,progress=True,**kwargs):"""Create an instance of VGG16. Parameters ---------- pretrained If True, usees VGG16 pretrained weights. progress If True, shows a progress bar when downloading weights. **kwargs Keyword arguments to be passed to the parent VGG model. * ``return_features`` (:py:class:`list`): An optional list of integers indicating the feature layers to be returned from the original module. Returns ------- Instance of VGG16. """return_make_vgg16_type_d_for_segmentation(pretrained=pretrained,batch_norm=False,progress=progress,**kwargs)
[docs]defvgg16_bn_for_segmentation(pretrained=False,progress=True,**kwargs):"""Create an instance of VGG16 with batch norm. Parameters ---------- pretrained If True, usees VGG16 pretrained weights. progress If True, shows a progress bar when downloading weights. **kwargs Keyword arguments to be passed to the parent VGG model. * ``return_features`` (:py:class:`list`): An optional list of integers indicating the feature layers to be returned from the original module. Returns ------- Instance of VGG16. """return_make_vgg16_type_d_for_segmentation(pretrained=pretrained,batch_norm=True,progress=progress,**kwargs)