diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml index 16b5e203..40ce5f7e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -30,6 +30,7 @@ body: description: The skrl version can be obtained with the command `pip show skrl`. options: - --- + - 1.3.0 - 1.2.0 - 1.1.0 - 1.0.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index 573283a8..a9f5d769 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,26 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [1.4.0] - Unreleased +### Added +- Utilities to operate on Gymnasium spaces (`Box`, `Discrete`, `MultiDiscrete`, `Tuple` and `Dict`) +- `parse_device` static method in ML framework configuration for JAX + +### Changed +- Call agent's `pre_interaction` method during evaluation +- Use spaces utilities to process states, observations and actions for all the library components +- Update model instantiators definitions to process supported fundamental and composite Gymnasium spaces +- Make flattened tensor storage in memory the default option (revert changed introduced in version 1.3.0) +- Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9). + +### Fixed +- Moved the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN) + +### Removed +- Remove OpenAI Gym (`gym`) from dependencies and source code. **skrl** continues to support gym environments, + it is just not installed as part of the library. If it is needed, it needs to be installed manually. + Any gym-based environment wrapper must use the `convert_gym_space` space utility to operate + ## [1.3.0] - 2024-09-11 ### Added - Distributed multi-GPU and multi-node learning (JAX implementation) @@ -70,7 +90,7 @@ Summary of the most relevant features: ## [1.0.0-rc.2] - 2023-08-11 ### Added - Get truncation from `time_outs` info in Isaac Gym, Isaac Orbit and Omniverse Isaac Gym environments -- Time-limit (truncation) boostrapping in on-policy actor-critic agents +- Time-limit (truncation) bootstrapping in on-policy actor-critic agents - Model instantiators `initial_log_std` parameter to set the log standard deviation's initial value ### Changed (breaking changes) @@ -84,7 +104,7 @@ Summary of the most relevant features: - `from skrl.envs.loaders.jax import load_omniverse_isaacgym_env` ### Changed -- Drop support for versions prior to PyTorch 1.9 (1.8.0 and 1.8.1) +- Drop support for PyTorch versions prior to 1.9 (the previous supported version was 1.8) ## [1.0.0-rc.1] - 2023-07-25 ### Added @@ -177,7 +197,7 @@ to allow storing samples in memories during evaluation - Parameter `role` to model methods - Wrapper compatibility with the new OpenAI Gym environment API - Internal library colored logger -- Migrate checkpoints/models from other RL libraries to skrl models/agents +- Migrate checkpoints/models from other RL libraries to **skrl** models/agents - Configuration parameter `store_separately` to agent configuration dict - Save/load agent modules (models, optimizers, preprocessors) - Set random seed and configure deterministic behavior for reproducibility diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1017b999..cbe0703c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -54,7 +54,7 @@ Read the code a little bit and you will understand it at first glance... Also ```ini function annotation (e.g. typing) # insert an empty line - python libraries and other libraries (e.g. gym, numpy, time, etc.) + python libraries and other libraries (e.g. gymnasium, numpy, time, etc.) # insert an empty line machine learning framework modules (e.g. torch, torch.nn) # insert an empty line diff --git a/docs/source/api/agents/ddqn.rst b/docs/source/api/agents/ddqn.rst index 5208192e..095cd967 100644 --- a/docs/source/api/agents/ddqn.rst +++ b/docs/source/api/agents/ddqn.rst @@ -40,10 +40,10 @@ Learning algorithm | | :literal:`_update(...)` -| :green:`# sample a batch from memory` -| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# gradient steps` | **FOR** each gradient step up to :guilabel:`gradient_steps` **DO** +| :green:`# sample a batch from memory` +| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# compute target values` | :math:`Q' \leftarrow Q_{\phi_{target}}(s')` | :math:`Q_{_{target}} \leftarrow Q'[\underset{a}{\arg\max} \; Q_\phi(s')] \qquad` :gray:`# the only difference with DQN` diff --git a/docs/source/api/agents/dqn.rst b/docs/source/api/agents/dqn.rst index 74e2ca04..d5d72e7b 100644 --- a/docs/source/api/agents/dqn.rst +++ b/docs/source/api/agents/dqn.rst @@ -40,10 +40,10 @@ Learning algorithm | | :literal:`_update(...)` -| :green:`# sample a batch from memory` -| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# gradient steps` | **FOR** each gradient step up to :guilabel:`gradient_steps` **DO** +| :green:`# sample a batch from memory` +| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# compute target values` | :math:`Q' \leftarrow Q_{\phi_{target}}(s')` | :math:`Q_{_{target}} \leftarrow \underset{a}{\max} \; Q' \qquad` :gray:`# the only difference with DDQN` diff --git a/docs/source/api/agents/sac.rst b/docs/source/api/agents/sac.rst index b465acbb..cf1e4265 100644 --- a/docs/source/api/agents/sac.rst +++ b/docs/source/api/agents/sac.rst @@ -34,10 +34,10 @@ Learning algorithm | | :literal:`_update(...)` -| :green:`# sample a batch from memory` -| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# gradient steps` | **FOR** each gradient step up to :guilabel:`gradient_steps` **DO** +| :green:`# sample a batch from memory` +| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# compute target values` | :math:`a',\; logp' \leftarrow \pi_\theta(s')` | :math:`Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')` diff --git a/docs/source/api/config/frameworks.rst b/docs/source/api/config/frameworks.rst index d6b2d3c9..d7d5be71 100644 --- a/docs/source/api/config/frameworks.rst +++ b/docs/source/api/config/frameworks.rst @@ -86,6 +86,8 @@ API The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise +.. autofunction:: skrl.config.jax.parse_device + .. py:data:: skrl.config.jax.backend :type: str :value: "numpy" diff --git a/docs/source/api/utils.rst b/docs/source/api/utils.rst index 600adc30..feaf638c 100644 --- a/docs/source/api/utils.rst +++ b/docs/source/api/utils.rst @@ -6,6 +6,7 @@ Utils and configurations ML frameworks configuration Random seed + Spaces Model instantiators Runner Distributed runs @@ -39,6 +40,9 @@ A set of utilities and configurations for managing an RL setup is provided as pa * - :doc:`Random seed ` - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\blacksquare` + * - :doc:`Spaces ` + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\blacksquare` * - :doc:`Model instantiators ` - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/utils/spaces.rst b/docs/source/api/utils/spaces.rst new file mode 100644 index 00000000..76a86fd3 --- /dev/null +++ b/docs/source/api/utils/spaces.rst @@ -0,0 +1,86 @@ +Spaces +====== + +Utilities to operate on Gymnasium `spaces `_. + +.. raw:: html + +

