# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Image transformations for our pipelines.
Differences between methods here and those from
:py:mod:`torchvision.transforms` is that these support multiple
simultaneous image inputs, which are required to feed segmentation
networks (e.g. image and labels or masks). We also take care of data
augmentations, in which random flipping and rotation needs to be applied
across all input images, but color jittering, for example, only on the
input image.
"""
import functools
import logging
import typing
import numpy.random
import numpy.typing
import torch
from scipy.ndimage import gaussian_filter, map_coordinates
logger = logging.getLogger(__name__)
def _elastic_deformation_on_image(
img: torch.Tensor,
rng: numpy.random.Generator,
alpha: float = 1000.0,
sigma: float = 30.0,
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
) -> torch.Tensor:
"""Perform elastic deformation on an image.
This implementation is based on 2 scipy functions
(:py:func:`scipy.ndimage.gaussian_filter` and
:py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it
requires data to be moved off the current running device and then back.
Parameters
----------
img
The input image to apply elastic deformation to. This image should
always have this shape: ``[C, H, W]``. It should always represent a
tensor on the CPU.
rng
The numpy random number generator to use.
alpha
A multiplier for the gaussian filter outputs.
sigma
Standard deviation for Gaussian kernel.
spline_order
The order of the spline interpolation, default is 1. The order has to
be in the range 0-5.
mode
The mode parameter determines how the input array is extended beyond
its boundaries.
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
Returns
-------
tensor
The image with elastic deformation applied, as a tensor on the CPU.
"""
if rng.random() < p:
assert img.ndim == 3, (
f"This filter accepts only images with 3 dimensions, "
f"however I got an image with {img.ndim} dimensions."
)
# Input tensor is of shape C x H x W
img_shape = img.shape[1:]
dx = alpha * typing.cast(
numpy.typing.NDArray[numpy.float64],
gaussian_filter(
(rng.random(size=(img_shape[0], img_shape[1])) * 2 - 1),
sigma,
mode="constant",
cval=0.0,
),
)
dy = alpha * typing.cast(
numpy.typing.NDArray[numpy.float64],
gaussian_filter(
(rng.random(size=(img_shape[0], img_shape[1])) * 2 - 1),
sigma,
mode="constant",
cval=0.0,
),
)
x, y = numpy.meshgrid(
numpy.arange(img_shape[0]),
numpy.arange(img_shape[1]),
indexing="ij",
)
indices = [
numpy.reshape(x + dx, (-1, 1)),
numpy.reshape(y + dy, (-1, 1)),
]
# may copy, if img is not on CPU originally
img_numpy = img.numpy()
output = numpy.zeros_like(img_numpy)
for i in range(img.shape[0]):
output[i, :, :] = torch.tensor(
map_coordinates(
img_numpy[i, :, :],
indices,
order=spline_order,
mode=mode,
).reshape(img_shape),
)
# wraps numpy array as tensor, move to destination device if need-be
return torch.as_tensor(output)
return img
def _elastic_deformation_on_batch(
batch: torch.Tensor,
rng: numpy.random.Generator,
alpha: float = 1000.0,
sigma: float = 30.0,
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
) -> torch.Tensor:
"""Perform elastic deformation on a batch of images.
This implementation is based on 2 scipy functions
(:py:func:`scipy.ndimage.gaussian_filter` and
:py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it
requires data to be moved off the current running device and then back.
Parameters
----------
batch
The batch to apply elastic deformation to.
rng
The numpy random number generator to use.
alpha
A multiplier for the gaussian filter outputs.
sigma
Standard deviation for Gaussian kernel.
spline_order
The order of the spline interpolation, default is 1. The order has to
be in the range 0-5.
mode
The mode parameter determines how the input array is extended beyond
its boundaries.
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
Returns
-------
tensor
A batch of images with elastic deformation applied, as a tensor on the CPU.
"""
# transforms our custom functions into simpler callables
partial = functools.partial(
_elastic_deformation_on_image,
rng=rng,
alpha=alpha,
sigma=sigma,
spline_order=spline_order,
mode=mode,
p=p,
)
# if a mp pool is available, do it in parallel
augmented_images: typing.Any
augmented_images = map(partial, batch.cpu())
return torch.stack(list(augmented_images))