Source code for mednet.utils.checkpointer

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import logging
import pathlib
import re
import typing
from collections.abc import Callable

logger = logging.getLogger(__name__)

CheckpointAliasType = str | Callable[[str, str], str]
"""Definition of a Checkpoint alias type to make it
flexible to user defined metric to monitor."""

CHECKPOINT_ALIASES: dict[str, CheckpointAliasType] = {
    "best": lambda metric, mode: (
        f"model-at-{'highest' if mode == 'max' else 'lowest'}-validation-{metric}"
        "-{epoch}"
    ),
    "periodic": "model-at-{epoch}",
}
"""Standard paths where checkpoints may be (if produced with this
framework)."""

CHECKPOINT_EXTENSION = ".ckpt"


def _get_checkpoint_from_alias(
    path: pathlib.Path,
    alias: typing.Literal["best", "periodic"],
    metric: str = "loss",
    mode: typing.Literal["min", "max"] = "min",
) -> pathlib.Path:
    """Get an existing checkpoint file path.

    This function can search for names matching the checkpoint alias "stem"
    (ie. the prefix), and then assumes a dash "-" and a number follows that
    prefix before the expected file extension.  The number is parsed and
    considred to be an epoch number.  The latest file (the file containing the
    highest epoch number) is returned.

    If only one file is present matching the alias characteristics, then it is
    returned.

    Parameters
    ----------
    path
        Folder in which may contain checkpoint.
    alias
        Can be one of "best" or "periodic".
    metric
        Name of the metric used for monitoring and saving the best checkpoint
        (default: "loss").
    mode
        One of {"min", "max"}.

    Returns
    -------
        Path to the requested checkpoint, or ``None``, if no checkpoint file
        matching specifications is found on the provided path.

    Raises
    ------
    FileNotFoundError
        In case it cannot find any file on the provided path matching the given
        specifications.
    """
    alias_value = CHECKPOINT_ALIASES[alias]

    if callable(alias_value):
        ckp_alias = alias_value(metric, mode)
    else:
        ckp_alias = alias_value
    template = path / (ckp_alias + CHECKPOINT_EXTENSION)

    if template.exists():
        return template

    # otherwise, we see if we are looking for a template instead, in which case
    # we must pick the latest.
    assert "{epoch}" in str(
        template,
    ), f"Template `{str(template)}` does not contain the keyword `{{epoch}}`"

    pattern = re.compile(
        template.name.replace("{epoch}", r"epoch(?P<separator>=|-|_)(?P<epoch>\d+)"),
    )
    highest = -1
    separator = "="
    for f in template.parent.iterdir():
        match = pattern.match(f.name)
        if match is not None:
            value = int(match.group("epoch"))
            if value > highest:
                highest = value
                separator = match.group("separator")

    if highest != -1:
        return template.with_name(
            template.name.replace("{epoch}", f"epoch{separator}{highest}"),
        )

    raise FileNotFoundError(
        f"A file matching `{str(template)}` specifications was not found",
    )


[docs] def get_checkpoint_to_resume_training( path: pathlib.Path, ) -> pathlib.Path: """Return the best checkpoint file path to resume training from. Parameters ---------- path The base directory containing either the "periodic" checkpoint to start the training session from. Returns ------- pathlib.Path Path to a checkpoint file that exists on disk. Raises ------ FileNotFoundError If none of the checkpoints can be found on the provided directory. """ return _get_checkpoint_from_alias(path, "periodic")
[docs] def get_checkpoint_to_run_inference( path: pathlib.Path, metric: str, mode: typing.Literal["min", "max"] = "min", ) -> pathlib.Path: """Return the best checkpoint file path to run inference with. Parameters ---------- path The base directory containing either the "best", "last" or "periodic" checkpoint to start the training session from. metric Name of the metric used for monitoring and saving the best checkpoint (default: "loss"). mode One of {"min", "max"}. Returns ------- pathlib.Path Path to a checkpoint file that exists on disk. Raises ------ FileNotFoundError If none of the checkpoints can be found on the provided directory. """ try: return _get_checkpoint_from_alias(path, "best", metric, mode) except FileNotFoundError: logger.error( f"Did not find {'highest' if mode == 'max' else 'lowest'}-validation-{metric} model to run inference " "from. Trying to search for the last periodically saved model...", ) return _get_checkpoint_from_alias(path, "periodic")