# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Engine and functions for score completeness analysis."""
import functools
import logging
import multiprocessing
import typing
import lightning.pytorch
import numpy as np
import torch
import tqdm
from pytorch_grad_cam.metrics.road import (
ROADLeastRelevantFirstAverage,
ROADMostRelevantFirstAverage,
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from ....data.typing import Sample
from ....engine.device import DeviceManager
from ....models.classify.typing import SaliencyMapAlgorithm
logger = logging.getLogger(__name__)
[docs]
class SigmoidClassifierOutputTarget(torch.nn.Module):
"""Consider output to be a sigmoid.
Parameters
----------
category
The category.
"""
def __init__(self, category):
self.category = category
def __call__(self, model_output):
sigmoid_output = torch.sigmoid(model_output)
if len(sigmoid_output.shape) == 1:
return sigmoid_output[self.category]
return sigmoid_output[:, self.category]
def _calculate_road_scores(
model: lightning.pytorch.LightningModule,
images: torch.Tensor,
output_num: int,
saliency_map_callable: typing.Callable,
percentiles: typing.Sequence[int],
) -> tuple[float, float, float]:
"""Calculate average ROAD scores for different removal percentiles.
This function calculates ROAD scores by averaging the scores for
different removal (hardcoded) percentiles, for a single input image, a
given visualization method, and a target class.
Parameters
----------
model
Neural network model (e.g. pasa).
images
A batch of input images to use for evaluating the ROAD scores. Currently,
we only support batches with a single image.
output_num
Target output neuron to take into consideration when evaluating the
saliency maps and calculating ROAD scores.
saliency_map_callable
A callable saliency-map generator from grad-cam.
percentiles
A sequence of percentiles (percent x100) integer values indicating the
proportion of pixels to perturb in the original image to calculate both
MoRF and LeRF scores.
Returns
-------
tuple[float, float, float]
A 3-tuple containing floating point numbers representing the
most-relevant-first average score (``morf``), least-relevant-first
average score (``lerf``) and the combined value (``(lerf-morf)/2``).
"""
saliency_map = saliency_map_callable(
input_tensor=images,
targets=[ClassifierOutputTarget(output_num)],
)
cam_metric_roadmorf_avg = ROADMostRelevantFirstAverage(
percentiles=percentiles,
)
cam_metric_roadlerf_avg = ROADLeastRelevantFirstAverage(
percentiles=percentiles,
)
# Calculate ROAD scores for all percentiles and average - this is NOT the
# current processing bottleneck. If you want to optimise anyting, look at
# the evaluation of the perturbation using scipy.sparse at the
# NoisyLinearImputer, part of the grad-cam package (submodule
# ``metrics.road``).
metric_target = [SigmoidClassifierOutputTarget(output_num)]
morf_scores = cam_metric_roadmorf_avg(
input_tensor=images,
cams=saliency_map,
model=model,
targets=metric_target,
)
lerf_scores = cam_metric_roadlerf_avg(
input_tensor=images,
cams=saliency_map,
model=model,
targets=metric_target,
)
return (
float(morf_scores.item()),
float(lerf_scores.item()),
float(lerf_scores.item() - morf_scores.item()) / 2.0,
)
def _process_sample(
sample: Sample,
model: lightning.pytorch.LightningModule,
device: torch.device,
saliency_map_callable: typing.Callable,
target_class: typing.Literal["highest", "all"],
positive_only: bool,
percentiles: typing.Sequence[int],
) -> list:
"""Process a single sample.
Helper function to :py:func:`run` to be used in multiprocessing contexts.
Parameters
----------
sample
The Sample to process.
model
Neural network model (e.g. pasa).
device
The device to process samples on.
saliency_map_callable
A callable saliency-map generator from grad-cam.
target_class
Class to target for saliency estimation. Can be set to
"all" or "highest". "highest" is default, which means
only saliency maps for the class with the highest
activation will be generated.
positive_only
If set, and the model chosen has a single output (binary), then
saliency maps will only be generated for samples of the positive class.
percentiles
A sequence of percentiles (percent x100) integer values indicating the
proportion of pixels to perturb in the original image to calculate both
MoRF and LeRF scores.
Returns
-------
list
A list containing the following items for a particular sample:
* The relative path to the sample.
* The label.
* An index to the specified target_class.
* The computed ROAD scores.
"""
name: str = sample["name"][0]
label: int = int(sample["target"].item())
image = sample["image"]
# in binary classification systems, negative labels may be skipped
if positive_only and (model.num_classes == 1) and (label == 0):
return [name, label]
# chooses target outputs to generate saliency maps for
if model.num_classes > 1: # type: ignore
if target_class == "all":
# test all outputs
for output_num in range(model.num_classes): # type: ignore
results = _calculate_road_scores(
model,
image,
output_num,
saliency_map_callable,
percentiles,
)
return [name, label, output_num, *results]
else:
# we will figure out the output with the highest value and
# evaluate the saliency mapping technique over it.
outputs = saliency_map_callable.activations_and_grads(image) # type: ignore
output_nums = np.argmax(outputs.cpu().data.numpy(), axis=-1)
assert len(output_nums) == 1
results = _calculate_road_scores(
model,
image,
output_nums[0],
saliency_map_callable,
percentiles,
)
return [name, label, output_nums[0], *results]
# default route for binary classification
results = _calculate_road_scores(
model,
image,
0,
saliency_map_callable,
percentiles,
)
return [name, label, 0, *results]
[docs]
def run(
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
saliency_map_algorithm: SaliencyMapAlgorithm,
target_class: typing.Literal["highest", "all"],
positive_only: bool,
percentiles: typing.Sequence[int],
parallel: int,
only_dataset: str | None,
) -> dict[str, list[typing.Any]]:
"""Evaluate ROAD scores for all samples in a DataModule.
The ROAD algorithm was first described in :cite:p:`rong_consistent_2022`. It estimates
explainability (in the completeness sense) of saliency maps by substituting
relevant pixels in the input image by a local average, re-running
prediction on the altered image, and measuring changes in the output
classification score when said perturbations are in place. By substituting
the most or least relevant pixels with surrounding averages, the ROAD algorithm
estimates the importance of such elements in the produced saliency map. As
of 2023, this measurement technique is considered to be one of the
state-of-the-art metrics of explainability.
This function returns a dictionary containing most-relevant-first (remove a
percentile of the most relevant pixels), least-relevant-first (remove a
percentile of the least relevant pixels), and combined ROAD evaluations per
sample for a particular saliency mapping algorithm.
Parameters
----------
model
Neural network model (e.g. pasa).
datamodule
The lightning DataModule to iterate on.
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a lightning accelerator setup.
saliency_map_algorithm
The algorithm for saliency map estimation to use.
target_class
(Use only with multi-label models) Which class to target for CAM
calculation. Can be set to "all" or "highest". "highest" is
default, which means only saliency maps for the class with the highest
activation will be generated.
positive_only
If set, saliency maps will only be generated for positive samples (ie.
label == 1 in a binary classification task). This option is ignored on
a multi-class output model.
percentiles
A sequence of percentiles (percent x100) integer values indicating the
proportion of pixels to perturb in the original image to calculate both
MoRF and LeRF scores.
parallel
Use multiprocessing for data processing: if set to -1, disables
multiprocessing. Set to 0 to enable as many data processing instances
as processing cores available in the system. Set to >= 1 to enable
that many multiprocessing instances for data processing.
only_dataset
If set, will only run this code for the named dataset on the provided
datamodule, skipping any other datasets.
Returns
-------
dict[str, list[typing.Any]]
A dictionary where keys are dataset names in the provide DataModule,
and values are lists containing sample information alongside metrics
calculated:
* Sample name
* Sample target class
* The model output number used for the ROAD analysis (0, for binary
classifers as there is typically only one output).
* ``morf``: ROAD most-relevant-first average of percentiles 20, 40, 60 and
80 (a.k.a. AOPC-MoRF).
* ``lerf``: ROAD least-relevant-first average of percentiles 20, 40, 60 and
80 (a.k.a. AOPC-LeRF).
* combined: Average ROAD combined score by evaluating ``(lerf-morf)/2``
(a.k.a. AOPC-Combined).
"""
from ....models.classify.densenet import Densenet
from ....models.classify.pasa import Pasa
from .generator import _create_saliency_map_callable
if isinstance(model, Pasa):
if saliency_map_algorithm == "fullgrad":
raise ValueError(
"Fullgrad saliency map algorithm is not supported for the Pasa model.",
)
target_layers = [model.fc14] # Last non-1x1 Conv2d layer
elif isinstance(model, Densenet):
target_layers = [
model.model.features.denseblock4.denselayer16.conv2, # type: ignore
]
else:
raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
if device_manager.device_type in ("cuda", "mps") and (
parallel == 0 or parallel > 1
):
raise RuntimeError(
f"The number of multiprocessing instances is set to {parallel} and "
f"you asked to use a GPU (device = `{device_manager.device_type}`"
f"). The current implementation can only handle a single GPU. "
f"Either disable GPU usage, set the number of "
f"multiprocessing instances to one, or disable multiprocessing "
"entirely (ie. set it to -1).",
)
# prepares model for evaluation, cast to target device
device = device_manager.torch_device()
model.eval()
saliency_map_callable = _create_saliency_map_callable(
saliency_map_algorithm,
model,
target_layers, # type: ignore
)
retval: dict[str, list[typing.Any]] = {}
# our worker function
_process = functools.partial(
_process_sample,
model=model,
device=device,
saliency_map_callable=saliency_map_callable,
target_class=target_class,
positive_only=positive_only,
percentiles=percentiles,
)
for k, v in datamodule.predict_dataloader().items():
if only_dataset is not None and k != only_dataset:
logger.warning(
f"Skipping processing for dataset `{k}` following user request..."
)
continue
retval[k] = []
if parallel < 0:
logger.info(
f"Computing ROAD scores for dataset `{k}` in the current "
f"process context...",
)
for sample in tqdm.tqdm(
v,
desc="samples",
leave=False,
disable=None,
):
retval[k].append(_process(sample))
else:
instances = parallel or multiprocessing.cpu_count()
logger.info(
f"Computing ROAD scores for dataset `{k}` using {instances} "
f"processes...",
)
with multiprocessing.Pool(instances) as p:
retval[k] = list(tqdm.tqdm(p.imap(_process, v), total=len(v)))
return retval