Source code for mednet.engine.detect.evaluator

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines functionality for the evaluation of object detection predictions."""

import collections
import logging
import typing

import matplotlib.axes
import matplotlib.figure
import numpy
import numpy.typing
import tabulate
import torch
import torchvision.ops
from matplotlib import pyplot as plt

from ...models.detect.typing import Prediction

logger = logging.getLogger(__name__)


def _compute_iou_from_predictions(
    predictions: typing.Sequence[Prediction],
) -> list[list[tuple[int, float, int, float]]]:
    """Calculate the IOU for each **detected** bounding-box in predictions.

    This function will calculate the IOU (intersection over union) metric for each
    detected bounding-box (as in output by a model) in the prediction dataset.  It will
    then return a list of tuples, each matching a prediction, indicating the matched
    target, the IOU, the class and the model score.

    Parameters
    ----------
    predictions
        A list of predictions to consider for measurement.

    Returns
    -------
        A list containing lists of tuples, matching the order of **detected**
        bounding-boxes (as in output by a model). Each tuple contains the index of the
        matching target bounding box, the IOU between the target and said detected
        bounding-box, the class of the said target/detected bounding box, and finally
        the model score.

        In case there is no match for a particular detected bounding-box, the output
        table matching this would contain ``(-1, 0.0, <class>, 0.0)``.  This model
        output can be accounted as a "misdetection".
    """

    retval: list[list[tuple[int, float, int, float]]] = []

    for sample in predictions:
        name, targets, detections = sample

        # calculates IOU of all targets against all bounding boxes
        if detections:
            iou = torchvision.ops.box_iou(
                torch.tensor([k[0] for k in targets]),
                torch.tensor([k[0] for k in detections]),
            ).numpy()
        else:
            logger.warning(f"No detections for sample `{name}` were found.")
            retval.append(list())
            continue

        # we are only interested in the positions in which targets and detections have
        # matching labels - everything else can be set to zero on the IOU matrix
        iou *= numpy.equal.outer(
            [k[1] for k in targets], [k[1] for k in detections]
        ).astype(int)

        # the order of attributions need to go from highest to lowest score
        attribution_order = numpy.flip(numpy.argsort([k[2] for k in detections]))

        attributions: list[tuple[int, float, int, float]] = []
        for detection_index in attribution_order:
            max_iou_arg = iou[:, detection_index].argmax()
            if iou[max_iou_arg, detection_index] > 0.0:  # match
                attributions.append(
                    (
                        max_iou_arg,
                        iou[max_iou_arg, detection_index],
                        detections[detection_index][1],
                        detections[detection_index][2],
                    )
                )
                iou[max_iou_arg] = 0.0  # this (ground-truth) target has been attributed
            else:  # no match
                attributions.append((-1, 0.0, detections[detection_index][1], 0.0))

        retval.append(attributions)

    return retval


[docs] def run( predictions: typing.Sequence[Prediction], binning: str | int, iou_threshold: float | None = None, ) -> dict[str, typing.Any]: """Run inference and calculates measures for multilabel object detection. Parameters ---------- predictions A list of predictions to consider for measurement. binning The binning algorithm to use for computing the bin widths and distribution for histograms. Choose from algorithms supported by :py:func:`numpy.histogram`. iou_threshold IOU threshold by which we consider successful object detection. If set to ``None``, then apply no thresholding. Returns ------- A dictionary containing the performance summary on the specified threshold, general performance curves (under the key ``curves``), and score histograms (under the key ``score-histograms``). """ detailed_iou = _compute_iou_from_predictions(predictions) if iou_threshold is not None: filtered_iou = [[k for k in j if k[1] >= iou_threshold] for j in detailed_iou] else: filtered_iou = detailed_iou iou_histogram = dict( zip( ("hist", "bin_edges"), numpy.histogram( [k[1] for j in filtered_iou for k in j if k], bins=binning, range=(0, 1) ), ) ) # the mean-iou only accounts the IoU for matches - non-matches are ignored on this # metric (test: k[0] >= 0) mean_iou = numpy.nan_to_num( numpy.mean([k[1] for j in filtered_iou for k in j if k if k[0] >= 0]) ) classes = sorted(list(set([k[1] for j in predictions for k in j[1]]))) # the mean-iou only accounts the IoU for matches - non-matches are ignored on this # metric (test: k[0] >= 0) mean_iou_per_class = { cl: numpy.nan_to_num( numpy.mean( [ k[1] for j in filtered_iou for k in j if k and k[0] >= 0 and k[2] == cl ] ) ) for cl in classes } num_targets_per_class = collections.Counter( [k[1] for j in predictions for k in j[1]] ) num_detections_per_class = collections.Counter( [k[2] for j in filtered_iou for k in j if k] ) iou_per_class = [(k[2], k[1]) for j in filtered_iou for k in j if k] iou_per_class = {cl: [k[1] for k in iou_per_class if k[0] == cl] for cl in classes} # type: ignore per_class_histogram = { cl: dict( zip( ("hist", "bin_edges"), numpy.histogram(iou_per_class[cl], bins=binning, range=(0, 1)), ) ) for cl in classes } return { "num-samples": len(predictions), "num-targets": sum([len(k[1]) for k in predictions]), "num-detections": len([k[2] for j in filtered_iou for k in j if k]), "mean-iou": mean_iou, "iou-histogram": iou_histogram, "per-class": { cl: { "mean-iou": mean_iou_per_class[cl], "num-targets": num_targets_per_class[cl], "num-detections": num_detections_per_class[cl], } for cl in classes }, "per-class-iou-histogram": per_class_histogram, }
[docs] def make_table( data: typing.Mapping[str, typing.Mapping[str, typing.Any]], fmt: str, ) -> str: """Tabulate summaries from multiple splits. This function can properly tabulate the various summaries produced for all the splits in a prediction database. Parameters ---------- data An iterable over all summary data collected. fmt One of the formats supported by `python-tabulate <https://pypi.org/project/tabulate/>`_. Returns ------- str A string containing the tabulated information. """ def _exclusion_condition(v: str) -> bool: return v not in ("iou-histogram", "per-class", "per-class-iou-histogram") # dump evaluation results in RST format to screen and file table_data = {} for k, v in data.items(): table_data[k] = {kk: vv for kk, vv in v.items() if _exclusion_condition(kk)} example = next(iter(table_data.values())) headers = list(example.keys()) table = [[k[h] for h in headers] for k in table_data.values()] # add subset names headers = ["subset"] + headers table = [[name] + k for name, k in zip(table_data.keys(), table)] return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
def _iou_histogram_plot( histograms: dict[str, numpy.typing.NDArray] | dict[int, dict[str, numpy.typing.NDArray]], title: str, threshold: float | None, ) -> matplotlib.figure.Figure: """Plot the normalized score distributions for all systems. Parameters ---------- histograms A dictionary containing all histograms that should be inserted into the plot. Each histogram should itself be setup as another dictionary containing the keys ``hist`` and ``bin_edges`` as returned by :py:func:`numpy.histogram`. title Title of the plot. threshold Shows where the threshold is in the figure. If set to ``None``, then does not show the threshold line. Returns ------- matplotlib.figure.Figure A single (matplotlib) plot containing the score distribution, ready to be saved to disk or displayed. """ from matplotlib.ticker import MaxNLocator fig, ax = plt.subplots(1, 1) assert isinstance(fig, matplotlib.figure.Figure) ax = typing.cast(matplotlib.axes.Axes, ax) # gets editor to behave # Here, we configure the "style" of our plot ax.set_xlim((0, 1)) ax.set_title(title) ax.set_xlabel("Score") ax.set_ylabel("Count") # Only show ticks on the left and bottom spines ax.spines.right.set_visible(False) ax.spines.top.set_visible(False) ax.get_xaxis().tick_bottom() ax.get_yaxis().tick_left() ax.get_yaxis().set_major_locator(MaxNLocator(integer=True)) # Setup the grid ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) ax.get_xaxis().grid(False) max_hist = 0 has_labels = False if isinstance(next(iter(histograms.keys())), int): histograms = typing.cast(dict[int, dict[str, numpy.typing.NDArray]], histograms) # per-class diagram, requires labels for cl in histograms.keys(): hist = histograms[cl]["hist"] bin_edges = histograms[cl]["bin_edges"] width = 0.7 * (bin_edges[1] - bin_edges[0]) center = (bin_edges[:-1] + bin_edges[1:]) / 2 ax.bar(center, hist, align="center", width=width, label=str(cl), alpha=0.7) max_hist = max(max_hist, hist.max()) has_labels |= True else: # single histogram, no need for labels histograms = typing.cast(dict[str, numpy.typing.NDArray], histograms) hist = histograms["hist"] bin_edges = histograms["bin_edges"] width = 0.7 * (bin_edges[1] - bin_edges[0]) center = (bin_edges[:-1] + bin_edges[1:]) / 2 ax.bar(center, hist, align="center", width=width, alpha=0.7) max_hist = max(max_hist, hist.max()) has_labels |= False # Detach axes from the plot ax.spines["left"].set_position(("data", -0.015)) ax.spines["bottom"].set_position(("data", -0.015 * max_hist)) if threshold is not None: # Adds threshold line (dotted red) ax.axvline( threshold, # type: ignore color="red", lw=2, alpha=0.75, ls="dotted", label="IOU threshold", ) has_labels |= True # Adds a nice legend if has_labels: ax.legend( fancybox=True, framealpha=0.7, ) # Makes sure the figure occupies most of the possible space fig.tight_layout() return fig
[docs] def make_plots( results: dict[str, dict[str, typing.Any]], iou_threshold: float | None = None ) -> list: """Create plots for all curves and score distributions in ``results``. Parameters ---------- results Evaluation data as returned by :py:func:`run`. iou_threshold IOU threshold by which we consider successful object detection. If set, it is shown on plots. Returns ------- A list of figures to record to file """ retval = [] # score plots for split_name, data in results.items(): retval.append( _iou_histogram_plot( data["iou-histogram"], f"IOU distribution (split: {split_name})", threshold=iou_threshold, ) ) if len(data["per-class-iou-histogram"]) > 1: retval.append( _iou_histogram_plot( data["per-class-iou-histogram"], f"IOU distribution per class (split: {split_name})", threshold=iou_threshold, ) ) return retval