Source code for mednet.engine.device
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Support for switching execution devices (GPU vs CPU)."""
import logging
import os
import typing
import torch
import torch.backends
logger = logging.getLogger(__name__)
SupportedPytorchDevice: typing.TypeAlias = typing.Literal[
"cpu",
"cuda",
"mps",
]
"""List of supported pytorch devices by this library."""
def _split_int_list(s: str) -> list[int]:
"""Split a list of integers encoded in a string (e.g. "1,2,3") into a
Python list of integers (e.g. ``[1, 2, 3]``).
Parameters
----------
s
A list of integers encoded in a string.
Returns
-------
list[int]
A Python list of integers.
"""
return [int(k.strip()) for k in s.split(",")]
[docs]
class DeviceManager:
r"""Manage Lightning Accelerator and Pytorch Devices.
It takes the user input, in the form of a string defined by
``[\S+][:\d[,\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``), and can
translate to the right incarnation of Pytorch devices or Lightning
Accelerators to interface with the various frameworks.
Instances of this class also manage the environment variable
``$CUDA_VISIBLE_DEVICES`` if necessary.
Parameters
----------
name
The name of the device to use, in the form of a string defined by
``[\S+][:\d[,\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``). In
the specific case of ``cuda``, one can also specify a device to use
either by adding ``:N``, where N is the zero-indexed board number on
the computer, or by setting the environment variable
``$CUDA_VISIBLE_DEVICES`` with the devices that are usable by the
current process.
"""
def __init__(self, name: SupportedPytorchDevice):
parts = name.split(":", 1)
# make device type of the right Python type
if parts[0] not in typing.get_args(SupportedPytorchDevice):
raise ValueError(f"Unsupported device-type `{parts[0]}`")
self.device_type: SupportedPytorchDevice = typing.cast(
SupportedPytorchDevice,
parts[0],
)
self.device_ids: list[int] = []
if len(parts) > 1:
self.device_ids = _split_int_list(parts[1])
if self.device_type == "cuda":
visible_env = os.environ.get("CUDA_VISIBLE_DEVICES")
if visible_env:
visible = _split_int_list(visible_env)
if self.device_ids and visible != self.device_ids:
logger.warning(
f"${{CUDA_VISIBLE_DEVICES}}={visible} and name={name} "
f"- overriding environment with value set on `name`",
)
else:
self.device_ids = visible
# make sure that it is consistent with the environment
if self.device_ids:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(k) for k in self.device_ids],
)
if self.device_type not in typing.get_args(SupportedPytorchDevice):
raise RuntimeError(
f"Unsupported device type `{self.device_type}`. "
f"Supported devices types are "
f"`{', '.join(typing.get_args(SupportedPytorchDevice))}`",
)
if self.device_ids and self.device_type in ("cpu", "mps"):
logger.warning(
f"Cannot pin device ids if using cpu or mps backend. "
f"Setting `name` to {name} is non-sensical. Ignoring...",
)
# check if the device_type that was set has support compiled in
if self.device_type == "cuda":
assert hasattr(torch, "cuda") and torch.cuda.is_available(), (
f"User asked for device = `{name}`, but CUDA support is "
f"not compiled into pytorch!"
)
if self.device_type == "mps":
assert (
hasattr(torch.backends, "mps") and torch.backends.mps.is_available() # type:ignore
), (
f"User asked for device = `{name}`, but MPS support is "
f"not compiled into pytorch!"
)
[docs]
def torch_device(self) -> torch.device:
"""Return a representation of the torch device to use by default.
.. warning::
If a list of devices is set, then this method only returns the first
device. This may impact Nvidia GPU logging in the case multiple
GPU cards are used.
Returns
-------
torch.device
The **first** torch device (if a list of ids is set).
"""
if self.device_type in ("cpu", "mps"):
return torch.device(self.device_type)
if self.device_type == "cuda":
if not self.device_ids:
return torch.device(self.device_type)
return torch.device(self.device_type, self.device_ids[0])
# if you get to this point, this is an unexpected RuntimeError
raise RuntimeError(
f"Unexpected device type {self.device_type} lacks support",
)
[docs]
def lightning_accelerator(self) -> tuple[str, int | list[int] | str]:
"""Return the lightning accelerator setup.
Returns
-------
accelerator
The lightning accelerator to use.
devices
The lightning devices to use.
"""
devices: int | list[int] | str = self.device_ids
if not devices:
devices = "auto"
elif self.device_type == "mps":
devices = 1
return self.device_type, devices