diff --git a/gymnasium/__init__.py b/gymnasium/__init__.py index fabd35cd7..41d9d2268 100644 --- a/gymnasium/__init__.py +++ b/gymnasium/__init__.py @@ -17,7 +17,7 @@ pprint_registry, make_vec, ) -from gymnasium import envs, spaces, utils, vector, wrappers, error, logger, experimental +from gymnasium import envs, spaces, utils, vector, wrappers, error, logger __all__ = [ @@ -43,7 +43,6 @@ "wrappers", "error", "logger", - "experimental", ] __version__ = "0.27.1" diff --git a/gymnasium/envs/classic_control/cartpole.py b/gymnasium/envs/classic_control/cartpole.py index dc80c2cdd..5ef5474d9 100644 --- a/gymnasium/envs/classic_control/cartpole.py +++ b/gymnasium/envs/classic_control/cartpole.py @@ -12,6 +12,7 @@ from gymnasium import logger, spaces from gymnasium.envs.classic_control import utils from gymnasium.error import DependencyNotInstalled +from gymnasium.experimental.vector import VectorEnv from gymnasium.vector.utils import batch_space @@ -315,7 +316,7 @@ def close(self): self.isopen = False -class CartPoleVectorEnv(gym.experimental.VectorEnv): +class CartPoleVectorEnv(VectorEnv): metadata = { "render_modes": ["human", "rgb_array"], "render_fps": 50, diff --git a/gymnasium/envs/registration.py b/gymnasium/envs/registration.py index 946d977cb..37c8e06e3 100644 --- a/gymnasium/envs/registration.py +++ b/gymnasium/envs/registration.py @@ -13,8 +13,8 @@ from dataclasses import dataclass, field from typing import Any, Callable, Iterable, Sequence -import gymnasium as gym from gymnasium import Env, Wrapper, error, logger +from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv from gymnasium.wrappers import ( AutoResetWrapper, HumanRendering, @@ -63,7 +63,7 @@ def __call__(self, **kwargs: Any) -> Env: class VectorEnvCreator(Protocol): """Function type expected for an environment.""" - def __call__(self, **kwargs: Any) -> gym.experimental.VectorEnv: + def __call__(self, **kwargs: Any) -> VectorEnv: ... @@ -695,7 +695,7 @@ def make_vec( vector_kwargs: dict[str, Any] | None = None, wrappers: Sequence[Callable[[Env], Wrapper]] | None = None, **kwargs, -) -> gym.experimental.VectorEnv: +) -> VectorEnv: """Create a vector environment according to the given ID. Note: @@ -778,12 +778,12 @@ def _create_env(): return _env if vectorization_mode == "sync": - env = gym.experimental.SyncVectorEnv( + env = SyncVectorEnv( env_fns=[_create_env for _ in range(num_envs)], **vector_kwargs, ) elif vectorization_mode == "async": - env = gym.experimental.AsyncVectorEnv( + env = AsyncVectorEnv( env_fns=[_create_env for _ in range(num_envs)], **vector_kwargs, ) diff --git a/gymnasium/experimental/__init__.py b/gymnasium/experimental/__init__.py index 6d06443d5..4d678685f 100644 --- a/gymnasium/experimental/__init__.py +++ b/gymnasium/experimental/__init__.py @@ -1,7 +1,7 @@ """Root __init__ of the gym experimental wrappers.""" -from gymnasium.experimental import functional, wrappers +from gymnasium.experimental import functional from gymnasium.experimental.functional import FuncEnv from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv @@ -12,12 +12,9 @@ # Functional "FuncEnv", "functional", - # Wrappers - "wrappers", # Vector "VectorEnv", "VectorWrapper", "SyncVectorEnv", "AsyncVectorEnv", - # "vector", ] diff --git a/gymnasium/experimental/functional_jax_env.py b/gymnasium/experimental/functional_jax_env.py index f112da5ed..2d18f7ba9 100644 --- a/gymnasium/experimental/functional_jax_env.py +++ b/gymnasium/experimental/functional_jax_env.py @@ -11,7 +11,7 @@ import gymnasium as gym from gymnasium.envs.registration import EnvSpec from gymnasium.experimental.functional import ActType, FuncEnv, StateType -from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy +from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy from gymnasium.utils import seeding from gymnasium.vector.utils import batch_space diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py index 74d7a0b12..42c2b204c 100644 --- a/gymnasium/experimental/wrappers/__init__.py +++ b/gymnasium/experimental/wrappers/__init__.py @@ -23,9 +23,6 @@ LambdaRewardV0, NormalizeRewardV0, ) -from gymnasium.experimental.wrappers.jax_to_numpy import JaxToNumpyV0 -from gymnasium.experimental.wrappers.jax_to_torch import JaxToTorchV0 -from gymnasium.experimental.wrappers.numpy_to_torch import NumpyToTorchV0 from gymnasium.experimental.wrappers.stateful_action import StickyActionV0 from gymnasium.experimental.wrappers.stateful_observation import ( TimeAwareObservationV0, @@ -85,10 +82,6 @@ "RenderCollectionV0", "RecordVideoV0", "HumanRenderingV0", - # --- Data Conversion --- - "JaxToNumpyV0", - "JaxToTorchV0", - "NumpyToTorchV0", # --- Vector --- "VectorRecordEpisodeStatistics", "VectorListInfo", diff --git a/gymnasium/experimental/wrappers/conversion/__init__.py b/gymnasium/experimental/wrappers/conversion/__init__.py new file mode 100644 index 000000000..bd1b3b471 --- /dev/null +++ b/gymnasium/experimental/wrappers/conversion/__init__.py @@ -0,0 +1 @@ +"""This is deliberately empty to avoid introducing redundant imports -- import each submodule individually.""" diff --git a/gymnasium/experimental/wrappers/jax_to_numpy.py b/gymnasium/experimental/wrappers/conversion/jax_to_numpy.py similarity index 100% rename from gymnasium/experimental/wrappers/jax_to_numpy.py rename to gymnasium/experimental/wrappers/conversion/jax_to_numpy.py diff --git a/gymnasium/experimental/wrappers/jax_to_torch.py b/gymnasium/experimental/wrappers/conversion/jax_to_torch.py similarity index 98% rename from gymnasium/experimental/wrappers/jax_to_torch.py rename to gymnasium/experimental/wrappers/conversion/jax_to_torch.py index 2211c8cde..1cc97deef 100644 --- a/gymnasium/experimental/wrappers/jax_to_torch.py +++ b/gymnasium/experimental/wrappers/conversion/jax_to_torch.py @@ -17,7 +17,7 @@ from gymnasium import Env, Wrapper from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy +from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy try: diff --git a/gymnasium/experimental/wrappers/numpy_to_torch.py b/gymnasium/experimental/wrappers/conversion/numpy_to_torch.py similarity index 100% rename from gymnasium/experimental/wrappers/numpy_to_torch.py rename to gymnasium/experimental/wrappers/conversion/numpy_to_torch.py diff --git a/tests/experimental/wrappers/test_jax_to_numpy.py b/tests/experimental/wrappers/test_jax_to_numpy.py index 25c5ee62a..1c9815385 100644 --- a/tests/experimental/wrappers/test_jax_to_numpy.py +++ b/tests/experimental/wrappers/test_jax_to_numpy.py @@ -4,8 +4,11 @@ import numpy as np import pytest -from gymnasium.experimental.wrappers import JaxToNumpyV0 -from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax +from gymnasium.experimental.wrappers.conversion.jax_to_numpy import ( + JaxToNumpyV0, + jax_to_numpy, + numpy_to_jax, +) from gymnasium.utils.env_checker import data_equivalence from tests.testing_env import GenericTestEnv diff --git a/tests/experimental/wrappers/test_jax_to_torch.py b/tests/experimental/wrappers/test_jax_to_torch.py index a2313ae90..d4899c0fc 100644 --- a/tests/experimental/wrappers/test_jax_to_torch.py +++ b/tests/experimental/wrappers/test_jax_to_torch.py @@ -5,8 +5,11 @@ import pytest import torch -from gymnasium.experimental.wrappers import JaxToTorchV0 -from gymnasium.experimental.wrappers.jax_to_torch import jax_to_torch, torch_to_jax +from gymnasium.experimental.wrappers.conversion.jax_to_torch import ( + JaxToTorchV0, + jax_to_torch, + torch_to_jax, +) from tests.testing_env import GenericTestEnv