Source code for mednet.scripts.utils
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Utilities for command-line scripts."""
import json
import logging
import pathlib
import re
import shutil
import typing
import lightning.pytorch
import lightning.pytorch.callbacks
import numpy
import torch.nn
from ..engine.device import SupportedPytorchDevice
logger = logging.getLogger(__name__)
JSONable: typing.TypeAlias = (
typing.Mapping[str, "JSONable"]
| typing.Sequence["JSONable"]
| str
| int
| float
| bool
| None
)
"""Defines types that can be encoded in a JSON string."""
[docs]
def model_summary(
model: torch.nn.Module,
) -> dict[str, int | list[tuple[str, str, int]]]:
"""Save a little summary of the model in a txt file.
Parameters
----------
model
Instance of the model for which to save the summary.
Returns
-------
tuple[lightning.pytorch.callbacks.ModelSummary, int]
A tuple with the model summary in a text format and number of parameters of the model.
"""
s = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore
model,
)
return dict(
model_summary=list(zip(s.layer_names, s.layer_types, s.param_nums)),
model_size=s.total_parameters,
)
[docs]
def device_properties(
device_type: SupportedPytorchDevice,
) -> dict[str, int | float | str]:
"""Generate information concerning hardware properties.
Parameters
----------
device_type
The type of compute device we are using.
Returns
-------
Static properties of the current machine.
"""
from ..utils.resources import cpu_constants, cuda_constants, mps_constants
retval: dict[str, int | float | str] = {}
retval.update(cpu_constants())
match device_type:
case "cpu":
pass
case "cuda":
results = cuda_constants()
if results is not None:
retval.update(results)
case "mps":
results = mps_constants()
if results is not None:
retval.update(results)
case _:
pass
return retval
[docs]
def execution_metadata() -> dict[str, int | float | str | dict[str, str] | list[str]]:
"""Produce metadata concerning the running script, in the form of a
dictionary.
This function returns potentially useful metadata concerning program
execution. It contains a certain number of preset variables.
Returns
-------
A dictionary that contains the following fields:
* ``package-name``: current package name (e.g. ``mednet``)
* ``package-version``: current package version (e.g. ``1.0.0b0``)
* ``datetime``: date and time in ISO8601 format (e.g. ``2024-02-23T18:38:09+01:00``)
* ``user``: username (e.g. ``johndoe``)
* ``conda-env``: if set, the name of the current conda environment
* ``path``: current path when executing the command
* ``command-line``: the command-line that is being run
* ``hostname``: machine hostname (e.g. ``localhost``)
* ``platform``: machine platform (e.g. ``darwin``)
* ``accelerator``: acceleration devices available (e.g. ``cuda``)
"""
import datetime
import importlib.metadata
import importlib.util
import os
import sys
args: list[str] = []
for k in sys.argv:
if " " in k:
args.append(f"'{k}'")
else:
args.append(k)
# current date time, in ISO8610 format
current_datetime = datetime.datetime.now().astimezone().isoformat()
# collects dependency information
package_name = __package__.split(".")[0] if __package__ is not None else "unknown"
requires = importlib.metadata.requires(package_name) or []
dependence_names = [re.split(r"(\=|~|!|>|<|;|\s)+", k)[0] for k in requires]
installed = {
v[0]: k for k, v in importlib.metadata.packages_distributions().items()
}
dependencies = {
k: importlib.metadata.version(k) # version number as str
for k in sorted(dependence_names)
if importlib.util.find_spec(k if k not in installed else installed[k])
is not None # if is installed
}
# checks if the current version corresponds to a dirty (uncommitted) change
# set, issues a warning to the user
current_version = importlib.metadata.version(package_name)
try:
import versioningit
actual_version = versioningit.get_version(".")
if current_version != actual_version:
logger.warning(
f"Version mismatch between current version set "
f"({current_version}) and actual version returned by "
f"versioningit ({actual_version}). This typically happens "
f"when you commit changes locally and do not re-install the "
f"package. Run `pixi update {package_name}`, `pip install -e .` "
f"or equivalent to fix this.",
)
except Exception as e:
# not in a git repo?
logger.debug(f"Error {e}")
pass
# checks if any acceleration device is present in the current platform
accelerators = [f"cpu ({torch.backends.cpu.get_cpu_capability()})"]
if torch.cuda.is_available() and torch.backends.cuda.is_built():
accelerators.append("cuda")
if torch.backends.cudnn.is_available():
accelerators.append("cudnn")
if torch.backends.mps.is_available():
accelerators.append("mps")
if torch.backends.mkl.is_available():
accelerators.append("mkl")
if torch.backends.mkldnn.is_available():
accelerators.append("mkldnn")
if torch.backends.openmp.is_available():
accelerators.append("openmp")
python = {
"version": ".".join([str(k) for k in sys.version_info[:3]]),
"path": sys.executable,
}
return {
"datetime": current_datetime,
"package-name": package_name,
"package-version": current_version,
"python": python,
"dependencies": dependencies,
"user": __import__("getpass").getuser(),
"conda-env": os.environ.get("CONDA_DEFAULT_ENV", ""),
"path": os.path.realpath(os.curdir),
"command-line": " ".join(args),
"hostname": __import__("platform").node(),
"platform": sys.platform,
"accelerators": accelerators,
}
[docs]
class NumpyJSONEncoder(json.JSONEncoder):
"""Extends the standard JSON encoder to support Numpy arrays."""
[docs]
def default(self, o: typing.Any) -> typing.Any:
"""If input object is a ndarray it will be converted into a list.
Parameters
----------
o
Input object to be JSON serialized.
Returns
-------
A serializable representation of object ``o``.
"""
if isinstance(o, numpy.ndarray):
try:
retval = o.tolist()
except TypeError:
pass
else:
return retval
# Let the base class default method raise the TypeError
return super().default(o)
[docs]
def save_json_with_backup(path: pathlib.Path, data: JSONable) -> None:
"""Save a dictionary into a JSON file with path checking and backup.
This function will save a dictionary into a JSON file. It will check to
the existence of the directory leading to the file and create it if
necessary. If the file already exists on the destination folder, it is
backed-up before a new file is created with the new contents.
Parameters
----------
path
The full path where to save the JSON data.
data
The data to save on the JSON file.
"""
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists():
backup = path.parent / (path.name + "~")
shutil.copy(path, backup)
with path.open("w") as f:
json.dump(data, f, indent=2, cls=NumpyJSONEncoder)
[docs]
def save_json_metadata(
output_file: pathlib.Path,
**kwargs: typing.Any,
) -> None: # numpydoc ignore=PR01
"""Save prediction hyperparameters into a .json file."""
from ..data.datamodule import ConcatDataModule
from ..engine.device import DeviceManager
from ..models.model import Model
from .utils import (
device_properties,
execution_metadata,
model_summary,
save_json_with_backup,
)
json_data: dict[str, typing.Any] = execution_metadata()
for key, value in kwargs.items():
match value:
case ConcatDataModule():
json_data["database_name"] = value.database_name
json_data["database_split"] = value.split_name
case Model():
json_data["model"] = f"{type(value).__module__}.{type(value).__name__}"
json_data.update(model_summary(value))
case pathlib.Path():
json_data[key] = str(value)
case DeviceManager():
json_data.update(device_properties(value.device_type))
case list() if key == "augmentations":
if len(value) != 0:
json_data[key] = [f"{type(k).__module__}.{str(k)}" for k in value]
else:
json_data[key] = []
case _:
json_data[key] = value
json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
logger.info(f"Writing run metadata at `{output_file}`...")
save_json_with_backup(output_file, json_data)