Source code for mednet.engine.classify.saliency.viewer

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Engine and functions for visualization of saliency maps."""

import logging
import pathlib
import typing

import lightning.pytorch
import numpy
import numpy.typing
import PIL.Image
import PIL.ImageColor
import PIL.ImageDraw
import torchvision.transforms.v2.functional
from torchvision import tv_tensors
from tqdm import tqdm

logger = logging.getLogger(__name__)


def _overlay_saliency_map(
    image: PIL.Image.Image,
    saliencies: numpy.typing.NDArray[numpy.double],
    colormap: typing.Literal[  # we accept any "Sequential" colormap from mpl
        "viridis",
        "plasma",
        "inferno",
        "magma",
        "cividis",
        "Greys",
        "Purples",
        "Blues",
        "Greens",
        "Oranges",
        "Reds",
        "YlOrBr",
        "YlOrRd",
        "OrRd",
        "PuRd",
        "RdPu",
        "BuPu",
        "GnBu",
        "PuBu",
        "YlGnBu",
        "PuBuGn",
        "BuGn",
        "YlGn",
    ],
    image_weight: float,
) -> PIL.Image.Image:
    """Create an overlayed represention of the saliency map on the original
    image.

    This is a slightly modified version of the show_cam_on_image implementation in:
    https://github.com/jacobgil/pytorch-grad-cam, but uses matplotlib instead
    of opencv.

    Parameters
    ----------
    image
        The input image that will be overlaid with the saliency map.
    saliencies
        The saliency map that will be overlaid on the (raw) image.
    colormap
        The name of the (matplotlib) colormap to be used.
    image_weight
        The final result is ``image_weight * image + (1-image_weight) *
        saliency_map``.

    Returns
    -------
    PIL.Image.Image
        A modified version of the input ``image`` with the overlaid saliency
        map.
    """
    import matplotlib

    image_array = numpy.array(image, dtype=numpy.float32) / 255.0

    assert image_array.shape[:2] == saliencies.shape, (
        f"The shape of the saliency map ({saliencies.shape}) is different "
        f"from the shape of the input image ({image_array.shape[:2]})."
    )

    assert saliencies.max() <= 1, (
        f"The input saliency map should be in the range [0, 1] (max={saliencies.max()})"
    )

    assert image_weight > 0 and image_weight < 1, (
        f"image_weight should be in the range [0, 1], but got {image_weight}"
    )

    heatmap = matplotlib.colormaps[colormap](saliencies)

    # For pixels where the mask is zero, the original image pixels are being
    # used without a mask.
    result = numpy.where(
        saliencies[..., numpy.newaxis] == 0,
        image_array,
        (image_weight * image_array) + ((1 - image_weight) * heatmap[:, :, :3]),
    )

    return PIL.Image.fromarray((result * 255).astype(numpy.uint8), "RGB")


def _process_sample(
    raw_data: numpy.typing.NDArray[numpy.double],
    saliencies: numpy.typing.NDArray[numpy.double],
    ground_truth: tv_tensors.BoundingBoxes,
) -> PIL.Image.Image:
    """Generate an overlaid representation of the original sample and saliency
    maps.

    Parameters
    ----------
    raw_data
        The raw data representing the input sample that will be overlaid with
        saliency maps and annotations.
    saliencies
        The saliency map recovered from the model, that will be imprinted on
        the raw_data.
    ground_truth
        Ground-truth annotations that may be imprinted on the final image.

    Returns
    -------
    PIL.Image.Image
        An image with the original raw data overlaid with the different
        elements as selected by the user.
    """

    # we need a colour image to eventually overlay a (coloured) saliency map on
    # the top, draw rectangles and other annotations in coulour.  So, we force
    # it right up front.
    retval = torchvision.transforms.v2.functional.to_pil_image(raw_data).convert(
        "RGB",
    )

    retval = _overlay_saliency_map(
        retval,
        saliencies,
        colormap="plasma",
        image_weight=0.5,
    )

    if ground_truth is not None:
        retval = torchvision.transforms.v2.functional.to_pil_image(
            torchvision.utils.draw_bounding_boxes(
                tv_tensors.Image(retval), ground_truth[0], colors="green", width=2
            )
        )

    return retval


[docs] def run( datamodule: lightning.pytorch.LightningDataModule, input_folder: pathlib.Path, target_label: int, output_folder: pathlib.Path, show_groundtruth: bool, threshold: float, ): """Overlay saliency maps on CXR to output final images with heatmaps. Parameters ---------- datamodule The Lightning DataModule to iterate on. input_folder Directory in which the saliency maps are stored for a specific visualization type. target_label The label to target for evaluating interpretability metrics. Samples contining any other label are ignored. output_folder Directory in which the resulting visualizations will be saved. show_groundtruth If set, imprint ground truth labels over the original image and saliency maps. threshold : float The pixel values above ``threshold`` % of max value are kept in the original saliency map. Everything else is set to zero. The value proposed on :cite:p:`wang_score-cam_2020` is 0.2. Use this value if unsure. """ for dataset_name, dataset_loader in datamodule.predict_dataloader().items(): logger.info( f"Generating visualizations for samples (target = {target_label}) " f"at dataset `{dataset_name}`..." ) for sample in tqdm( dataset_loader, desc="batches", leave=False, disable=None, ): # WARNING: following code assumes a batch size of 1. Will break if # not the case. name = str(sample["name"][0]) label = int(sample["target"].item()) data = sample["image"][0] if label != target_label: # no visualisation was generated continue saliencies = numpy.load( input_folder / pathlib.Path(name).with_suffix(".npy"), ) saliencies[saliencies < (threshold * saliencies.max())] = 0 # TODO: This is very specific to the TBX11k system for labelling # regions of interest. We need to abstract from this to support more # datasets and other ways to annotate. if show_groundtruth: ground_truth = sample.get("bboxes", None) else: ground_truth = None # we fully process this entry image = _process_sample( data, saliencies, ground_truth, ) # Save image output_file_path = output_folder / pathlib.Path(name).with_suffix( ".png", ) output_file_path.parent.mkdir(parents=True, exist_ok=True) image.save(output_file_path)