[docs]classDeviceManager:r"""Manage Lightning Accelerator and Pytorch Devices. It takes the user input, in the form of a string defined by ``[\S+][:\d[,\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``), and can translate to the right incarnation of Pytorch devices or Lightning Accelerators to interface with the various frameworks. Instances of this class also manage the environment variable ``$CUDA_VISIBLE_DEVICES`` if necessary. Parameters ---------- name The name of the device to use, in the form of a string defined by ``[\S+][:\d[,\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``). In the specific case of ``cuda``, one can also specify a device to use either by adding ``:N``, where N is the zero-indexed board number on the computer, or by setting the environment variable ``$CUDA_VISIBLE_DEVICES`` with the devices that are usable by the current process. """def__init__(self,name:SupportedPytorchDevice):parts=name.split(":",1)# make device type of the right Python typeifparts[0]notintyping.get_args(SupportedPytorchDevice):raiseValueError(f"Unsupported device-type `{parts[0]}`")self.device_type:SupportedPytorchDevice=typing.cast(SupportedPytorchDevice,parts[0],)self.device_ids:list[int]=[]iflen(parts)>1:self.device_ids=_split_int_list(parts[1])ifself.device_type=="cuda":visible_env=os.environ.get("CUDA_VISIBLE_DEVICES")ifvisible_env:visible=_split_int_list(visible_env)ifself.device_idsandvisible!=self.device_ids:logger.warning(f"${{CUDA_VISIBLE_DEVICES}}={visible} and name={name} "f"- overriding environment with value set on `name`",)else:self.device_ids=visible# make sure that it is consistent with the environmentifself.device_ids:os.environ["CUDA_VISIBLE_DEVICES"]=",".join([str(k)forkinself.device_ids],)ifself.device_typenotintyping.get_args(SupportedPytorchDevice):raiseRuntimeError(f"Unsupported device type `{self.device_type}`. "f"Supported devices types are "f"`{', '.join(typing.get_args(SupportedPytorchDevice))}`",)ifself.device_idsandself.device_typein("cpu","mps"):logger.warning(f"Cannot pin device ids if using cpu or mps backend. "f"Setting `name` to {name} is non-sensical. Ignoring...",)# check if the device_type that was set has support compiled inifself.device_type=="cuda":asserthasattr(torch,"cuda")andtorch.cuda.is_available(),(f"User asked for device = `{name}`, but CUDA support is "f"not compiled into pytorch!")ifself.device_type=="mps":assert(hasattr(torch.backends,"mps")andtorch.backends.mps.is_available()# type:ignore),(f"User asked for device = `{name}`, but MPS support is "f"not compiled into pytorch!")
[docs]deftorch_device(self)->torch.device:"""Return a representation of the torch device to use by default. .. warning:: If a list of devices is set, then this method only returns the first device. This may impact Nvidia GPU logging in the case multiple GPU cards are used. Returns ------- torch.device The **first** torch device (if a list of ids is set). """ifself.device_typein("cpu","mps"):returntorch.device(self.device_type)ifself.device_type=="cuda":ifnotself.device_ids:returntorch.device(self.device_type)returntorch.device(self.device_type,self.device_ids[0])# if you get to this point, this is an unexpected RuntimeErrorraiseRuntimeError(f"Unexpected device type {self.device_type} lacks support",)
[docs]deflightning_accelerator(self)->tuple[str,int|list[int]|str]:"""Return the lightning accelerator setup. Returns ------- accelerator The lightning accelerator to use. devices The lightning devices to use. """devices:int|list[int]|str=self.device_idsifnotdevices:devices="auto"elifself.device_type=="mps":devices=1returnself.device_type,devices