[docs]defcrop_image_to_mask(img:torch.Tensor,mask:torch.Tensor)->torch.Tensor:"""Square crop image to the boundaries of a boolean mask. Parameters ---------- img The image to crop, of shape channels x height x width. mask The boolean mask to use for cropping. Returns ------- The cropped image. """ifimg.shape[-2:]!=mask.shape[-2:]:raiseValueError(f"Image and mask must have the same size: {img.shape[-2:]} != {mask.shape[-2:]}")h,w=img.shape[-2:]flat_mask=mask.flatten()top=flat_mask.nonzero()[0]//wbottom=h-(torch.flip(flat_mask,dims=(0,)).nonzero()[0]//w)flat_transposed_mask=torch.transpose(mask,1,2).flatten()left=flat_transposed_mask.nonzero()[0]//hright=w-(torch.flip(flat_transposed_mask,dims=(0,)).nonzero()[0]//h)returnimg[:,top:bottom,left:right]
[docs]defsquare_center_pad(img:torch.Tensor,size:typing.Any)->torch.Tensor:"""Return a squared version of the image, centered on a canvas padded with zeros. Parameters ---------- img The tensor to be transformed. Expected to be in the form: ``[..., [1,3], H, W]`` (i.e. arbitrary number of leading dimensions). size Height and width of the image. Returns ------- Transformed tensor, guaranteed to be square (ie. equal height and width). """height,width=sizemaxdim=numpy.max([height,width])# paddingleft=(maxdim-width)//2top=(maxdim-height)//2right=maxdim-width-leftbottom=maxdim-height-topreturntorchvision.transforms.v2.functional.pad(img,[left,top,right,bottom],0,"constant",)
[docs]classSquareCenterPad(torchvision.transforms.v2.Transform):"""Transform to a squared version of the image, centered on a canvas padded with zeros. """def__init__(self):super().__init__()@singledispatchmethoddef_transform(self,inpt:typing.Any,params:dict[str,typing.Any])->typing.Any:returninpt@_transform.register(torch.Tensor)@_transform.register(torchvision.tv_tensors.Image)@_transform.register(torchvision.tv_tensors.Mask)def_(self,inpt:torch.Tensor|torchvision.tv_tensors.Image|torchvision.tv_tensors.Mask,params:dict[str,typing.Any],)->torch.Tensor|torchvision.tv_tensors.Image|torchvision.tv_tensors.Mask:returnsquare_center_pad(inpt,inpt.shape[-2:])@_transform.register(torchvision.tv_tensors.BoundingBoxes)def_(self,inpt:torchvision.tv_tensors.BoundingBoxes,params:dict[str,typing.Any])->torchvision.tv_tensors.BoundingBoxes:returnsquare_center_pad(inpt,inpt.canvas_size)