Source code for mednet.scripts.utils

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Utilities for command-line scripts."""

import json
import logging
import pathlib
import re
import shutil
import typing

import compact_json
import lightning.pytorch
import lightning.pytorch.callbacks
import numpy
import torch.nn
from click import BadParameter
from pydantic import TypeAdapter
from pydantic.types import StringConstraints

from ..engine.device import SupportedPytorchDevice

logger = logging.getLogger(__name__)

JSONable: typing.TypeAlias = (
    typing.Mapping[str, "JSONable"]
    | typing.Sequence["JSONable"]
    | str
    | int
    | float
    | bool
    | None
)
"""Defines types that can be encoded in a JSON string."""

CheckpointMetricType: typing.TypeAlias = typing.Annotated[
    str,
    StringConstraints(
        strip_whitespace=True,
        pattern=r"^(min|max)/.+$",
    ),
]
"""
Defines a type for the metric used to track and save the best
checkpoint of a model. This type represents a constrained string
in the format 'mode/metric', where:
- 'mode' is either 'min' or 'max', indicating the optimization direction;
- 'metric' is a non-empty string specifying the name of the evaluation
metric (e.g., 'loss', 'auc').
"""


[docs] def parse_checkpoint_metric(value: str) -> tuple[str, typing.Literal["min", "max"]]: """Validate and then parse the string as a 'CheckpointMetricType'. Parameters ---------- value The string to be validated and then parsed. Returns ------- The name of the metric used for saving the best checkpoint and the modality {'min', 'max'} in this exact order. """ adapter = TypeAdapter(CheckpointMetricType) try: validated = adapter.validate_python(value) except Exception as e: raise BadParameter( f"Invalid format: '{value.strip()}'. Must match 'min/<metric>' or 'max/<metric>'." ) from e mode, metric = validated.split("/", 1) return metric, mode
[docs] def model_summary( model: torch.nn.Module, ) -> dict[str, int | list[tuple[str, str, int]]]: """Save a little summary of the model in a txt file. Parameters ---------- model Instance of the model for which to save the summary. Returns ------- tuple[lightning.pytorch.callbacks.ModelSummary, int] A tuple with the model summary in a text format and number of parameters of the model. """ s = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore model, ) return dict( model_summary=list(zip(s.layer_names, s.layer_types, s.param_nums)), model_size=s.total_parameters, )
[docs] def device_properties( device_type: SupportedPytorchDevice, ) -> dict[str, int | float | str]: """Generate information concerning hardware properties. Parameters ---------- device_type The type of compute device we are using. Returns ------- Static properties of the current machine. """ from ..utils.resources import cpu_constants, cuda_constants, mps_constants retval: dict[str, int | float | str] = {} retval.update(cpu_constants()) match device_type: case "cpu": pass case "cuda": results = cuda_constants() if results is not None: retval.update(results) case "mps": results = mps_constants() if results is not None: retval.update(results) case _: pass return retval
[docs] def execution_metadata() -> dict[str, int | float | str | dict[str, str] | list[str]]: """Produce metadata concerning the running script, in the form of a dictionary. This function returns potentially useful metadata concerning program execution. It contains a certain number of preset variables. Returns ------- A dictionary that contains the following fields: * ``package-name``: current package name (e.g. ``mednet``) * ``package-version``: current package version (e.g. ``1.0.0b0``) * ``datetime``: date and time in ISO8601 format (e.g. ``2024-02-23T18:38:09+01:00``) * ``user``: username (e.g. ``johndoe``) * ``conda-env``: if set, the name of the current conda environment * ``path``: current path when executing the command * ``command-line``: the command-line that is being run * ``hostname``: machine hostname (e.g. ``localhost``) * ``platform``: machine platform (e.g. ``darwin``) * ``accelerator``: acceleration devices available (e.g. ``cuda``) """ import datetime import importlib.metadata import importlib.util import os import sys args: list[str] = [] for k in sys.argv: if " " in k: args.append(f"'{k}'") else: args.append(k) # current date time, in ISO8610 format current_datetime = datetime.datetime.now().astimezone().isoformat() # collects dependency information package_name = __package__.split(".")[0] if __package__ is not None else "unknown" requires = importlib.metadata.requires(package_name) or [] dependence_names = [re.split(r"(\=|~|!|>|<|;|\s)+", k)[0] for k in requires] installed = { v[0]: k for k, v in importlib.metadata.packages_distributions().items() } dependencies = { k: importlib.metadata.version(k) # version number as str for k in sorted(dependence_names) if importlib.util.find_spec(k if k not in installed else installed[k]) is not None # if is installed } # checks if the current version corresponds to a dirty (uncommitted) change # set, issues a warning to the user current_version = importlib.metadata.version(package_name) try: import versioningit actual_version = versioningit.get_version(".") if current_version != actual_version: logger.warning( f"Version mismatch between current version set " f"({current_version}) and actual version returned by " f"versioningit ({actual_version}). This typically happens " f"when you commit changes locally and do not re-install the " f"package. Run `pixi update {package_name}`, `pip install -e .` " f"or equivalent to fix this.", ) except Exception as e: # not in a git repo? logger.debug(f"Error {e}") pass # checks if any acceleration device is present in the current platform accelerators = [f"cpu ({torch.backends.cpu.get_cpu_capability()})"] if torch.cuda.is_available() and torch.backends.cuda.is_built(): accelerators.append("cuda") if torch.backends.cudnn.is_available(): accelerators.append("cudnn") if torch.backends.mps.is_available(): accelerators.append("mps") if torch.backends.mkl.is_available(): accelerators.append("mkl") if torch.backends.mkldnn.is_available(): accelerators.append("mkldnn") if torch.backends.openmp.is_available(): accelerators.append("openmp") python = { "version": ".".join([str(k) for k in sys.version_info[:3]]), "path": sys.executable, } return { "datetime": current_datetime, "package-name": package_name, "package-version": current_version, "python": python, "dependencies": dependencies, "user": __import__("getpass").getuser(), "conda-env": os.environ.get("CONDA_DEFAULT_ENV", ""), "path": os.path.realpath(os.curdir), "command-line": " ".join(args), "hostname": __import__("platform").node(), "platform": sys.platform, "accelerators": accelerators, }
[docs] class NumpyJSONEncoder(json.JSONEncoder): """Extends the standard JSON encoder to support Numpy arrays."""
[docs] def default(self, o: typing.Any) -> typing.Any: """If input object is a ndarray it will be converted into a list. Parameters ---------- o Input object to be JSON serialized. Returns ------- A serializable representation of object ``o``. """ if isinstance(o, numpy.ndarray): try: retval = o.tolist() except TypeError: pass else: return retval elif isinstance(o, numpy.generic): try: retval = o.item() except TypeError: pass else: return retval # Let the base class default method raise the TypeError return super().default(o)
[docs] def save_json_with_backup(path: pathlib.Path, data: JSONable) -> None: """Save a dictionary into a JSON file with path checking and backup. This function will save a dictionary into a JSON file. It will check to the existence of the directory leading to the file and create it if necessary. If the file already exists on the destination folder, it is backed-up before a new file is created with the new contents. Parameters ---------- path The full path where to save the JSON data. data The data to save on the JSON file. """ formatter = compact_json.Formatter() # only only 2 indent spaces for further levels formatter.indent_spaces = 2 # controls how much nesting can happen formatter.max_inline_complexity = 2 # controls the maximum line width (has priority over nesting) formatter.max_inline_length = 88 # remove any trailing whitespaces formatter.omit_trailing_whitespace = True path.parent.mkdir(parents=True, exist_ok=True) if path.exists(): backup = path.parent / (path.name + "~") shutil.copy(path, backup) data = json.loads(json.dumps(data, indent=2, cls=NumpyJSONEncoder)) formatter.dump(data, str(path))
[docs] def save_json_metadata( output_file: pathlib.Path, **kwargs: typing.Any, ) -> None: # numpydoc ignore=PR01 """Save prediction hyperparameters into a .json file.""" from ..data.datamodule import ConcatDataModule from ..engine.device import DeviceManager from ..models.model import Model from .utils import ( device_properties, execution_metadata, model_summary, save_json_with_backup, ) json_data: dict[str, typing.Any] = execution_metadata() for key, value in kwargs.items(): match value: case ConcatDataModule(): json_data["database_name"] = value.database_name json_data["database_split"] = value.split_name case Model(): json_data["model"] = f"{type(value).__module__}.{type(value).__name__}" json_data.update(model_summary(value)) case pathlib.Path(): json_data[key] = str(value) case DeviceManager(): json_data.update(device_properties(value.device_type)) case list() if key == "augmentations": if len(value) != 0: json_data[key] = [f"{type(k).__module__}.{str(k)}" for k in value] else: json_data[key] = [] case _: json_data[key] = value json_data = {k.replace("_", "-"): v for k, v in json_data.items()} logger.info(f"Writing run metadata at `{output_file}`...") save_json_with_backup(output_file, json_data)
[docs] def get_ckpt_metric_mode( train_metadata_file: pathlib.Path, default_metric: str = "loss", default_mode: typing.Literal["min", "max"] = "min", ) -> tuple[str, typing.Literal["min", "max"]]: """Retrieve information regarding the metric and modality used to save the best checkpoint of the model by looking at the train metadata in the json file. Parameters ---------- train_metadata_file Path of the train.meta.json file. default_metric The metric name to return when no metric information is found in train.meta JSON file. The default value is set to "loss". default_mode The modality of evaluation to return when no mode information is found in train.meta JSON file. The default value is set to "min". Returns ------- The name of the metric used for saving the best checkpoint and the modality {'min', 'max'} in this exact order. """ with train_metadata_file.open("r") as f: train_metadata = json.load(f) metric = ( train_metadata["checkpoint-metric"] if "checkpoint-metric" in train_metadata else default_metric ) mode = ( train_metadata["checkpoint-mode"] if "checkpoint-mode" in train_metadata else default_mode ) return metric, mode