Skip to content

Commit

Permalink
Add docstring to JAX config's parse_device function
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 13, 2024
1 parent 582a668 commit 44ee0c1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ API

The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise

.. autofunction:: skrl.config.jax.parse_device

.. py:data:: skrl.config.jax.backend
:type: str
:value: "numpy"
Expand Down
16 changes: 14 additions & 2 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,26 @@ def __init__(self) -> None:
local_device_ids=self._local_rank)

@staticmethod
def parse_device(device):
def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
"""Parse the input device and return a :py:class:`~jax.Device` instance.
.. hint::
This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``).
:param device: Device specification. If the specified device is ``None`` ot it cannot be resolved,
the default available device will be returned instead.
:return: JAX Device.
"""
import jax

if isinstance(device, str):
device_type, device_index = f"{device}:0".split(':')[:2]
try:
return jax.devices(device_type)[int(device_index)]
except (RuntimeError, IndexError):
except (RuntimeError, IndexError) as e:
logger.info(f"Invalid device specification ({device}): {e}")
device = None
if device is None:
return jax.devices()[0]
Expand Down

0 comments on commit 44ee0c1

Please sign in to comment.