+ +Overview +-------- + +The utilities described in this section supports the following Gymnasium spaces: + +.. list-table:: + :header-rows: 1 + + * - Type + - Supported spaces + * - Fundamental + - :py:class:`~gymnasium.spaces.Box`, :py:class:`~gymnasium.spaces.Discrete`, and :py:class:`~gymnasium.spaces.MultiDiscrete` + * - Composite + - :py:class:`~gymnasium.spaces.Dict` and :py:class:`~gymnasium.spaces.Tuple` + +The following table provides a snapshot of the space sample conversion functions: + +.. list-table:: + :header-rows: 1 + + * - Input + - Function + - Output + * - Space (NumPy / int) + - :py:func:`~skrl.utils.spaces.torch.tensorize_space` + - Space (PyTorch / JAX) + * - Space (PyTorch / JAX) + - :py:func:`~skrl.utils.spaces.torch.untensorize_space` + - Space (NumPy / int) + * - Space (PyTorch / JAX) + - :py:func:`~skrl.utils.spaces.torch.flatten_tensorized_space` + - PyTorch tensor / JAX array + * - PyTorch tensor / JAX array + - :py:func:`~skrl.utils.spaces.torch.unflatten_tensorized_space` + - Space (PyTorch / JAX) + +.. raw:: html + +
+ +API (PyTorch) +------------- + +.. autofunction:: skrl.utils.spaces.torch.compute_space_size + +.. autofunction:: skrl.utils.spaces.torch.convert_gym_space + +.. autofunction:: skrl.utils.spaces.torch.flatten_tensorized_space + +.. autofunction:: skrl.utils.spaces.torch.sample_space + +.. autofunction:: skrl.utils.spaces.torch.tensorize_space + +.. autofunction:: skrl.utils.spaces.torch.unflatten_tensorized_space + +.. autofunction:: skrl.utils.spaces.torch.untensorize_space + +.. raw:: html + +
+ +API (JAX) +--------- + +.. autofunction:: skrl.utils.spaces.jax.compute_space_size + +.. autofunction:: skrl.utils.spaces.jax.convert_gym_space + +.. autofunction:: skrl.utils.spaces.jax.flatten_tensorized_space + +.. autofunction:: skrl.utils.spaces.jax.sample_space + +.. autofunction:: skrl.utils.spaces.jax.tensorize_space + +.. autofunction:: skrl.utils.spaces.jax.unflatten_tensorized_space + +.. autofunction:: skrl.utils.spaces.jax.untensorize_space diff --git a/docs/source/index.rst b/docs/source/index.rst index d024c64d..9e5f9053 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -197,6 +197,7 @@ Utils and configurations * :doc:`ML frameworks ` configuration * :doc:`Random seed ` + * :doc:`Spaces ` * :doc:`Model instantiators ` * :doc:`Runner ` * :doc:`Distributed runs ` diff --git a/docs/source/intro/installation.rst b/docs/source/intro/installation.rst index 1c8fd3fc..dd04e765 100644 --- a/docs/source/intro/installation.rst +++ b/docs/source/intro/installation.rst @@ -12,10 +12,10 @@ In this section, you will find the steps to install the library, troubleshoot kn **skrl** requires Python 3.6 or higher and the following libraries (they will be installed automatically): - * `gym `_ / `gymnasium `_ - * `tqdm `_ + * `gymnasium `_ * `packaging `_ * `tensorboard `_ + * `tqdm `_ Machine learning (ML) framework ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -25,7 +25,7 @@ According to the specific ML frameworks, the following libraries are required: PyTorch """"""" - * `torch `_ 1.9.0 or higher + * `torch `_ 1.10.0 or higher JAX """ diff --git a/docs/source/snippets/agent.py b/docs/source/snippets/agent.py index b0f0186f..f5613645 100644 --- a/docs/source/snippets/agent.py +++ b/docs/source/snippets/agent.py @@ -1,7 +1,7 @@ # [start-agent-base-class-torch] from typing import Union, Tuple, Dict, Any, Optional -import gym, gymnasium +import gymnasium import copy import torch @@ -33,8 +33,8 @@ class CUSTOM(Agent): def __init__(self, models: Dict[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Custom agent @@ -46,9 +46,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: None) - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None) - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a torch tensor is or will be allocated (default: ``None``). If None, the device will be either ``"cuda:0"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -179,7 +179,7 @@ def _update(self, timestep: int, timesteps: int) -> None: # [start-agent-base-class-jax] from typing import Union, Tuple, Dict, Any, Optional -import gym, gymnasium +import gymnasium import copy import jaxlib @@ -213,8 +213,8 @@ class CUSTOM(Agent): def __init__(self, models: Dict[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jaxlib.xla_extension.Device]] = None, cfg: Optional[dict] = None) -> None: """Custom agent @@ -226,9 +226,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: None) - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None) - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a jax array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda:0"`` if available or ``"cpu"`` :type device: str or jaxlib.xla_extension.Device, optional diff --git a/docs/source/snippets/model_mixin.py b/docs/source/snippets/model_mixin.py index f5b0872d..28ad4ce2 100644 --- a/docs/source/snippets/model_mixin.py +++ b/docs/source/snippets/model_mixin.py @@ -1,7 +1,7 @@ # [start-model-torch] from typing import Optional, Union, Mapping, Sequence, Tuple, Any -import gym, gymnasium +import gymnasium import torch @@ -10,17 +10,17 @@ class CustomModel(Model): def __init__(self, - observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space], - action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space], + observation_space: Union[int, Sequence[int], gymnasium.Space], + action_space: Union[int, Sequence[int], gymnasium.Space], device: Optional[Union[str, torch.device]] = None) -> None: """Custom model :param observation_space: Observation/state space or shape. The ``num_observations`` property will contain the size of that space - :type observation_space: int, sequence of int, gym.Space, gymnasium.Space + :type observation_space: int, sequence of int, gymnasium.Space :param action_space: Action space or shape. The ``num_actions`` property will contain the size of that space - :type action_space: int, sequence of int, gym.Space, gymnasium.Space + :type action_space: int, sequence of int, gymnasium.Space :param device: Device on which a torch tensor is or will be allocated (default: ``None``). If None, the device will be either ``"cuda:0"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -58,7 +58,7 @@ def act(self, # [start-model-jax] from typing import Optional, Union, Mapping, Tuple, Any -import gym, gymnasium +import gymnasium import flax import jaxlib @@ -69,8 +69,8 @@ def act(self, class CustomModel(Model): def __init__(self, - observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space], - action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space], + observation_space: Union[int, Sequence[int], gymnasium.Space], + action_space: Union[int, Sequence[int], gymnasium.Space], device: Optional[Union[str, jaxlib.xla_extension.Device]] = None, parent: Optional[Any] = None, name: Optional[str] = None) -> None: @@ -78,10 +78,10 @@ def __init__(self, :param observation_space: Observation/state space or shape. The ``num_observations`` property will contain the size of that space - :type observation_space: int, sequence of int, gym.Space, gymnasium.Space + :type observation_space: int, sequence of int, gymnasium.Space :param action_space: Action space or shape. The ``num_actions`` property will contain the size of that space - :type action_space: int, sequence of int, gym.Space, gymnasium.Space + :type action_space: int, sequence of int, gymnasium.Space :param device: Device on which a jax array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda:0"`` if available or ``"cpu"`` :type device: str or jaxlib.xla_extension.Device, optional diff --git a/docs/source/snippets/multi_agent.py b/docs/source/snippets/multi_agent.py index 175f639f..0644a2d6 100644 --- a/docs/source/snippets/multi_agent.py +++ b/docs/source/snippets/multi_agent.py @@ -1,7 +1,7 @@ # [start-multi-agent-base-class-torch] from typing import Union, Dict, Any, Optional, Sequence, Mapping -import gym, gymnasium +import gymnasium import copy import torch @@ -34,8 +34,8 @@ def __init__(self, possible_agents: Sequence[str], models: Dict[str, Model], memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Custom multi-agent @@ -48,9 +48,9 @@ def __init__(self, :param memories: Memories to storage the transitions. :type memories: dictionary of skrl.memory.torch.Memory, optional :param observation_spaces: Observation/state spaces or shapes (default: ``None``) - :type observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param action_spaces: Action spaces or shapes (default: ``None``) - :type action_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type action_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param device: Device on which a torch tensor is or will be allocated (default: ``None``). If None, the device will be either ``"cuda:0"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -182,7 +182,7 @@ def _update(self, timestep: int, timesteps: int) -> None: # [start-multi-agent-base-class-jax] from typing import Union, Dict, Any, Optional, Sequence, Mapping -import gym, gymnasium +import gymnasium import copy import jaxlib @@ -217,8 +217,8 @@ def __init__(self, possible_agents: Sequence[str], models: Dict[str, Model], memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, device: Optional[Union[str, jaxlib.xla_extension.Device]] = None, cfg: Optional[dict] = None) -> None: """Custom multi-agent @@ -231,9 +231,9 @@ def __init__(self, :param memories: Memories to storage the transitions. :type memories: dictionary of skrl.memory.torch.Memory, optional :param observation_spaces: Observation/state spaces or shapes (default: ``None``) - :type observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param action_spaces: Action spaces or shapes (default: ``None``) - :type action_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type action_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param device: Device on which a jax array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda:0"`` if available or ``"cpu"`` :type device: str or jaxlib.xla_extension.Device, optional diff --git a/pyproject.toml b/pyproject.toml index 408d1c48..d3d04cb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "skrl" -version = "1.3.0" +version = "1.4.0" description = "Modular and flexible library for reinforcement learning on PyTorch and JAX" readme = "README.md" requires-python = ">=3.6" @@ -22,15 +22,14 @@ classifiers = [ ] # dependencies / optional-dependencies dependencies = [ - "gym", "gymnasium", - "tqdm", "packaging", "tensorboard", + "tqdm", ] [project.optional-dependencies] torch = [ - "torch>=1.9", + "torch>=1.10", ] jax = [ "jax>=0.4.3", diff --git a/skrl/__init__.py b/skrl/__init__.py index 9da176a8..7144c3eb 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -140,6 +140,31 @@ def __init__(self) -> None: process_id=self._rank, local_device_ids=self._local_rank) + @staticmethod + def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": + """Parse the input device and return a :py:class:`~jax.Device` instance. + + .. hint:: + + This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``). + + :param device: Device specification. If the specified device is ``None`` ot it cannot be resolved, + the default available device will be returned instead. + + :return: JAX Device. + """ + import jax + + if isinstance(device, jax.Device): + return device + elif isinstance(device, str): + device_type, device_index = f"{device}:0".split(':')[:2] + try: + return jax.devices(device_type)[int(device_index)] + except (RuntimeError, IndexError) as e: + logger.warning(f"Invalid device specification ({device}): {e}") + return jax.devices()[0] + @property def device(self) -> "jax.Device": """Default device @@ -147,18 +172,7 @@ def device(self) -> "jax.Device": The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise """ - try: - import jax - if type(self._device) == str: - device_type, device_index = f"{self._device}:0".split(':')[:2] - try: - self._device = jax.devices(device_type)[int(device_index)] - except (RuntimeError, IndexError): - self._device = None - if self._device is None: - self._device = jax.devices()[0] - except ImportError: - pass + self._device = self.parse_device(self._device) return self._device @device.setter diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py index 81dc6ef1..16b533a6 100644 --- a/skrl/agents/jax/a2c/a2c.py +++ b/skrl/agents/jax/a2c/a2c.py @@ -2,7 +2,6 @@ import copy import functools -import gym import gymnasium import jax @@ -172,8 +171,8 @@ class A2C(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Advantage Actor Critic (A2C) @@ -187,9 +186,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional diff --git a/skrl/agents/jax/base.py b/skrl/agents/jax/base.py index 319ab2c3..1d9ddc8f 100644 --- a/skrl/agents/jax/base.py +++ b/skrl/agents/jax/base.py @@ -5,7 +5,6 @@ import datetime import os import pickle -import gym import gymnasium import flax @@ -21,8 +20,8 @@ class Agent: def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Base class that represent a RL agent @@ -34,9 +33,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional diff --git a/skrl/agents/jax/cem/cem.py b/skrl/agents/jax/cem/cem.py index a597cc0e..f47f5b1f 100644 --- a/skrl/agents/jax/cem/cem.py +++ b/skrl/agents/jax/cem/cem.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Tuple, Union import copy -import gym import gymnasium import jax @@ -54,8 +53,8 @@ class CEM(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Cross-Entropy Method (CEM) @@ -69,9 +68,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional diff --git a/skrl/agents/jax/ddpg/ddpg.py b/skrl/agents/jax/ddpg/ddpg.py index a7865aeb..efec21d0 100644 --- a/skrl/agents/jax/ddpg/ddpg.py +++ b/skrl/agents/jax/ddpg/ddpg.py @@ -2,7 +2,6 @@ import copy import functools -import gym import gymnasium import jax @@ -114,8 +113,8 @@ class DDPG(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Deep Deterministic Policy Gradient (DDPG) @@ -129,9 +128,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional diff --git a/skrl/agents/jax/dqn/ddqn.py b/skrl/agents/jax/dqn/ddqn.py index f11ebe9b..76868e68 100644 --- a/skrl/agents/jax/dqn/ddqn.py +++ b/skrl/agents/jax/dqn/ddqn.py @@ -2,7 +2,6 @@ import copy import functools -import gym import gymnasium import jax @@ -92,8 +91,8 @@ class DDQN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Double Deep Q-Network (DDQN) @@ -107,9 +106,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional @@ -337,13 +336,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py index e75247d7..30629c09 100644 --- a/skrl/agents/jax/dqn/dqn.py +++ b/skrl/agents/jax/dqn/dqn.py @@ -2,7 +2,6 @@ import copy import functools -import gym import gymnasium import jax @@ -89,8 +88,8 @@ class DQN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Deep Q-Network (DQN) @@ -104,9 +103,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional @@ -334,13 +333,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py index cfa7c7b7..7fde2472 100644 --- a/skrl/agents/jax/ppo/ppo.py +++ b/skrl/agents/jax/ppo/ppo.py @@ -2,7 +2,6 @@ import copy import functools -import gym import gymnasium import jax @@ -191,8 +190,8 @@ class PPO(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Proximal Policy Optimization (PPO) @@ -206,9 +205,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py index 1520081b..c0373627 100644 --- a/skrl/agents/jax/rpo/rpo.py +++ b/skrl/agents/jax/rpo/rpo.py @@ -2,7 +2,6 @@ import copy import functools -import gym import gymnasium import jax @@ -194,8 +193,8 @@ class RPO(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Robust Policy Optimization (RPO) @@ -209,9 +208,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py index 8f77da3b..7656c26d 100644 --- a/skrl/agents/jax/sac/sac.py +++ b/skrl/agents/jax/sac/sac.py @@ -2,7 +2,6 @@ import copy import functools -import gym import gymnasium import flax @@ -125,8 +124,8 @@ class SAC(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Soft Actor-Critic (SAC) @@ -140,9 +139,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional @@ -213,9 +212,9 @@ def __init__(self, if self._learn_entropy: self._target_entropy = self.cfg["target_entropy"] if self._target_entropy is None: - if issubclass(type(self.action_space), gym.spaces.Box) or issubclass(type(self.action_space), gymnasium.spaces.Box): + if issubclass(type(self.action_space), gymnasium.spaces.Box): self._target_entropy = -np.prod(self.action_space.shape).astype(np.float32) - elif issubclass(type(self.action_space), gym.spaces.Discrete) or issubclass(type(self.action_space), gymnasium.spaces.Discrete): + elif issubclass(type(self.action_space), gymnasium.spaces.Discrete): self._target_entropy = -self.action_space.n else: self._target_entropy = 0 @@ -397,13 +396,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index 6a1c909c..23f4885a 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -2,7 +2,6 @@ import copy import functools -import gym import gymnasium import jax @@ -132,8 +131,8 @@ class TD3(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Twin Delayed DDPG (TD3) @@ -147,9 +146,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional @@ -436,6 +435,7 @@ def _update(self, timestep: int, timesteps: int) -> None: # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index d5b80687..9c9cf9b6 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -62,8 +61,8 @@ class A2C(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Advantage Actor Critic (A2C) @@ -77,9 +76,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index b21f39e7..97cc93e1 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -62,8 +61,8 @@ class A2C_RNN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Advantage Actor Critic (A2C) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) @@ -77,9 +76,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index 3db6f5e0..181e5ac6 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -3,7 +3,6 @@ import copy import itertools import math -import gym import gymnasium import torch @@ -79,11 +78,11 @@ class AMP(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None, - amp_observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + amp_observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, motion_dataset: Optional[Memory] = None, reply_buffer: Optional[Memory] = None, collect_reference_motions: Optional[Callable[[int], torch.Tensor]] = None, @@ -102,16 +101,16 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional :param cfg: Configuration dictionary :type cfg: dict :param amp_observation_space: AMP observation/state space or shape (default: ``None``) - :type amp_observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None + :type amp_observation_space: int, tuple or list of int, gymnasium.Space or None :param motion_dataset: Reference motion dataset: M (default: ``None``) :type motion_dataset: skrl.memory.torch.Memory or None :param reply_buffer: Reply buffer for preventing discriminator overfitting: B (default: ``None``) diff --git a/skrl/agents/torch/base.py b/skrl/agents/torch/base.py index aa94dc1a..d7c29e19 100644 --- a/skrl/agents/torch/base.py +++ b/skrl/agents/torch/base.py @@ -4,7 +4,6 @@ import copy import datetime import os -import gym import gymnasium from packaging import version @@ -21,8 +20,8 @@ class Agent: def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Base class that represent a RL agent @@ -34,9 +33,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py index 06f5ba24..864b6b20 100644 --- a/skrl/agents/torch/cem/cem.py +++ b/skrl/agents/torch/cem/cem.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Tuple, Union import copy -import gym import gymnasium import torch @@ -51,8 +50,8 @@ class CEM(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Cross-Entropy Method (CEM) @@ -66,9 +65,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py index 88f6ccc9..1e3d2690 100644 --- a/skrl/agents/torch/ddpg/ddpg.py +++ b/skrl/agents/torch/ddpg/ddpg.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Tuple, Union import copy -import gym import gymnasium import torch @@ -63,8 +62,8 @@ class DDPG(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Deep Deterministic Policy Gradient (DDPG) @@ -78,9 +77,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index 436184c1..36a98fee 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Tuple, Union import copy -import gym import gymnasium import torch @@ -63,8 +62,8 @@ class DDPG_RNN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Deep Deterministic Policy Gradient (DDPG) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) @@ -78,9 +77,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -361,18 +360,19 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] - - rnn_policy = {} - if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + + rnn_policy = {} + if self._rnn: + sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py index 8d181ac8..d7e93886 100644 --- a/skrl/agents/torch/dqn/ddqn.py +++ b/skrl/agents/torch/dqn/ddqn.py @@ -2,7 +2,6 @@ import copy import math -import gym import gymnasium import torch @@ -62,8 +61,8 @@ class DDQN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Double Deep Q-Network (DDQN) @@ -77,9 +76,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -284,13 +283,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index c7f4f709..03ffa320 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -2,7 +2,6 @@ import copy import math -import gym import gymnasium import torch @@ -62,8 +61,8 @@ class DQN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Deep Q-Network (DQN) @@ -77,9 +76,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -284,13 +283,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 3143505f..21124fdf 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -50,6 +49,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward "time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation) + "mixed_precision": False, # mixed torch.float32 and torch.float16 precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -69,8 +70,8 @@ class PPO(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Proximal Policy Optimization (PPO) @@ -84,9 +85,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -151,6 +152,12 @@ def __init__(self, self._rewards_shaper = self.cfg["rewards_shaper"] self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self._scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None: if self.policy is self.value: @@ -219,8 +226,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") - self._current_log_prob = log_prob + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + self._current_log_prob = log_prob return actions, log_prob, outputs @@ -265,8 +273,9 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # compute values - values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") - values = self._value_preprocessor(values, inverse=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") + values = self._value_preprocessor(values, inverse=True) # time-limit (truncation) boostrapping if self._time_limit_bootstrap: @@ -388,55 +397,62 @@ def compute_gae(rewards: torch.Tensor, # mini-batches loop for sampled_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages in sampled_batches: - sampled_states = self._state_preprocessor(sampled_states, train=not epoch) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions}, role="policy") + sampled_states = self._state_preprocessor(sampled_states, train=not epoch) - # compute approximate KL divergence - with torch.no_grad(): - ratio = next_log_prob - sampled_log_prob - kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() - kl_divergences.append(kl_divergence) + _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions}, role="policy") - # early stopping with KL divergence - if self._kl_threshold and kl_divergence > self._kl_threshold: - break + # compute approximate KL divergence + with torch.no_grad(): + ratio = next_log_prob - sampled_log_prob + kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() + kl_divergences.append(kl_divergence) - # compute entropy loss - if self._entropy_loss_scale: - entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() - else: - entropy_loss = 0 + # early stopping with KL divergence + if self._kl_threshold and kl_divergence > self._kl_threshold: + break - # compute policy loss - ratio = torch.exp(next_log_prob - sampled_log_prob) - surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip) + # compute entropy loss + if self._entropy_loss_scale: + entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() + else: + entropy_loss = 0 + + # compute policy loss + ratio = torch.exp(next_log_prob - sampled_log_prob) + surrogate = sampled_advantages * ratio + surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip) - policy_loss = -torch.min(surrogate, surrogate_clipped).mean() + policy_loss = -torch.min(surrogate, surrogate_clipped).mean() - # compute value loss - predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value") + # compute value loss + predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value") - if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip(predicted_values - sampled_values, - min=-self._value_clip, - max=self._value_clip) - value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) + if self._clip_predicted_values: + predicted_values = sampled_values + torch.clip(predicted_values - sampled_values, + min=-self._value_clip, + max=self._value_clip) + value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) # optimization step self.optimizer.zero_grad() - (policy_loss + entropy_loss + value_loss).backward() + self._scaler.scale(policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() if self.policy is not self.value: self.value.reduce_parameters() + if self._grad_norm_clip > 0: + self._scaler.unscale_(self.optimizer) if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip) - self.optimizer.step() + + self._scaler.step(self.optimizer) + self._scaler.update() # update cumulative losses cumulative_policy_loss += policy_loss.item() diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index 48a488ff..59d8cbe2 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -69,8 +68,8 @@ class PPO_RNN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Proximal Policy Optimization (PPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) @@ -84,9 +83,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/q_learning/q_learning.py b/skrl/agents/torch/q_learning/q_learning.py index 90da91cc..f9dd3442 100644 --- a/skrl/agents/torch/q_learning/q_learning.py +++ b/skrl/agents/torch/q_learning/q_learning.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Tuple, Union import copy -import gym import gymnasium import torch @@ -41,8 +40,8 @@ class Q_LEARNING(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Q-learning @@ -56,9 +55,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index a0024d13..1cc4a18d 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -70,8 +69,8 @@ class RPO(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Robust Policy Optimization (RPO) @@ -85,9 +84,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index efda1fed..5f1ee485 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -70,8 +69,8 @@ class RPO_RNN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Robust Policy Optimization (RPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) @@ -85,9 +84,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index dc8678a6..78f4a556 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import numpy as np @@ -63,8 +62,8 @@ class SAC(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Soft Actor-Critic (SAC) @@ -78,9 +77,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -159,9 +158,9 @@ def __init__(self, if self._learn_entropy: self._target_entropy = self.cfg["target_entropy"] if self._target_entropy is None: - if issubclass(type(self.action_space), gym.spaces.Box) or issubclass(type(self.action_space), gymnasium.spaces.Box): + if issubclass(type(self.action_space), gymnasium.spaces.Box): self._target_entropy = -np.prod(self.action_space.shape).astype(np.float32) - elif issubclass(type(self.action_space), gym.spaces.Discrete) or issubclass(type(self.action_space), gymnasium.spaces.Discrete): + elif issubclass(type(self.action_space), gymnasium.spaces.Discrete): self._target_entropy = -self.action_space.n else: self._target_entropy = 0 @@ -308,13 +307,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index 6553d7aa..4d8764b4 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import numpy as np @@ -63,8 +62,8 @@ class SAC_RNN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Soft Actor-Critic (SAC) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) @@ -78,9 +77,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -159,9 +158,9 @@ def __init__(self, if self._learn_entropy: self._target_entropy = self.cfg["target_entropy"] if self._target_entropy is None: - if issubclass(type(self.action_space), gym.spaces.Box) or issubclass(type(self.action_space), gymnasium.spaces.Box): + if issubclass(type(self.action_space), gymnasium.spaces.Box): self._target_entropy = -np.prod(self.action_space.shape).astype(np.float32) - elif issubclass(type(self.action_space), gym.spaces.Discrete) or issubclass(type(self.action_space), gymnasium.spaces.Discrete): + elif issubclass(type(self.action_space), gymnasium.spaces.Discrete): self._target_entropy = -self.action_space.n else: self._target_entropy = 0 @@ -345,18 +344,19 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] - - rnn_policy = {} - if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + + rnn_policy = {} + if self._rnn: + sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/sarsa/sarsa.py b/skrl/agents/torch/sarsa/sarsa.py index 5abc260a..9f27bc3a 100644 --- a/skrl/agents/torch/sarsa/sarsa.py +++ b/skrl/agents/torch/sarsa/sarsa.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Tuple, Union import copy -import gym import gymnasium import torch @@ -41,8 +40,8 @@ class SARSA(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """State Action Reward State Action (SARSA) @@ -56,9 +55,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py index abbcc467..2b791994 100644 --- a/skrl/agents/torch/td3/td3.py +++ b/skrl/agents/torch/td3/td3.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -68,8 +67,8 @@ class TD3(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Twin Delayed DDPG (TD3) @@ -83,9 +82,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index abd6a922..aeb2fd78 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -68,8 +67,8 @@ class TD3_RNN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Twin Delayed DDPG (TD3) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) @@ -83,9 +82,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -383,18 +382,19 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] - - rnn_policy = {} - if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + + rnn_policy = {} + if self._rnn: + sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py index 96565bef..32a3b34c 100644 --- a/skrl/agents/torch/trpo/trpo.py +++ b/skrl/agents/torch/trpo/trpo.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Tuple, Union import copy -import gym import gymnasium import torch @@ -68,8 +67,8 @@ class TRPO(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Trust Region Policy Optimization (TRPO) @@ -83,9 +82,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py index 4b3ad05a..3599223f 100644 --- a/skrl/agents/torch/trpo/trpo_rnn.py +++ b/skrl/agents/torch/trpo/trpo_rnn.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Tuple, Union import copy -import gym import gymnasium import torch @@ -68,8 +67,8 @@ class TRPO_RNN(Agent): def __init__(self, models: Mapping[str, Model], memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Trust Region Policy Optimization (TRPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) @@ -83,9 +82,9 @@ def __init__(self, for the rest only the environment transitions will be added :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None :param observation_space: Observation/state space or shape (default: ``None``) - :type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of int, gymnasium.Space or None, optional :param action_space: Action space or shape (default: ``None``) - :type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of int, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/envs/loaders/torch/isaaclab_envs.py b/skrl/envs/loaders/torch/isaaclab_envs.py index fe5c47bb..643ddf26 100644 --- a/skrl/envs/loaders/torch/isaaclab_envs.py +++ b/skrl/envs/loaders/torch/isaaclab_envs.py @@ -1,6 +1,5 @@ from typing import Optional, Sequence -import os import sys from skrl import logger @@ -61,11 +60,11 @@ def load_isaaclab_env(task_name: str = "", :raises ValueError: The task name has not been defined, neither by the function parameter nor by the command line arguments :return: Isaac Lab environment - :rtype: gym.Env + :rtype: gymnasium.Env """ import argparse import atexit - import gymnasium as gym + import gymnasium # check task from command line arguments defined = False @@ -154,6 +153,6 @@ def close_the_simulator(): pass # load environment - env = gym.make(args.task, cfg=cfg, render_mode="rgb_array" if args.video else None) + env = gymnasium.make(args.task, cfg=cfg, render_mode="rgb_array" if args.video else None) return env diff --git a/skrl/envs/wrappers/jax/__init__.py b/skrl/envs/wrappers/jax/__init__.py index 511a4c09..48250a98 100644 --- a/skrl/envs/wrappers/jax/__init__.py +++ b/skrl/envs/wrappers/jax/__init__.py @@ -28,7 +28,7 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra >>> env = wrap_env(env) :param env: The environment to be wrapped - :type env: gym.Env, gymnasium.Env, dm_env.Environment or VecTask + :type env: Any :param wrapper: The type of wrapper to use (default: ``"auto"``). If ``"auto"``, the wrapper will be automatically selected based on the environment class. The supported wrappers are described in the following table: diff --git a/skrl/envs/wrappers/jax/base.py b/skrl/envs/wrappers/jax/base.py index 3920345c..a7a2bc2d 100644 --- a/skrl/envs/wrappers/jax/base.py +++ b/skrl/envs/wrappers/jax/base.py @@ -1,6 +1,6 @@ from typing import Any, Mapping, Sequence, Tuple, Union -import gym +import gymnasium import jax import numpy as np @@ -132,7 +132,7 @@ def num_agents(self) -> int: return self._unwrapped.num_agents if hasattr(self._unwrapped, "num_agents") else 1 @property - def state_space(self) -> Union[gym.Space, None]: + def state_space(self) -> Union[gymnasium.Space, None]: """State space If the wrapped environment does not have the ``state_space`` property, ``None`` will be returned @@ -140,13 +140,13 @@ def state_space(self) -> Union[gym.Space, None]: return self._unwrapped.state_space if hasattr(self._unwrapped, "state_space") else None @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gymnasium.Space: """Observation space """ return self._unwrapped.observation_space @property - def action_space(self) -> gym.Space: + def action_space(self) -> gymnasium.Space: """Action space """ return self._unwrapped.action_space @@ -307,7 +307,7 @@ def possible_agents(self) -> Sequence[str]: return self._unwrapped.possible_agents @property - def state_spaces(self) -> Mapping[str, gym.Space]: + def state_spaces(self) -> Mapping[str, gymnasium.Space]: """State spaces Since the state space is a global view of the environment (and therefore the same for all the agents), @@ -318,18 +318,18 @@ def state_spaces(self) -> Mapping[str, gym.Space]: return {agent: space for agent in self.possible_agents} @property - def observation_spaces(self) -> Mapping[str, gym.Space]: + def observation_spaces(self) -> Mapping[str, gymnasium.Space]: """Observation spaces """ return self._unwrapped.observation_spaces @property - def action_spaces(self) -> Mapping[str, gym.Space]: + def action_spaces(self) -> Mapping[str, gymnasium.Space]: """Action spaces """ return self._unwrapped.action_spaces - def state_space(self, agent: str) -> gym.Space: + def state_space(self, agent: str) -> gymnasium.Space: """State space Since the state space is a global view of the environment (and therefore the same for all the agents), @@ -340,28 +340,28 @@ def state_space(self, agent: str) -> gym.Space: :type agent: str :return: The state space for the specified agent - :rtype: gym.Space + :rtype: gymnasium.Space """ return self.state_spaces[agent] - def observation_space(self, agent: str) -> gym.Space: + def observation_space(self, agent: str) -> gymnasium.Space: """Observation space :param agent: Name of the agent :type agent: str :return: The observation space for the specified agent - :rtype: gym.Space + :rtype: gymnasium.Space """ return self.observation_spaces[agent] - def action_space(self, agent: str) -> gym.Space: + def action_space(self, agent: str) -> gymnasium.Space: """Action space :param agent: Name of the agent :type agent: str :return: The action space for the specified agent - :rtype: gym.Space + :rtype: gymnasium.Space """ return self.action_spaces[agent] diff --git a/skrl/envs/wrappers/jax/bidexhands_envs.py b/skrl/envs/wrappers/jax/bidexhands_envs.py index eadae5e9..e292047c 100644 --- a/skrl/envs/wrappers/jax/bidexhands_envs.py +++ b/skrl/envs/wrappers/jax/bidexhands_envs.py @@ -1,6 +1,6 @@ from typing import Any, Mapping, Sequence, Tuple, Union -import gym +import gymnasium import jax import jax.dlpack as jax_dlpack @@ -14,6 +14,7 @@ pass # TODO: show warning message from skrl.envs.wrappers.jax.base import MultiAgentEnvWrapper +from skrl.utils.spaces.jax import convert_gym_space # ML frameworks conversion utilities @@ -62,26 +63,26 @@ def possible_agents(self) -> Sequence[str]: return [f"agent_{i}" for i in range(self.num_agents)] @property - def state_spaces(self) -> Mapping[str, gym.Space]: + def state_spaces(self) -> Mapping[str, gymnasium.Space]: """State spaces Since the state space is a global view of the environment (and therefore the same for all the agents), this property returns a dictionary (for consistency with the other space-related properties) with the same space for all the agents """ - return {uid: space for uid, space in zip(self.possible_agents, self._env.share_observation_space)} + return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.share_observation_space)} @property - def observation_spaces(self) -> Mapping[str, gym.Space]: + def observation_spaces(self) -> Mapping[str, gymnasium.Space]: """Observation spaces """ - return {uid: space for uid, space in zip(self.possible_agents, self._env.observation_space)} + return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.observation_space)} @property - def action_spaces(self) -> Mapping[str, gym.Space]: + def action_spaces(self) -> Mapping[str, gymnasium.Space]: """Action spaces """ - return {uid: space for uid, space in zip(self.possible_agents, self._env.action_space)} + return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.action_space)} def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], diff --git a/skrl/envs/wrappers/jax/brax_envs.py b/skrl/envs/wrappers/jax/brax_envs.py index 1e052764..8fa10a29 100644 --- a/skrl/envs/wrappers/jax/brax_envs.py +++ b/skrl/envs/wrappers/jax/brax_envs.py @@ -8,6 +8,12 @@ from skrl import logger from skrl.envs.wrappers.jax.base import Wrapper +from skrl.utils.spaces.jax import ( + convert_gym_space, + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space +) class BraxWrapper(Wrapper): @@ -28,15 +34,13 @@ def __init__(self, env: Any) -> None: def observation_space(self) -> gymnasium.Space: """Observation space """ - limit = np.inf * np.ones(self._unwrapped.observation_space.shape[1:], dtype='float32') - return gymnasium.spaces.Box(-limit, limit, dtype='float32') + return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True) @property def action_space(self) -> gymnasium.Space: """Action space """ - limit = np.inf * np.ones(self._unwrapped.action_space.shape[1:], dtype='float32') - return gymnasium.spaces.Box(-limit, limit, dtype='float32') + return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True) def step(self, actions: Union[np.ndarray, jax.Array]) -> \ Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], @@ -49,7 +53,8 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ :return: Observation, reward, terminated, truncated, info :rtype: tuple of np.ndarray or jax.Array and any other info """ - observation, reward, terminated, info = self._env.step(actions) + observation, reward, terminated, info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) truncated = jnp.zeros_like(terminated) if not self._jax: observation = np.asarray(jax.device_get(observation)) @@ -65,6 +70,7 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: :rtype: np.ndarray or jax.Array and any other info """ observation = self._env.reset() + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) if not self._jax: observation = np.asarray(jax.device_get(observation)) return observation, {} diff --git a/skrl/envs/wrappers/jax/gym_envs.py b/skrl/envs/wrappers/jax/gym_envs.py index ce006493..412a6435 100644 --- a/skrl/envs/wrappers/jax/gym_envs.py +++ b/skrl/envs/wrappers/jax/gym_envs.py @@ -1,6 +1,6 @@ -from typing import Any, Optional, Tuple, Union +from typing import Any, Tuple, Union -import gym +import gymnasium from packaging import version import jax @@ -8,6 +8,13 @@ from skrl import logger from skrl.envs.wrappers.jax.base import Wrapper +from skrl.utils.spaces.jax import ( + convert_gym_space, + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) class GymWrapper(Wrapper): @@ -25,6 +32,7 @@ def __init__(self, env: Any) -> None: except AttributeError: np.bool8 = np.bool + import gym self._vectorized = False try: if isinstance(env, gym.vector.VectorEnv): @@ -40,80 +48,20 @@ def __init__(self, env: Any) -> None: logger.warning(f"Using a deprecated version of OpenAI Gym's API: {gym.__version__}") @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gymnasium.Space: """Observation space """ if self._vectorized: - return self._env.single_observation_space - return self._env.observation_space + return convert_gym_space(self._env.single_observation_space) + return convert_gym_space(self._env.observation_space) @property - def action_space(self) -> gym.Space: + def action_space(self) -> gymnasium.Space: """Action space """ if self._vectorized: - return self._env.single_action_space - return self._env.action_space - - def _observation_to_tensor(self, observation: Any, space: Optional[gym.Space] = None) -> np.ndarray: - """Convert the OpenAI Gym observation to a flat tensor - - :param observation: The OpenAI Gym observation to convert to a tensor - :type observation: Any supported OpenAI Gym observation space - - :raises: ValueError if the observation space type is not supported - - :return: The observation as a flat tensor - :rtype: np.ndarray - """ - observation_space = self._env.observation_space if self._vectorized else self.observation_space - space = space if space is not None else observation_space - - if self._vectorized and isinstance(space, gym.spaces.MultiDiscrete): - return observation.reshape(self.num_envs, -1).astype(np.int32) - elif isinstance(observation, int): - return np.array(observation, dtype=np.int32).reshape(self.num_envs, -1) - elif isinstance(observation, np.ndarray): - return observation.reshape(self.num_envs, -1).astype(np.float32) - elif isinstance(space, gym.spaces.Discrete): - return np.array(observation, dtype=np.float32).reshape(self.num_envs, -1) - elif isinstance(space, gym.spaces.Box): - return observation.reshape(self.num_envs, -1).astype(np.float32) - elif isinstance(space, gym.spaces.Dict): - tmp = np.concatenate([self._observation_to_tensor(observation[k], space[k]) \ - for k in sorted(space.keys())], axis=-1).reshape(self.num_envs, -1) - return tmp - else: - raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue") - - def _tensor_to_action(self, actions: np.ndarray) -> Any: - """Convert the action to the OpenAI Gym expected format - - :param actions: The actions to perform - :type actions: np.ndarray - - :raise ValueError: If the action space type is not supported - - :return: The action in the OpenAI Gym format - :rtype: Any supported OpenAI Gym action space - """ - space = self._env.action_space if self._vectorized else self.action_space - - if self._vectorized: - if isinstance(space, gym.spaces.MultiDiscrete): - return actions.astype(space.dtype).reshape(space.shape) - elif isinstance(space, gym.spaces.Tuple): - if isinstance(space[0], gym.spaces.Box): - return actions.astype(space[0].dtype).reshape(space.shape) - elif isinstance(space[0], gym.spaces.Discrete): - return actions.astype(space[0].dtype).reshape(-1) - if isinstance(space, gym.spaces.Discrete): - return actions.item() - elif isinstance(space, gym.spaces.MultiDiscrete): - return actions.astype(space.dtype).reshape(space.shape) - elif isinstance(space, gym.spaces.Box): - return actions.astype(space.dtype).reshape(space.shape) - raise ValueError(f"Action space type {type(space)} not supported. Please report this issue") + return convert_gym_space(self._env.single_action_space) + return convert_gym_space(self._env.action_space) def step(self, actions: Union[np.ndarray, jax.Array]) -> \ Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], @@ -128,8 +76,12 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ """ if self._jax or isinstance(actions, jax.Array): actions = np.asarray(jax.device_get(actions)) + actions = untensorize_space(self.action_space, + unflatten_tensorized_space(self.action_space, actions), + squeeze_batch_dimension=not self._vectorized) + if self._deprecated_api: - observation, reward, terminated, info = self._env.step(self._tensor_to_action(actions)) + observation, reward, terminated, info = self._env.step(actions) # truncated: https://gymnasium.farama.org/tutorials/handling_time_limits if type(info) is list: truncated = np.array([d.get("TimeLimit.truncated", False) for d in info], dtype=terminated.dtype) @@ -139,10 +91,10 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ if truncated: terminated = False else: - observation, reward, terminated, truncated, info = self._env.step(self._tensor_to_action(actions)) + observation, reward, terminated, truncated, info = self._env.step(actions) # convert response to numpy or jax - observation = self._observation_to_tensor(observation) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) reward = np.array(reward, dtype=np.float32).reshape(self.num_envs, -1) terminated = np.array(terminated, dtype=np.int8).reshape(self.num_envs, -1) truncated = np.array(truncated, dtype=np.int8).reshape(self.num_envs, -1) @@ -173,7 +125,7 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: self._info = {} else: observation, self._info = self._env.reset() - self._observation = self._observation_to_tensor(observation) + self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) if self._jax: self._observation = jax.device_put(self._observation, device=self.device) self._reset_once = False @@ -186,7 +138,7 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: observation, info = self._env.reset() # convert response to numpy or jax - observation = self._observation_to_tensor(observation) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) if self._jax: observation = jax.device_put(observation, device=self.device) return observation, info diff --git a/skrl/envs/wrappers/jax/gymnasium_envs.py b/skrl/envs/wrappers/jax/gymnasium_envs.py index 458ff792..b8edd4fe 100644 --- a/skrl/envs/wrappers/jax/gymnasium_envs.py +++ b/skrl/envs/wrappers/jax/gymnasium_envs.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple, Union +from typing import Any, Tuple, Union import gymnasium @@ -7,6 +7,12 @@ from skrl import logger from skrl.envs.wrappers.jax.base import Wrapper +from skrl.utils.spaces.jax import ( + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) class GymnasiumWrapper(Wrapper): @@ -44,66 +50,6 @@ def action_space(self) -> gymnasium.Space: return self._env.single_action_space return self._env.action_space - def _observation_to_tensor(self, observation: Any, space: Optional[gymnasium.Space] = None) -> np.ndarray: - """Convert the Gymnasium observation to a flat tensor - - :param observation: The Gymnasium observation to convert to a tensor - :type observation: Any supported Gymnasium observation space - - :raises: ValueError if the observation space type is not supported - - :return: The observation as a flat tensor - :rtype: np.ndarray - """ - observation_space = self._env.observation_space if self._vectorized else self.observation_space - space = space if space is not None else observation_space - - if self._vectorized and isinstance(space, gymnasium.spaces.MultiDiscrete): - return observation.reshape(self.num_envs, -1).astype(np.int32) - elif isinstance(observation, int): - return np.array(observation, dtype=np.int32).reshape(self.num_envs, -1) - elif isinstance(observation, np.ndarray): - return observation.reshape(self.num_envs, -1).astype(np.float32) - elif isinstance(space, gymnasium.spaces.Discrete): - return np.array(observation, dtype=np.float32).reshape(self.num_envs, -1) - elif isinstance(space, gymnasium.spaces.Box): - return observation.reshape(self.num_envs, -1).astype(np.float32) - elif isinstance(space, gymnasium.spaces.Dict): - tmp = np.concatenate([self._observation_to_tensor(observation[k], space[k]) \ - for k in sorted(space.keys())], axis=-1).reshape(self.num_envs, -1) - return tmp - else: - raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue") - - def _tensor_to_action(self, actions: np.ndarray) -> Any: - """Convert the action to the Gymnasium expected format - - :param actions: The actions to perform - :type actions: np.ndarray - - :raise ValueError: If the action space type is not supported - - :return: The action in the Gymnasium format - :rtype: Any supported Gymnasium action space - """ - space = self._env.action_space if self._vectorized else self.action_space - - if self._vectorized: - if isinstance(space, gymnasium.spaces.MultiDiscrete): - return actions.astype(space.dtype).reshape(space.shape) - elif isinstance(space, gymnasium.spaces.Tuple): - if isinstance(space[0], gymnasium.spaces.Box): - return actions.astype(space[0].dtype).reshape(space.shape) - elif isinstance(space[0], gymnasium.spaces.Discrete): - return actions.astype(space[0].dtype).reshape(-1) - if isinstance(space, gymnasium.spaces.Discrete): - return actions.item() - elif isinstance(space, gymnasium.spaces.MultiDiscrete): - return actions.astype(space.dtype).reshape(space.shape) - elif isinstance(space, gymnasium.spaces.Box): - return actions.astype(space.dtype).reshape(space.shape) - raise ValueError(f"Action space type {type(space)} not supported. Please report this issue") - def step(self, actions: Union[np.ndarray, jax.Array]) -> \ Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: @@ -117,10 +63,14 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ """ if self._jax or isinstance(actions, jax.Array): actions = np.asarray(jax.device_get(actions)) - observation, reward, terminated, truncated, info = self._env.step(self._tensor_to_action(actions)) + actions = untensorize_space(self.action_space, + unflatten_tensorized_space(self.action_space, actions), + squeeze_batch_dimension=not self._vectorized) + + observation, reward, terminated, truncated, info = self._env.step(actions) # convert response to numpy or jax - observation = self._observation_to_tensor(observation) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) reward = np.array(reward, dtype=np.float32).reshape(self.num_envs, -1) terminated = np.array(terminated, dtype=np.int8).reshape(self.num_envs, -1) truncated = np.array(truncated, dtype=np.int8).reshape(self.num_envs, -1) @@ -147,7 +97,7 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: if self._vectorized: if self._reset_once: observation, self._info = self._env.reset() - self._observation = self._observation_to_tensor(observation) + self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) if self._jax: self._observation = jax.device_put(self._observation, device=self.device) self._reset_once = False @@ -156,7 +106,7 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: observation, info = self._env.reset() # convert response to numpy or jax - observation = self._observation_to_tensor(observation) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) if self._jax: observation = jax.device_put(observation, device=self.device) return observation, info diff --git a/skrl/envs/wrappers/jax/isaacgym_envs.py b/skrl/envs/wrappers/jax/isaacgym_envs.py index 584a1a8d..5136de8a 100644 --- a/skrl/envs/wrappers/jax/isaacgym_envs.py +++ b/skrl/envs/wrappers/jax/isaacgym_envs.py @@ -1,6 +1,6 @@ from typing import Any, Tuple, Union -import gym +import gymnasium import jax import jax.dlpack as jax_dlpack @@ -12,6 +12,13 @@ import torch.utils.dlpack as torch_dlpack except: pass # TODO: show warning message +else: + from skrl.utils.spaces.torch import ( + convert_gym_space, + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space + ) from skrl import logger from skrl.envs.wrappers.jax.base import Wrapper @@ -47,6 +54,18 @@ def __init__(self, env: Any) -> None: self._observations = None self._info = {} + @property + def observation_space(self) -> gymnasium.Space: + """Observation space + """ + return convert_gym_space(self._unwrapped.observation_space) + + @property + def action_space(self) -> gymnasium.Space: + """Action space + """ + return convert_gym_space(self._unwrapped.action_space) + def step(self, actions: Union[np.ndarray, jax.Array]) -> \ Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: @@ -61,12 +80,14 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ actions = _jax2torch(actions, self._env.device, self._jax) with torch.no_grad(): - self._observations, reward, terminated, self._info = self._env.step(actions) + observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations)) terminated = terminated.to(dtype=torch.int8) truncated = self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) - return _torch2jax(self._observations, self._jax), \ + self._observations = _torch2jax(observations, self._jax) + return self._observations, \ _torch2jax(reward.view(-1, 1), self._jax), \ _torch2jax(terminated.view(-1, 1), self._jax), \ _torch2jax(truncated.view(-1, 1), self._jax), \ @@ -79,9 +100,11 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: :rtype: np.ndarray or jax.Array and any other info """ if self._reset_once: - self._observations = self._env.reset() + observations = self._env.reset() + observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations)) + self._observations = _torch2jax(observations, self._jax) self._reset_once = False - return _torch2jax(self._observations, self._jax), self._info + return self._observations, self._info def render(self, *args, **kwargs) -> None: """Render the environment @@ -108,12 +131,24 @@ def __init__(self, env: Any) -> None: self._info = {} @property - def state_space(self) -> Union[gym.Space, None]: + def observation_space(self) -> gymnasium.Space: + """Observation space + """ + return convert_gym_space(self._unwrapped.observation_space) + + @property + def action_space(self) -> gymnasium.Space: + """Action space + """ + return convert_gym_space(self._unwrapped.action_space) + + @property + def state_space(self) -> Union[gymnasium.Space, None]: """State space """ try: if self.num_states: - return self._unwrapped.state_space + return convert_gym_space(self._unwrapped.state_space) except: pass return None @@ -132,12 +167,14 @@ def step(self, actions: Union[np.ndarray, jax.Array]) ->\ actions = _jax2torch(actions, self._env.device, self._jax) with torch.no_grad(): - self._observations, reward, terminated, self._info = self._env.step(actions) + observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) terminated = terminated.to(dtype=torch.int8) truncated = self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) - return _torch2jax(self._observations["obs"], self._jax), \ + self._observations = _torch2jax(observations, self._jax) + return self._observations, \ _torch2jax(reward.view(-1, 1), self._jax), \ _torch2jax(terminated.view(-1, 1), self._jax), \ _torch2jax(truncated.view(-1, 1), self._jax), \ @@ -150,9 +187,11 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: :rtype: np.ndarray or jax.Array and any other info """ if self._reset_once: - self._observations = self._env.reset() + observations = self._env.reset() + observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) + self._observations = _torch2jax(observations, self._jax) self._reset_once = False - return _torch2jax(self._observations["obs"], self._jax), self._info + return self._observations, self._info def render(self, *args, **kwargs) -> None: """Render the environment diff --git a/skrl/envs/wrappers/jax/isaaclab_envs.py b/skrl/envs/wrappers/jax/isaaclab_envs.py index f02dbc2e..6ce151e7 100644 --- a/skrl/envs/wrappers/jax/isaaclab_envs.py +++ b/skrl/envs/wrappers/jax/isaaclab_envs.py @@ -12,6 +12,8 @@ import torch.utils.dlpack as torch_dlpack except: pass # TODO: show warning message +else: + from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space from skrl import logger from skrl.envs.wrappers.jax.base import MultiAgentEnvWrapper, Wrapper @@ -91,14 +93,17 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ :rtype: tuple of np.ndarray or jax.Array and any other info """ actions = _jax2torch(actions, self._env_device, self._jax) + actions = unflatten_tensorized_space(self.action_space, actions) with torch.no_grad(): - self._observations, reward, terminated, truncated, self._info = self._env.step(actions) + observations, reward, terminated, truncated, self._info = self._env.step(actions) + observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["policy"])) terminated = terminated.to(dtype=torch.int8) truncated = truncated.to(dtype=torch.int8) - return _torch2jax(self._observations["policy"], self._jax), \ + self._observations = _torch2jax(observations, self._jax) + return self._observations, \ _torch2jax(reward.view(-1, 1), self._jax), \ _torch2jax(terminated.view(-1, 1), self._jax), \ _torch2jax(truncated.view(-1, 1), self._jax), \ @@ -111,9 +116,11 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: :rtype: np.ndarray or jax.Array and any other info """ if self._reset_once: - self._observations, self._info = self._env.reset() + observations, self._info = self._env.reset() + observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["policy"])) + self._observations = _torch2jax(observations, self._jax) self._reset_once = False - return _torch2jax(self._observations["policy"], self._jax), self._info + return self._observations, self._info def render(self, *args, **kwargs) -> None: """Render the environment @@ -152,9 +159,11 @@ def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ :rtype: tuple of dictionaries of np.ndarray or jax.Array and any other info """ actions = {uid: _jax2torch(value, self._env_device, self._jax) for uid, value in actions.items()} + actions = {k: unflatten_tensorized_space(self.action_spaces[k], v) for k, v in actions.items()} with torch.no_grad(): observations, rewards, terminated, truncated, self._info = self._env.step(actions) + observations = {k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items()} self._observations = {uid: _torch2jax(value, self._jax) for uid, value in observations.items()} return self._observations, \ @@ -171,6 +180,7 @@ def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str """ if self._reset_once: observations, self._info = self._env.reset() + observations = {k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items()} self._observations = {uid: _torch2jax(value, self._jax) for uid, value in observations.items()} self._reset_once = False return self._observations, self._info @@ -182,7 +192,10 @@ def state(self) -> Union[np.ndarray, jax.Array, None]: :rtype: np.ndarray, jax.Array or None """ state = self._env.state() - return None if state is None else _torch2jax(state, self._jax) + if state is not None: + state = flatten_tensorized_space(tensorize_space(next(iter(self.state_spaces.values())), state)) + return _torch2jax(state, self._jax) + return state def render(self, *args, **kwargs) -> None: """Render the environment diff --git a/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py b/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py index e7f4feed..1ecde802 100644 --- a/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py +++ b/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py @@ -10,6 +10,8 @@ import torch.utils.dlpack as torch_dlpack except: pass # TODO: show warning message +else: + from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space from skrl import logger from skrl.envs.wrappers.jax.base import Wrapper @@ -44,6 +46,7 @@ def __init__(self, env: Any) -> None: self._env_device = torch.device(self._unwrapped.device) self._reset_once = True self._observations = None + self._info = {} def run(self, trainer: Optional["omni.isaac.gym.vec_env.vec_env_mt.TrainerMT"] = None) -> None: """Run the simulation in the main thread @@ -69,16 +72,18 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ actions = _jax2torch(actions, self._env_device, self._jax) with torch.no_grad(): - self._observations, reward, terminated, info = self._env.step(actions) + observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) terminated = terminated.to(dtype=torch.int8) - truncated = info["time_outs"].to(dtype=torch.int8) if "time_outs" in info else torch.zeros_like(terminated) + truncated = self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) - return _torch2jax(self._observations["obs"], self._jax), \ + self._observations = _torch2jax(observations, self._jax) + return self._observations, \ _torch2jax(reward.view(-1, 1), self._jax), \ _torch2jax(terminated.view(-1, 1), self._jax), \ _torch2jax(truncated.view(-1, 1), self._jax), \ - info + self._info def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: """Reset the environment @@ -87,9 +92,11 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: :rtype: np.ndarray or jax.Array and any other info """ if self._reset_once: - self._observations = self._env.reset() + observations = self._env.reset() + observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) + self._observations = _torch2jax(observations, self._jax) self._reset_once = False - return _torch2jax(self._observations["obs"], self._jax), {} + return self._observations, self._info def render(self, *args, **kwargs) -> None: """Render the environment diff --git a/skrl/envs/wrappers/jax/pettingzoo_envs.py b/skrl/envs/wrappers/jax/pettingzoo_envs.py index 1e524c9a..180e0209 100644 --- a/skrl/envs/wrappers/jax/pettingzoo_envs.py +++ b/skrl/envs/wrappers/jax/pettingzoo_envs.py @@ -1,12 +1,17 @@ from typing import Any, Mapping, Tuple, Union import collections -import gymnasium import jax import numpy as np from skrl.envs.wrappers.jax.base import MultiAgentEnvWrapper +from skrl.utils.spaces.jax import ( + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) class PettingZooWrapper(MultiAgentEnvWrapper): @@ -18,49 +23,6 @@ def __init__(self, env: Any) -> None: """ super().__init__(env) - def _observation_to_tensor(self, observation: Any, space: gymnasium.Space) -> np.ndarray: - """Convert the Gymnasium observation to a flat tensor - - :param observation: The Gymnasium observation to convert to a tensor - :type observation: Any supported Gymnasium observation space - - :raises: ValueError if the observation space type is not supported - - :return: The observation as a flat tensor - :rtype: np.ndarray - """ - if isinstance(observation, int): - return np.array(observation, dtype=np.int32).view(self.num_envs, -1) - elif isinstance(observation, np.ndarray): - return observation.reshape(self.num_envs, -1).astype(np.float32) - elif isinstance(space, gymnasium.spaces.Discrete): - return np.array(observation, dtype=np.float32).reshape(self.num_envs, -1) - elif isinstance(space, gymnasium.spaces.Box): - return observation.reshape(self.num_envs, -1).astype(np.float32) - elif isinstance(space, gymnasium.spaces.Dict): - tmp = np.concatenate([self._observation_to_tensor(observation[k], space[k]) \ - for k in sorted(space.keys())], axis=-1).view(self.num_envs, -1) - return tmp - else: - raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue") - - def _tensor_to_action(self, actions: np.ndarray, space: gymnasium.Space) -> Any: - """Convert the action to the Gymnasium expected format - - :param actions: The actions to perform - :type actions: np.ndarray - - :raise ValueError: If the action space type is not supported - - :return: The action in the Gymnasium format - :rtype: Any supported Gymnasium action space - """ - if isinstance(space, gymnasium.spaces.Discrete): - return actions.item() - elif isinstance(space, gymnasium.spaces.Box): - return actions.astype(space.dtype).reshape(space.shape) - raise ValueError(f"Action space type {type(space)} not supported. Please report this issue") - def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], @@ -75,11 +37,11 @@ def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ """ if self._jax: actions = jax.device_get(actions) - actions = {uid: self._tensor_to_action(action, self.action_space(uid)) for uid, action in actions.items()} + actions = {uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) for uid, action in actions.items()} observations, rewards, terminated, truncated, infos = self._env.step(actions) # convert response to numpy or jax - observations = {uid: self._observation_to_tensor(value, self.observation_space(uid)) for uid, value in observations.items()} + observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, self.device, False), False) for uid, value in observations.items()} rewards = {uid: np.array(value, dtype=np.float32).reshape(self.num_envs, -1) for uid, value in rewards.items()} terminated = {uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in terminated.items()} truncated = {uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in truncated.items()} @@ -96,7 +58,7 @@ def state(self) -> Union[np.ndarray, jax.Array]: :return: State :rtype: np.ndarray or jax.Array """ - state = self._observation_to_tensor(self._env.state(), next(iter(self.state_spaces.values()))) + state = flatten_tensorized_space(tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), self.device, False), False) if self._jax: state = jax.device_put(state, device=self.device) return state @@ -115,7 +77,7 @@ def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str observations, infos = outputs # convert response to numpy or jax - observations = {uid: self._observation_to_tensor(value, self.observation_space(uid)) for uid, value in observations.items()} + observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, self.device, False), False) for uid, value in observations.items()} if self._jax: observations = {uid: jax.device_put(value, device=self.device) for uid, value in observations.items()} return observations, infos diff --git a/skrl/envs/wrappers/torch/__init__.py b/skrl/envs/wrappers/torch/__init__.py index ce78eb87..4983ddf2 100644 --- a/skrl/envs/wrappers/torch/__init__.py +++ b/skrl/envs/wrappers/torch/__init__.py @@ -30,7 +30,7 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra >>> env = wrap_env(env) :param env: The environment to be wrapped - :type env: gym.Env, gymnasium.Env, dm_env.Environment or VecTask + :type env: Any :param wrapper: The type of wrapper to use (default: ``"auto"``). If ``"auto"``, the wrapper will be automatically selected based on the environment class. The supported wrappers are described in the following table: diff --git a/skrl/envs/wrappers/torch/base.py b/skrl/envs/wrappers/torch/base.py index 1ebe9923..7b6893e7 100644 --- a/skrl/envs/wrappers/torch/base.py +++ b/skrl/envs/wrappers/torch/base.py @@ -1,6 +1,6 @@ from typing import Any, Mapping, Sequence, Tuple, Union -import gym +import gymnasium import torch @@ -117,7 +117,7 @@ def num_agents(self) -> int: return self._unwrapped.num_agents if hasattr(self._unwrapped, "num_agents") else 1 @property - def state_space(self) -> Union[gym.Space, None]: + def state_space(self) -> Union[gymnasium.Space, None]: """State space If the wrapped environment does not have the ``state_space`` property, ``None`` will be returned @@ -125,13 +125,13 @@ def state_space(self) -> Union[gym.Space, None]: return self._unwrapped.state_space if hasattr(self._unwrapped, "state_space") else None @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gymnasium.Space: """Observation space """ return self._unwrapped.observation_space @property - def action_space(self) -> gym.Space: + def action_space(self) -> gymnasium.Space: """Action space """ return self._unwrapped.action_space @@ -281,7 +281,7 @@ def possible_agents(self) -> Sequence[str]: return self._unwrapped.possible_agents @property - def state_spaces(self) -> Mapping[str, gym.Space]: + def state_spaces(self) -> Mapping[str, gymnasium.Space]: """State spaces Since the state space is a global view of the environment (and therefore the same for all the agents), @@ -292,18 +292,18 @@ def state_spaces(self) -> Mapping[str, gym.Space]: return {agent: space for agent in self.possible_agents} @property - def observation_spaces(self) -> Mapping[str, gym.Space]: + def observation_spaces(self) -> Mapping[str, gymnasium.Space]: """Observation spaces """ return self._unwrapped.observation_spaces @property - def action_spaces(self) -> Mapping[str, gym.Space]: + def action_spaces(self) -> Mapping[str, gymnasium.Space]: """Action spaces """ return self._unwrapped.action_spaces - def state_space(self, agent: str) -> gym.Space: + def state_space(self, agent: str) -> gymnasium.Space: """State space Since the state space is a global view of the environment (and therefore the same for all the agents), @@ -314,28 +314,28 @@ def state_space(self, agent: str) -> gym.Space: :type agent: str :return: The state space for the specified agent - :rtype: gym.Space + :rtype: gymnasium.Space """ return self.state_spaces[agent] - def observation_space(self, agent: str) -> gym.Space: + def observation_space(self, agent: str) -> gymnasium.Space: """Observation space :param agent: Name of the agent :type agent: str :return: The observation space for the specified agent - :rtype: gym.Space + :rtype: gymnasium.Space """ return self.observation_spaces[agent] - def action_space(self, agent: str) -> gym.Space: + def action_space(self, agent: str) -> gymnasium.Space: """Action space :param agent: Name of the agent :type agent: str :return: The action space for the specified agent - :rtype: gym.Space + :rtype: gymnasium.Space """ return self.action_spaces[agent] diff --git a/skrl/envs/wrappers/torch/bidexhands_envs.py b/skrl/envs/wrappers/torch/bidexhands_envs.py index 0c5d2a98..827c181f 100644 --- a/skrl/envs/wrappers/torch/bidexhands_envs.py +++ b/skrl/envs/wrappers/torch/bidexhands_envs.py @@ -1,10 +1,11 @@ from typing import Any, Mapping, Sequence, Tuple -import gym +import gymnasium import torch from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper +from skrl.utils.spaces.torch import convert_gym_space class BiDexHandsWrapper(MultiAgentEnvWrapper): @@ -38,26 +39,26 @@ def possible_agents(self) -> Sequence[str]: return [f"agent_{i}" for i in range(self.num_agents)] @property - def state_spaces(self) -> Mapping[str, gym.Space]: + def state_spaces(self) -> Mapping[str, gymnasium.Space]: """State spaces Since the state space is a global view of the environment (and therefore the same for all the agents), this property returns a dictionary (for consistency with the other space-related properties) with the same space for all the agents """ - return {uid: space for uid, space in zip(self.possible_agents, self._env.share_observation_space)} + return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.share_observation_space)} @property - def observation_spaces(self) -> Mapping[str, gym.Space]: + def observation_spaces(self) -> Mapping[str, gymnasium.Space]: """Observation spaces """ - return {uid: space for uid, space in zip(self.possible_agents, self._env.observation_space)} + return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.observation_space)} @property - def action_spaces(self) -> Mapping[str, gym.Space]: + def action_spaces(self) -> Mapping[str, gymnasium.Space]: """Action spaces """ - return {uid: space for uid, space in zip(self.possible_agents, self._env.action_space)} + return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.action_space)} def step(self, actions: Mapping[str, torch.Tensor]) -> \ Tuple[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], diff --git a/skrl/envs/wrappers/torch/brax_envs.py b/skrl/envs/wrappers/torch/brax_envs.py index 5f50e8c5..775994f9 100644 --- a/skrl/envs/wrappers/torch/brax_envs.py +++ b/skrl/envs/wrappers/torch/brax_envs.py @@ -2,11 +2,16 @@ import gymnasium -import numpy as np import torch from skrl import logger from skrl.envs.wrappers.torch.base import Wrapper +from skrl.utils.spaces.torch import ( + convert_gym_space, + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space +) class BraxWrapper(Wrapper): @@ -29,15 +34,13 @@ def __init__(self, env: Any) -> None: def observation_space(self) -> gymnasium.Space: """Observation space """ - limit = np.inf * np.ones(self._unwrapped.observation_space.shape[1:], dtype='float32') - return gymnasium.spaces.Box(-limit, limit, dtype='float32') + return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True) @property def action_space(self) -> gymnasium.Space: """Action space """ - limit = np.inf * np.ones(self._unwrapped.action_space.shape[1:], dtype='float32') - return gymnasium.spaces.Box(-limit, limit, dtype='float32') + return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True) def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: """Perform a step in the environment @@ -48,7 +51,8 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - observation, reward, terminated, info = self._env.step(actions) + observation, reward, terminated, info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation)) truncated = torch.zeros_like(terminated) return observation, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), info @@ -59,6 +63,7 @@ def reset(self) -> Tuple[torch.Tensor, Any]: :rtype: torch.Tensor and any other info """ observation = self._env.reset() + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation)) return observation, {} def render(self, *args, **kwargs) -> None: diff --git a/skrl/envs/wrappers/torch/deepmind_envs.py b/skrl/envs/wrappers/torch/deepmind_envs.py index 4cf978e9..1777703a 100644 --- a/skrl/envs/wrappers/torch/deepmind_envs.py +++ b/skrl/envs/wrappers/torch/deepmind_envs.py @@ -1,13 +1,19 @@ -from typing import Any, Optional, Tuple +from typing import Any, Tuple import collections -import gym +import gymnasium import numpy as np import torch from skrl import logger from skrl.envs.wrappers.torch.base import Wrapper +from skrl.utils.spaces.torch import ( + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) class DeepMindWrapper(Wrapper): @@ -23,88 +29,45 @@ def __init__(self, env: Any) -> None: self._specs = specs @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gymnasium.Space: """Observation space """ return self._spec_to_space(self._env.observation_spec()) @property - def action_space(self) -> gym.Space: + def action_space(self) -> gymnasium.Space: """Action space """ return self._spec_to_space(self._env.action_spec()) - def _spec_to_space(self, spec: Any) -> gym.Space: - """Convert the DeepMind spec to a Gym space + def _spec_to_space(self, spec: Any) -> gymnasium.Space: + """Convert the DeepMind spec to a gymnasium space :param spec: The DeepMind spec to convert :type spec: Any supported DeepMind spec :raises: ValueError if the spec type is not supported - :return: The Gym space - :rtype: gym.Space + :return: The gymnasium space + :rtype: gymnasium.Space """ if isinstance(spec, self._specs.DiscreteArray): - return gym.spaces.Discrete(spec.num_values) + return gymnasium.spaces.Discrete(spec.num_values) elif isinstance(spec, self._specs.BoundedArray): - return gym.spaces.Box(shape=spec.shape, + return gymnasium.spaces.Box(shape=spec.shape, dtype=spec.dtype, low=spec.minimum if spec.minimum.ndim else np.full(spec.shape, spec.minimum), high=spec.maximum if spec.maximum.ndim else np.full(spec.shape, spec.maximum)) elif isinstance(spec, self._specs.Array): - return gym.spaces.Box(shape=spec.shape, + return gymnasium.spaces.Box(shape=spec.shape, dtype=spec.dtype, low=np.full(spec.shape, float("-inf")), high=np.full(spec.shape, float("inf"))) elif isinstance(spec, collections.OrderedDict): - return gym.spaces.Dict({k: self._spec_to_space(v) for k, v in spec.items()}) + return gymnasium.spaces.Dict({k: self._spec_to_space(v) for k, v in spec.items()}) else: raise ValueError(f"Spec type {type(spec)} not supported. Please report this issue") - def _observation_to_tensor(self, observation: Any, spec: Optional[Any] = None) -> torch.Tensor: - """Convert the DeepMind observation to a flat tensor - - :param observation: The DeepMind observation to convert to a tensor - :type observation: Any supported DeepMind observation - - :raises: ValueError if the observation spec type is not supported - - :return: The observation as a flat tensor - :rtype: torch.Tensor - """ - spec = spec if spec is not None else self._env.observation_spec() - - if isinstance(spec, self._specs.DiscreteArray): - return torch.tensor(observation, device=self.device, dtype=torch.float32).reshape(self.num_envs, -1) - elif isinstance(spec, self._specs.Array): # includes BoundedArray - return torch.tensor(observation, device=self.device, dtype=torch.float32).reshape(self.num_envs, -1) - elif isinstance(spec, collections.OrderedDict): - return torch.cat([self._observation_to_tensor(observation[k], spec[k]) \ - for k in sorted(spec.keys())], dim=-1).reshape(self.num_envs, -1) - else: - raise ValueError(f"Observation spec type {type(spec)} not supported. Please report this issue") - - def _tensor_to_action(self, actions: torch.Tensor) -> Any: - """Convert the action to the DeepMind expected format - - :param actions: The actions to perform - :type actions: torch.Tensor - - :raise ValueError: If the action space type is not supported - - :return: The action in the DeepMind expected format - :rtype: Any supported DeepMind action - """ - spec = self._env.action_spec() - - if isinstance(spec, self._specs.DiscreteArray): - return np.array(actions.item(), dtype=spec.dtype) - elif isinstance(spec, self._specs.Array): # includes BoundedArray - return np.array(actions.cpu().numpy(), dtype=spec.dtype).reshape(spec.shape) - else: - raise ValueError(f"Action spec type {type(spec)} not supported. Please report this issue") - def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: """Perform a step in the environment @@ -114,16 +77,17 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - timestep = self._env.step(self._tensor_to_action(actions)) + actions = untensorize_space(self.action_space, unflatten_tensorized_space(self.action_space, actions)) + timestep = self._env.step(actions) - observation = timestep.observation + observation = flatten_tensorized_space(tensorize_space(self.observation_space, timestep.observation, self.device)) reward = timestep.reward if timestep.reward is not None else 0 terminated = timestep.last() truncated = False info = {} # convert response to torch - return self._observation_to_tensor(observation), \ + return observation, \ torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1), \ torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), \ torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), \ @@ -136,7 +100,8 @@ def reset(self) -> Tuple[torch.Tensor, Any]: :rtype: torch.Tensor """ timestep = self._env.reset() - return self._observation_to_tensor(timestep.observation), {} + observation = flatten_tensorized_space(tensorize_space(self.observation_space, timestep.observation, self.device)) + return observation, {} def render(self, *args, **kwargs) -> np.ndarray: """Render the environment diff --git a/skrl/envs/wrappers/torch/gym_envs.py b/skrl/envs/wrappers/torch/gym_envs.py index 4e326818..a432d211 100644 --- a/skrl/envs/wrappers/torch/gym_envs.py +++ b/skrl/envs/wrappers/torch/gym_envs.py @@ -1,6 +1,6 @@ -from typing import Any, Optional, Tuple +from typing import Any, Tuple -import gym +import gymnasium from packaging import version import numpy as np @@ -8,6 +8,13 @@ from skrl import logger from skrl.envs.wrappers.torch.base import Wrapper +from skrl.utils.spaces.torch import ( + convert_gym_space, + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) class GymWrapper(Wrapper): @@ -25,6 +32,7 @@ def __init__(self, env: Any) -> None: except AttributeError: np.bool8 = np.bool + import gym self._vectorized = False try: if isinstance(env, gym.vector.VectorEnv): @@ -40,80 +48,20 @@ def __init__(self, env: Any) -> None: logger.warning(f"Using a deprecated version of OpenAI Gym's API: {gym.__version__}") @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gymnasium.Space: """Observation space """ if self._vectorized: - return self._env.single_observation_space - return self._env.observation_space + return convert_gym_space(self._env.single_observation_space) + return convert_gym_space(self._env.observation_space) @property - def action_space(self) -> gym.Space: + def action_space(self) -> gymnasium.Space: """Action space """ if self._vectorized: - return self._env.single_action_space - return self._env.action_space - - def _observation_to_tensor(self, observation: Any, space: Optional[gym.Space] = None) -> torch.Tensor: - """Convert the OpenAI Gym observation to a flat tensor - - :param observation: The OpenAI Gym observation to convert to a tensor - :type observation: Any supported OpenAI Gym observation space - - :raises: ValueError if the observation space type is not supported - - :return: The observation as a flat tensor - :rtype: torch.Tensor - """ - observation_space = self._env.observation_space if self._vectorized else self.observation_space - space = space if space is not None else observation_space - - if self._vectorized and isinstance(space, gym.spaces.MultiDiscrete): - return torch.tensor(observation, device=self.device, dtype=torch.int64).view(self.num_envs, -1) - elif isinstance(observation, int): - return torch.tensor(observation, device=self.device, dtype=torch.int64).view(self.num_envs, -1) - elif isinstance(observation, np.ndarray): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gym.spaces.Discrete): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gym.spaces.Box): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gym.spaces.Dict): - tmp = torch.cat([self._observation_to_tensor(observation[k], space[k]) \ - for k in sorted(space.keys())], dim=-1).view(self.num_envs, -1) - return tmp - else: - raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue") - - def _tensor_to_action(self, actions: torch.Tensor) -> Any: - """Convert the action to the OpenAI Gym expected format - - :param actions: The actions to perform - :type actions: torch.Tensor - - :raise ValueError: If the action space type is not supported - - :return: The action in the OpenAI Gym format - :rtype: Any supported OpenAI Gym action space - """ - space = self._env.action_space if self._vectorized else self.action_space - - if self._vectorized: - if isinstance(space, gym.spaces.MultiDiscrete): - return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) - elif isinstance(space, gym.spaces.Tuple): - if isinstance(space[0], gym.spaces.Box): - return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(space.shape) - elif isinstance(space[0], gym.spaces.Discrete): - return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(-1) - if isinstance(space, gym.spaces.Discrete): - return actions.item() - elif isinstance(space, gym.spaces.MultiDiscrete): - return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) - elif isinstance(space, gym.spaces.Box): - return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) - raise ValueError(f"Action space type {type(space)} not supported. Please report this issue") + return convert_gym_space(self._env.single_action_space) + return convert_gym_space(self._env.action_space) def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: """Perform a step in the environment @@ -124,8 +72,12 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ + actions = untensorize_space(self.action_space, + unflatten_tensorized_space(self.action_space, actions), + squeeze_batch_dimension=not self._vectorized) + if self._deprecated_api: - observation, reward, terminated, info = self._env.step(self._tensor_to_action(actions)) + observation, reward, terminated, info = self._env.step(actions) # truncated: https://gymnasium.farama.org/tutorials/handling_time_limits if type(info) is list: truncated = np.array([d.get("TimeLimit.truncated", False) for d in info], dtype=terminated.dtype) @@ -135,10 +87,10 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch if truncated: terminated = False else: - observation, reward, terminated, truncated, info = self._env.step(self._tensor_to_action(actions)) + observation, reward, terminated, truncated, info = self._env.step(actions) # convert response to torch - observation = self._observation_to_tensor(observation) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) reward = torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1) terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) @@ -164,7 +116,7 @@ def reset(self) -> Tuple[torch.Tensor, Any]: self._info = {} else: observation, self._info = self._env.reset() - self._observation = self._observation_to_tensor(observation) + self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) self._reset_once = False return self._observation, self._info @@ -173,7 +125,8 @@ def reset(self) -> Tuple[torch.Tensor, Any]: info = {} else: observation, info = self._env.reset() - return self._observation_to_tensor(observation), info + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) + return observation, info def render(self, *args, **kwargs) -> Any: """Render the environment diff --git a/skrl/envs/wrappers/torch/gymnasium_envs.py b/skrl/envs/wrappers/torch/gymnasium_envs.py index e33ca603..1de58993 100644 --- a/skrl/envs/wrappers/torch/gymnasium_envs.py +++ b/skrl/envs/wrappers/torch/gymnasium_envs.py @@ -1,12 +1,17 @@ -from typing import Any, Optional, Tuple +from typing import Any, Tuple import gymnasium -import numpy as np import torch from skrl import logger from skrl.envs.wrappers.torch.base import Wrapper +from skrl.utils.spaces.torch import ( + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) class GymnasiumWrapper(Wrapper): @@ -44,66 +49,6 @@ def action_space(self) -> gymnasium.Space: return self._env.single_action_space return self._env.action_space - def _observation_to_tensor(self, observation: Any, space: Optional[gymnasium.Space] = None) -> torch.Tensor: - """Convert the Gymnasium observation to a flat tensor - - :param observation: The Gymnasium observation to convert to a tensor - :type observation: Any supported Gymnasium observation space - - :raises: ValueError if the observation space type is not supported - - :return: The observation as a flat tensor - :rtype: torch.Tensor - """ - observation_space = self._env.observation_space if self._vectorized else self.observation_space - space = space if space is not None else observation_space - - if self._vectorized and isinstance(space, gymnasium.spaces.MultiDiscrete): - return torch.tensor(observation, device=self.device, dtype=torch.int64).view(self.num_envs, -1) - elif isinstance(observation, int): - return torch.tensor(observation, device=self.device, dtype=torch.int64).view(self.num_envs, -1) - elif isinstance(observation, np.ndarray): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gymnasium.spaces.Discrete): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gymnasium.spaces.Box): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gymnasium.spaces.Dict): - tmp = torch.cat([self._observation_to_tensor(observation[k], space[k]) \ - for k in sorted(space.keys())], dim=-1).view(self.num_envs, -1) - return tmp - else: - raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue") - - def _tensor_to_action(self, actions: torch.Tensor) -> Any: - """Convert the action to the Gymnasium expected format - - :param actions: The actions to perform - :type actions: torch.Tensor - - :raise ValueError: If the action space type is not supported - - :return: The action in the Gymnasium format - :rtype: Any supported Gymnasium action space - """ - space = self._env.action_space if self._vectorized else self.action_space - - if self._vectorized: - if isinstance(space, gymnasium.spaces.MultiDiscrete): - return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) - elif isinstance(space, gymnasium.spaces.Tuple): - if isinstance(space[0], gymnasium.spaces.Box): - return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(space.shape) - elif isinstance(space[0], gymnasium.spaces.Discrete): - return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(-1) - if isinstance(space, gymnasium.spaces.Discrete): - return actions.item() - elif isinstance(space, gymnasium.spaces.MultiDiscrete): - return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) - elif isinstance(space, gymnasium.spaces.Box): - return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) - raise ValueError(f"Action space type {type(space)} not supported. Please report this issue") - def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: """Perform a step in the environment @@ -113,10 +58,14 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - observation, reward, terminated, truncated, info = self._env.step(self._tensor_to_action(actions)) + actions = untensorize_space(self.action_space, + unflatten_tensorized_space(self.action_space, actions), + squeeze_batch_dimension=not self._vectorized) + + observation, reward, terminated, truncated, info = self._env.step(actions) # convert response to torch - observation = self._observation_to_tensor(observation) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) reward = torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1) terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) @@ -138,12 +87,13 @@ def reset(self) -> Tuple[torch.Tensor, Any]: if self._vectorized: if self._reset_once: observation, self._info = self._env.reset() - self._observation = self._observation_to_tensor(observation) + self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) self._reset_once = False return self._observation, self._info observation, info = self._env.reset() - return self._observation_to_tensor(observation), info + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) + return observation, info def render(self, *args, **kwargs) -> Any: """Render the environment diff --git a/skrl/envs/wrappers/torch/isaacgym_envs.py b/skrl/envs/wrappers/torch/isaacgym_envs.py index 21411a85..0d5a3cd2 100644 --- a/skrl/envs/wrappers/torch/isaacgym_envs.py +++ b/skrl/envs/wrappers/torch/isaacgym_envs.py @@ -1,10 +1,16 @@ from typing import Any, Tuple, Union -import gym +import gymnasium import torch from skrl.envs.wrappers.torch.base import Wrapper +from skrl.utils.spaces.torch import ( + convert_gym_space, + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space +) class IsaacGymPreview2Wrapper(Wrapper): @@ -20,6 +26,18 @@ def __init__(self, env: Any) -> None: self._observations = None self._info = {} + @property + def observation_space(self) -> gymnasium.Space: + """Observation space + """ + return convert_gym_space(self._unwrapped.observation_space) + + @property + def action_space(self) -> gymnasium.Space: + """Action space + """ + return convert_gym_space(self._unwrapped.action_space) + def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: """Perform a step in the environment @@ -29,7 +47,8 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - self._observations, reward, terminated, self._info = self._env.step(actions) + observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations)) truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info @@ -40,7 +59,8 @@ def reset(self) -> Tuple[torch.Tensor, Any]: :rtype: torch.Tensor and any other info """ if self._reset_once: - self._observations = self._env.reset() + observations = self._env.reset() + self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations)) self._reset_once = False return self._observations, self._info @@ -69,12 +89,24 @@ def __init__(self, env: Any) -> None: self._info = {} @property - def state_space(self) -> Union[gym.Space, None]: + def observation_space(self) -> gymnasium.Space: + """Observation space + """ + return convert_gym_space(self._unwrapped.observation_space) + + @property + def action_space(self) -> gymnasium.Space: + """Action space + """ + return convert_gym_space(self._unwrapped.action_space) + + @property + def state_space(self) -> Union[gymnasium.Space, None]: """State space """ try: if self.num_states: - return self._unwrapped.state_space + return convert_gym_space(self._unwrapped.state_space) except: pass return None @@ -88,9 +120,10 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - self._observations, reward, terminated, self._info = self._env.step(actions) + observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) - return self._observations["obs"], reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info + return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info def reset(self) -> Tuple[torch.Tensor, Any]: """Reset the environment @@ -99,9 +132,10 @@ def reset(self) -> Tuple[torch.Tensor, Any]: :rtype: torch.Tensor and any other info """ if self._reset_once: - self._observations = self._env.reset() + observations = self._env.reset() + self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) self._reset_once = False - return self._observations["obs"], self._info + return self._observations, self._info def render(self, *args, **kwargs) -> None: """Render the environment diff --git a/skrl/envs/wrappers/torch/isaaclab_envs.py b/skrl/envs/wrappers/torch/isaaclab_envs.py index fdb9808a..c3ed2589 100644 --- a/skrl/envs/wrappers/torch/isaaclab_envs.py +++ b/skrl/envs/wrappers/torch/isaaclab_envs.py @@ -5,6 +5,7 @@ import torch from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper, Wrapper +from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space class IsaacLabWrapper(Wrapper): @@ -60,8 +61,10 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - self._observations, reward, terminated, truncated, self._info = self._env.step(actions) - return self._observations["policy"], reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info + actions = unflatten_tensorized_space(self.action_space, actions) + observations, reward, terminated, truncated, self._info = self._env.step(actions) + self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["policy"])) + return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info def reset(self) -> Tuple[torch.Tensor, Any]: """Reset the environment @@ -70,9 +73,10 @@ def reset(self) -> Tuple[torch.Tensor, Any]: :rtype: torch.Tensor and any other info """ if self._reset_once: - self._observations, self._info = self._env.reset() + observations, self._info = self._env.reset() + self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["policy"])) self._reset_once = False - return self._observations["policy"], self._info + return self._observations, self._info def render(self, *args, **kwargs) -> None: """Render the environment @@ -109,7 +113,9 @@ def step(self, actions: Mapping[str, torch.Tensor]) -> \ :return: Observation, reward, terminated, truncated, info :rtype: tuple of dictionaries torch.Tensor and any other info """ - self._observations, rewards, terminated, truncated, self._info = self._env.step(actions) + actions = {k: unflatten_tensorized_space(self.action_spaces[k], v) for k, v in actions.items()} + observations, rewards, terminated, truncated, self._info = self._env.step(actions) + self._observations = {k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items()} return self._observations, \ {k: v.view(-1, 1) for k, v in rewards.items()}, \ {k: v.view(-1, 1) for k, v in terminated.items()}, \ @@ -123,7 +129,8 @@ def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: :rtype: torch.Tensor and any other info """ if self._reset_once: - self._observations, self._info = self._env.reset() + observations, self._info = self._env.reset() + self._observations = {k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items()} self._reset_once = False return self._observations, self._info @@ -133,7 +140,10 @@ def state(self) -> torch.Tensor: :return: State :rtype: torch.Tensor """ - return self._env.state() + state = self._env.state() + if state is not None: + return flatten_tensorized_space(tensorize_space(next(iter(self.state_spaces.values())), state)) + return state def render(self, *args, **kwargs) -> None: """Render the environment diff --git a/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py b/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py index c96a733c..88f11a18 100644 --- a/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py +++ b/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py @@ -3,6 +3,7 @@ import torch from skrl.envs.wrappers.torch.base import Wrapper +from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space class OmniverseIsaacGymWrapper(Wrapper): @@ -16,6 +17,7 @@ def __init__(self, env: Any) -> None: self._reset_once = True self._observations = None + self._info = {} def run(self, trainer: Optional["omni.isaac.gym.vec_env.vec_env_mt.TrainerMT"] = None) -> None: """Run the simulation in the main thread @@ -36,9 +38,10 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - self._observations, reward, terminated, info = self._env.step(actions) - truncated = info["time_outs"] if "time_outs" in info else torch.zeros_like(terminated) - return self._observations["obs"], reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), info + observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) + truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) + return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info def reset(self) -> Tuple[torch.Tensor, Any]: """Reset the environment @@ -47,9 +50,10 @@ def reset(self) -> Tuple[torch.Tensor, Any]: :rtype: torch.Tensor and any other info """ if self._reset_once: - self._observations = self._env.reset() + observations = self._env.reset() + self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) self._reset_once = False - return self._observations["obs"], {} + return self._observations, self._info def render(self, *args, **kwargs) -> None: """Render the environment diff --git a/skrl/envs/wrappers/torch/pettingzoo_envs.py b/skrl/envs/wrappers/torch/pettingzoo_envs.py index ec950400..7b55785c 100644 --- a/skrl/envs/wrappers/torch/pettingzoo_envs.py +++ b/skrl/envs/wrappers/torch/pettingzoo_envs.py @@ -1,12 +1,16 @@ from typing import Any, Mapping, Tuple import collections -import gymnasium -import numpy as np import torch from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper +from skrl.utils.spaces.torch import ( + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) class PettingZooWrapper(MultiAgentEnvWrapper): @@ -18,49 +22,6 @@ def __init__(self, env: Any) -> None: """ super().__init__(env) - def _observation_to_tensor(self, observation: Any, space: gymnasium.Space) -> torch.Tensor: - """Convert the Gymnasium observation to a flat tensor - - :param observation: The Gymnasium observation to convert to a tensor - :type observation: Any supported Gymnasium observation space - - :raises: ValueError if the observation space type is not supported - - :return: The observation as a flat tensor - :rtype: torch.Tensor - """ - if isinstance(observation, int): - return torch.tensor(observation, device=self.device, dtype=torch.int64).view(self.num_envs, -1) - elif isinstance(observation, np.ndarray): - return torch.tensor(np.ascontiguousarray(observation), device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gymnasium.spaces.Discrete): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gymnasium.spaces.Box): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gymnasium.spaces.Dict): - tmp = torch.cat([self._observation_to_tensor(observation[k], space[k]) \ - for k in sorted(space.keys())], dim=-1).view(self.num_envs, -1) - return tmp - else: - raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue") - - def _tensor_to_action(self, actions: torch.Tensor, space: gymnasium.Space) -> Any: - """Convert the action to the Gymnasium expected format - - :param actions: The actions to perform - :type actions: torch.Tensor - - :raise ValueError: If the action space type is not supported - - :return: The action in the Gymnasium format - :rtype: Any supported Gymnasium action space - """ - if isinstance(space, gymnasium.spaces.Discrete): - return actions.item() - elif isinstance(space, gymnasium.spaces.Box): - return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) - raise ValueError(f"Action space type {type(space)} not supported. Please report this issue") - def step(self, actions: Mapping[str, torch.Tensor]) -> \ Tuple[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], Mapping[str, Any]]: @@ -72,11 +33,11 @@ def step(self, actions: Mapping[str, torch.Tensor]) -> \ :return: Observation, reward, terminated, truncated, info :rtype: tuple of dictionaries torch.Tensor and any other info """ - actions = {uid: self._tensor_to_action(action, self.action_space(uid)) for uid, action in actions.items()} + actions = {uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) for uid, action in actions.items()} observations, rewards, terminated, truncated, infos = self._env.step(actions) # convert response to torch - observations = {uid: self._observation_to_tensor(value, self.observation_space(uid)) for uid, value in observations.items()} + observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, device=self.device)) for uid, value in observations.items()} rewards = {uid: torch.tensor(value, device=self.device, dtype=torch.float32).view(self.num_envs, -1) for uid, value in rewards.items()} terminated = {uid: torch.tensor(value, device=self.device, dtype=torch.bool).view(self.num_envs, -1) for uid, value in terminated.items()} truncated = {uid: torch.tensor(value, device=self.device, dtype=torch.bool).view(self.num_envs, -1) for uid, value in truncated.items()} @@ -88,7 +49,7 @@ def state(self) -> torch.Tensor: :return: State :rtype: torch.Tensor """ - return self._observation_to_tensor(self._env.state(), next(iter(self.state_spaces.values()))) + return flatten_tensorized_space(tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), device=self.device)) def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: """Reset the environment @@ -104,7 +65,7 @@ def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: observations, infos = outputs # convert response to torch - observations = {uid: self._observation_to_tensor(value, self.observation_space(uid)) for uid, value in observations.items()} + observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, device=self.device)) for uid, value in observations.items()} return observations, infos def render(self, *args, **kwargs) -> Any: diff --git a/skrl/envs/wrappers/torch/robosuite_envs.py b/skrl/envs/wrappers/torch/robosuite_envs.py index be4d9877..d07b438b 100644 --- a/skrl/envs/wrappers/torch/robosuite_envs.py +++ b/skrl/envs/wrappers/torch/robosuite_envs.py @@ -1,12 +1,13 @@ from typing import Any, Optional, Tuple import collections -import gym +import gymnasium import numpy as np import torch from skrl.envs.wrappers.torch.base import Wrapper +from skrl.utils.spaces.torch import convert_gym_space class RobosuiteWrapper(Wrapper): @@ -23,26 +24,26 @@ def __init__(self, env: Any) -> None: self._action_space = self._spec_to_space(self._env.action_spec) @property - def state_space(self) -> gym.Space: + def state_space(self) -> gymnasium.Space: """State space An alias for the ``observation_space`` property """ - return self._observation_space + return convert_gym_space(self._observation_space) @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gymnasium.Space: """Observation space """ - return self._observation_space + return convert_gym_space(self._observation_space) @property - def action_space(self) -> gym.Space: + def action_space(self) -> gymnasium.Space: """Action space """ - return self._action_space + return convert_gym_space(self._action_space) - def _spec_to_space(self, spec: Any) -> gym.Space: + def _spec_to_space(self, spec: Any) -> gymnasium.Space: """Convert the robosuite spec to a Gym space :param spec: The robosuite spec to convert @@ -51,20 +52,20 @@ def _spec_to_space(self, spec: Any) -> gym.Space: :raises: ValueError if the spec type is not supported :return: The Gym space - :rtype: gym.Space + :rtype: gymnasium.Space """ if type(spec) is tuple: - return gym.spaces.Box(shape=spec[0].shape, - dtype=np.float32, - low=spec[0], - high=spec[1]) + return gymnasium.spaces.Box(shape=spec[0].shape, + dtype=np.float32, + low=spec[0], + high=spec[1]) elif isinstance(spec, np.ndarray): - return gym.spaces.Box(shape=spec.shape, - dtype=np.float32, - low=np.full(spec.shape, float("-inf")), - high=np.full(spec.shape, float("inf"))) + return gymnasium.spaces.Box(shape=spec.shape, + dtype=np.float32, + low=np.full(spec.shape, float("-inf")), + high=np.full(spec.shape, float("inf"))) elif isinstance(spec, collections.OrderedDict): - return gym.spaces.Dict({k: self._spec_to_space(v) for k, v in spec.items()}) + return gymnasium.spaces.Dict({k: self._spec_to_space(v) for k, v in spec.items()}) else: raise ValueError(f"Spec type {type(spec)} not supported. Please report this issue") diff --git a/skrl/memories/jax/base.py b/skrl/memories/jax/base.py index 9a061287..f6168a48 100644 --- a/skrl/memories/jax/base.py +++ b/skrl/memories/jax/base.py @@ -5,7 +5,6 @@ import functools import operator import os -import gym import gymnasium import jax @@ -13,6 +12,7 @@ import numpy as np from skrl import config +from skrl.utils.spaces.jax import compute_space_size # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @@ -107,49 +107,6 @@ def __len__(self) -> int: """ return self.memory_size * self.num_envs if self.filled else self.memory_index * self.num_envs + self.env_index - def _get_space_size(self, - space: Union[int, Tuple[int], gym.Space, gymnasium.Space], - keep_dimensions: bool = False) -> Union[Tuple, int]: - """Get the size (number of elements) of a space - - :param space: Space or shape from which to obtain the number of elements - :type space: int, tuple or list of integers, gym.Space, or gymnasium.Space - :param keep_dimensions: Whether or not to keep the space dimensions (default: ``False``) - :type keep_dimensions: bool, optional - - :raises ValueError: If the space is not supported - - :return: Size of the space. If ``keep_dimensions`` is True, the space size will be a tuple - :rtype: int or tuple of int - """ - if type(space) in [int, float]: - return (int(space),) if keep_dimensions else int(space) - elif type(space) in [tuple, list]: - return tuple(space) if keep_dimensions else np.prod(space) - elif issubclass(type(space), gym.Space): - if issubclass(type(space), gym.spaces.Discrete): - return (1,) if keep_dimensions else 1 - elif issubclass(type(space), gym.spaces.MultiDiscrete): - return space.nvec.shape[0] - elif issubclass(type(space), gym.spaces.Box): - return tuple(space.shape) if keep_dimensions else np.prod(space.shape) - elif issubclass(type(space), gym.spaces.Dict): - if keep_dimensions: - raise ValueError("keep_dimensions=True cannot be used with Dict spaces") - return sum([self._get_space_size(space.spaces[key]) for key in space.spaces]) - elif issubclass(type(space), gymnasium.Space): - if issubclass(type(space), gymnasium.spaces.Discrete): - return (1,) if keep_dimensions else 1 - elif issubclass(type(space), gymnasium.spaces.MultiDiscrete): - return space.nvec.shape[0] - elif issubclass(type(space), gymnasium.spaces.Box): - return tuple(space.shape) if keep_dimensions else np.prod(space.shape) - elif issubclass(type(space), gymnasium.spaces.Dict): - if keep_dimensions: - raise ValueError("keep_dimensions=True cannot be used with Dict spaces") - return sum([self._get_space_size(space.spaces[key]) for key in space.spaces]) - raise ValueError(f"Space type {type(space)} not supported") - def _get_tensors_view(self, name): if self.tensors_keep_dimensions[name]: return self.tensors_view[name] if self._views else self.tensors[name].reshape(-1, *self.tensors_keep_dimensions[name]) @@ -202,9 +159,9 @@ def set_tensor_by_name(self, name: str, tensor: Union[np.ndarray, jax.Array]) -> def create_tensor(self, name: str, - size: Union[int, Tuple[int], gym.Space, gymnasium.Space], + size: Union[int, Tuple[int], gymnasium.Space], dtype: Optional[np.dtype] = None, - keep_dimensions: bool = True) -> bool: + keep_dimensions: bool = False) -> bool: """Create a new internal tensor in memory The tensor will have a 3-components shape (memory size, number of environments, size). @@ -213,8 +170,8 @@ def create_tensor(self, :param name: Tensor name (the name has to follow the python PEP 8 style) :type name: str :param size: Number of elements in the last dimension (effective data size). - The product of the elements will be computed for sequences or gym/gymnasium spaces - :type size: int, tuple or list of integers or gym.Space + The product of the elements will be computed for sequences or gymnasium spaces + :type size: int, tuple or list of integers or gymnasium space :param dtype: Data type (np.dtype) (default: ``None``). If None, the global default jax.numpy.float32 data type will be used :type dtype: np.dtype or None, optional @@ -227,7 +184,7 @@ def create_tensor(self, :rtype: bool """ # compute data size - size = self._get_space_size(size, keep_dimensions) + size = compute_space_size(size, occupied_size=True) # check dtype and size if the tensor exists if name in self.tensors: tensor = self.tensors[name] diff --git a/skrl/memories/torch/base.py b/skrl/memories/torch/base.py index 4cf89548..acc5d83a 100644 --- a/skrl/memories/torch/base.py +++ b/skrl/memories/torch/base.py @@ -5,13 +5,14 @@ import functools import operator import os -import gym import gymnasium import numpy as np import torch from torch.utils.data.sampler import BatchSampler +from skrl.utils.spaces.torch import compute_space_size + class Memory: def __init__(self, @@ -80,49 +81,6 @@ def __len__(self) -> int: """ return self.memory_size * self.num_envs if self.filled else self.memory_index * self.num_envs + self.env_index - def _get_space_size(self, - space: Union[int, Tuple[int], gym.Space, gymnasium.Space], - keep_dimensions: bool = False) -> Union[Tuple, int]: - """Get the size (number of elements) of a space - - :param space: Space or shape from which to obtain the number of elements - :type space: int, tuple or list of integers, gym.Space, or gymnasium.Space - :param keep_dimensions: Whether or not to keep the space dimensions (default: ``False``) - :type keep_dimensions: bool, optional - - :raises ValueError: If the space is not supported - - :return: Size of the space. If ``keep_dimensions`` is True, the space size will be a tuple - :rtype: int or tuple of int - """ - if type(space) in [int, float]: - return (int(space),) if keep_dimensions else int(space) - elif type(space) in [tuple, list]: - return tuple(space) if keep_dimensions else np.prod(space) - elif issubclass(type(space), gym.Space): - if issubclass(type(space), gym.spaces.Discrete): - return (1,) if keep_dimensions else 1 - elif issubclass(type(space), gym.spaces.MultiDiscrete): - return space.nvec.shape[0] - elif issubclass(type(space), gym.spaces.Box): - return tuple(space.shape) if keep_dimensions else np.prod(space.shape) - elif issubclass(type(space), gym.spaces.Dict): - if keep_dimensions: - raise ValueError("keep_dimensions=True cannot be used with Dict spaces") - return sum([self._get_space_size(space.spaces[key]) for key in space.spaces]) - elif issubclass(type(space), gymnasium.Space): - if issubclass(type(space), gymnasium.spaces.Discrete): - return (1,) if keep_dimensions else 1 - elif issubclass(type(space), gymnasium.spaces.MultiDiscrete): - return space.nvec.shape[0] - elif issubclass(type(space), gymnasium.spaces.Box): - return tuple(space.shape) if keep_dimensions else np.prod(space.shape) - elif issubclass(type(space), gymnasium.spaces.Dict): - if keep_dimensions: - raise ValueError("keep_dimensions=True cannot be used with Dict spaces") - return sum([self._get_space_size(space.spaces[key]) for key in space.spaces]) - raise ValueError(f"Space type {type(space)} not supported") - def share_memory(self) -> None: """Share the tensors between processes """ @@ -169,9 +127,9 @@ def set_tensor_by_name(self, name: str, tensor: torch.Tensor) -> None: def create_tensor(self, name: str, - size: Union[int, Tuple[int], gym.Space, gymnasium.Space], + size: Union[int, Tuple[int], gymnasium.Space], dtype: Optional[torch.dtype] = None, - keep_dimensions: bool = True) -> bool: + keep_dimensions: bool = False) -> bool: """Create a new internal tensor in memory The tensor will have a 3-components shape (memory size, number of environments, size). @@ -180,8 +138,8 @@ def create_tensor(self, :param name: Tensor name (the name has to follow the python PEP 8 style) :type name: str :param size: Number of elements in the last dimension (effective data size). - The product of the elements will be computed for sequences or gym/gymnasium spaces - :type size: int, tuple or list of integers, gym.Space, or gymnasium.Space + The product of the elements will be computed for sequences or gymnasium spaces + :type size: int, tuple or list of integers or gymnasium space :param dtype: Data type (torch.dtype) (default: ``None``). If None, the global default torch data type will be used :type dtype: torch.dtype or None, optional @@ -194,7 +152,7 @@ def create_tensor(self, :rtype: bool """ # compute data size - size = self._get_space_size(size, keep_dimensions) + size = compute_space_size(size, occupied_size=True) # check dtype and size if the tensor exists if name in self.tensors: tensor = self.tensors[name] diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index f6e2f0fc..9e68954a 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -1,6 +1,5 @@ from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union -import gym import gymnasium import flax @@ -9,6 +8,7 @@ import numpy as np from skrl import config +from skrl.utils.spaces.torch import compute_space_size, unflatten_tensorized_space @jax.jit @@ -34,13 +34,13 @@ def create(cls, *, apply_fn, params, **kwargs): class Model(flax.linen.Module): - observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space] - action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space] + observation_space: Union[int, Sequence[int], gymnasium.Space] + action_space: Union[int, Sequence[int], gymnasium.Space] device: Optional[Union[str, jax.Device]] = None def __init__(self, - observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space], - action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space], + observation_space: Union[int, Sequence[int], gymnasium.Space], + action_space: Union[int, Sequence[int], gymnasium.Space], device: Optional[Union[str, jax.Device]] = None, parent: Optional[Any] = None, name: Optional[str] = None) -> None: @@ -49,17 +49,17 @@ def __init__(self, The following properties are defined: - ``device`` (jax.Device): Device to be used for the computations - - ``observation_space`` (int, sequence of int, gym.Space, gymnasium.Space): Observation/state space - - ``action_space`` (int, sequence of int, gym.Space, gymnasium.Space): Action space + - ``observation_space`` (int, sequence of int, gymnasium.Space): Observation/state space + - ``action_space`` (int, sequence of int, gymnasium.Space): Action space - ``num_observations`` (int): Number of elements in the observation/state space - ``num_actions`` (int): Number of elements in the action space :param observation_space: Observation/state space or shape. The ``num_observations`` property will contain the size of that space - :type observation_space: int, sequence of int, gym.Space, gymnasium.Space + :type observation_space: int, sequence of int, gymnasium.Space :param action_space: Action space or shape. The ``num_actions`` property will contain the size of that space - :type action_space: int, sequence of int, gym.Space, gymnasium.Space + :type action_space: int, sequence of int, gymnasium.Space :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional @@ -100,8 +100,8 @@ def __call__(self, inputs, role): self.observation_space = observation_space self.action_space = action_space - self.num_observations = None if observation_space is None else self._get_space_size(observation_space) - self.num_actions = None if action_space is None else self._get_space_size(action_space) + self.num_observations = None if observation_space is None else compute_space_size(observation_space) + self.num_actions = None if action_space is None else compute_space_size(action_space) self.state_dict: StateDict self.training = False @@ -139,116 +139,20 @@ def init_state_dict(self, with jax.default_device(self.device): self.state_dict = StateDict.create(apply_fn=self.apply, params=self.init(key, inputs, role)) - def _get_space_size(self, - space: Union[int, Sequence[int], gym.Space, gymnasium.Space], - number_of_elements: bool = True) -> int: - """Get the size (number of elements) of a space - - :param space: Space or shape from which to obtain the number of elements - :type space: int, sequence of int, gym.Space, or gymnasium.Space - :param number_of_elements: Whether the number of elements occupied by the space is returned (default: ``True``). - If ``False``, the shape of the space is returned. - It only affects Discrete and MultiDiscrete spaces - :type number_of_elements: bool, optional - - :raises ValueError: If the space is not supported - - :return: Size of the space (number of elements) - :rtype: int - - Example:: - - # from int - >>> model._get_space_size(2) - 2 - - # from sequence of int - >>> model._get_space_size([2, 3]) - 6 - - # Box space - >>> space = gym.spaces.Box(low=-1, high=1, shape=(2, 3)) - >>> model._get_space_size(space) - 6 - - # Discrete space - >>> space = gym.spaces.Discrete(4) - >>> model._get_space_size(space) - 4 - >>> model._get_space_size(space, number_of_elements=False) - 1 - - # MultiDiscrete space - >>> space = gym.spaces.MultiDiscrete([5, 3, 2]) - >>> model._get_space_size(space) - 10 - >>> model._get_space_size(space, number_of_elements=False) - 3 - - # Dict space - >>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), - ... 'b': gym.spaces.Discrete(4)}) - >>> model._get_space_size(space) - 10 - >>> model._get_space_size(space, number_of_elements=False) - 7 - """ - size = None - if type(space) in [int, float]: - size = space - elif type(space) in [tuple, list]: - size = np.prod(space) - elif issubclass(type(space), gym.Space): - if issubclass(type(space), gym.spaces.Discrete): - if number_of_elements: - size = space.n - else: - size = 1 - elif issubclass(type(space), gym.spaces.MultiDiscrete): - if number_of_elements: - size = np.sum(space.nvec) - else: - size = space.nvec.shape[0] - elif issubclass(type(space), gym.spaces.Box): - size = np.prod(space.shape) - elif issubclass(type(space), gym.spaces.Dict): - size = sum([self._get_space_size(space.spaces[key], number_of_elements) for key in space.spaces]) - elif issubclass(type(space), gymnasium.Space): - if issubclass(type(space), gymnasium.spaces.Discrete): - if number_of_elements: - size = space.n - else: - size = 1 - elif issubclass(type(space), gymnasium.spaces.MultiDiscrete): - if number_of_elements: - size = np.sum(space.nvec) - else: - size = space.nvec.shape[0] - elif issubclass(type(space), gymnasium.spaces.Box): - size = np.prod(space.shape) - elif issubclass(type(space), gymnasium.spaces.Dict): - size = sum([self._get_space_size(space.spaces[key], number_of_elements) for key in space.spaces]) - if size is None: - raise ValueError(f"Space type {type(space)} not supported") - return int(size) - def tensor_to_space(self, tensor: Union[np.ndarray, jax.Array], - space: Union[gym.Space, gymnasium.Space], + space: gymnasium.Space, start: int = 0) -> Union[Union[np.ndarray, jax.Array], dict]: """Map a flat tensor to a Gym/Gymnasium space - The mapping is done in the following way: + .. warning:: - - Tensors belonging to Discrete spaces are returned without modification - - Tensors belonging to Box spaces are reshaped to the corresponding space shape - keeping the first dimension (number of samples) as they are - - Tensors belonging to Dict spaces are mapped into a dictionary with the same keys as the original space + This method is deprecated in favor of the :py:func:`skrl.utils.spaces.jax.unflatten_tensorized_space` :param tensor: Tensor to map from :type tensor: np.ndarray or jax.Array :param space: Space to map the tensor to - :type space: gym.Space or gymnasium.Space + :type space: gymnasium.Space :param start: Index of the first element of the tensor to map (default: ``0``) :type start: int, optional @@ -259,8 +163,8 @@ def tensor_to_space(self, Example:: - >>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), - ... 'b': gym.spaces.Discrete(4)}) + >>> space = gymnasium.spaces.Dict({'a': gymnasium.spaces.Box(low=-1, high=1, shape=(2, 3)), + ... 'b': gymnasium.spaces.Discrete(4)}) >>> tensor = jnp.array([[-0.3, -0.2, -0.1, 0.1, 0.2, 0.3, 2]]) >>> >>> model.tensor_to_space(tensor, space) @@ -268,31 +172,7 @@ def tensor_to_space(self, [ 0.1, 0.2, 0.3]]], dtype=float32), 'b': Array([[2.]], dtype=float32)} """ - if issubclass(type(space), gym.Space): - if issubclass(type(space), gym.spaces.Discrete): - return tensor - elif issubclass(type(space), gym.spaces.Box): - return tensor.reshape(tensor.shape[0], *space.shape) - elif issubclass(type(space), gym.spaces.Dict): - output = {} - for k in sorted(space.keys()): - end = start + self._get_space_size(space[k], number_of_elements=False) - output[k] = self.tensor_to_space(tensor[:, start:end], space[k], end) - start = end - return output - else: - if issubclass(type(space), gymnasium.spaces.Discrete): - return tensor - elif issubclass(type(space), gymnasium.spaces.Box): - return tensor.reshape(tensor.shape[0], *space.shape) - elif issubclass(type(space), gymnasium.spaces.Dict): - output = {} - for k in sorted(space.keys()): - end = start + self._get_space_size(space[k], number_of_elements=False) - output[k] = self.tensor_to_space(tensor[:, start:end], space[k], end) - start = end - return output - raise ValueError(f"Space type {type(space)} not supported") + return unflatten_tensorized_space(space, tensor) def random_act(self, inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], @@ -317,10 +197,10 @@ def random_act(self, :rtype: tuple of np.ndarray or jax.Array, None, and dict """ # discrete action space (Discrete) - if issubclass(type(self.action_space), gym.spaces.Discrete) or issubclass(type(self.action_space), gymnasium.spaces.Discrete): + if isinstance(self.action_space, gymnasium.spaces.Discrete): actions = np.random.randint(self.action_space.n, size=(inputs["states"].shape[0], 1)) # continuous action space (Box) - elif issubclass(type(self.action_space), gym.spaces.Box) or issubclass(type(self.action_space), gymnasium.spaces.Box): + elif isinstance(self.action_space, gymnasium.spaces.Box): actions = np.random.uniform(low=self.action_space.low[0], high=self.action_space.high[0], size=(inputs["states"].shape[0], self.num_actions)) else: raise NotImplementedError(f"Action space type ({type(self.action_space)}) not supported") diff --git a/skrl/models/jax/categorical.py b/skrl/models/jax/categorical.py index c9aa11a2..14ad316a 100644 --- a/skrl/models/jax/categorical.py +++ b/skrl/models/jax/categorical.py @@ -72,8 +72,8 @@ def __init__(self, unnormalized_log_prob: bool = True, role: str = "") -> None: ... x = nn.Dense(self.num_actions)(x) ... return x, {} ... - >>> # given an observation_space: gym.spaces.Box with shape (4,) - >>> # and an action_space: gym.spaces.Discrete with n = 2 + >>> # given an observation_space: gymnasium.spaces.Box with shape (4,) + >>> # and an action_space: gymnasium.spaces.Discrete with n = 2 >>> model = Policy(observation_space, action_space) >>> >>> print(model) diff --git a/skrl/models/jax/deterministic.py b/skrl/models/jax/deterministic.py index 3ebebc63..407251de 100644 --- a/skrl/models/jax/deterministic.py +++ b/skrl/models/jax/deterministic.py @@ -1,6 +1,5 @@ from typing import Any, Mapping, Optional, Tuple, Union -import gym import gymnasium import flax @@ -36,8 +35,8 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None: ... x = nn.Dense(1)(x) ... return x, {} ... - >>> # given an observation_space: gym.spaces.Box with shape (60,) - >>> # and an action_space: gym.spaces.Box with shape (8,) + >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) + >>> # and an action_space: gymnasium.spaces.Box with shape (8,) >>> model = Value(observation_space, action_space) >>> >>> print(model) @@ -50,8 +49,7 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None: """ if not hasattr(self, "_d_clip_actions"): self._d_clip_actions = {} - self._d_clip_actions[role] = clip_actions and (issubclass(type(self.action_space), gym.Space) or \ - issubclass(type(self.action_space), gymnasium.Space)) + self._d_clip_actions[role] = clip_actions and isinstance(self.action_space, gymnasium.Space) if self._d_clip_actions[role]: self.clip_actions_min = jnp.array(self.action_space.low, dtype=jnp.float32) diff --git a/skrl/models/jax/gaussian.py b/skrl/models/jax/gaussian.py index 53245372..e9783be6 100644 --- a/skrl/models/jax/gaussian.py +++ b/skrl/models/jax/gaussian.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Tuple, Union from functools import partial -import gym import gymnasium import flax @@ -102,8 +101,8 @@ def __init__(self, ... x = nn.elu(self.layer_2(x)) ... return self.layer_3(x), self.log_std_parameter, {} ... - >>> # given an observation_space: gym.spaces.Box with shape (60,) - >>> # and an action_space: gym.spaces.Box with shape (8,) + >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) + >>> # and an action_space: gymnasium.spaces.Box with shape (8,) >>> model = Policy(observation_space, action_space) >>> >>> print(model) @@ -114,8 +113,7 @@ def __init__(self, device = StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0) ) """ - self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \ - issubclass(type(self.action_space), gymnasium.Space)) + self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space) if self._clip_actions: self.clip_actions_min = jnp.array(self.action_space.low, dtype=jnp.float32) diff --git a/skrl/models/jax/multicategorical.py b/skrl/models/jax/multicategorical.py index b3a91e84..adb2a95c 100644 --- a/skrl/models/jax/multicategorical.py +++ b/skrl/models/jax/multicategorical.py @@ -78,8 +78,8 @@ def __init__(self, unnormalized_log_prob: bool = True, reduction: str = "sum", r ... x = nn.Dense(self.num_actions)(x) ... return x, {} ... - >>> # given an observation_space: gym.spaces.Box with shape (4,) - >>> # and an action_space: gym.spaces.MultiDiscrete with nvec = [3, 2] + >>> # given an observation_space: gymnasium.spaces.Box with shape (4,) + >>> # and an action_space: gymnasium.spaces.MultiDiscrete with nvec = [3, 2] >>> model = Policy(observation_space, action_space) >>> >>> print(model) diff --git a/skrl/models/torch/base.py b/skrl/models/torch/base.py index 79ed091e..39e746b4 100644 --- a/skrl/models/torch/base.py +++ b/skrl/models/torch/base.py @@ -1,37 +1,36 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union import collections -import gym import gymnasium from packaging import version -import numpy as np import torch from skrl import config, logger +from skrl.utils.spaces.torch import compute_space_size, unflatten_tensorized_space class Model(torch.nn.Module): def __init__(self, - observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space], - action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space], + observation_space: Union[int, Sequence[int], gymnasium.Space], + action_space: Union[int, Sequence[int], gymnasium.Space], device: Optional[Union[str, torch.device]] = None) -> None: """Base class representing a function approximator The following properties are defined: - ``device`` (torch.device): Device to be used for the computations - - ``observation_space`` (int, sequence of int, gym.Space, gymnasium.Space): Observation/state space - - ``action_space`` (int, sequence of int, gym.Space, gymnasium.Space): Action space + - ``observation_space`` (int, sequence of int, gymnasium.Space): Observation/state space + - ``action_space`` (int, sequence of int, gymnasium.Space): Action space - ``num_observations`` (int): Number of elements in the observation/state space - ``num_actions`` (int): Number of elements in the action space :param observation_space: Observation/state space or shape. The ``num_observations`` property will contain the size of that space - :type observation_space: int, sequence of int, gym.Space, gymnasium.Space + :type observation_space: int, sequence of int, gymnasium.Space :param action_space: Action space or shape. The ``num_actions`` property will contain the size of that space - :type action_space: int, sequence of int, gym.Space, gymnasium.Space + :type action_space: int, sequence of int, gymnasium.Space :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -59,121 +58,25 @@ def act(self, inputs, role=""): self.observation_space = observation_space self.action_space = action_space - self.num_observations = None if observation_space is None else self._get_space_size(observation_space) - self.num_actions = None if action_space is None else self._get_space_size(action_space) + self.num_observations = None if observation_space is None else compute_space_size(observation_space) + self.num_actions = None if action_space is None else compute_space_size(action_space) self._random_distribution = None - def _get_space_size(self, - space: Union[int, Sequence[int], gym.Space, gymnasium.Space], - number_of_elements: bool = True) -> int: - """Get the size (number of elements) of a space - - :param space: Space or shape from which to obtain the number of elements - :type space: int, sequence of int, gym.Space, or gymnasium.Space - :param number_of_elements: Whether the number of elements occupied by the space is returned (default: ``True``). - If ``False``, the shape of the space is returned. - It only affects Discrete and MultiDiscrete spaces - :type number_of_elements: bool, optional - - :raises ValueError: If the space is not supported - - :return: Size of the space (number of elements) - :rtype: int - - Example:: - - # from int - >>> model._get_space_size(2) - 2 - - # from sequence of int - >>> model._get_space_size([2, 3]) - 6 - - # Box space - >>> space = gym.spaces.Box(low=-1, high=1, shape=(2, 3)) - >>> model._get_space_size(space) - 6 - - # Discrete space - >>> space = gym.spaces.Discrete(4) - >>> model._get_space_size(space) - 4 - >>> model._get_space_size(space, number_of_elements=False) - 1 - - # MultiDiscrete space - >>> space = gym.spaces.MultiDiscrete([5, 3, 2]) - >>> model._get_space_size(space) - 10 - >>> model._get_space_size(space, number_of_elements=False) - 3 - - # Dict space - >>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), - ... 'b': gym.spaces.Discrete(4)}) - >>> model._get_space_size(space) - 10 - >>> model._get_space_size(space, number_of_elements=False) - 7 - """ - size = None - if type(space) in [int, float]: - size = space - elif type(space) in [tuple, list]: - size = np.prod(space) - elif issubclass(type(space), gym.Space): - if issubclass(type(space), gym.spaces.Discrete): - if number_of_elements: - size = space.n - else: - size = 1 - elif issubclass(type(space), gym.spaces.MultiDiscrete): - if number_of_elements: - size = np.sum(space.nvec) - else: - size = space.nvec.shape[0] - elif issubclass(type(space), gym.spaces.Box): - size = np.prod(space.shape) - elif issubclass(type(space), gym.spaces.Dict): - size = sum([self._get_space_size(space.spaces[key], number_of_elements) for key in space.spaces]) - elif issubclass(type(space), gymnasium.Space): - if issubclass(type(space), gymnasium.spaces.Discrete): - if number_of_elements: - size = space.n - else: - size = 1 - elif issubclass(type(space), gymnasium.spaces.MultiDiscrete): - if number_of_elements: - size = np.sum(space.nvec) - else: - size = space.nvec.shape[0] - elif issubclass(type(space), gymnasium.spaces.Box): - size = np.prod(space.shape) - elif issubclass(type(space), gymnasium.spaces.Dict): - size = sum([self._get_space_size(space.spaces[key], number_of_elements) for key in space.spaces]) - if size is None: - raise ValueError(f"Space type {type(space)} not supported") - return int(size) - def tensor_to_space(self, tensor: torch.Tensor, - space: Union[gym.Space, gymnasium.Space], + space: gymnasium.Space, start: int = 0) -> Union[torch.Tensor, dict]: """Map a flat tensor to a Gym/Gymnasium space - The mapping is done in the following way: + .. warning:: - - Tensors belonging to Discrete spaces are returned without modification - - Tensors belonging to Box spaces are reshaped to the corresponding space shape - keeping the first dimension (number of samples) as they are - - Tensors belonging to Dict spaces are mapped into a dictionary with the same keys as the original space + This method is deprecated in favor of the :py:func:`skrl.utils.spaces.torch.unflatten_tensorized_space` :param tensor: Tensor to map from :type tensor: torch.Tensor :param space: Space to map the tensor to - :type space: gym.Space or gymnasium.Space + :type space: gymnasium.Space :param start: Index of the first element of the tensor to map (default: ``0``) :type start: int, optional @@ -184,8 +87,8 @@ def tensor_to_space(self, Example:: - >>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), - ... 'b': gym.spaces.Discrete(4)}) + >>> space = gymnasium.spaces.Dict({'a': gymnasium.spaces.Box(low=-1, high=1, shape=(2, 3)), + ... 'b': gymnasium.spaces.Discrete(4)}) >>> tensor = torch.tensor([[-0.3, -0.2, -0.1, 0.1, 0.2, 0.3, 2]]) >>> >>> model.tensor_to_space(tensor, space) @@ -193,31 +96,7 @@ def tensor_to_space(self, [ 0.1000, 0.2000, 0.3000]]]), 'b': tensor([[2.]])} """ - if issubclass(type(space), gym.Space): - if issubclass(type(space), gym.spaces.Discrete): - return tensor - elif issubclass(type(space), gym.spaces.Box): - return tensor.view(tensor.shape[0], *space.shape) - elif issubclass(type(space), gym.spaces.Dict): - output = {} - for k in sorted(space.keys()): - end = start + self._get_space_size(space[k], number_of_elements=False) - output[k] = self.tensor_to_space(tensor[:, start:end], space[k], end) - start = end - return output - else: - if issubclass(type(space), gymnasium.spaces.Discrete): - return tensor - elif issubclass(type(space), gymnasium.spaces.Box): - return tensor.view(tensor.shape[0], *space.shape) - elif issubclass(type(space), gymnasium.spaces.Dict): - output = {} - for k in sorted(space.keys()): - end = start + self._get_space_size(space[k], number_of_elements=False) - output[k] = self.tensor_to_space(tensor[:, start:end], space[k], end) - start = end - return output - raise ValueError(f"Space type {type(space)} not supported") + return unflatten_tensorized_space(space, tensor) def random_act(self, inputs: Mapping[str, Union[torch.Tensor, Any]], @@ -238,10 +117,10 @@ def random_act(self, :rtype: tuple of torch.Tensor, None, and dict """ # discrete action space (Discrete) - if issubclass(type(self.action_space), gym.spaces.Discrete) or issubclass(type(self.action_space), gymnasium.spaces.Discrete): + if isinstance(self.action_space, gymnasium.spaces.Discrete): return torch.randint(self.action_space.n, (inputs["states"].shape[0], 1), device=self.device), None, {} # continuous action space (Box) - elif issubclass(type(self.action_space), gym.spaces.Box) or issubclass(type(self.action_space), gymnasium.spaces.Box): + elif isinstance(self.action_space, gymnasium.spaces.Box): if self._random_distribution is None: self._random_distribution = torch.distributions.uniform.Uniform( low=torch.tensor(self.action_space.low[0], device=self.device, dtype=torch.float32), diff --git a/skrl/models/torch/categorical.py b/skrl/models/torch/categorical.py index 6b338ca5..7f52f1b3 100644 --- a/skrl/models/torch/categorical.py +++ b/skrl/models/torch/categorical.py @@ -37,8 +37,8 @@ def __init__(self, unnormalized_log_prob: bool = True, role: str = "") -> None: ... def compute(self, inputs, role): ... return self.net(inputs["states"]), {} ... - >>> # given an observation_space: gym.spaces.Box with shape (4,) - >>> # and an action_space: gym.spaces.Discrete with n = 2 + >>> # given an observation_space: gymnasium.spaces.Box with shape (4,) + >>> # and an action_space: gymnasium.spaces.Discrete with n = 2 >>> model = Policy(observation_space, action_space) >>> >>> print(model) diff --git a/skrl/models/torch/deterministic.py b/skrl/models/torch/deterministic.py index af6cdce5..8dd52d45 100644 --- a/skrl/models/torch/deterministic.py +++ b/skrl/models/torch/deterministic.py @@ -1,6 +1,5 @@ from typing import Any, Mapping, Tuple, Union -import gym import gymnasium import torch @@ -36,8 +35,8 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None: ... def compute(self, inputs, role): ... return self.net(inputs["states"]), {} ... - >>> # given an observation_space: gym.spaces.Box with shape (60,) - >>> # and an action_space: gym.spaces.Box with shape (8,) + >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) + >>> # and an action_space: gymnasium.spaces.Box with shape (8,) >>> model = Value(observation_space, action_space) >>> >>> print(model) @@ -51,8 +50,7 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None: ) ) """ - self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \ - issubclass(type(self.action_space), gymnasium.Space)) + self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space) if self._clip_actions: self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32) diff --git a/skrl/models/torch/gaussian.py b/skrl/models/torch/gaussian.py index a9721b63..6b569cca 100644 --- a/skrl/models/torch/gaussian.py +++ b/skrl/models/torch/gaussian.py @@ -1,6 +1,5 @@ from typing import Any, Mapping, Tuple, Union -import gym import gymnasium import torch @@ -57,8 +56,8 @@ def __init__(self, ... def compute(self, inputs, role): ... return self.net(inputs["states"]), self.log_std_parameter, {} ... - >>> # given an observation_space: gym.spaces.Box with shape (60,) - >>> # and an action_space: gym.spaces.Box with shape (8,) + >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) + >>> # and an action_space: gymnasium.spaces.Box with shape (8,) >>> model = Policy(observation_space, action_space) >>> >>> print(model) @@ -72,8 +71,7 @@ def __init__(self, ) ) """ - self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \ - issubclass(type(self.action_space), gymnasium.Space)) + self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space) if self._clip_actions: self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32) diff --git a/skrl/models/torch/multicategorical.py b/skrl/models/torch/multicategorical.py index 2c749862..3ed95e8b 100644 --- a/skrl/models/torch/multicategorical.py +++ b/skrl/models/torch/multicategorical.py @@ -43,8 +43,8 @@ def __init__(self, unnormalized_log_prob: bool = True, reduction: str = "sum", r ... def compute(self, inputs, role): ... return self.net(inputs["states"]), {} ... - >>> # given an observation_space: gym.spaces.Box with shape (4,) - >>> # and an action_space: gym.spaces.MultiDiscrete with nvec = [3, 2] + >>> # given an observation_space: gymnasium.spaces.Box with shape (4,) + >>> # and an action_space: gymnasium.spaces.MultiDiscrete with nvec = [3, 2] >>> model = Policy(observation_space, action_space) >>> >>> print(model) diff --git a/skrl/models/torch/multivariate_gaussian.py b/skrl/models/torch/multivariate_gaussian.py index 0f43aadc..9a66041c 100644 --- a/skrl/models/torch/multivariate_gaussian.py +++ b/skrl/models/torch/multivariate_gaussian.py @@ -1,6 +1,5 @@ from typing import Any, Mapping, Tuple, Union -import gym import gymnasium import torch @@ -50,8 +49,8 @@ def __init__(self, ... def compute(self, inputs, role): ... return self.net(inputs["states"]), self.log_std_parameter, {} ... - >>> # given an observation_space: gym.spaces.Box with shape (60,) - >>> # and an action_space: gym.spaces.Box with shape (8,) + >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) + >>> # and an action_space: gymnasium.spaces.Box with shape (8,) >>> model = Policy(observation_space, action_space) >>> >>> print(model) @@ -65,8 +64,7 @@ def __init__(self, ) ) """ - self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \ - issubclass(type(self.action_space), gymnasium.Space)) + self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space) if self._clip_actions: self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32) diff --git a/skrl/models/torch/tabular.py b/skrl/models/torch/tabular.py index 371c6c20..58afa805 100644 --- a/skrl/models/torch/tabular.py +++ b/skrl/models/torch/tabular.py @@ -35,8 +35,8 @@ def __init__(self, num_envs: int = 1, role: str = "") -> None: ... dim=-1, keepdim=True).view(-1,1) ... return actions, {} ... - >>> # given an observation_space: gym.spaces.Discrete with n=100 - >>> # and an action_space: gym.spaces.Discrete with n=5 + >>> # given an observation_space: gymnasium.spaces.Discrete with n=100 + >>> # and an action_space: gymnasium.spaces.Discrete with n=5 >>> model = GreedyPolicy(observation_space, action_space, num_envs=1) >>> >>> print(model) diff --git a/skrl/multi_agents/jax/base.py b/skrl/multi_agents/jax/base.py index 57a484cf..99f9e055 100644 --- a/skrl/multi_agents/jax/base.py +++ b/skrl/multi_agents/jax/base.py @@ -5,7 +5,6 @@ import datetime import os import pickle -import gym import gymnasium import flax @@ -22,8 +21,8 @@ def __init__(self, possible_agents: Sequence[str], models: Mapping[str, Mapping[str, Model]], memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Mapping[str, Union[int, Sequence[int], gym.Space, gymnasium.Space]]] = None, - action_spaces: Optional[Mapping[str, Union[int, Sequence[int], gym.Space, gymnasium.Space]]] = None, + observation_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, + action_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Base class that represent a RL multi-agent @@ -36,9 +35,9 @@ def __init__(self, :param memories: Memories to storage the transitions. :type memories: dictionary of skrl.memory.jax.Memory, optional :param observation_spaces: Observation/state spaces or shapes (default: ``None``) - :type observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param action_spaces: Action spaces or shapes (default: ``None``) - :type action_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type action_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional diff --git a/skrl/multi_agents/jax/ippo/ippo.py b/skrl/multi_agents/jax/ippo/ippo.py index a495588c..b1f4284c 100644 --- a/skrl/multi_agents/jax/ippo/ippo.py +++ b/skrl/multi_agents/jax/ippo/ippo.py @@ -1,8 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Union -import copy import functools -import gym import gymnasium import jax @@ -192,8 +190,8 @@ def __init__(self, possible_agents: Sequence[str], models: Mapping[str, Model], memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None) -> None: """Independent Proximal Policy Optimization (IPPO) @@ -208,9 +206,9 @@ def __init__(self, :param memories: Memories to storage the transitions. :type memories: dictionary of skrl.memory.jax.Memory, optional :param observation_spaces: Observation/state spaces or shapes (default: ``None``) - :type observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param action_spaces: Action spaces or shapes (default: ``None``) - :type action_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type action_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional diff --git a/skrl/multi_agents/jax/mappo/mappo.py b/skrl/multi_agents/jax/mappo/mappo.py index dfd1c547..108b1306 100644 --- a/skrl/multi_agents/jax/mappo/mappo.py +++ b/skrl/multi_agents/jax/mappo/mappo.py @@ -1,8 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Union -import copy import functools -import gym import gymnasium import jax @@ -194,11 +192,11 @@ def __init__(self, possible_agents: Sequence[str], models: Mapping[str, Model], memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, device: Optional[Union[str, jax.Device]] = None, cfg: Optional[dict] = None, - shared_observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None) -> None: + shared_observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None) -> None: """Multi-Agent Proximal Policy Optimization (MAPPO) https://arxiv.org/abs/2103.01955 @@ -211,16 +209,16 @@ def __init__(self, :param memories: Memories to storage the transitions. :type memories: dictionary of skrl.memory.jax.Memory, optional :param observation_spaces: Observation/state spaces or shapes (default: ``None``) - :type observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param action_spaces: Action spaces or shapes (default: ``None``) - :type action_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type action_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional :param cfg: Configuration dictionary :type cfg: dict :param shared_observation_spaces: Shared observation/state space or shape (default: ``None``) - :type shared_observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type shared_observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional """ # _cfg = copy.deepcopy(IPPO_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = MAPPO_DEFAULT_CONFIG diff --git a/skrl/multi_agents/torch/base.py b/skrl/multi_agents/torch/base.py index 8dc797b2..3eecf23f 100644 --- a/skrl/multi_agents/torch/base.py +++ b/skrl/multi_agents/torch/base.py @@ -4,7 +4,6 @@ import copy import datetime import os -import gym import gymnasium from packaging import version @@ -22,8 +21,8 @@ def __init__(self, possible_agents: Sequence[str], models: Mapping[str, Mapping[str, Model]], memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Mapping[str, Union[int, Sequence[int], gym.Space, gymnasium.Space]]] = None, - action_spaces: Optional[Mapping[str, Union[int, Sequence[int], gym.Space, gymnasium.Space]]] = None, + observation_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, + action_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Base class that represent a RL multi-agent @@ -36,9 +35,9 @@ def __init__(self, :param memories: Memories to storage the transitions. :type memories: dictionary of skrl.memory.torch.Memory, optional :param observation_spaces: Observation/state spaces or shapes (default: ``None``) - :type observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param action_spaces: Action spaces or shapes (default: ``None``) - :type action_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type action_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/multi_agents/torch/ippo/ippo.py b/skrl/multi_agents/torch/ippo/ippo.py index 8fd59478..9ab490d7 100644 --- a/skrl/multi_agents/torch/ippo/ippo.py +++ b/skrl/multi_agents/torch/ippo/ippo.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -70,8 +69,8 @@ def __init__(self, possible_agents: Sequence[str], models: Mapping[str, Model], memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None) -> None: """Independent Proximal Policy Optimization (IPPO) @@ -86,9 +85,9 @@ def __init__(self, :param memories: Memories to storage the transitions. :type memories: dictionary of skrl.memory.torch.Memory, optional :param observation_spaces: Observation/state spaces or shapes (default: ``None``) - :type observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param action_spaces: Action spaces or shapes (default: ``None``) - :type action_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type action_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional diff --git a/skrl/multi_agents/torch/mappo/mappo.py b/skrl/multi_agents/torch/mappo/mappo.py index 3a3f45ce..c985b8df 100644 --- a/skrl/multi_agents/torch/mappo/mappo.py +++ b/skrl/multi_agents/torch/mappo/mappo.py @@ -2,7 +2,6 @@ import copy import itertools -import gym import gymnasium import torch @@ -72,11 +71,11 @@ def __init__(self, possible_agents: Sequence[str], models: Mapping[str, Model], memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, device: Optional[Union[str, torch.device]] = None, cfg: Optional[dict] = None, - shared_observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None) -> None: + shared_observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None) -> None: """Multi-Agent Proximal Policy Optimization (MAPPO) https://arxiv.org/abs/2103.01955 @@ -89,16 +88,16 @@ def __init__(self, :param memories: Memories to storage the transitions. :type memories: dictionary of skrl.memory.torch.Memory, optional :param observation_spaces: Observation/state spaces or shapes (default: ``None``) - :type observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param action_spaces: Action spaces or shapes (default: ``None``) - :type action_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type action_spaces: dictionary of int, sequence of int or gymnasium.Space, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional :param cfg: Configuration dictionary :type cfg: dict :param shared_observation_spaces: Shared observation/state space or shape (default: ``None``) - :type shared_observation_spaces: dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional + :type shared_observation_spaces: dictionary of int, sequence of int or gymnasium.Space, optional """ _cfg = copy.deepcopy(MAPPO_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) diff --git a/skrl/resources/preprocessors/jax/running_standard_scaler.py b/skrl/resources/preprocessors/jax/running_standard_scaler.py index 97d5eb63..3563942e 100644 --- a/skrl/resources/preprocessors/jax/running_standard_scaler.py +++ b/skrl/resources/preprocessors/jax/running_standard_scaler.py @@ -1,6 +1,5 @@ from typing import Mapping, Optional, Tuple, Union -import gym import gymnasium import jax @@ -8,6 +7,7 @@ import numpy as np from skrl import config +from skrl.utils.spaces.jax import compute_space_size # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @@ -60,7 +60,7 @@ def _standardization(running_mean: jax.Array, class RunningStandardScaler: def __init__(self, - size: Union[int, Tuple[int], gym.Space, gymnasium.Space], + size: Union[int, Tuple[int], gymnasium.Space], epsilon: float = 1e-8, clip_threshold: float = 5.0, device: Optional[Union[str, jax.Device]] = None) -> None: @@ -79,7 +79,7 @@ def __init__(self, [0.59656656, 0.45325184]], dtype=float32) :param size: Size of the input space - :type size: int, tuple or list of integers, gym.Space, or gymnasium.Space + :type size: int, tuple or list of integers, or gymnasium.Space :param epsilon: Small number to avoid division by zero (default: ``1e-8``) :type epsilon: float :param clip_threshold: Threshold to clip the data (default: ``5.0``) @@ -100,7 +100,7 @@ def __init__(self, device_type, device_index = f"{device}:0".split(':')[:2] self.device = jax.devices(device_type)[int(device_index)] - size = self._get_space_size(size) + size = compute_space_size(size, occupied_size=True) if self._jax: with jax.default_device(self.device): @@ -140,37 +140,6 @@ def state_dict(self, value: Mapping[str, Union[np.ndarray, jax.Array]]) -> None: np.copyto(self.running_variance, value["running_variance"]) np.copyto(self.current_count, value["current_count"]) - def _get_space_size(self, space: Union[int, Tuple[int], gym.Space, gymnasium.Space]) -> int: - """Get the size (number of elements) of a space - - :param space: Space or shape from which to obtain the number of elements - :type space: int, tuple or list of integers, gym.Space, or gymnasium.Space - - :raises ValueError: If the space is not supported - - :return: Size of the space data - :rtype: Space size (number of elements) - """ - if type(space) in [int, float]: - return int(space) - elif type(space) in [tuple, list]: - return np.prod(space) - elif issubclass(type(space), gym.Space): - if issubclass(type(space), gym.spaces.Discrete): - return 1 - elif issubclass(type(space), gym.spaces.Box): - return np.prod(space.shape) - elif issubclass(type(space), gym.spaces.Dict): - return sum([self._get_space_size(space.spaces[key]) for key in space.spaces]) - elif issubclass(type(space), gymnasium.Space): - if issubclass(type(space), gymnasium.spaces.Discrete): - return 1 - elif issubclass(type(space), gymnasium.spaces.Box): - return np.prod(space.shape) - elif issubclass(type(space), gymnasium.spaces.Dict): - return sum([self._get_space_size(space.spaces[key]) for key in space.spaces]) - raise ValueError(f"Space type {type(space)} not supported") - def _parallel_variance(self, input_mean: Union[np.ndarray, jax.Array], input_var: Union[np.ndarray, jax.Array], diff --git a/skrl/resources/preprocessors/torch/running_standard_scaler.py b/skrl/resources/preprocessors/torch/running_standard_scaler.py index 13e8f053..43f5cfda 100644 --- a/skrl/resources/preprocessors/torch/running_standard_scaler.py +++ b/skrl/resources/preprocessors/torch/running_standard_scaler.py @@ -1,16 +1,16 @@ from typing import Optional, Tuple, Union -import gym import gymnasium -import numpy as np import torch import torch.nn as nn +from skrl.utils.spaces.torch import compute_space_size + class RunningStandardScaler(nn.Module): def __init__(self, - size: Union[int, Tuple[int], gym.Space, gymnasium.Space], + size: Union[int, Tuple[int], gymnasium.Space], epsilon: float = 1e-8, clip_threshold: float = 5.0, device: Optional[Union[str, torch.device]] = None) -> None: @@ -29,7 +29,7 @@ def __init__(self, [0.8540, 0.1982]]) :param size: Size of the input space - :type size: int, tuple or list of integers, gym.Space, or gymnasium.Space + :type size: int, tuple or list of integers, or gymnasium.Space :param epsilon: Small number to avoid division by zero (default: ``1e-8``) :type epsilon: float :param clip_threshold: Threshold to clip the data (default: ``5.0``) @@ -47,43 +47,12 @@ def __init__(self, else: self.device = torch.device(device) - size = self._get_space_size(size) + size = compute_space_size(size, occupied_size=True) self.register_buffer("running_mean", torch.zeros(size, dtype=torch.float64, device=self.device)) self.register_buffer("running_variance", torch.ones(size, dtype=torch.float64, device=self.device)) self.register_buffer("current_count", torch.ones((), dtype=torch.float64, device=self.device)) - def _get_space_size(self, space: Union[int, Tuple[int], gym.Space, gymnasium.Space]) -> int: - """Get the size (number of elements) of a space - - :param space: Space or shape from which to obtain the number of elements - :type space: int, tuple or list of integers, gym.Space, or gymnasium.Space - - :raises ValueError: If the space is not supported - - :return: Size of the space data - :rtype: Space size (number of elements) - """ - if type(space) in [int, float]: - return int(space) - elif type(space) in [tuple, list]: - return np.prod(space) - elif issubclass(type(space), gym.Space): - if issubclass(type(space), gym.spaces.Discrete): - return 1 - elif issubclass(type(space), gym.spaces.Box): - return np.prod(space.shape) - elif issubclass(type(space), gym.spaces.Dict): - return sum([self._get_space_size(space.spaces[key]) for key in space.spaces]) - elif issubclass(type(space), gymnasium.Space): - if issubclass(type(space), gymnasium.spaces.Discrete): - return 1 - elif issubclass(type(space), gymnasium.spaces.Box): - return np.prod(space.shape) - elif issubclass(type(space), gymnasium.spaces.Dict): - return sum([self._get_space_size(space.spaces[key]) for key in space.spaces]) - raise ValueError(f"Space type {type(space)} not supported") - def _parallel_variance(self, input_mean: torch.Tensor, input_var: torch.Tensor, input_count: int) -> None: """Update internal variables using the parallel algorithm for computing variance diff --git a/skrl/trainers/jax/base.py b/skrl/trainers/jax/base.py index 284c175c..a9205e56 100644 --- a/skrl/trainers/jax/base.py +++ b/skrl/trainers/jax/base.py @@ -58,6 +58,7 @@ def __init__(self, self.headless = self.cfg.get("headless", False) self.disable_progressbar = self.cfg.get("disable_progressbar", False) self.close_environment_at_exit = self.cfg.get("close_environment_at_exit", True) + self.environment_info = self.cfg.get("environment_info", "episode") self.initial_timestep = 0 @@ -172,19 +173,18 @@ def single_agent_train(self) -> None: # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) - # compute actions with contextlib.nullcontext(): + # compute actions actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] - # step the environments - next_states, rewards, terminated, truncated, infos = self.env.step(actions) + # step the environments + next_states, rewards, terminated, truncated, infos = self.env.step(actions) - # render scene - if not self.headless: - self.env.render() + # render scene + if not self.headless: + self.env.render() - # record the environments' transitions - with contextlib.nullcontext(): + # record the environments' transitions self.agents.record_transition(states=states, actions=actions, rewards=rewards, @@ -226,18 +226,20 @@ def single_agent_eval(self) -> None: for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): - # compute actions + # pre-interaction + self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) + with contextlib.nullcontext(): + # compute actions actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] - # step the environments - next_states, rewards, terminated, truncated, infos = self.env.step(actions) + # step the environments + next_states, rewards, terminated, truncated, infos = self.env.step(actions) - # render scene - if not self.headless: - self.env.render() + # render scene + if not self.headless: + self.env.render() - with contextlib.nullcontext(): # write data to TensorBoard self.agents.record_transition(states=states, actions=actions, @@ -248,7 +250,9 @@ def single_agent_eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + + # post-interaction + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) # reset environments if self.env.num_envs > 1: @@ -285,22 +289,21 @@ def multi_agent_train(self) -> None: # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) - # compute actions with contextlib.nullcontext(): + # compute actions actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] - # step the environments - next_states, rewards, terminated, truncated, infos = self.env.step(actions) - shared_next_states = self.env.state() - infos["shared_states"] = shared_states - infos["shared_next_states"] = shared_next_states + # step the environments + next_states, rewards, terminated, truncated, infos = self.env.step(actions) + shared_next_states = self.env.state() + infos["shared_states"] = shared_states + infos["shared_next_states"] = shared_next_states - # render scene - if not self.headless: - self.env.render() + # render scene + if not self.headless: + self.env.render() - # record the environments' transitions - with contextlib.nullcontext(): + # record the environments' transitions self.agents.record_transition(states=states, actions=actions, rewards=rewards, @@ -315,13 +318,13 @@ def multi_agent_train(self) -> None: self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps) # reset environments - with contextlib.nullcontext(): - if not self.env.agents: + if not self.env.agents: + with contextlib.nullcontext(): states, infos = self.env.reset() shared_states = self.env.state() - else: - states = next_states - shared_states = shared_next_states + else: + states = next_states + shared_states = shared_next_states def multi_agent_eval(self) -> None: """Evaluate multi-agents @@ -342,21 +345,23 @@ def multi_agent_eval(self) -> None: for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): - # compute actions + # pre-interaction + self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) + with contextlib.nullcontext(): + # compute actions actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] - # step the environments - next_states, rewards, terminated, truncated, infos = self.env.step(actions) - shared_next_states = self.env.state() - infos["shared_states"] = shared_states - infos["shared_next_states"] = shared_next_states + # step the environments + next_states, rewards, terminated, truncated, infos = self.env.step(actions) + shared_next_states = self.env.state() + infos["shared_states"] = shared_states + infos["shared_next_states"] = shared_next_states - # render scene - if not self.headless: - self.env.render() + # render scene + if not self.headless: + self.env.render() - with contextlib.nullcontext(): # write data to TensorBoard self.agents.record_transition(states=states, actions=actions, @@ -367,12 +372,15 @@ def multi_agent_eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) - # reset environments - if not self.env.agents: + # post-interaction + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + + # reset environments + if not self.env.agents: + with contextlib.nullcontext(): states, infos = self.env.reset() shared_states = self.env.state() - else: - states = next_states - shared_states = shared_next_states + else: + states = next_states + shared_states = shared_next_states diff --git a/skrl/trainers/jax/sequential.py b/skrl/trainers/jax/sequential.py index 6fbb261c..931030be 100644 --- a/skrl/trainers/jax/sequential.py +++ b/skrl/trainers/jax/sequential.py @@ -18,6 +18,7 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination + "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-jax] @@ -93,20 +94,19 @@ def train(self) -> None: for agent in self.agents: agent.pre_interaction(timestep=timestep, timesteps=self.timesteps) - # compute actions with contextlib.nullcontext(): + # compute actions actions = jnp.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \ for agent, scope in zip(self.agents, self.agents_scope)]) - # step the environments - next_states, rewards, terminated, truncated, infos = self.env.step(actions) + # step the environments + next_states, rewards, terminated, truncated, infos = self.env.step(actions) - # render scene - if not self.headless: - self.env.render() + # render scene + if not self.headless: + self.env.render() - # record the environments' transitions - with contextlib.nullcontext(): + # record the environments' transitions for agent, scope in zip(self.agents, self.agents_scope): agent.record_transition(states=states[scope[0]:scope[1]], actions=actions[scope[0]:scope[1]], @@ -123,11 +123,11 @@ def train(self) -> None: agent.post_interaction(timestep=timestep, timesteps=self.timesteps) # reset environments - with contextlib.nullcontext(): - if terminated.any() or truncated.any(): + if terminated.any() or truncated.any(): + with contextlib.nullcontext(): states, infos = self.env.reset() - else: - states = next_states + else: + states = next_states def eval(self) -> None: """Evaluate the agents sequentially @@ -161,19 +161,22 @@ def eval(self) -> None: for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): - # compute actions + # pre-interaction + for agent in self.agents: + agent.pre_interaction(timestep=timestep, timesteps=self.timesteps) + with contextlib.nullcontext(): + # compute actions actions = jnp.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \ for agent, scope in zip(self.agents, self.agents_scope)]) - # step the environments - next_states, rewards, terminated, truncated, infos = self.env.step(actions) + # step the environments + next_states, rewards, terminated, truncated, infos = self.env.step(actions) - # render scene - if not self.headless: - self.env.render() + # render scene + if not self.headless: + self.env.render() - with contextlib.nullcontext(): # write data to TensorBoard for agent, scope in zip(self.agents, self.agents_scope): agent.record_transition(states=states[scope[0]:scope[1]], @@ -185,10 +188,14 @@ def eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps) - # reset environments - if terminated.any() or truncated.any(): + # post-interaction + for agent in self.agents: + super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps) + + # reset environments + if terminated.any() or truncated.any(): + with contextlib.nullcontext(): states, infos = self.env.reset() - else: - states = next_states + else: + states = next_states diff --git a/skrl/trainers/jax/step.py b/skrl/trainers/jax/step.py index ae7e5986..2e27b0de 100644 --- a/skrl/trainers/jax/step.py +++ b/skrl/trainers/jax/step.py @@ -20,6 +20,7 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination + "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-jax] @@ -95,82 +96,56 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar, file=sys.stdout) self._progress.update(n=1) + # hack to simplify code + if self.num_simultaneous_agents == 1: + self.agents = [self.agents] + # set running mode - if self.num_simultaneous_agents > 1: - for agent in self.agents: - agent.set_running_mode("train") - else: - self.agents.set_running_mode("train") + for agent in self.agents: + agent.set_running_mode("train") # reset env if self.states is None: self.states, infos = self.env.reset() - if self.num_simultaneous_agents == 1: - # pre-interaction - self.agents.pre_interaction(timestep=timestep, timesteps=timesteps) - - # compute actions - with contextlib.nullcontext(): - actions = self.agents.act(self.states, timestep=timestep, timesteps=timesteps)[0] - - else: - # pre-interaction - for agent in self.agents: - agent.pre_interaction(timestep=timestep, timesteps=timesteps) + # pre-interaction + for agent in self.agents: + agent.pre_interaction(timestep=timestep, timesteps=timesteps) + with contextlib.nullcontext(): # compute actions - with contextlib.nullcontext(): - actions = jnp.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = jnp.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ + for agent, scope in zip(self.agents, self.agents_scope)]) - # step the environments - next_states, rewards, terminated, truncated, infos = self.env.step(actions) + # step the environments + next_states, rewards, terminated, truncated, infos = self.env.step(actions) - # render scene - if not self.headless: - self.env.render() + # render scene + if not self.headless: + self.env.render() - if self.num_simultaneous_agents == 1: - # record the environments' transitions - with contextlib.nullcontext(): - self.agents.record_transition(states=self.states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=timesteps) - - # post-interaction - self.agents.post_interaction(timestep=timestep, timesteps=timesteps) - - else: # record the environments' transitions - with contextlib.nullcontext(): - for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=self.states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=timesteps) - - # post-interaction - for agent in self.agents: - agent.post_interaction(timestep=timestep, timesteps=timesteps) + for agent, scope in zip(self.agents, self.agents_scope): + agent.record_transition(states=self.states[scope[0]:scope[1]], + actions=actions[scope[0]:scope[1]], + rewards=rewards[scope[0]:scope[1]], + next_states=next_states[scope[0]:scope[1]], + terminated=terminated[scope[0]:scope[1]], + truncated=truncated[scope[0]:scope[1]], + infos=infos, + timestep=timestep, + timesteps=timesteps) + + # post-interaction + for agent in self.agents: + agent.post_interaction(timestep=timestep, timesteps=timesteps) # reset environments - with contextlib.nullcontext(): - if terminated.any() or truncated.any(): + if terminated.any() or truncated.any(): + with contextlib.nullcontext(): self.states, infos = self.env.reset() - else: - self.states = next_states + else: + self.states = next_states return next_states, rewards, terminated, truncated, infos @@ -205,66 +180,55 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar, file=sys.stdout) self._progress.update(n=1) + # hack to simplify code + if self.num_simultaneous_agents == 1: + self.agents = [self.agents] + # set running mode - if self.num_simultaneous_agents > 1: - for agent in self.agents: - agent.set_running_mode("eval") - else: - self.agents.set_running_mode("eval") + for agent in self.agents: + agent.set_running_mode("eval") # reset env if self.states is None: self.states, infos = self.env.reset() - with contextlib.nullcontext(): - if self.num_simultaneous_agents == 1: - # compute actions - actions = self.agents.act(self.states, timestep=timestep, timesteps=timesteps)[0] - - else: - # compute actions - actions = jnp.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) - - # step the environments - next_states, rewards, terminated, truncated, infos = self.env.step(actions) - - # render scene - if not self.headless: - self.env.render() + # pre-interaction + for agent in self.agents: + agent.pre_interaction(timestep=timestep, timesteps=timesteps) with contextlib.nullcontext(): - if self.num_simultaneous_agents == 1: - # write data to TensorBoard - self.agents.record_transition(states=self.states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=timesteps) - - else: - # write data to TensorBoard - for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=self.states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=timesteps) - super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps) - - # reset environments - if terminated.any() or truncated.any(): + # compute actions + actions = jnp.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ + for agent, scope in zip(self.agents, self.agents_scope)]) + + # step the environments + next_states, rewards, terminated, truncated, infos = self.env.step(actions) + + # render scene + if not self.headless: + self.env.render() + + # write data to TensorBoard + for agent, scope in zip(self.agents, self.agents_scope): + agent.record_transition(states=self.states[scope[0]:scope[1]], + actions=actions[scope[0]:scope[1]], + rewards=rewards[scope[0]:scope[1]], + next_states=next_states[scope[0]:scope[1]], + terminated=terminated[scope[0]:scope[1]], + truncated=truncated[scope[0]:scope[1]], + infos=infos, + timestep=timestep, + timesteps=timesteps) + + # post-interaction + for agent in self.agents: + super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps) + + # reset environments + if terminated.any() or truncated.any(): + with contextlib.nullcontext(): self.states, infos = self.env.reset() - else: - self.states = next_states + else: + self.states = next_states return next_states, rewards, terminated, truncated, infos diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py index c4d008db..3d3c607e 100644 --- a/skrl/trainers/torch/base.py +++ b/skrl/trainers/torch/base.py @@ -174,8 +174,8 @@ def single_agent_train(self) -> None: # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) - # compute actions with torch.no_grad(): + # compute actions actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] # step the environments @@ -233,8 +233,11 @@ def single_agent_eval(self) -> None: for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): - # compute actions + # pre-interaction + self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) + with torch.no_grad(): + # compute actions actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] # step the environments @@ -254,7 +257,6 @@ def single_agent_eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) # log environment info if self.environment_info in infos: @@ -262,6 +264,9 @@ def single_agent_eval(self) -> None: if isinstance(v, torch.Tensor) and v.numel() == 1: self.agents.track_data(f"Info / {k}", v.item()) + # post-interaction + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + # reset environments if self.env.num_envs > 1: states = next_states @@ -297,8 +302,8 @@ def multi_agent_train(self) -> None: # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) - # compute actions with torch.no_grad(): + # compute actions actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] # step the environments @@ -332,13 +337,13 @@ def multi_agent_train(self) -> None: self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps) # reset environments - with torch.no_grad(): - if not self.env.agents: + if not self.env.agents: + with torch.no_grad(): states, infos = self.env.reset() shared_states = self.env.state() - else: - states = next_states - shared_states = shared_next_states + else: + states = next_states + shared_states = shared_next_states def multi_agent_eval(self) -> None: """Evaluate multi-agents @@ -359,8 +364,11 @@ def multi_agent_eval(self) -> None: for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): - # compute actions + # pre-interaction + self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) + with torch.no_grad(): + # compute actions actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] # step the environments @@ -383,7 +391,6 @@ def multi_agent_eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) # log environment info if self.environment_info in infos: @@ -391,10 +398,14 @@ def multi_agent_eval(self) -> None: if isinstance(v, torch.Tensor) and v.numel() == 1: self.agents.track_data(f"Info / {k}", v.item()) - # reset environments - if not self.env.agents: + # post-interaction + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + + # reset environments + if not self.env.agents: + with torch.no_grad(): states, infos = self.env.reset() shared_states = self.env.state() - else: - states = next_states - shared_states = shared_next_states + else: + states = next_states + shared_states = shared_next_states diff --git a/skrl/trainers/torch/parallel.py b/skrl/trainers/torch/parallel.py index 1a086e0e..cc221255 100644 --- a/skrl/trainers/torch/parallel.py +++ b/skrl/trainers/torch/parallel.py @@ -343,6 +343,11 @@ def eval(self) -> None: for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + # pre-interaction + for pipe in producer_pipes: + pipe.send({"task": "pre_interaction", "timestep": timestep, "timesteps": self.timesteps}) + barrier.wait() + # compute actions with torch.no_grad(): for pipe, queue in zip(producer_pipes, queues): @@ -369,18 +374,20 @@ def eval(self) -> None: if not truncated.is_cuda: truncated.share_memory_() - for pipe, queue in zip(producer_pipes, queues): - pipe.send({"task": "eval-record_transition-post_interaction", - "timestep": timestep, - "timesteps": self.timesteps}) - queue.put(rewards) - queue.put(next_states) - queue.put(terminated) - queue.put(truncated) - queue.put(infos) - barrier.wait() + # post-interaction + for pipe, queue in zip(producer_pipes, queues): + pipe.send({"task": "eval-record_transition-post_interaction", + "timestep": timestep, + "timesteps": self.timesteps}) + queue.put(rewards) + queue.put(next_states) + queue.put(terminated) + queue.put(truncated) + queue.put(infos) + barrier.wait() - # reset environments + # reset environments + with torch.no_grad(): if terminated.any() or truncated.any(): states, infos = self.env.reset() if not states.is_cuda: diff --git a/skrl/trainers/torch/sequential.py b/skrl/trainers/torch/sequential.py index 67b43ca7..9921faa2 100644 --- a/skrl/trainers/torch/sequential.py +++ b/skrl/trainers/torch/sequential.py @@ -93,8 +93,8 @@ def train(self) -> None: for agent in self.agents: agent.pre_interaction(timestep=timestep, timesteps=self.timesteps) - # compute actions with torch.no_grad(): + # compute actions actions = torch.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \ for agent, scope in zip(self.agents, self.agents_scope)]) @@ -129,11 +129,11 @@ def train(self) -> None: agent.post_interaction(timestep=timestep, timesteps=self.timesteps) # reset environments - with torch.no_grad(): - if terminated.any() or truncated.any(): + if terminated.any() or truncated.any(): + with torch.no_grad(): states, infos = self.env.reset() - else: - states = next_states + else: + states = next_states def eval(self) -> None: """Evaluate the agents sequentially @@ -167,8 +167,12 @@ def eval(self) -> None: for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): - # compute actions + # pre-interaction + for agent in self.agents: + agent.pre_interaction(timestep=timestep, timesteps=self.timesteps) + with torch.no_grad(): + # compute actions actions = torch.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \ for agent, scope in zip(self.agents, self.agents_scope)]) @@ -190,7 +194,6 @@ def eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps) # log environment info if self.environment_info in infos: @@ -199,8 +202,13 @@ def eval(self) -> None: for agent in self.agents: agent.track_data(f"Info / {k}", v.item()) - # reset environments - if terminated.any() or truncated.any(): + # post-interaction + for agent in self.agents: + super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps) + + # reset environments + if terminated.any() or truncated.any(): + with torch.no_grad(): states, infos = self.env.reset() - else: - states = next_states + else: + states = next_states diff --git a/skrl/trainers/torch/step.py b/skrl/trainers/torch/step.py index 3bfd9acc..7405b23e 100644 --- a/skrl/trainers/torch/step.py +++ b/skrl/trainers/torch/step.py @@ -92,36 +92,27 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar, file=sys.stdout) self._progress.update(n=1) + # hack to simplify code + if self.num_simultaneous_agents == 1: + self.agents = [self.agents] + # set running mode - if self.num_simultaneous_agents > 1: - for agent in self.agents: - agent.set_running_mode("train") - else: - self.agents.set_running_mode("train") + for agent in self.agents: + agent.set_running_mode("train") # reset env if self.states is None: self.states, infos = self.env.reset() - if self.num_simultaneous_agents == 1: - # pre-interaction - self.agents.pre_interaction(timestep=timestep, timesteps=timesteps) - - # compute actions - with torch.no_grad(): - actions = self.agents.act(self.states, timestep=timestep, timesteps=timesteps)[0] - - else: - # pre-interaction - for agent in self.agents: - agent.pre_interaction(timestep=timestep, timesteps=timesteps) + # pre-interaction + for agent in self.agents: + agent.pre_interaction(timestep=timestep, timesteps=timesteps) + with torch.no_grad(): # compute actions - with torch.no_grad(): - actions = torch.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = torch.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ + for agent, scope in zip(self.agents, self.agents_scope)]) - with torch.no_grad(): # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -129,59 +120,35 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) if not self.headless: self.env.render() - if self.num_simultaneous_agents == 1: - with torch.no_grad(): - # record the environments' transitions - self.agents.record_transition(states=self.states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=timesteps) - - # log environment info - if self.environment_info in infos: - for k, v in infos[self.environment_info].items(): - if isinstance(v, torch.Tensor) and v.numel() == 1: - self.agents.track_data(f"Info / {k}", v.item()) - - # post-interaction - self.agents.post_interaction(timestep=timestep, timesteps=timesteps) - - else: - with torch.no_grad(): - # record the environments' transitions - for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=self.states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=timesteps) - - # log environment info - if self.environment_info in infos: - for k, v in infos[self.environment_info].items(): - if isinstance(v, torch.Tensor) and v.numel() == 1: - for agent in self.agents: - agent.track_data(f"Info / {k}", v.item()) - - # post-interaction - for agent in self.agents: - agent.post_interaction(timestep=timestep, timesteps=timesteps) + # record the environments' transitions + for agent, scope in zip(self.agents, self.agents_scope): + agent.record_transition(states=self.states[scope[0]:scope[1]], + actions=actions[scope[0]:scope[1]], + rewards=rewards[scope[0]:scope[1]], + next_states=next_states[scope[0]:scope[1]], + terminated=terminated[scope[0]:scope[1]], + truncated=truncated[scope[0]:scope[1]], + infos=infos, + timestep=timestep, + timesteps=timesteps) + + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + for agent in self.agents: + agent.track_data(f"Info / {k}", v.item()) + + # post-interaction + for agent in self.agents: + agent.post_interaction(timestep=timestep, timesteps=timesteps) # reset environments - with torch.no_grad(): - if terminated.any() or truncated.any(): + if terminated.any() or truncated.any(): + with torch.no_grad(): self.states, infos = self.env.reset() - else: - self.states = next_states + else: + self.states = next_states return next_states, rewards, terminated, truncated, infos @@ -215,26 +182,26 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar, file=sys.stdout) self._progress.update(n=1) + # hack to simplify code + if self.num_simultaneous_agents == 1: + self.agents = [self.agents] + # set running mode - if self.num_simultaneous_agents > 1: - for agent in self.agents: - agent.set_running_mode("eval") - else: - self.agents.set_running_mode("eval") + for agent in self.agents: + agent.set_running_mode("eval") # reset env if self.states is None: self.states, infos = self.env.reset() - with torch.no_grad(): - if self.num_simultaneous_agents == 1: - # compute actions - actions = self.agents.act(self.states, timestep=timestep, timesteps=timesteps)[0] + # pre-interaction + for agent in self.agents: + agent.pre_interaction(timestep=timestep, timesteps=timesteps) - else: - # compute actions - actions = torch.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + with torch.no_grad(): + # compute actions + actions = torch.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ + for agent, scope in zip(self.agents, self.agents_scope)]) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -243,50 +210,34 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) if not self.headless: self.env.render() - if self.num_simultaneous_agents == 1: - # write data to TensorBoard - self.agents.record_transition(states=self.states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=timesteps) - - # log environment info - if self.environment_info in infos: - for k, v in infos[self.environment_info].items(): - if isinstance(v, torch.Tensor) and v.numel() == 1: - self.agents.track_data(f"Info / {k}", v.item()) - - else: - # write data to TensorBoard - for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=self.states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=timesteps) - super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps) - - # log environment info - if self.environment_info in infos: - for k, v in infos[self.environment_info].items(): - if isinstance(v, torch.Tensor) and v.numel() == 1: - for agent in self.agents: - agent.track_data(f"Info / {k}", v.item()) - - # reset environments - if terminated.any() or truncated.any(): + # write data to TensorBoard + for agent, scope in zip(self.agents, self.agents_scope): + agent.record_transition(states=self.states[scope[0]:scope[1]], + actions=actions[scope[0]:scope[1]], + rewards=rewards[scope[0]:scope[1]], + next_states=next_states[scope[0]:scope[1]], + terminated=terminated[scope[0]:scope[1]], + truncated=truncated[scope[0]:scope[1]], + infos=infos, + timestep=timestep, + timesteps=timesteps) + + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + for agent in self.agents: + agent.track_data(f"Info / {k}", v.item()) + + # post-interaction + for agent in self.agents: + super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps) + + # reset environments + if terminated.any() or truncated.any(): + with torch.no_grad(): self.states, infos = self.env.reset() - else: - self.states = next_states + else: + self.states = next_states return next_states, rewards, terminated, truncated, infos diff --git a/skrl/utils/model_instantiators/jax/categorical.py b/skrl/utils/model_instantiators/jax/categorical.py index ab37250c..844b4aa2 100644 --- a/skrl/utils/model_instantiators/jax/categorical.py +++ b/skrl/utils/model_instantiators/jax/categorical.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union import textwrap -import gym import gymnasium import flax.linen as nn # noqa @@ -11,10 +10,11 @@ from skrl.models.jax import CategoricalMixin # noqa from skrl.models.jax import Model # noqa from skrl.utils.model_instantiators.jax.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa -def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, +def categorical_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, unnormalized_log_prob: bool = True, network: Sequence[Mapping[str, Any]] = [], @@ -26,10 +26,10 @@ def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Spa :param observation_space: Observation/state space or shape (default: None). If it is not None, the num_observations property will contain the size of that space - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None). If it is not None, the num_actions property will contain the size of that space - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional @@ -84,6 +84,8 @@ def setup(self): {networks} def __call__(self, inputs, role): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, {{}} """ diff --git a/skrl/utils/model_instantiators/jax/common.py b/skrl/utils/model_instantiators/jax/common.py index 1efce173..52d79e2b 100644 --- a/skrl/utils/model_instantiators/jax/common.py +++ b/skrl/utils/model_instantiators/jax/common.py @@ -52,7 +52,7 @@ def visit_Call(self, node: ast.Call): node.func = ast.Attribute(value=ast.Name("jnp"), attr="concatenate") node.keywords = [ast.keyword(arg="axis", value=ast.Constant(value=-1))] # operation: permute - if node.func.id == "permute": + elif node.func.id == "permute": node.func = ast.Attribute(value=ast.Name("jnp"), attr="permute_dims") return node @@ -61,11 +61,11 @@ def visit_Call(self, node: ast.Call): NodeTransformer().visit(tree) source = ast.unparse(tree) # enum substitutions - source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", 'jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)') - source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", 'jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)') - source = source.replace("Shape.STATES", "STATES").replace("STATES", 'inputs["states"]') - source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", 'inputs["states"]') - source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", 'inputs["taken_actions"]') + source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", "jnp.concatenate([states, taken_actions], axis=-1)") + source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", "jnp.concatenate([states, taken_actions], axis=-1)") + source = source.replace("Shape.STATES", "STATES").replace("STATES", "states") + source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", "states") + source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", "taken_actions") return source def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequence[str]], Sequence[str], int]: diff --git a/skrl/utils/model_instantiators/jax/deterministic.py b/skrl/utils/model_instantiators/jax/deterministic.py index effe602b..be7f9bd9 100644 --- a/skrl/utils/model_instantiators/jax/deterministic.py +++ b/skrl/utils/model_instantiators/jax/deterministic.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union import textwrap -import gym import gymnasium import flax.linen as nn # noqa @@ -11,10 +10,11 @@ from skrl.models.jax import DeterministicMixin # noqa from skrl.models.jax import Model # noqa from skrl.utils.model_instantiators.jax.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa -def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, +def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, clip_actions: bool = False, network: Sequence[Mapping[str, Any]] = [], @@ -26,10 +26,10 @@ def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.S :param observation_space: Observation/state space or shape (default: None). If it is not None, the num_observations property will contain the size of that space - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None). If it is not None, the num_actions property will contain the size of that space - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional @@ -81,6 +81,8 @@ def setup(self): {networks} def __call__(self, inputs, role): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, {{}} """ diff --git a/skrl/utils/model_instantiators/jax/gaussian.py b/skrl/utils/model_instantiators/jax/gaussian.py index 2965ac16..865e7ff3 100644 --- a/skrl/utils/model_instantiators/jax/gaussian.py +++ b/skrl/utils/model_instantiators/jax/gaussian.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union import textwrap -import gym import gymnasium import flax.linen as nn # noqa @@ -11,10 +10,11 @@ from skrl.models.jax import GaussianMixin # noqa from skrl.models.jax import Model # noqa from skrl.utils.model_instantiators.jax.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa -def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, +def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, jax.Device]] = None, clip_actions: bool = False, clip_log_std: bool = True, @@ -30,10 +30,10 @@ def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, :param observation_space: Observation/state space or shape (default: None). If it is not None, the num_observations property will contain the size of that space - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None). If it is not None, the num_actions property will contain the size of that space - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or jax.Device, optional @@ -95,6 +95,8 @@ def setup(self): self.log_std_parameter = self.param("log_std_parameter", lambda _: {initial_log_std} * jnp.ones({output["size"]})) def __call__(self, inputs, role): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, self.log_std_parameter, {{}} """ diff --git a/skrl/utils/model_instantiators/torch/categorical.py b/skrl/utils/model_instantiators/torch/categorical.py index 3be54bb0..bb930211 100644 --- a/skrl/utils/model_instantiators/torch/categorical.py +++ b/skrl/utils/model_instantiators/torch/categorical.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union import textwrap -import gym import gymnasium import torch @@ -10,10 +9,11 @@ from skrl.models.torch import CategoricalMixin # noqa from skrl.models.torch import Model from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, +def categorical_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, unnormalized_log_prob: bool = True, network: Sequence[Mapping[str, Any]] = [], @@ -25,10 +25,10 @@ def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Spa :param observation_space: Observation/state space or shape (default: None). If it is not None, the num_observations property will contain the size of that space - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None). If it is not None, the num_actions property will contain the size of that space - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -82,6 +82,8 @@ def __init__(self, observation_space, action_space, device, unnormalized_log_pro {networks} def compute(self, inputs, role=""): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, {{}} """ diff --git a/skrl/utils/model_instantiators/torch/common.py b/skrl/utils/model_instantiators/torch/common.py index 331427b9..cde885d3 100644 --- a/skrl/utils/model_instantiators/torch/common.py +++ b/skrl/utils/model_instantiators/torch/common.py @@ -53,7 +53,7 @@ def visit_Call(self, node: ast.Call): node.func = ast.Attribute(value=ast.Name("torch"), attr="cat") node.keywords = [ast.keyword(arg="dim", value=ast.Constant(value=1))] # operation: permute - if node.func.id == "permute": + elif node.func.id == "permute": node.func = ast.Attribute(value=ast.Name("torch"), attr="permute") return node @@ -62,11 +62,11 @@ def visit_Call(self, node: ast.Call): NodeTransformer().visit(tree) source = ast.unparse(tree) # enum substitutions - source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", 'torch.cat((inputs["states"], inputs["taken_actions"]), dim=1)') - source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", 'torch.cat((inputs["states"], inputs["taken_actions"]), dim=1)') - source = source.replace("Shape.STATES", "STATES").replace("STATES", 'inputs["states"]') - source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", 'inputs["states"]') - source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", 'inputs["taken_actions"]') + source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", "torch.cat([states, taken_actions], dim=1)") + source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", "torch.cat([states, taken_actions], dim=1)") + source = source.replace("Shape.STATES", "STATES").replace("STATES", "states") + source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", "states") + source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", "taken_actions") return source def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequence[str]], Sequence[str], int]: diff --git a/skrl/utils/model_instantiators/torch/deterministic.py b/skrl/utils/model_instantiators/torch/deterministic.py index ca67c9b6..440223cd 100644 --- a/skrl/utils/model_instantiators/torch/deterministic.py +++ b/skrl/utils/model_instantiators/torch/deterministic.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union import textwrap -import gym import gymnasium import torch @@ -10,10 +9,11 @@ from skrl.models.torch import DeterministicMixin # noqa from skrl.models.torch import Model from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, +def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, clip_actions: bool = False, network: Sequence[Mapping[str, Any]] = [], @@ -25,10 +25,10 @@ def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.S :param observation_space: Observation/state space or shape (default: None). If it is not None, the num_observations property will contain the size of that space - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None). If it is not None, the num_actions property will contain the size of that space - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -79,6 +79,8 @@ def __init__(self, observation_space, action_space, device, clip_actions): {networks} def compute(self, inputs, role=""): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, {{}} """ diff --git a/skrl/utils/model_instantiators/torch/gaussian.py b/skrl/utils/model_instantiators/torch/gaussian.py index b3e5cfce..b806cab0 100644 --- a/skrl/utils/model_instantiators/torch/gaussian.py +++ b/skrl/utils/model_instantiators/torch/gaussian.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union import textwrap -import gym import gymnasium import torch @@ -10,10 +9,11 @@ from skrl.models.torch import GaussianMixin # noqa from skrl.models.torch import Model from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, +def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, clip_actions: bool = False, clip_log_std: bool = True, @@ -29,10 +29,10 @@ def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, :param observation_space: Observation/state space or shape (default: None). If it is not None, the num_observations property will contain the size of that space - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None). If it is not None, the num_actions property will contain the size of that space - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -93,6 +93,8 @@ def __init__(self, observation_space, action_space, device, clip_actions, self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]})) def compute(self, inputs, role=""): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, self.log_std_parameter, {{}} """ diff --git a/skrl/utils/model_instantiators/torch/multivariate_gaussian.py b/skrl/utils/model_instantiators/torch/multivariate_gaussian.py index e6d46f00..b7172cc9 100644 --- a/skrl/utils/model_instantiators/torch/multivariate_gaussian.py +++ b/skrl/utils/model_instantiators/torch/multivariate_gaussian.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union import textwrap -import gym import gymnasium import torch @@ -10,10 +9,11 @@ from skrl.models.torch import MultivariateGaussianMixin # noqa from skrl.models.torch import Model from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, +def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, clip_actions: bool = False, clip_log_std: bool = True, @@ -29,10 +29,10 @@ def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int :param observation_space: Observation/state space or shape (default: None). If it is not None, the num_observations property will contain the size of that space - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None). If it is not None, the num_actions property will contain the size of that space - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -93,6 +93,8 @@ def __init__(self, observation_space, action_space, device, clip_actions, self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]})) def compute(self, inputs, role=""): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, self.log_std_parameter, {{}} """ diff --git a/skrl/utils/model_instantiators/torch/shared.py b/skrl/utils/model_instantiators/torch/shared.py index 861eea93..a30b2efa 100644 --- a/skrl/utils/model_instantiators/torch/shared.py +++ b/skrl/utils/model_instantiators/torch/shared.py @@ -1,7 +1,6 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union import textwrap -import gym import gymnasium import torch @@ -10,10 +9,11 @@ from skrl.models.torch import Model # noqa from skrl.models.torch import DeterministicMixin, GaussianMixin # noqa from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, +def shared_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, device: Optional[Union[str, torch.device]] = None, structure: str = "", roles: Sequence[str] = [], @@ -24,10 +24,10 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, g :param observation_space: Observation/state space or shape (default: None). If it is not None, the num_observations property will contain the size of that space - :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional :param action_space: Action space or shape (default: None). If it is not None, the num_actions property will contain the size of that space - :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional + :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional :param device: Device on which a tensor/array is or will be allocated (default: ``None``). If None, the device will be either ``"cuda"`` if available or ``"cpu"`` :type device: str or torch.device, optional @@ -68,6 +68,8 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, g for container in containers_gaussian: networks_common.append(f'self.{container["name"]}_container = {container["sequential"]}') forward_common.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') + forward_common.insert(0, 'taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))') + forward_common.insert(0, 'states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))') # process output networks_gaussian = [] diff --git a/skrl/utils/runner/jax/runner.py b/skrl/utils/runner/jax/runner.py index 73e17ff3..27ce4f7b 100644 --- a/skrl/utils/runner/jax/runner.py +++ b/skrl/utils/runner/jax/runner.py @@ -14,7 +14,7 @@ from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa from skrl.trainers.jax import SequentialTrainer, Trainer from skrl.utils import set_seed -from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model +from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model class Runner: @@ -35,6 +35,7 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, self._class_mapping = { # model "gaussianmixin": gaussian_model, + "categoricalmixin": categorical_model, "deterministicmixin": deterministic_model, "shared": None, # memory diff --git a/skrl/utils/runner/torch/runner.py b/skrl/utils/runner/torch/runner.py index 7f7d32c4..5e4f6b52 100644 --- a/skrl/utils/runner/torch/runner.py +++ b/skrl/utils/runner/torch/runner.py @@ -14,7 +14,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa from skrl.trainers.torch import SequentialTrainer, Trainer from skrl.utils import set_seed -from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model +from skrl.utils.model_instantiators.torch import categorical_model, deterministic_model, gaussian_model, shared_model class Runner: @@ -35,6 +35,7 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, self._class_mapping = { # model "gaussianmixin": gaussian_model, + "categoricalmixin": categorical_model, "deterministicmixin": deterministic_model, "shared": shared_model, # memory diff --git a/skrl/utils/spaces/__init__.py b/skrl/utils/spaces/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skrl/utils/spaces/jax/__init__.py b/skrl/utils/spaces/jax/__init__.py new file mode 100644 index 00000000..9e8de363 --- /dev/null +++ b/skrl/utils/spaces/jax/__init__.py @@ -0,0 +1,9 @@ +from skrl.utils.spaces.jax.spaces import ( + compute_space_size, + convert_gym_space, + flatten_tensorized_space, + sample_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) diff --git a/skrl/utils/spaces/jax/spaces.py b/skrl/utils/spaces/jax/spaces.py new file mode 100644 index 00000000..663d1fb7 --- /dev/null +++ b/skrl/utils/spaces/jax/spaces.py @@ -0,0 +1,315 @@ +from typing import Any, Literal, Optional, Sequence, Union + +import gymnasium +from gymnasium import spaces + +import jax +import jax.numpy as jnp +import numpy as np + +from skrl import config + + +def convert_gym_space(space: "gym.Space", squeeze_batch_dimension: bool = False) -> gymnasium.Space: + """Converts a gym space to a gymnasium space. + + :param space: Gym space to convert to. + :param squeeze_batch_dimension: Whether to remove fundamental spaces' first dimension. + It currently affects ``Box`` space only. + + :raises ValueError: The given space is not supported. + + :return: Converted space. + """ + import gym + + if isinstance(space, gym.spaces.Discrete): + return spaces.Discrete(n=space.n) + elif isinstance(space, gym.spaces.Box): + if squeeze_batch_dimension: + return spaces.Box(low=space.low[0], high=space.high[0], shape=space.shape[1:], dtype=space.dtype) + return spaces.Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype) + elif isinstance(space, gym.spaces.MultiDiscrete): + return spaces.MultiDiscrete(nvec=space.nvec) + elif isinstance(space, gym.spaces.Tuple): + return spaces.Tuple(spaces=tuple(map(convert_gym_space, space.spaces))) + elif isinstance(space, gym.spaces.Dict): + return spaces.Dict(spaces={k: convert_gym_space(v) for k, v in space.spaces.items()}) + raise ValueError(f"Unsupported space ({space})") + +def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, jax.Device]] = None, _jax: bool = True) -> Any: + """Convert the sample/value items of a given gymnasium space to JAX array. + + :param space: Gymnasium space. + :param x: Sample/value of the given space to tensorize to. + :param device: Device on which a tensor/array is or will be allocated (default: ``None``). + This parameter is used when the space value is not a JAX array (e.g.: NumPy array, number). + :param _jax: Whether the converted value should be a JAX array. It only affects NumPy space values. + + :raises ValueError: The given space or the sample/value type is not supported. + + :return: Sample/value space with items converted to tensors. + """ + if x is None: + return None + # fundamental spaces + # Box + if isinstance(space, spaces.Box): + if isinstance(x, jax.Array): + return x.reshape(-1, *space.shape) + elif isinstance(x, np.ndarray): + if _jax: + return jax.device_put(x.reshape(-1, *space.shape), config.jax.parse_device(device)) + return x.reshape(-1, *space.shape) + else: + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # Discrete + elif isinstance(space, spaces.Discrete): + if isinstance(x, jax.Array): + return x.reshape(-1, 1) + elif isinstance(x, np.ndarray): + if _jax: + return jax.device_put(x.reshape(-1, 1), config.jax.parse_device(device)) + return x.reshape(-1, 1) + elif isinstance(x, np.number) or type(x) in [int, float]: + if _jax: + return jnp.array([x], device=device, dtype=jnp.int32).reshape(-1, 1) + return np.array([x], dtype=np.int32).reshape(-1, 1) + else: + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + if isinstance(x, jax.Array): + return x.reshape(-1, *space.shape) + elif isinstance(x, np.ndarray): + if _jax: + return jax.device_put(x.reshape(-1, *space.shape), config.jax.parse_device(device)) + return x.reshape(-1, *space.shape) + elif type(x) in [list, tuple]: + if _jax: + return jnp.array(x, device=device, dtype=jnp.int32).reshape(-1, *space.shape) + return np.array(x, dtype=np.int32).reshape(-1, *space.shape) + else: + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + return {k: tensorize_space(s, x[k], device) for k, s in space.items()} + # Tuple + elif isinstance(space, spaces.Tuple): + return tuple([tensorize_space(s, _x, device) for s, _x in zip(space, x)]) + raise ValueError(f"Unsupported space ({space})") + +def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool = True) -> Any: + """Convert a tensorized space to a gymnasium space with expected sample/value item types. + + :param space: Gymnasium space. + :param x: Tensorized space (Sample/value space where items are tensors). + :param squeeze_batch_dimension: Whether to remove the batch dimension. If True, only the + sample/value with a batch dimension of size 1 will be affected + + :raises ValueError: The given space or the sample/value type is not supported. + + :return: Sample/value space with expected item types. + """ + if x is None: + return None + # fundamental spaces + # Box + if isinstance(space, spaces.Box): + if isinstance(x, jax.Array): + array = np.asarray(jax.device_get(x), dtype=space.dtype) + if squeeze_batch_dimension and array.shape[0] == 1: + return array.reshape(space.shape) + return array.reshape(-1, *space.shape) + elif isinstance(x, np.ndarray): + if squeeze_batch_dimension and x.shape[0] == 1: + return x.reshape(space.shape) + return x.reshape(-1, *space.shape) + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # Discrete + elif isinstance(space, spaces.Discrete): + if isinstance(x, jax.Array): + array = np.asarray(jax.device_get(x), dtype=space.dtype) + if squeeze_batch_dimension and array.shape[0] == 1: + return array.item() + return array.reshape(-1, 1) + elif isinstance(x, np.ndarray): + if squeeze_batch_dimension and x.shape[0] == 1: + return x.item() + return x.reshape(-1, 1) + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + if isinstance(x, jax.Array): + array = np.asarray(jax.device_get(x), dtype=space.dtype) + if squeeze_batch_dimension and array.shape[0] == 1: + return array.reshape(space.nvec.shape) + return array.reshape(-1, *space.nvec.shape) + elif isinstance(x, np.ndarray): + if squeeze_batch_dimension and x.shape[0] == 1: + return x.reshape(space.nvec.shape) + return x.reshape(-1, *space.nvec.shape) + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + return {k: untensorize_space(s, x[k], squeeze_batch_dimension) for k, s in space.items()} + # Tuple + elif isinstance(space, spaces.Tuple): + return tuple([untensorize_space(s, _x, squeeze_batch_dimension) for s, _x in zip(space, x)]) + raise ValueError(f"Unsupported space ({space})") + +def flatten_tensorized_space(x: Any, _jax: bool = True) -> jax.Array: + """Flatten a tensorized space. + + :param x: Tensorized space sample/value. + :param _jax: Whether the space should be handled using JAX operations. + It only affects composite spaces. + + :raises ValueError: The given sample/value type is not supported. + + :return: A tensor. The returned tensor will have shape (batch, space size). + """ + # fundamental spaces + # Box / Discrete / MultiDiscrete + if isinstance(x, (jax.Array, np.ndarray)): + return x.reshape(x.shape[0], -1) if x.ndim > 1 else x.reshape(1, -1) + # composite spaces + # Dict + elif isinstance(x, dict): + if _jax: + return jnp.concatenate([flatten_tensorized_space(x[k]) for k in sorted(x.keys())], axis=-1) + return np.concatenate([flatten_tensorized_space(x[k]) for k in sorted(x.keys())], axis=-1) + # Tuple + elif type(x) in [list, tuple]: + if _jax: + return jnp.concatenate([flatten_tensorized_space(_x) for _x in x], axis=-1) + return np.concatenate([flatten_tensorized_space(_x) for _x in x], axis=-1) + raise ValueError(f"Unsupported sample/value type ({type(x)})") + +def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x: jax.Array) -> Any: + """Unflatten a tensor to create a tensorized space. + + :param space: Gymnasium space. + :param x: A tensor with shape (batch, space size). + + :raises ValueError: The given space is not supported. + + :return: Tensorized space value. + """ + if x is None: + return None + # fundamental spaces + # Box + if isinstance(space, spaces.Box): + return x.reshape(-1, *space.shape) + # Discrete + elif isinstance(space, spaces.Discrete): + return x.reshape(-1, 1) + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + return x.reshape(-1, *space.shape) + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + start = 0 + output = {} + for k in sorted(space.keys()): + end = start + compute_space_size(space[k], occupied_size=True) + output[k] = unflatten_tensorized_space(space[k], x[:, start:end]) + start = end + return output + # Tuple + elif isinstance(space, spaces.Tuple): + start = 0 + output = [] + for s in space: + end = start + compute_space_size(s, occupied_size=True) + output.append(unflatten_tensorized_space(s, x[:, start:end])) + start = end + return output + raise ValueError(f"Unsupported space ({space})") + +def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_size: bool = False) -> int: + """Get the size (number of elements) of a space. + + :param space: Gymnasium space. + :param occupied_size: Whether the number of elements occupied by the space is returned (default: ``False``). + It only affects :py:class:`~gymnasium.spaces.Discrete` (occupied space is 1), + and :py:class:`~gymnasium.spaces.MultiDiscrete` (occupied space is the number of discrete spaces). + + :return: Size of the space (number of elements). + """ + if occupied_size: + # fundamental spaces + # Discrete + if isinstance(space, spaces.Discrete): + return 1 + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + return space.nvec.shape[0] + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + return sum([compute_space_size(s, occupied_size) for s in space.values()]) + # Tuple + elif isinstance(space, spaces.Tuple): + return sum([compute_space_size(s, occupied_size) for s in space]) + # non-gymnasium spaces + if type(space) in [int, float]: + return space + elif type(space) in [tuple, list]: + return int(np.prod(space)) + # gymnasium computation + return gymnasium.spaces.flatdim(space) + +def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "jax"], device = None) -> Any: + """Generates a random sample from the specified space. + + :param space: Gymnasium space. + :param batch_size: Size of the sampled batch (default: ``1``). + :param backend: Whether backend will be used to construct the fundamental spaces (default: ``"numpy"``). + :param device: Device on which a tensor/array is or will be allocated (default: ``None``). + This parameter is used when the backend is ``"jax"``. + + :raises ValueError: The given space or backend is not supported. + + :return: Sample of the space + """ + # fundamental spaces + # Box + if isinstance(space, spaces.Box): + sample = gymnasium.vector.utils.batch_space(space, batch_size).sample() + if backend == "numpy": + return np.array(sample).reshape(batch_size, *space.shape) + elif backend == "jax": + return jnp.array(sample, device=device).reshape(batch_size, *space.shape) + else: + raise ValueError(f"Unsupported backend type ({backend})") + # Discrete + elif isinstance(space, spaces.Discrete): + sample = gymnasium.vector.utils.batch_space(space, batch_size).sample() + if backend == "numpy": + return np.array(sample).reshape(batch_size, -1) + elif backend == "jax": + return jnp.array(sample, device=device).reshape(batch_size, -1) + else: + raise ValueError(f"Unsupported backend type ({backend})") + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + sample = gymnasium.vector.utils.batch_space(space, batch_size).sample() + if backend == "numpy": + return np.array(sample).reshape(batch_size, *space.nvec.shape) + elif backend == "jax": + return jnp.array(sample, device=device).reshape(batch_size, *space.nvec.shape) + else: + raise ValueError(f"Unsupported backend type ({backend})") + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + return {k: sample_space(s, batch_size, backend, device) for k, s in space.items()} + # Tuple + elif isinstance(space, spaces.Tuple): + return tuple([sample_space(s, batch_size, backend, device) for s in space]) + raise ValueError(f"Unsupported space ({space})") diff --git a/skrl/utils/spaces/torch/__init__.py b/skrl/utils/spaces/torch/__init__.py new file mode 100644 index 00000000..62413382 --- /dev/null +++ b/skrl/utils/spaces/torch/__init__.py @@ -0,0 +1,9 @@ +from skrl.utils.spaces.torch.spaces import ( + compute_space_size, + convert_gym_space, + flatten_tensorized_space, + sample_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) diff --git a/skrl/utils/spaces/torch/spaces.py b/skrl/utils/spaces/torch/spaces.py new file mode 100644 index 00000000..579bf33c --- /dev/null +++ b/skrl/utils/spaces/torch/spaces.py @@ -0,0 +1,283 @@ +from typing import Any, Literal, Optional, Sequence, Union + +import gymnasium +from gymnasium import spaces + +import numpy as np +import torch + + +def convert_gym_space(space: "gym.Space", squeeze_batch_dimension: bool = False) -> gymnasium.Space: + """Converts a gym space to a gymnasium space. + + :param space: Gym space to convert to. + :param squeeze_batch_dimension: Whether to remove fundamental spaces' first dimension. + It currently affects ``Box`` space only. + + :raises ValueError: The given space is not supported. + + :return: Converted space. + """ + import gym + + if isinstance(space, gym.spaces.Discrete): + return spaces.Discrete(n=space.n) + elif isinstance(space, gym.spaces.Box): + if squeeze_batch_dimension: + return spaces.Box(low=space.low[0], high=space.high[0], shape=space.shape[1:], dtype=space.dtype) + return spaces.Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype) + elif isinstance(space, gym.spaces.MultiDiscrete): + return spaces.MultiDiscrete(nvec=space.nvec) + elif isinstance(space, gym.spaces.Tuple): + return spaces.Tuple(spaces=tuple(map(convert_gym_space, space.spaces))) + elif isinstance(space, gym.spaces.Dict): + return spaces.Dict(spaces={k: convert_gym_space(v) for k, v in space.spaces.items()}) + raise ValueError(f"Unsupported space ({space})") + +def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, torch.device]] = None) -> Any: + """Convert the sample/value items of a given gymnasium space to PyTorch tensors. + + :param space: Gymnasium space. + :param x: Sample/value of the given space to tensorize to. + :param device: Device on which a tensor/array is or will be allocated (default: ``None``). + This parameter is used when the space value is not a PyTorch tensor (e.g.: NumPy array, number). + + :raises ValueError: The given space or the sample/value type is not supported. + + :return: Sample/value space with items converted to tensors. + """ + if x is None: + return None + # fundamental spaces + # Box + if isinstance(space, spaces.Box): + if isinstance(x, torch.Tensor): + return x.reshape(-1, *space.shape) + elif isinstance(x, np.ndarray): + return torch.tensor(x, device=device, dtype=torch.float32).reshape(-1, *space.shape) + else: + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # Discrete + elif isinstance(space, spaces.Discrete): + if isinstance(x, torch.Tensor): + return x.reshape(-1, 1) + elif isinstance(x, np.ndarray): + return torch.tensor(x, device=device, dtype=torch.int32).reshape(-1, 1) + elif isinstance(x, np.number) or type(x) in [int, float]: + return torch.tensor([x], device=device, dtype=torch.int32).reshape(-1, 1) + else: + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + if isinstance(x, torch.Tensor): + return x.reshape(-1, *space.shape) + elif isinstance(x, np.ndarray): + return torch.tensor(x, device=device, dtype=torch.int32).reshape(-1, *space.shape) + elif type(x) in [list, tuple]: + return torch.tensor([x], device=device, dtype=torch.int32).reshape(-1, *space.shape) + else: + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + return {k: tensorize_space(s, x[k], device) for k, s in space.items()} + # Tuple + elif isinstance(space, spaces.Tuple): + return tuple([tensorize_space(s, _x, device) for s, _x in zip(space, x)]) + raise ValueError(f"Unsupported space ({space})") + +def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool = True) -> Any: + """Convert a tensorized space to a gymnasium space with expected sample/value item types. + + :param space: Gymnasium space. + :param x: Tensorized space (Sample/value space where items are tensors). + :param squeeze_batch_dimension: Whether to remove the batch dimension. If True, only the + sample/value with a batch dimension of size 1 will be affected + + :raises ValueError: The given space or the sample/value type is not supported. + + :return: Sample/value space with expected item types. + """ + if x is None: + return None + # fundamental spaces + # Box + if isinstance(space, spaces.Box): + if isinstance(x, torch.Tensor): + array = np.array(x.cpu().numpy(), dtype=space.dtype) + if squeeze_batch_dimension and array.shape[0] == 1: + return array.reshape(space.shape) + return array.reshape(-1, *space.shape) + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # Discrete + elif isinstance(space, spaces.Discrete): + if isinstance(x, torch.Tensor): + array = np.array(x.cpu().numpy(), dtype=space.dtype) + if squeeze_batch_dimension and array.shape[0] == 1: + return array.item() + return array.reshape(-1, 1) + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + if isinstance(x, torch.Tensor): + array = np.array(x.cpu().numpy(), dtype=space.dtype) + if squeeze_batch_dimension and array.shape[0] == 1: + return array.reshape(space.nvec.shape) + return array.reshape(-1, *space.nvec.shape) + raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})") + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + return {k: untensorize_space(s, x[k], squeeze_batch_dimension) for k, s in space.items()} + # Tuple + elif isinstance(space, spaces.Tuple): + return tuple([untensorize_space(s, _x, squeeze_batch_dimension) for s, _x in zip(space, x)]) + raise ValueError(f"Unsupported space ({space})") + +def flatten_tensorized_space(x: Any) -> torch.Tensor: + """Flatten a tensorized space. + + :param x: Tensorized space sample/value. + + :raises ValueError: The given sample/value type is not supported. + + :return: A tensor. The returned tensor will have shape (batch, space size). + """ + # fundamental spaces + # Box / Discrete / MultiDiscrete + if isinstance(x, torch.Tensor): + return x.reshape(x.shape[0], -1) if x.ndim > 1 else x.reshape(1, -1) + # composite spaces + # Dict + elif isinstance(x, dict): + return torch.cat([flatten_tensorized_space(x[k])for k in sorted(x.keys())], dim=-1) + # Tuple + elif type(x) in [list, tuple]: + return torch.cat([flatten_tensorized_space(_x) for _x in x], dim=-1) + raise ValueError(f"Unsupported sample/value type ({type(x)})") + +def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x: torch.Tensor) -> Any: + """Unflatten a tensor to create a tensorized space. + + :param space: Gymnasium space. + :param x: A tensor with shape (batch, space size). + + :raises ValueError: The given space is not supported. + + :return: Tensorized space value. + """ + if x is None: + return None + # fundamental spaces + # Box + if isinstance(space, spaces.Box): + return x.reshape(-1, *space.shape) + # Discrete + elif isinstance(space, spaces.Discrete): + return x.reshape(-1, 1) + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + return x.reshape(-1, *space.shape) + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + start = 0 + output = {} + for k in sorted(space.keys()): + end = start + compute_space_size(space[k], occupied_size=True) + output[k] = unflatten_tensorized_space(space[k], x[:, start:end]) + start = end + return output + # Tuple + elif isinstance(space, spaces.Tuple): + start = 0 + output = [] + for s in space: + end = start + compute_space_size(s, occupied_size=True) + output.append(unflatten_tensorized_space(s, x[:, start:end])) + start = end + return output + raise ValueError(f"Unsupported space ({space})") + +def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_size: bool = False) -> int: + """Get the size (number of elements) of a space. + + :param space: Gymnasium space. + :param occupied_size: Whether the number of elements occupied by the space is returned (default: ``False``). + It only affects :py:class:`~gymnasium.spaces.Discrete` (occupied space is 1), + and :py:class:`~gymnasium.spaces.MultiDiscrete` (occupied space is the number of discrete spaces). + + :return: Size of the space (number of elements). + """ + if occupied_size: + # fundamental spaces + # Discrete + if isinstance(space, spaces.Discrete): + return 1 + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + return space.nvec.shape[0] + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + return sum([compute_space_size(s, occupied_size) for s in space.values()]) + # Tuple + elif isinstance(space, spaces.Tuple): + return sum([compute_space_size(s, occupied_size) for s in space]) + # non-gymnasium spaces + if type(space) in [int, float]: + return space + elif type(space) in [tuple, list]: + return int(np.prod(space)) + # gymnasium computation + return gymnasium.spaces.flatdim(space) + +def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "torch"], device = None) -> Any: + """Generates a random sample from the specified space. + + :param space: Gymnasium space. + :param batch_size: Size of the sampled batch (default: ``1``). + :param backend: Whether backend will be used to construct the fundamental spaces (default: ``"numpy"``). + :param device: Device on which a tensor/array is or will be allocated (default: ``None``). + This parameter is used when the backend is ``"torch"``. + + :raises ValueError: The given space or backend is not supported. + + :return: Sample of the space + """ + # fundamental spaces + # Box + if isinstance(space, spaces.Box): + sample = gymnasium.vector.utils.batch_space(space, batch_size).sample() + if backend == "numpy": + return np.array(sample).reshape(batch_size, *space.shape) + elif backend == "torch": + return torch.tensor(sample, device=device).reshape(batch_size, *space.shape) + else: + raise ValueError(f"Unsupported backend type ({backend})") + # Discrete + elif isinstance(space, spaces.Discrete): + sample = gymnasium.vector.utils.batch_space(space, batch_size).sample() + if backend == "numpy": + return np.array(sample).reshape(batch_size, -1) + elif backend == "torch": + return torch.tensor(sample, device=device).reshape(batch_size, -1) + else: + raise ValueError(f"Unsupported backend type ({backend})") + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + sample = gymnasium.vector.utils.batch_space(space, batch_size).sample() + if backend == "numpy": + return np.array(sample).reshape(batch_size, *space.nvec.shape) + elif backend == "torch": + return torch.tensor(sample, device=device).reshape(batch_size, *space.nvec.shape) + else: + raise ValueError(f"Unsupported backend type ({backend})") + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + return {k: sample_space(s, batch_size, backend, device) for k, s in space.items()} + # Tuple + elif isinstance(space, spaces.Tuple): + return tuple([sample_space(s, batch_size, backend, device) for s in space]) + raise ValueError(f"Unsupported space ({space})") diff --git a/tests/jax/test_jax_model_instantiators_definition.py b/tests/jax/test_jax_model_instantiators_definition.py index cf44d9a4..f6992394 100644 --- a/tests/jax/test_jax_model_instantiators_definition.py +++ b/tests/jax/test_jax_model_instantiators_definition.py @@ -39,7 +39,7 @@ def test_parse_input(capsys): assert item not in output, f"'{item}' in '{output}'" # Mixed operation input = 'OBSERVATIONS["joint"] + concatenate([net * ACTIONS[:, -3:]])' - statement = 'inputs["states"]["joint"] + jnp.concatenate([net * inputs["taken_actions"][:, -3:]], axis=-1)' + statement = 'states["joint"] + jnp.concatenate([net * taken_actions[:, -3:]], axis=-1)' output = _parse_input(str(input)) assert output.replace("'", '"') == statement, f"'{output}' != '{statement}'" diff --git a/tests/jax/test_jax_utils_spaces.py b/tests/jax/test_jax_utils_spaces.py new file mode 100644 index 00000000..44d5d983 --- /dev/null +++ b/tests/jax/test_jax_utils_spaces.py @@ -0,0 +1,191 @@ +import hypothesis +import hypothesis.strategies as st + +import gym +import gymnasium + +import jax +import jax.numpy as jnp +import numpy as np + +from skrl.utils.spaces.jax import ( + compute_space_size, + convert_gym_space, + flatten_tensorized_space, + sample_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) + +from ..stategies import gym_space_stategy, gymnasium_space_stategy + + +def _check_backend(x, backend): + if backend == "jax": + assert isinstance(x, jax.Array) + elif backend == "numpy": + assert isinstance(x, np.ndarray) + else: + raise ValueError(f"Invalid backend type: {backend}") + +def check_sampled_space(space, x, n, backend): + if isinstance(space, gymnasium.spaces.Box): + _check_backend(x, backend) + assert x.shape == (n, *space.shape) + elif isinstance(space, gymnasium.spaces.Discrete): + _check_backend(x, backend) + assert x.shape == (n, 1) + elif isinstance(space, gymnasium.spaces.MultiDiscrete): + assert x.shape == (n, *space.nvec.shape) + elif isinstance(space, gymnasium.spaces.Dict): + list(map(check_sampled_space, space.values(), x.values(), [n] * len(space), [backend] * len(space))) + elif isinstance(space, gymnasium.spaces.Tuple): + list(map(check_sampled_space, space, x, [n] * len(space), [backend] * len(space))) + else: + raise ValueError(f"Invalid space type: {type(space)}") + + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_compute_space_size(capsys, space: gymnasium.spaces.Space): + def occupied_size(s): + if isinstance(s, gymnasium.spaces.Discrete): + return 1 + elif isinstance(s, gymnasium.spaces.MultiDiscrete): + return s.nvec.shape[0] + elif isinstance(s, gymnasium.spaces.Dict): + return sum([occupied_size(_s) for _s in s.values()]) + elif isinstance(s, gymnasium.spaces.Tuple): + return sum([occupied_size(_s) for _s in s]) + return gymnasium.spaces.flatdim(s) + + space_size = compute_space_size(space, occupied_size=False) + assert space_size == gymnasium.spaces.flatdim(space) + + space_size = compute_space_size(space, occupied_size=True) + assert space_size == occupied_size(space) + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_tensorize_space(capsys, space: gymnasium.spaces.Space): + def check_tensorized_space(s, x, n): + if isinstance(s, gymnasium.spaces.Box): + assert isinstance(x, jax.Array) and x.shape == (n, *s.shape) + elif isinstance(s, gymnasium.spaces.Discrete): + assert isinstance(x, jax.Array) and x.shape == (n, 1) + elif isinstance(s, gymnasium.spaces.MultiDiscrete): + assert isinstance(x, jax.Array) and x.shape == (n, *s.nvec.shape) + elif isinstance(s, gymnasium.spaces.Dict): + list(map(check_tensorized_space, s.values(), x.values(), [n] * len(s))) + elif isinstance(s, gymnasium.spaces.Tuple): + list(map(check_tensorized_space, s, x, [n] * len(s))) + else: + raise ValueError(f"Invalid space type: {type(s)}") + + tensorized_space = tensorize_space(space, space.sample()) + check_tensorized_space(space, tensorized_space, 1) + + tensorized_space = tensorize_space(space, tensorized_space) + check_tensorized_space(space, tensorized_space, 1) + + sampled_space = sample_space(space, 5, backend="numpy") + tensorized_space = tensorize_space(space, sampled_space) + check_tensorized_space(space, tensorized_space, 5) + + sampled_space = sample_space(space, 5, backend="jax") + tensorized_space = tensorize_space(space, sampled_space) + check_tensorized_space(space, tensorized_space, 5) + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_untensorize_space(capsys, space: gymnasium.spaces.Space): + def check_untensorized_space(s, x, squeeze_batch_dimension): + if isinstance(s, gymnasium.spaces.Box): + assert isinstance(x, np.ndarray) + assert x.shape == s.shape if squeeze_batch_dimension else (1, *s.shape) + elif isinstance(s, gymnasium.spaces.Discrete): + assert isinstance(x, (np.ndarray, int)) + assert isinstance(x, int) if squeeze_batch_dimension else x.shape == (1, 1) + elif isinstance(s, gymnasium.spaces.MultiDiscrete): + assert isinstance(x, np.ndarray) and x.shape == s.nvec.shape if squeeze_batch_dimension else (1, *s.nvec.shape) + elif isinstance(s, gymnasium.spaces.Dict): + list(map(check_untensorized_space, s.values(), x.values(), [squeeze_batch_dimension] * len(s))) + elif isinstance(s, gymnasium.spaces.Tuple): + list(map(check_untensorized_space, s, x, [squeeze_batch_dimension] * len(s))) + else: + raise ValueError(f"Invalid space type: {type(s)}") + + tensorized_space = tensorize_space(space, space.sample()) + + untensorized_space = untensorize_space(space, tensorized_space, squeeze_batch_dimension=False) + check_untensorized_space(space, untensorized_space, squeeze_batch_dimension=False) + + untensorized_space = untensorize_space(space, tensorized_space, squeeze_batch_dimension=True) + check_untensorized_space(space, untensorized_space, squeeze_batch_dimension=True) + +@hypothesis.given(space=gymnasium_space_stategy(), batch_size=st.integers(min_value=1, max_value=10)) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_sample_space(capsys, space: gymnasium.spaces.Space, batch_size: int): + + sampled_space = sample_space(space, batch_size, backend="numpy") + check_sampled_space(space, sampled_space, batch_size, backend="numpy") + + sampled_space = sample_space(space, batch_size, backend="jax") + check_sampled_space(space, sampled_space, batch_size, backend="jax") + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_flatten_tensorized_space(capsys, space: gymnasium.spaces.Space): + space_size = compute_space_size(space, occupied_size=True) + + tensorized_space = tensorize_space(space, space.sample()) + flattened_space = flatten_tensorized_space(tensorized_space) + assert flattened_space.shape == (1, space_size) + + tensorized_space = sample_space(space, batch_size=5, backend="jax") + flattened_space = flatten_tensorized_space(tensorized_space) + assert flattened_space.shape == (5, space_size) + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_unflatten_tensorized_space(capsys, space: gymnasium.spaces.Space): + tensorized_space = tensorize_space(space, space.sample()) + flattened_space = flatten_tensorized_space(tensorized_space) + unflattened_space = unflatten_tensorized_space(space, flattened_space) + check_sampled_space(space, unflattened_space, 1, backend="jax") + + tensorized_space = sample_space(space, batch_size=5, backend="jax") + flattened_space = flatten_tensorized_space(tensorized_space) + unflattened_space = unflatten_tensorized_space(space, flattened_space) + check_sampled_space(space, unflattened_space, 5, backend="jax") + +@hypothesis.given(space=gym_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_convert_gym_space(capsys, space: gym.spaces.Space): + def check_converted_space(gym_space, gymnasium_space): + if isinstance(gym_space, gym.spaces.Box): + assert isinstance(gymnasium_space, gymnasium.spaces.Box) + assert np.all(gym_space.low == gymnasium_space.low) + assert np.all(gym_space.high == gymnasium_space.high) + assert gym_space.shape == gymnasium_space.shape + assert gym_space.dtype == gymnasium_space.dtype + elif isinstance(gym_space, gym.spaces.Discrete): + assert isinstance(gymnasium_space, gymnasium.spaces.Discrete) + assert gym_space.n == gymnasium_space.n + elif isinstance(gym_space, gym.spaces.MultiDiscrete): + assert isinstance(gymnasium_space, gymnasium.spaces.MultiDiscrete) + assert np.all(gym_space.nvec) == np.all(gymnasium_space.nvec) + elif isinstance(gym_space, gym.spaces.Tuple): + assert isinstance(gymnasium_space, gymnasium.spaces.Tuple) + assert len(gym_space) == len(gymnasium_space) + list(map(check_converted_space, gym_space, gymnasium_space)) + elif isinstance(gym_space, gym.spaces.Dict): + assert isinstance(gymnasium_space, gymnasium.spaces.Dict) + assert sorted(list(gym_space.keys())) == sorted(list(gymnasium_space.keys())) + for k in gym_space.keys(): + check_converted_space(gym_space[k], gymnasium_space[k]) + else: + raise ValueError(f"Invalid space type: {type(gym_space)}") + + check_converted_space(space, convert_gym_space(space)) diff --git a/tests/jax/test_jax_wrapper_gym.py b/tests/jax/test_jax_wrapper_gym.py index 1206bdf7..7d7042ff 100644 --- a/tests/jax/test_jax_wrapper_gym.py +++ b/tests/jax/test_jax_wrapper_gym.py @@ -2,6 +2,7 @@ from collections.abc import Mapping import gym +import gymnasium import jax import jax.numpy as jnp @@ -28,8 +29,8 @@ def test_env(capsys: pytest.CaptureFixture, backend: str): # check properties assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (3,) - assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) + assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (3,) + assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 assert isinstance(env.device, jax.Device) @@ -70,8 +71,8 @@ def test_vectorized_env(capsys: pytest.CaptureFixture, backend: str, vectorizati # check properties assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (3,) - assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) + assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (3,) + assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 assert isinstance(env.device, jax.Device) diff --git a/tests/jax/test_jax_wrapper_isaacgym.py b/tests/jax/test_jax_wrapper_isaacgym.py index 56efd160..4629e666 100644 --- a/tests/jax/test_jax_wrapper_isaacgym.py +++ b/tests/jax/test_jax_wrapper_isaacgym.py @@ -4,6 +4,7 @@ from collections.abc import Mapping import gym +import gymnasium import jax import jax.numpy as jnp @@ -74,11 +75,11 @@ def test_env(capsys: pytest.CaptureFixture, backend: str, num_states: str): # check properties if num_states: - assert isinstance(env.state_space, gym.Space) and env.state_space.shape == (num_states,) + assert isinstance(env.state_space, gymnasium.Space) and env.state_space.shape == (num_states,) else: assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (4,) - assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) + assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (4,) + assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 assert isinstance(env.device, jax.Device) diff --git a/tests/stategies.py b/tests/stategies.py new file mode 100644 index 00000000..409b9771 --- /dev/null +++ b/tests/stategies.py @@ -0,0 +1,63 @@ +import hypothesis.strategies as st + +import gym +import gymnasium + + +@st.composite +def gymnasium_space_stategy(draw, space_type: str = "", remaining_iterations: int = 5) -> gymnasium.spaces.Space: + if not space_type: + space_type = draw(st.sampled_from(["Box", "Discrete", "MultiDiscrete", "Dict", "Tuple"])) + # recursion base case + if remaining_iterations <= 0 and space_type == "Dict": + space_type = "Box" + + if space_type == "Box": + shape = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5)) + return gymnasium.spaces.Box(low=-1, high=1, shape=shape) + elif space_type == "Discrete": + n = draw(st.integers(min_value=1, max_value=5)) + return gymnasium.spaces.Discrete(n) + elif space_type == "MultiDiscrete": + nvec = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5)) + return gymnasium.spaces.MultiDiscrete(nvec) + elif space_type == "Dict": + remaining_iterations -= 1 + keys = draw(st.lists(st.text(st.characters(codec="ascii"), min_size=1, max_size=5), min_size=1, max_size=3)) + spaces = {key: draw(gymnasium_space_stategy(remaining_iterations=remaining_iterations)) for key in keys} + return gymnasium.spaces.Dict(spaces) + elif space_type == "Tuple": + remaining_iterations -= 1 + spaces = draw(st.lists(gymnasium_space_stategy(remaining_iterations=remaining_iterations), min_size=1, max_size=3)) + return gymnasium.spaces.Tuple(spaces) + else: + raise ValueError(f"Invalid space type: {space_type}") + +@st.composite +def gym_space_stategy(draw, space_type: str = "", remaining_iterations: int = 5) -> gym.spaces.Space: + if not space_type: + space_type = draw(st.sampled_from(["Box", "Discrete", "MultiDiscrete", "Dict", "Tuple"])) + # recursion base case + if remaining_iterations <= 0 and space_type == "Dict": + space_type = "Box" + + if space_type == "Box": + shape = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5)) + return gym.spaces.Box(low=-1, high=1, shape=shape) + elif space_type == "Discrete": + n = draw(st.integers(min_value=1, max_value=5)) + return gym.spaces.Discrete(n) + elif space_type == "MultiDiscrete": + nvec = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5)) + return gym.spaces.MultiDiscrete(nvec) + elif space_type == "Dict": + remaining_iterations -= 1 + keys = draw(st.lists(st.text(st.characters(codec="ascii"), min_size=1, max_size=5), min_size=1, max_size=3)) + spaces = {key: draw(gym_space_stategy(remaining_iterations=remaining_iterations)) for key in keys} + return gym.spaces.Dict(spaces) + elif space_type == "Tuple": + remaining_iterations -= 1 + spaces = draw(st.lists(gym_space_stategy(remaining_iterations=remaining_iterations), min_size=1, max_size=3)) + return gym.spaces.Tuple(spaces) + else: + raise ValueError(f"Invalid space type: {space_type}") diff --git a/tests/torch/test_torch_model_instantiators_definition.py b/tests/torch/test_torch_model_instantiators_definition.py index e2fff0c6..b4c9d52a 100644 --- a/tests/torch/test_torch_model_instantiators_definition.py +++ b/tests/torch/test_torch_model_instantiators_definition.py @@ -48,7 +48,7 @@ def test_parse_input(capsys): assert item not in output, f"'{item}' in '{output}'" # Mixed operation input = 'OBSERVATIONS["joint"] + concatenate([net * ACTIONS[:, -3:]])' - statement = 'inputs["states"]["joint"] + torch.cat([net * inputs["taken_actions"][:, -3:]], dim=1)' + statement = 'states["joint"] + torch.cat([net * taken_actions[:, -3:]], dim=1)' output = _parse_input(str(input)) assert output.replace("'", '"') == statement, f"'{output}' != '{statement}'" diff --git a/tests/torch/test_torch_utils_spaces.py b/tests/torch/test_torch_utils_spaces.py new file mode 100644 index 00000000..aa3a08ac --- /dev/null +++ b/tests/torch/test_torch_utils_spaces.py @@ -0,0 +1,190 @@ +import hypothesis +import hypothesis.strategies as st + +import gym +import gymnasium + +import numpy as np +import torch + +from skrl.utils.spaces.torch import ( + compute_space_size, + convert_gym_space, + flatten_tensorized_space, + sample_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) + +from ..stategies import gym_space_stategy, gymnasium_space_stategy + + +def _check_backend(x, backend): + if backend == "torch": + assert isinstance(x, torch.Tensor) + elif backend == "numpy": + assert isinstance(x, np.ndarray) + else: + raise ValueError(f"Invalid backend type: {backend}") + +def check_sampled_space(space, x, n, backend): + if isinstance(space, gymnasium.spaces.Box): + _check_backend(x, backend) + assert x.shape == (n, *space.shape) + elif isinstance(space, gymnasium.spaces.Discrete): + _check_backend(x, backend) + assert x.shape == (n, 1) + elif isinstance(space, gymnasium.spaces.MultiDiscrete): + assert x.shape == (n, *space.nvec.shape) + elif isinstance(space, gymnasium.spaces.Dict): + list(map(check_sampled_space, space.values(), x.values(), [n] * len(space), [backend] * len(space))) + elif isinstance(space, gymnasium.spaces.Tuple): + list(map(check_sampled_space, space, x, [n] * len(space), [backend] * len(space))) + else: + raise ValueError(f"Invalid space type: {type(space)}") + + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_compute_space_size(capsys, space: gymnasium.spaces.Space): + def occupied_size(s): + if isinstance(s, gymnasium.spaces.Discrete): + return 1 + elif isinstance(s, gymnasium.spaces.MultiDiscrete): + return s.nvec.shape[0] + elif isinstance(s, gymnasium.spaces.Dict): + return sum([occupied_size(_s) for _s in s.values()]) + elif isinstance(s, gymnasium.spaces.Tuple): + return sum([occupied_size(_s) for _s in s]) + return gymnasium.spaces.flatdim(s) + + space_size = compute_space_size(space, occupied_size=False) + assert space_size == gymnasium.spaces.flatdim(space) + + space_size = compute_space_size(space, occupied_size=True) + assert space_size == occupied_size(space) + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_tensorize_space(capsys, space: gymnasium.spaces.Space): + def check_tensorized_space(s, x, n): + if isinstance(s, gymnasium.spaces.Box): + assert isinstance(x, torch.Tensor) and x.shape == (n, *s.shape) + elif isinstance(s, gymnasium.spaces.Discrete): + assert isinstance(x, torch.Tensor) and x.shape == (n, 1) + elif isinstance(s, gymnasium.spaces.MultiDiscrete): + assert isinstance(x, torch.Tensor) and x.shape == (n, *s.nvec.shape) + elif isinstance(s, gymnasium.spaces.Dict): + list(map(check_tensorized_space, s.values(), x.values(), [n] * len(s))) + elif isinstance(s, gymnasium.spaces.Tuple): + list(map(check_tensorized_space, s, x, [n] * len(s))) + else: + raise ValueError(f"Invalid space type: {type(s)}") + + tensorized_space = tensorize_space(space, space.sample()) + check_tensorized_space(space, tensorized_space, 1) + + tensorized_space = tensorize_space(space, tensorized_space) + check_tensorized_space(space, tensorized_space, 1) + + sampled_space = sample_space(space, 5, backend="numpy") + tensorized_space = tensorize_space(space, sampled_space) + check_tensorized_space(space, tensorized_space, 5) + + sampled_space = sample_space(space, 5, backend="torch") + tensorized_space = tensorize_space(space, sampled_space) + check_tensorized_space(space, tensorized_space, 5) + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_untensorize_space(capsys, space: gymnasium.spaces.Space): + def check_untensorized_space(s, x, squeeze_batch_dimension): + if isinstance(s, gymnasium.spaces.Box): + assert isinstance(x, np.ndarray) + assert x.shape == s.shape if squeeze_batch_dimension else (1, *s.shape) + elif isinstance(s, gymnasium.spaces.Discrete): + assert isinstance(x, (np.ndarray, int)) + assert isinstance(x, int) if squeeze_batch_dimension else x.shape == (1, 1) + elif isinstance(s, gymnasium.spaces.MultiDiscrete): + assert isinstance(x, np.ndarray) and x.shape == s.nvec.shape if squeeze_batch_dimension else (1, *s.nvec.shape) + elif isinstance(s, gymnasium.spaces.Dict): + list(map(check_untensorized_space, s.values(), x.values(), [squeeze_batch_dimension] * len(s))) + elif isinstance(s, gymnasium.spaces.Tuple): + list(map(check_untensorized_space, s, x, [squeeze_batch_dimension] * len(s))) + else: + raise ValueError(f"Invalid space type: {type(s)}") + + tensorized_space = tensorize_space(space, space.sample()) + + untensorized_space = untensorize_space(space, tensorized_space, squeeze_batch_dimension=False) + check_untensorized_space(space, untensorized_space, squeeze_batch_dimension=False) + + untensorized_space = untensorize_space(space, tensorized_space, squeeze_batch_dimension=True) + check_untensorized_space(space, untensorized_space, squeeze_batch_dimension=True) + +@hypothesis.given(space=gymnasium_space_stategy(), batch_size=st.integers(min_value=1, max_value=10)) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_sample_space(capsys, space: gymnasium.spaces.Space, batch_size: int): + + sampled_space = sample_space(space, batch_size, backend="numpy") + check_sampled_space(space, sampled_space, batch_size, backend="numpy") + + sampled_space = sample_space(space, batch_size, backend="torch") + check_sampled_space(space, sampled_space, batch_size, backend="torch") + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_flatten_tensorized_space(capsys, space: gymnasium.spaces.Space): + space_size = compute_space_size(space, occupied_size=True) + + tensorized_space = tensorize_space(space, space.sample()) + flattened_space = flatten_tensorized_space(tensorized_space) + assert flattened_space.shape == (1, space_size) + + tensorized_space = sample_space(space, batch_size=5, backend="torch") + flattened_space = flatten_tensorized_space(tensorized_space) + assert flattened_space.shape == (5, space_size) + +@hypothesis.given(space=gymnasium_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_unflatten_tensorized_space(capsys, space: gymnasium.spaces.Space): + tensorized_space = tensorize_space(space, space.sample()) + flattened_space = flatten_tensorized_space(tensorized_space) + unflattened_space = unflatten_tensorized_space(space, flattened_space) + check_sampled_space(space, unflattened_space, 1, backend="torch") + + tensorized_space = sample_space(space, batch_size=5, backend="torch") + flattened_space = flatten_tensorized_space(tensorized_space) + unflattened_space = unflatten_tensorized_space(space, flattened_space) + check_sampled_space(space, unflattened_space, 5, backend="torch") + +@hypothesis.given(space=gym_space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_convert_gym_space(capsys, space: gym.spaces.Space): + def check_converted_space(gym_space, gymnasium_space): + if isinstance(gym_space, gym.spaces.Box): + assert isinstance(gymnasium_space, gymnasium.spaces.Box) + assert np.all(gym_space.low == gymnasium_space.low) + assert np.all(gym_space.high == gymnasium_space.high) + assert gym_space.shape == gymnasium_space.shape + assert gym_space.dtype == gymnasium_space.dtype + elif isinstance(gym_space, gym.spaces.Discrete): + assert isinstance(gymnasium_space, gymnasium.spaces.Discrete) + assert gym_space.n == gymnasium_space.n + elif isinstance(gym_space, gym.spaces.MultiDiscrete): + assert isinstance(gymnasium_space, gymnasium.spaces.MultiDiscrete) + assert np.all(gym_space.nvec) == np.all(gymnasium_space.nvec) + elif isinstance(gym_space, gym.spaces.Tuple): + assert isinstance(gymnasium_space, gymnasium.spaces.Tuple) + assert len(gym_space) == len(gymnasium_space) + list(map(check_converted_space, gym_space, gymnasium_space)) + elif isinstance(gym_space, gym.spaces.Dict): + assert isinstance(gymnasium_space, gymnasium.spaces.Dict) + assert sorted(list(gym_space.keys())) == sorted(list(gymnasium_space.keys())) + for k in gym_space.keys(): + check_converted_space(gym_space[k], gymnasium_space[k]) + else: + raise ValueError(f"Invalid space type: {type(gym_space)}") + + check_converted_space(space, convert_gym_space(space)) diff --git a/tests/torch/test_torch_wrapper_brax.py b/tests/torch/test_torch_wrapper_brax.py index 383f5308..7f1291cc 100644 --- a/tests/torch/test_torch_wrapper_brax.py +++ b/tests/torch/test_torch_wrapper_brax.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Mapping -import gymnasium as gym +import gymnasium import torch @@ -28,8 +28,8 @@ def test_env(capsys: pytest.CaptureFixture): # check properties assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (4,) - assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) + assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (4,) + assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 assert isinstance(env.device, torch.device) @@ -44,7 +44,10 @@ def test_env(capsys: pytest.CaptureFixture): assert isinstance(info, Mapping) for _ in range(3): observation, reward, terminated, truncated, info = env.step(action) - env.render() + try: + env.render() + except AttributeError as e: + warnings.warn(f"Brax exception when rendering: {e}") assert isinstance(observation, torch.Tensor) and observation.shape == torch.Size([num_envs, 4]) assert isinstance(reward, torch.Tensor) and reward.shape == torch.Size([num_envs, 1]) assert isinstance(terminated, torch.Tensor) and terminated.shape == torch.Size([num_envs, 1]) diff --git a/tests/torch/test_torch_wrapper_deepmind.py b/tests/torch/test_torch_wrapper_deepmind.py index 5c7b4f04..57c46ca4 100644 --- a/tests/torch/test_torch_wrapper_deepmind.py +++ b/tests/torch/test_torch_wrapper_deepmind.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Mapping -import gym +import gymnasium as gym import torch diff --git a/tests/torch/test_torch_wrapper_gym.py b/tests/torch/test_torch_wrapper_gym.py index a9e1aefb..cfee7672 100644 --- a/tests/torch/test_torch_wrapper_gym.py +++ b/tests/torch/test_torch_wrapper_gym.py @@ -2,6 +2,7 @@ from collections.abc import Mapping import gym +import gymnasium import torch @@ -21,8 +22,8 @@ def test_env(capsys: pytest.CaptureFixture): # check properties assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (3,) - assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) + assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (3,) + assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 assert isinstance(env.device, torch.device) @@ -59,8 +60,8 @@ def test_vectorized_env(capsys: pytest.CaptureFixture, vectorization_mode: str): # check properties assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (3,) - assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) + assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (3,) + assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 assert isinstance(env.device, torch.device) diff --git a/tests/torch/test_torch_wrapper_isaacgym.py b/tests/torch/test_torch_wrapper_isaacgym.py index 1c35587e..e7d2b367 100644 --- a/tests/torch/test_torch_wrapper_isaacgym.py +++ b/tests/torch/test_torch_wrapper_isaacgym.py @@ -4,6 +4,7 @@ from collections.abc import Mapping import gym +import gymnasium import numpy as np import torch @@ -67,11 +68,11 @@ def test_env(capsys: pytest.CaptureFixture, num_states: int): # check properties if num_states: - assert isinstance(env.state_space, gym.Space) and env.state_space.shape == (num_states,) + assert isinstance(env.state_space, gymnasium.Space) and env.state_space.shape == (num_states,) else: assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (4,) - assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) + assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (4,) + assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 assert isinstance(env.device, torch.device)