Source code for mednet.engine.classify.evaluator

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines functionality for the evaluation of classification predictions."""

import logging
import typing
from collections.abc import Iterable

import credible.bayesian.metrics
import credible.curves
import credible.plot
import matplotlib.axes
import matplotlib.figure
import numpy
import numpy.typing
import sklearn.metrics
import tabulate
from matplotlib import pyplot as plt
from sklearn.preprocessing import LabelBinarizer

from ...models.classify.typing import Prediction

logger = logging.getLogger(__name__)


[docs] def eer_threshold(predictions: Iterable[Prediction]) -> float: """Calculate the (approximate) threshold leading to the equal error rate. For multi-label problems, calculate the EER threshold in the "micro" sense by first rasterizing all scores and labels (with :py:func:`numpy.ravel`), and then using this (large) 1D vector like in a binary classifier. Parameters ---------- predictions An iterable of multiple :py:data:`.models.classify.typing.Prediction`'s. Returns ------- float The EER threshold value. """ from scipy.interpolate import interp1d from scipy.optimize import brentq y_scores = numpy.array([k[2] for k in predictions]).ravel() y_labels = numpy.array([k[1] for k in predictions]).ravel() fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_labels, y_scores) eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) return float(interp1d(fpr, thresholds)(eer))
def _get_centered_maxf1( f1_scores: numpy.typing.NDArray, thresholds: numpy.typing.NDArray, ) -> tuple[float, float]: """Return the centered max F1 score threshold when multiple thresholds give the same max F1 score. Parameters ---------- f1_scores 1D array of f1 scores. thresholds 1D array of thresholds. Returns ------- tuple(float, float) A tuple with the maximum F1-score and the "centered" threshold. """ maxf1 = f1_scores.max() maxf1_indices = numpy.where(f1_scores == maxf1)[0] # If multiple thresholds give the same max F1 score if len(maxf1_indices) > 1: mean_maxf1_index = int(round(numpy.mean(maxf1_indices))) else: mean_maxf1_index = maxf1_indices[0] return maxf1, thresholds[mean_maxf1_index]
[docs] def maxf1_threshold(predictions: Iterable[Prediction]) -> float: """Calculate the threshold leading to the maximum F1-score on a precision- recall curve. For multi-label problems, calculate the maximum F1-core threshold in the "micro" sense by first rasterizing all scores and labels (with :py:func:`numpy.ravel`), and then using this (large) 1D vector like in a binary classifier. Parameters ---------- predictions An iterable of multiple :py:data:`.models.classify.typing.Prediction`'s. Returns ------- float The threshold value leading to the maximum F1-score on the provided set of predictions. """ y_scores = numpy.array([k[2] for k in predictions]).ravel() y_labels = numpy.array([k[1] for k in predictions]).ravel() precision, recall, thresholds = sklearn.metrics.precision_recall_curve( y_labels, y_scores, ) numerator = 2 * recall * precision denom = recall + precision f1_scores = numpy.divide( numerator, denom, out=numpy.zeros_like(denom), where=(denom != 0), ) _, maxf1_threshold = _get_centered_maxf1(f1_scores, thresholds) return maxf1_threshold
[docs] def run_single( name: str, predictions: typing.Sequence[Prediction], binning: str | int, rng: numpy.random.Generator, threshold_a_priori: float | None = None, credible_regions: bool = False, ) -> dict[str, typing.Any]: """Run inference and calculates measures for binary or multilabel classification. For multi-label problems, calculate the metrics in the "micro" sense by first rasterizing all scores and labels (with :py:func:`numpy.ravel`), and then using this (large) 1D vector like in a binary classifier. Parameters ---------- name The name of subset to load. predictions A list of predictions to consider for measurement. binning The binning algorithm to use for computing the bin widths and distribution for histograms. Choose from algorithms supported by :py:func:`numpy.histogram`. rng An initialized numpy random number generator. threshold_a_priori A threshold to use, evaluated *a priori*, if must report single values. If this value is not provided, an *a posteriori* threshold is calculated on the input scores. This is a biased estimator. credible_regions If set to ``True``, then returns also credible intervals via :py:mod:`credible.bayesian.metrics`. Notice the evaluation of ROC-AUC and Average Precision confidence margins can be rather slow for larger datasets. Returns ------- dict[str, typing.Any] A dictionary containing the performance summary on the specified threshold, general performance curves (under the key ``curves``), and score histograms (under the key ``score-histograms``). """ y_scores = numpy.array([k[2] for k in predictions]).ravel() y_labels = numpy.array([k[1] for k in predictions]) ## ctype = classifier_type(y_labels) num_samples, num_classes = y_labels.shape y_labels = y_labels.ravel() neg_label = y_labels.min() pos_label = y_labels.max() use_threshold = threshold_a_priori if use_threshold is None: use_threshold = maxf1_threshold(predictions) logger.warning( f"User did not pass an *a priori* threshold for the evaluation " f"of split `{name}`. Using threshold a posteriori (biased) with value " f"`{use_threshold:.4f}`", ) y_predictions = numpy.where(y_scores >= use_threshold, pos_label, neg_label) summary = dict( num_samples=num_samples, num_classes=num_classes, threshold=use_threshold, threshold_a_posteriori=(threshold_a_priori is None), precision=sklearn.metrics.precision_score( y_labels, y_predictions, pos_label=pos_label ), recall=sklearn.metrics.recall_score( y_labels, y_predictions, pos_label=pos_label ), f1=sklearn.metrics.f1_score(y_labels, y_predictions, pos_label=pos_label), average_precision=sklearn.metrics.average_precision_score( y_labels, y_scores, pos_label=pos_label ), specificity=sklearn.metrics.recall_score( y_labels, y_predictions, pos_label=neg_label ), roc_auc=sklearn.metrics.roc_auc_score(y_labels, y_scores), accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions), ) if credible_regions: logger.info( f"Computing credible regions for metrics on split `{name}` " f"(samples = {len(predictions)}) - " f"note this can be slow on very large datasets..." ) f1 = credible.bayesian.metrics.f1_score(y_labels, y_predictions, rng=rng) roc_auc = credible.bayesian.metrics.roc_auc_score(y_labels, y_scores) precision = credible.bayesian.metrics.precision_score(y_labels, y_predictions) recall = credible.bayesian.metrics.recall_score(y_labels, y_predictions) average_precision = credible.bayesian.metrics.average_precision_score( y_labels, y_scores ) specificity = credible.bayesian.metrics.specificity_score( y_labels, y_predictions ) accuracy = credible.bayesian.metrics.accuracy_score(y_labels, y_predictions) summary.update( dict( precision_mean=precision[0], precision_mode=precision[1], precision_lo=precision[2], precision_hi=precision[3], recall_mean=recall[0], recall_mode=recall[1], recall_lo=recall[2], recall_hi=recall[3], f1_mean=f1[0], f1_mode=f1[1], f1_lo=f1[2], f1_hi=f1[3], average_precision_exact=average_precision[0], average_precision_lo=average_precision[1], average_precision_hi=average_precision[2], specificity_mean=specificity[0], specificity_mode=specificity[1], specificity_lo=specificity[2], specificity_hi=specificity[3], roc_auc_exact=roc_auc[0], roc_auc_lo=roc_auc[1], roc_auc_hi=roc_auc[2], accuracy_mean=accuracy[0], accuracy_mode=accuracy[1], accuracy_lo=accuracy[2], accuracy_hi=accuracy[3], ) ) # curves: ROC and precision recall summary["curves"] = dict( roc=dict( zip( ("fpr", "tpr", "thresholds"), sklearn.metrics.roc_curve( y_labels, y_scores, pos_label=pos_label, ), ), ), precision_recall=dict( zip( ("precision", "recall", "thresholds"), sklearn.metrics.precision_recall_curve( y_labels, y_scores, pos_label=pos_label, ), ), ), ) # score histograms # what works: <integer>, doane*, scott, stone, rice*, sturges*, sqrt # what does not work: auto, fd summary["score-histograms"] = dict( positives=dict( zip( ("hist", "bin_edges"), numpy.histogram( y_scores[y_labels == pos_label], bins=binning, range=(0, 1), ), ), ), negatives=dict( zip( ("hist", "bin_edges"), numpy.histogram( y_scores[y_labels == neg_label], bins=binning, range=(0, 1), ), ), ), ) return summary
[docs] def make_table( data: typing.Mapping[str, typing.Mapping[str, typing.Any]], fmt: str, ) -> str: """Tabulate summaries from multiple splits. This function can properly tabulate the various summaries produced for all the splits in a prediction database. Parameters ---------- data An iterable over all summary data collected. fmt One of the formats supported by `python-tabulate <https://pypi.org/project/tabulate/>`_. Returns ------- str A string containing the tabulated information. """ def _exclusion_condition(v: str) -> bool: return not ( v in ("curves", "score-histograms", "confusion_matrix") or v.endswith( ("_mean", "_mode", "_hi", "_lo", "_exact", "_per_class", "_micro") ) ) # dump evaluation results in RST format to screen and file table_data = {} for k, v in data.items(): table_data[k] = {kk: vv for kk, vv in v.items() if _exclusion_condition(kk)} example = next(iter(table_data.values())) headers = list(example.keys()) table = [[k[h] for h in headers] for k in table_data.values()] # add subset names headers = ["subset"] + headers table = [[name] + k for name, k in zip(table_data.keys(), table)] return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
def _score_plot( histograms: dict[str, dict[str, numpy.typing.NDArray]], title: str, threshold: float | None, ) -> matplotlib.figure.Figure: """Plot the normalized score distributions for all systems. Parameters ---------- histograms A dictionary containing all histograms that should be inserted into the plot. Each histogram should itself be setup as another dictionary containing the keys ``hist`` and ``bin_edges`` as returned by :py:func:`numpy.histogram`. title Title of the plot. threshold Shows where the threshold is in the figure. If set to ``None``, then does not show the threshold line. Returns ------- matplotlib.figure.Figure A single (matplotlib) plot containing the score distribution, ready to be saved to disk or displayed. """ from matplotlib.ticker import MaxNLocator fig, ax = plt.subplots(1, 1) assert isinstance(fig, matplotlib.figure.Figure) ax = typing.cast(matplotlib.axes.Axes, ax) # gets editor to behave # Here, we configure the "style" of our plot ax.set_xlim((0, 1)) ax.set_title(title) ax.set_xlabel("Score") ax.set_ylabel("Count") # Only show ticks on the left and bottom spines ax.spines.right.set_visible(False) ax.spines.top.set_visible(False) ax.get_xaxis().tick_bottom() ax.get_yaxis().tick_left() ax.get_yaxis().set_major_locator(MaxNLocator(integer=True)) # Setup the grid ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) ax.get_xaxis().grid(False) max_hist = 0 for name in histograms.keys(): hist = histograms[name]["hist"] bin_edges = histograms[name]["bin_edges"] width = 0.7 * (bin_edges[1] - bin_edges[0]) center = (bin_edges[:-1] + bin_edges[1:]) / 2 ax.bar(center, hist, align="center", width=width, label=name, alpha=0.7) max_hist = max(max_hist, hist.max()) # Detach axes from the plot ax.spines["left"].set_position(("data", -0.015)) ax.spines["bottom"].set_position(("data", -0.015 * max_hist)) if threshold is not None: # Adds threshold line (dotted red) ax.axvline( threshold, # type: ignore color="red", lw=2, alpha=0.75, ls="dotted", label="threshold", ) # Adds a nice legend ax.legend( fancybox=True, framealpha=0.7, ) # Makes sure the figure occupies most of the possible space fig.tight_layout() return fig
[docs] def make_plots_single(results: dict[str, dict[str, typing.Any]]) -> list: """Create plots for all curves and score distributions in ``results``. Parameters ---------- results Evaluation data as returned by :py:func:`run_single`. Returns ------- A list of figures to record to file """ retval = [] with credible.plot.tight_layout( ("False Positive Rate", "True Positive Rate"), "ROC" ) as (fig, ax): for split_name, data in results.items(): _auroc = credible.curves.area_under_the_curve( (data["curves"]["roc"]["fpr"], data["curves"]["roc"]["tpr"]), ) ax.plot( data["curves"]["roc"]["fpr"], data["curves"]["roc"]["tpr"], label=f"{split_name} (AUC: {_auroc:.2f})", ) ax.legend(loc="best", fancybox=True, framealpha=0.7) retval.append(fig) with credible.plot.tight_layout_f1iso( ("Recall", "Precision"), "Precison-Recall" ) as (fig, ax): for split_name, data in results.items(): _ap = credible.curves.average_metric( (data["precision"], data["recall"]), ) ax.plot( data["curves"]["precision_recall"]["recall"], data["curves"]["precision_recall"]["precision"], label=f"{split_name} (AP: {_ap:.2f})", ) ax.legend(loc="best", fancybox=True, framealpha=0.7) retval.append(fig) # score plots for split_name, data in results.items(): score_fig = _score_plot( data["score-histograms"], f"Score distribution (split: {split_name})", data["threshold"], ) retval.append(score_fig) return retval
[docs] def run_multi( predictions: typing.Sequence[Prediction], ) -> dict[str, typing.Any]: """Run inference and calculates measures for multiclass classification. It computes the recall, precision, f1, auc and aupr for each single class and also with both macro and micro averaging. ROC curves are all computed following the OvR scheme macro-averaged. Parameters ---------- predictions A list of predictions to consider for measurement. Returns ------- dict[str, typing.Any] A dictionary containing the performance summary, general performance curves (under the key ``curves``), and confusion matrices (under the key ``confusion_matrix``). """ y_probabilities = numpy.array([k[2] for k in predictions]) y_labels = numpy.array([k[1] for k in predictions]) num_samples = y_labels.size num_classes = numpy.max(y_labels) + 1 y_labels = y_labels.ravel() y_predictions = numpy.argmax(y_probabilities, axis=1) summary = dict( num_samples=num_samples, num_classes=num_classes, precision_per_class=sklearn.metrics.precision_score( y_labels, y_predictions, average=None ), precision_macro=sklearn.metrics.precision_score( y_labels, y_predictions, average="macro" ), precision_micro=sklearn.metrics.precision_score( y_labels, y_predictions, average="micro" ), recall_per_class=sklearn.metrics.recall_score( y_labels, y_predictions, average=None ), recall_macro=sklearn.metrics.recall_score( y_labels, y_predictions, average="macro" ), recall_micro=sklearn.metrics.recall_score( y_labels, y_predictions, average="micro" ), f1_per_class=sklearn.metrics.f1_score(y_labels, y_predictions, average=None), f1_macro=sklearn.metrics.f1_score(y_labels, y_predictions, average="macro"), f1_micro=sklearn.metrics.f1_score(y_labels, y_predictions, average="micro"), average_precision_per_class=sklearn.metrics.average_precision_score( y_labels, y_probabilities, average=None ), average_precision_macro=sklearn.metrics.average_precision_score( y_labels, y_probabilities, average="macro" ), average_precision_micro=sklearn.metrics.average_precision_score( y_labels, y_probabilities, average="micro" ), roc_auc_per_class=sklearn.metrics.roc_auc_score( y_labels, y_probabilities, average=None, multi_class="ovr" ), roc_auc_macro=sklearn.metrics.roc_auc_score( y_labels, y_probabilities, average="macro", multi_class="ovr" ), roc_auc_micro=sklearn.metrics.roc_auc_score( y_labels, y_probabilities, average="micro", multi_class="ovr" ), accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions), # I would like to add the MCP and IMCP scores ) # curves: ROC and precision recall summary["curves"] = dict( { **{f"roc_{class_id}": dict() for class_id in range(num_classes)}, **{ f"precision_recall_{class_id}": dict() for class_id in range(num_classes) }, } ) # Binarize the labels y_labels_onehot = LabelBinarizer().fit_transform(y_labels) # Store one roc and precision recall curve for each class # In the meanwhile compute the macro ROC curve fpr_grid_macro = numpy.linspace(0.0, 1.0, 1000) mean_tpr = numpy.zeros_like(fpr_grid_macro) for class_id in range(num_classes): summary["curves"][f"roc_{class_id}"] = dict( zip( ("fpr", "tpr", "thresholds"), sklearn.metrics.roc_curve( y_labels_onehot[:, class_id], # type: ignore y_probabilities[:, class_id], ), ), ) summary["curves"][f"precision_recall_{class_id}"] = dict( zip( ("precision", "recall", "thresholds"), sklearn.metrics.precision_recall_curve( y_labels_onehot[:, class_id], # type: ignore y_probabilities[:, class_id], ), ), ) mean_tpr += numpy.interp( fpr_grid_macro, summary["curves"][f"roc_{class_id}"]["fpr"], summary["curves"][f"roc_{class_id}"]["tpr"], ) # linear interpolation # Average it and compute AUC macro mean_tpr /= num_classes summary["curves"]["roc_fpr_macro"] = fpr_grid_macro # type: ignore summary["curves"]["roc_tpr_macro"] = mean_tpr # Confusion matrix cm = sklearn.metrics.confusion_matrix(y_labels, y_predictions) summary["confusion_matrix"] = cm return summary
[docs] def make_plots_multi(results: dict[str, dict[str, typing.Any]]) -> list: """Create plots for all curves and score distributions in ``results``. Parameters ---------- results Evaluation data as returned by :py:func:`run_multi`. Returns ------- A list of figures to record to file """ retval = [] # ROC macro averaged (OvR). One ROC for each split with credible.plot.tight_layout( ("False Positive Rate", "True Positive Rate"), "Macro-averaged One-vs-Rest ROC" ) as (fig, ax): for split_name, data in results.items(): _auroc = credible.curves.area_under_the_curve( (data["curves"]["roc_fpr_macro"], data["curves"]["roc_tpr_macro"]), ) ax.plot( data["curves"]["roc_fpr_macro"], data["curves"]["roc_tpr_macro"], label=f"{split_name} (AUC: {_auroc:.2f})", ) ax.legend(loc="best", fancybox=True, framealpha=0.7) retval.append(fig) # Plot all OvR ROC curves per class together. Each split has its own plot with num_classes ROC curves for split_name, data in results.items(): with credible.plot.tight_layout( ("False Positive Rate", "True Positive Rate"), f"{split_name} - OvR ROC for each class", ) as (fig, ax): for class_id in range(data["num_classes"]): _auroc = credible.curves.area_under_the_curve( ( data["curves"][f"roc_{class_id}"]["fpr"], data["curves"][f"roc_{class_id}"]["tpr"], ), ) ax.plot( data["curves"][f"roc_{class_id}"]["fpr"], data["curves"][f"roc_{class_id}"]["tpr"], label=f"Class {class_id} (AUC: {_auroc:.2f})", ) ax.legend(loc="best", fancybox=True, framealpha=0.7) retval.append(fig) # Plot all Precision-Recall curves per class together. for split_name, data in results.items(): with credible.plot.tight_layout( ("Recall", "Precision"), f"{split_name} - Precision-Recall curve for each class", ) as (fig, ax): for class_id in range(data["num_classes"]): _ap = credible.curves.average_metric( ( data["curves"][f"precision_recall_{class_id}"]["precision"], data["curves"][f"precision_recall_{class_id}"]["recall"], ), ) ax.plot( data["curves"][f"precision_recall_{class_id}"]["recall"], data["curves"][f"precision_recall_{class_id}"]["precision"], label=f"Class {class_id} (AP: {_ap:.2f})", ) ax.legend(loc="best", fancybox=True, framealpha=0.7) retval.append(fig) # Plot the confusion matrices. One for each split for split_name, data in results.items(): fig, ax = plt.subplots() disp = sklearn.metrics.ConfusionMatrixDisplay( confusion_matrix=data["confusion_matrix"] ) disp.plot(ax=ax, cmap=plt.colormaps.get_cmap("Blues")) ax.set_title(f"Confusion Matrix - {split_name}") retval.append(fig) return retval
[docs] def run( name: str, predictions: typing.Sequence[Prediction], binning: str | int, rng: numpy.random.Generator, threshold_a_priori: float | None = None, credible_regions: bool = False, ) -> dict[str, typing.Any]: """Run inference and calculates measures for binary, multilabel or multiclass classification. It autocmatically detects the problem type, and route the arguments to the pertinent functions (:py:func:`run_single` for binary and multilabel, :py:func:`run_multi` for multiclass classification problem). For multi-label problems, calculate the metrics in the "micro" sense by first rasterizing all scores and labels (with :py:func:`numpy.ravel`), and then using this (large) 1D vector like in a binary classifier. Parameters ---------- name The name of subset to load. predictions A list of predictions to consider for measurement. binning The binning algorithm to use for computing the bin widths and distribution for histograms. Choose from algorithms supported by :py:func:`numpy.histogram`. rng An initialized numpy random number generator. threshold_a_priori A threshold to use, evaluated *a priori*, if must report single values. If this value is not provided, an *a posteriori* threshold is calculated on the input scores. This is a biased estimator. credible_regions If set to ``True``, then returns also credible intervals via :py:mod:`credible.bayesian.metrics`. Notice the evaluation of ROC-AUC and Average Precision confidence margins can be rather slow for larger datasets. Returns ------- dict[str, typing.Any] For **binary** and **multilabel** classification problem: a dictionary containing the performance summary on the specified threshold, general performance curves (under the key ``curves``), and score histograms (under the key ``score-histograms``). For **multiclass** classification problem: a dictionary containing the performance summary, general performance curves (under the key ``curves``), and confusion matrices (under the key ``confusion_matrix``). """ sample = predictions[0] target_shape = len(sample[1]) prediction_shape = len(sample[2]) # Binary and Multilabel problems have same target and prediction shape: (1,) for binary and (C,) for multilabel # For Multiclass scenario the prediction shape is (C,) but the target is (1,) since it is a single value between [0, C-1] multiclass_problem = prediction_shape != target_shape if multiclass_problem: summary = run_multi( predictions=predictions, ) else: summary = run_single( name=name, predictions=predictions, binning=binning, rng=rng, threshold_a_priori=threshold_a_priori, credible_regions=credible_regions, ) return summary
[docs] def make_plots(results: dict[str, dict[str, typing.Any]]) -> list: """Create plots for all curves and score distributions in ``results``. Parameters ---------- results Evaluation data as returned by :py:func:`run`. Returns ------- A list of figures to record to file """ multiclass = True for _, data in results.items(): if "score-histograms" in data.keys(): multiclass = False break if multiclass: retval = make_plots_multi( results=results, ) else: retval = make_plots_single( results=results, ) return retval