# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines functionality for the evaluation of predictions."""
import logging
import typing
from collections.abc import Iterable
import credible.bayesian.metrics
import credible.curves
import credible.plot
import matplotlib.axes
import matplotlib.figure
import numpy
import numpy.typing
import sklearn.metrics
import tabulate
from matplotlib import pyplot as plt
from ...models.classify.typing import BinaryPrediction
logger = logging.getLogger(__name__)
[docs]
def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float:
"""Calculate the (approximate) threshold leading to the equal error rate.
Parameters
----------
predictions
An iterable of multiple
:py:data:`.models.classify.typing.BinaryPrediction`'s.
Returns
-------
float
The EER threshold value.
"""
from scipy.interpolate import interp1d
from scipy.optimize import brentq
y_scores = [k[2] for k in predictions]
y_labels = [k[1] for k in predictions]
fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_labels, y_scores)
eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
return float(interp1d(fpr, thresholds)(eer))
def _get_centered_maxf1(
f1_scores: numpy.typing.NDArray,
thresholds: numpy.typing.NDArray,
) -> tuple[float, float]:
"""Return the centered max F1 score threshold when multiple thresholds give
the same max F1 score.
Parameters
----------
f1_scores
1D array of f1 scores.
thresholds
1D array of thresholds.
Returns
-------
tuple(float, float)
A tuple with the maximum F1-score and the "centered" threshold.
"""
maxf1 = f1_scores.max()
maxf1_indices = numpy.where(f1_scores == maxf1)[0]
# If multiple thresholds give the same max F1 score
if len(maxf1_indices) > 1:
mean_maxf1_index = int(round(numpy.mean(maxf1_indices)))
else:
mean_maxf1_index = maxf1_indices[0]
return maxf1, thresholds[mean_maxf1_index]
[docs]
def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float:
"""Calculate the threshold leading to the maximum F1-score on a precision-
recall curve.
Parameters
----------
predictions
An iterable of multiple
:py:data:`.models.classify.typing.BinaryPrediction`'s.
Returns
-------
float
The threshold value leading to the maximum F1-score on the provided set
of predictions.
"""
y_scores = [k[2] for k in predictions]
y_labels = [k[1] for k in predictions]
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
y_labels,
y_scores,
)
numerator = 2 * recall * precision
denom = recall + precision
f1_scores = numpy.divide(
numerator,
denom,
out=numpy.zeros_like(denom),
where=(denom != 0),
)
_, maxf1_threshold = _get_centered_maxf1(f1_scores, thresholds)
return maxf1_threshold
[docs]
def run_binary(
name: str,
predictions: Iterable[BinaryPrediction],
binning: str | int,
threshold_a_priori: float | None = None,
) -> dict[str, typing.Any]:
"""Run inference and calculates measures for binary classification.
Parameters
----------
name
The name of subset to load.
predictions
A list of predictions to consider for measurement.
binning
The binning algorithm to use for computing the bin widths and
distribution for histograms. Choose from algorithms supported by
:py:func:`numpy.histogram`.
threshold_a_priori
A threshold to use, evaluated *a priori*, if must report single values.
If this value is not provided, an *a posteriori* threshold is calculated
on the input scores. This is a biased estimator.
Returns
-------
dict[str, typing.Any]
A tuple containing the following entries:
* summary: A dictionary containing the performance summary on the
specified threshold, general performance curves (under the key
``curves``), and score histograms (under the key
``score-histograms``).
"""
y_scores = numpy.array([k[2] for k in predictions]) # likelihoods
y_labels = numpy.array([k[1] for k in predictions]) # integers
neg_label = y_labels.min()
pos_label = y_labels.max()
use_threshold = threshold_a_priori
if use_threshold is None:
use_threshold = maxf1_threshold(predictions)
logger.warning(
f"User did not pass an *a priori* threshold for the evaluation "
f"of split `{name}`. Using threshold a posteriori (biased) with value "
f"`{use_threshold:.4f}`",
)
y_predictions = numpy.where(y_scores >= use_threshold, pos_label, neg_label)
# point measures on threshold, and confidence intervals
f1 = credible.bayesian.metrics.f1_score(y_labels, y_predictions)
roc_auc = credible.bayesian.metrics.roc_auc_score(y_labels, y_scores)
precision = credible.bayesian.metrics.precision_score(y_labels, y_predictions)
recall = credible.bayesian.metrics.recall_score(y_labels, y_predictions)
average_precision = credible.bayesian.metrics.average_precision_score(
y_labels, y_scores
)
specificity = credible.bayesian.metrics.specificity_score(y_labels, y_predictions)
accuracy = credible.bayesian.metrics.accuracy_score(y_labels, y_predictions)
summary = dict(
num_samples=len(y_labels),
threshold=use_threshold,
threshold_a_posteriori=(threshold_a_priori is None),
precision=sklearn.metrics.precision_score(
y_labels, y_predictions, pos_label=pos_label
),
precision_mean=precision[0],
precision_mode=precision[1],
precision_lo=precision[2],
precision_hi=precision[3],
recall=sklearn.metrics.recall_score(
y_labels, y_predictions, pos_label=pos_label
),
recall_mean=recall[0],
recall_mode=recall[1],
recall_lo=recall[2],
recall_hi=recall[3],
f1=sklearn.metrics.f1_score(y_labels, y_predictions, pos_label=pos_label),
f1_mean=f1[0],
f1_mode=f1[1],
f1_lo=f1[2],
f1_hi=f1[3],
average_precision=sklearn.metrics.average_precision_score(
y_labels, y_scores, pos_label=pos_label
),
average_precision_exact=average_precision[0],
average_precision_lo=average_precision[1],
average_precision_hi=average_precision[2],
specificity=sklearn.metrics.recall_score(
y_labels, y_predictions, pos_label=neg_label
),
specificity_mean=specificity[0],
specificity_mode=specificity[1],
specificity_lo=specificity[2],
specificity_hi=specificity[3],
roc_auc=sklearn.metrics.roc_auc_score(y_labels, y_scores),
roc_auc_exact=roc_auc[0],
roc_auc_lo=roc_auc[1],
roc_auc_hi=roc_auc[2],
accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions),
accuracy_mean=accuracy[0],
accuracy_mode=accuracy[1],
accuracy_lo=accuracy[2],
accuracy_hi=accuracy[3],
)
# curves: ROC and precision recall
summary["curves"] = dict(
roc=dict(
zip(
("fpr", "tpr", "thresholds"),
sklearn.metrics.roc_curve(
y_labels,
y_scores,
pos_label=pos_label,
),
),
),
precision_recall=dict(
zip(
("precision", "recall", "thresholds"),
sklearn.metrics.precision_recall_curve(
y_labels,
y_scores,
pos_label=pos_label,
),
),
),
)
# score histograms
# what works: <integer>, doane*, scott, stone, rice*, sturges*, sqrt
# what does not work: auto, fd
summary["score-histograms"] = dict(
positives=dict(
zip(
("hist", "bin_edges"),
numpy.histogram(
y_scores[y_labels == pos_label],
bins=binning,
range=(0, 1),
),
),
),
negatives=dict(
zip(
("hist", "bin_edges"),
numpy.histogram(
y_scores[y_labels == neg_label],
bins=binning,
range=(0, 1),
),
),
),
)
return summary
[docs]
def make_table(
data: typing.Mapping[str, typing.Mapping[str, typing.Any]],
fmt: str,
) -> str:
"""Tabulate summaries from multiple splits.
This function can properly tabulate the various summaries produced for all
the splits in a prediction database.
Parameters
----------
data
An iterable over all summary data collected.
fmt
One of the formats supported by `python-tabulate
<https://pypi.org/project/tabulate/>`_.
Returns
-------
str
A string containing the tabulated information.
"""
def _exclusion_condition(v: str) -> bool:
return not (
v in ("curves", "score-histograms")
or v.endswith(("_mean", "_mode", "_hi", "_lo", "_exact"))
)
# dump evaluation results in RST format to screen and file
table_data = {}
for k, v in data.items():
table_data[k] = {kk: vv for kk, vv in v.items() if _exclusion_condition(kk)}
example = next(iter(table_data.values()))
headers = list(example.keys())
table = [[k[h] for h in headers] for k in table_data.values()]
# add subset names
headers = ["subset"] + headers
table = [[name] + k for name, k in zip(table_data.keys(), table)]
return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
def _score_plot(
histograms: dict[str, dict[str, numpy.typing.NDArray]],
title: str,
threshold: float | None,
) -> matplotlib.figure.Figure:
"""Plot the normalized score distributions for all systems.
Parameters
----------
histograms
A dictionary containing all histograms that should be inserted into the
plot. Each histogram should itself be setup as another dictionary
containing the keys ``hist`` and ``bin_edges`` as returned by
:py:func:`numpy.histogram`.
title
Title of the plot.
threshold
Shows where the threshold is in the figure. If set to ``None``, then
does not show the threshold line.
Returns
-------
matplotlib.figure.Figure
A single (matplotlib) plot containing the score distribution, ready to
be saved to disk or displayed.
"""
from matplotlib.ticker import MaxNLocator
fig, ax = plt.subplots(1, 1)
assert isinstance(fig, matplotlib.figure.Figure)
ax = typing.cast(matplotlib.axes.Axes, ax) # gets editor to behave
# Here, we configure the "style" of our plot
ax.set_xlim((0, 1))
ax.set_title(title)
ax.set_xlabel("Score")
ax.set_ylabel("Count")
# Only show ticks on the left and bottom spines
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.get_yaxis().set_major_locator(MaxNLocator(integer=True))
# Setup the grid
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
ax.get_xaxis().grid(False)
max_hist = 0
for name in histograms.keys():
hist = histograms[name]["hist"]
bin_edges = histograms[name]["bin_edges"]
width = 0.7 * (bin_edges[1] - bin_edges[0])
center = (bin_edges[:-1] + bin_edges[1:]) / 2
ax.bar(center, hist, align="center", width=width, label=name, alpha=0.7)
max_hist = max(max_hist, hist.max())
# Detach axes from the plot
ax.spines["left"].set_position(("data", -0.015))
ax.spines["bottom"].set_position(("data", -0.015 * max_hist))
if threshold is not None:
# Adds threshold line (dotted red)
ax.axvline(
threshold, # type: ignore
color="red",
lw=2,
alpha=0.75,
ls="dotted",
label="threshold",
)
# Adds a nice legend
ax.legend(
fancybox=True,
framealpha=0.7,
)
# Makes sure the figure occupies most of the possible space
fig.tight_layout()
return fig
[docs]
def make_plots(results: dict[str, dict[str, typing.Any]]) -> list:
"""Create plots for all curves and score distributions in ``results``.
Parameters
----------
results
Evaluation data as returned by :py:func:`run_binary`.
Returns
-------
A list of figures to record to file
"""
retval = []
with credible.plot.tight_layout(
("False Positive Rate", "True Positive Rate"), "ROC"
) as (fig, ax):
for split_name, data in results.items():
_auroc = credible.curves.area_under_the_curve(
(data["curves"]["roc"]["fpr"], data["curves"]["roc"]["tpr"]),
)
ax.plot(
data["curves"]["roc"]["fpr"],
data["curves"]["roc"]["tpr"],
label=f"{split_name} (AUC: {_auroc:.2f})",
)
ax.legend(loc="best", fancybox=True, framealpha=0.7)
retval.append(fig)
with credible.plot.tight_layout_f1iso(
("Recall", "Precision"), "Precison-Recall"
) as (fig, ax):
for split_name, data in results.items():
_ap = credible.curves.average_metric(
(data["precision"], data["recall"]),
)
ax.plot(
data["curves"]["precision_recall"]["recall"],
data["curves"]["precision_recall"]["precision"],
label=f"{split_name} (AP: {_ap:.2f})",
)
ax.legend(loc="best", fancybox=True, framealpha=0.7)
retval.append(fig)
# score plots
for split_name, data in results.items():
score_fig = _score_plot(
data["score-histograms"],
f"Score distribution (split: {split_name})",
data["threshold"],
)
retval.append(score_fig)
return retval