[docs]defmake_z_normalizer(dataloader:torch.utils.data.DataLoader,)->torchvision.transforms.Normalize:"""Compute mean and standard deviation from a dataloader. This function will input a dataloader, and compute the mean and standard deviation by image channel. It will work for both monochromatic, and color inputs with 2, 3 or more color planes. Parameters ---------- dataloader A torch Dataloader from which to compute the mean and std. Returns ------- An initialized normalizer. """# Peek the number of channels of batches in the data loaderbatch=next(iter(dataloader))channels=batch["image"].shape[1]# Initialises accumulatorsmean=torch.zeros(channels,dtype=batch["image"].dtype)var=torch.zeros(channels,dtype=batch["image"].dtype)num_images=0# Evaluates mean and standard deviationforbatchintqdm.tqdm(dataloader,unit="batch"):data=batch["image"]data=data.view(data.size(0),data.size(1),-1)num_images+=data.size(0)mean+=data.mean(2).sum(0)var+=data.var(2).sum(0)mean/=num_imagesvar/=num_imagesstd=torch.sqrt(var)returntorchvision.transforms.Normalize(mean,std)
[docs]defmake_imagenet_normalizer()->torchvision.transforms.Normalize:"""Return the stock ImageNet normalisation weights from torchvision. The weights are wrapped in a torch module. This normalizer only works for **RGB (color) images**. Returns ------- An initialized normalizer. """returntorchvision.transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225),)