Source code for mednet.scripts.utils

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Utilities for command-line scripts."""

import json
import logging
import pathlib
import re
import shutil
import typing

import lightning.pytorch
import lightning.pytorch.callbacks
import numpy
import torch.nn

from ..engine.device import SupportedPytorchDevice

logger = logging.getLogger(__name__)

JSONable: typing.TypeAlias = (
    typing.Mapping[str, "JSONable"]
    | typing.Sequence["JSONable"]
    | str
    | int
    | float
    | bool
    | None
)
"""Defines types that can be encoded in a JSON string."""


[docs] def model_summary( model: torch.nn.Module, ) -> dict[str, int | list[tuple[str, str, int]]]: """Save a little summary of the model in a txt file. Parameters ---------- model Instance of the model for which to save the summary. Returns ------- tuple[lightning.pytorch.callbacks.ModelSummary, int] A tuple with the model summary in a text format and number of parameters of the model. """ s = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore model, ) return dict( model_summary=list(zip(s.layer_names, s.layer_types, s.param_nums)), model_size=s.total_parameters, )
[docs] def device_properties( device_type: SupportedPytorchDevice, ) -> dict[str, int | float | str]: """Generate information concerning hardware properties. Parameters ---------- device_type The type of compute device we are using. Returns ------- Static properties of the current machine. """ from ..utils.resources import cpu_constants, cuda_constants, mps_constants retval: dict[str, int | float | str] = {} retval.update(cpu_constants()) match device_type: case "cpu": pass case "cuda": results = cuda_constants() if results is not None: retval.update(results) case "mps": results = mps_constants() if results is not None: retval.update(results) case _: pass return retval
[docs] def execution_metadata() -> dict[str, int | float | str | dict[str, str] | list[str]]: """Produce metadata concerning the running script, in the form of a dictionary. This function returns potentially useful metadata concerning program execution. It contains a certain number of preset variables. Returns ------- A dictionary that contains the following fields: * ``package-name``: current package name (e.g. ``mednet``) * ``package-version``: current package version (e.g. ``1.0.0b0``) * ``datetime``: date and time in ISO8601 format (e.g. ``2024-02-23T18:38:09+01:00``) * ``user``: username (e.g. ``johndoe``) * ``conda-env``: if set, the name of the current conda environment * ``path``: current path when executing the command * ``command-line``: the command-line that is being run * ``hostname``: machine hostname (e.g. ``localhost``) * ``platform``: machine platform (e.g. ``darwin``) * ``accelerator``: acceleration devices available (e.g. ``cuda``) """ import datetime import importlib.metadata import importlib.util import os import sys args: list[str] = [] for k in sys.argv: if " " in k: args.append(f"'{k}'") else: args.append(k) # current date time, in ISO8610 format current_datetime = datetime.datetime.now().astimezone().isoformat() # collects dependency information package_name = __package__.split(".")[0] if __package__ is not None else "unknown" requires = importlib.metadata.requires(package_name) or [] dependence_names = [re.split(r"(\=|~|!|>|<|;|\s)+", k)[0] for k in requires] installed = { v[0]: k for k, v in importlib.metadata.packages_distributions().items() } dependencies = { k: importlib.metadata.version(k) # version number as str for k in sorted(dependence_names) if importlib.util.find_spec(k if k not in installed else installed[k]) is not None # if is installed } # checks if the current version corresponds to a dirty (uncommitted) change # set, issues a warning to the user current_version = importlib.metadata.version(package_name) try: import versioningit actual_version = versioningit.get_version(".") if current_version != actual_version: logger.warning( f"Version mismatch between current version set " f"({current_version}) and actual version returned by " f"versioningit ({actual_version}). This typically happens " f"when you commit changes locally and do not re-install the " f"package. Run `pixi install`, `pip install -e .` or equivalent " f"to fix this.", ) except Exception as e: # not in a git repo? logger.debug(f"Error {e}") pass # checks if any acceleration device is present in the current platform accelerators = [f"cpu ({torch.backends.cpu.get_cpu_capability()})"] if torch.cuda.is_available() and torch.backends.cuda.is_built(): accelerators.append("cuda") if torch.backends.cudnn.is_available(): accelerators.append("cudnn") if torch.backends.mps.is_available(): accelerators.append("mps") if torch.backends.mkl.is_available(): accelerators.append("mkl") if torch.backends.mkldnn.is_available(): accelerators.append("mkldnn") if torch.backends.openmp.is_available(): accelerators.append("openmp") python = { "version": ".".join([str(k) for k in sys.version_info[:3]]), "path": sys.executable, } return { "datetime": current_datetime, "package-name": package_name, "package-version": current_version, "python": python, "dependencies": dependencies, "user": __import__("getpass").getuser(), "conda-env": os.environ.get("CONDA_DEFAULT_ENV", ""), "path": os.path.realpath(os.curdir), "command-line": " ".join(args), "hostname": __import__("platform").node(), "platform": sys.platform, "accelerators": accelerators, }
[docs] class NumpyJSONEncoder(json.JSONEncoder): """Extends the standard JSON encoder to support Numpy arrays."""
[docs] def default(self, o: typing.Any) -> typing.Any: """If input object is a ndarray it will be converted into a list. Parameters ---------- o Input object to be JSON serialized. Returns ------- A serializable representation of object ``o``. """ if isinstance(o, numpy.ndarray): try: retval = o.tolist() except TypeError: pass else: return retval # Let the base class default method raise the TypeError return super().default(o)
[docs] def save_json_with_backup(path: pathlib.Path, data: JSONable) -> None: """Save a dictionary into a JSON file with path checking and backup. This function will save a dictionary into a JSON file. It will check to the existence of the directory leading to the file and create it if necessary. If the file already exists on the destination folder, it is backed-up before a new file is created with the new contents. Parameters ---------- path The full path where to save the JSON data. data The data to save on the JSON file. """ path.parent.mkdir(parents=True, exist_ok=True) if path.exists(): backup = path.parent / (path.name + "~") shutil.copy(path, backup) with path.open("w") as f: json.dump(data, f, indent=2, cls=NumpyJSONEncoder)
[docs] def save_json_metadata( output_file: pathlib.Path, **kwargs: typing.Any, ) -> None: # numpydoc ignore=PR01 """Save prediction hyperparameters into a .json file.""" from ..data.datamodule import ConcatDataModule from ..engine.device import DeviceManager from ..models.model import Model from .utils import ( device_properties, execution_metadata, model_summary, save_json_with_backup, ) json_data: dict[str, typing.Any] = execution_metadata() for key, value in kwargs.items(): match value: case ConcatDataModule(): json_data["database_name"] = value.database_name json_data["database_split"] = value.split_name case Model(): json_data["model"] = f"{type(value).__module__}.{type(value).__name__}" json_data.update(model_summary(value)) case pathlib.Path(): json_data[key] = str(value) case DeviceManager(): json_data.update(device_properties(value.device_type)) case list() if key == "augmentations": if len(value) != 0: json_data[key] = [f"{type(k).__module__}.{str(k)}" for k in value] else: json_data[key] = [] case _: json_data[key] = value json_data = {k.replace("_", "-"): v for k, v in json_data.items()} logger.info(f"Writing run metadata at `{output_file}`...") save_json_with_backup(output_file, json_data)