diff --git a/pyproject.toml b/pyproject.toml index 8d4b298..b7cc836 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ requires-python = ">=3.10.0" license = { text = "MIT" } dependencies = [ "Deprecated", - "gymnasium<1.0.0", + "gymnasium", "numpy", "pandas", "pettingzoo", @@ -29,7 +29,7 @@ dependencies = [ [project.optional-dependencies] docs = ["ipykernel", "ipywidgets", "nbdime", "nbsphinx", "sphinx-rtd-theme"] -rllib = ["dm_tree", "pyarrow", "ray[rllib]", "scikit-image", "torch", "typer"] +rllib = ["dm_tree", "pyarrow", "ray[rllib]==2.35.0", "scikit-image", "torch", "typer"] [project.scripts] finish_install = "bsk_rl.finish_install:pck_install" diff --git a/src/bsk_rl/obs/observations.py b/src/bsk_rl/obs/observations.py index 43a4825..324ef3d 100644 --- a/src/bsk_rl/obs/observations.py +++ b/src/bsk_rl/obs/observations.py @@ -37,18 +37,17 @@ def nested_obs_to_space(obs_dict): ) elif isinstance(obs_dict, list): return spaces.Box( - low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float64 + low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float32 ) elif isinstance(obs_dict, (float, int)): - return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64) + return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float32) elif isinstance(obs_dict, np.ndarray): - return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float64) + return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float32) else: raise TypeError(f"Cannot convert {obs_dict} to gym space.") class ObservationBuilder: - def __init__(self, satellite: "Satellite", obs_type: type = np.ndarray) -> None: """Satellite subclass for composing observations. @@ -312,7 +311,6 @@ def _r_LB_H(sat, opp): class OpportunityProperties(Observation): - _fn_map = { "priority": lambda sat, opp: opp["object"].priority, "r_LP_P": lambda sat, opp: opp["r_LP_P"], diff --git a/tests/integration/test_int_gym_env.py b/tests/integration/test_int_gym_env.py index 4d9d50b..8decd41 100644 --- a/tests/integration/test_int_gym_env.py +++ b/tests/integration/test_int_gym_env.py @@ -35,7 +35,9 @@ def test_action_space(self): assert self.env.action_space == spaces.Discrete(1) def test_observation_space(self): - assert self.env.observation_space == spaces.Box(-1e16, 1e16, (1,)) + assert self.env.observation_space == spaces.Box( + -1e16, 1e16, (1,), dtype=np.float32 + ) def test_step(self): observation, reward, terminated, truncated, info = self.env.step(0) @@ -124,7 +126,10 @@ def test_action_space(self): def test_observation_space(self): assert self.env.observation_space == spaces.Tuple( - (spaces.Box(-1e16, 1e16, (1,)), spaces.Box(-1e16, 1e16, (1,))) + ( + spaces.Box(-1e16, 1e16, (1,), dtype=np.float32), + spaces.Box(-1e16, 1e16, (1,), dtype=np.float32), + ) ) def test_step(self): diff --git a/tests/unittest/obs/test_observations.py b/tests/unittest/obs/test_observations.py index a689101..721ee43 100644 --- a/tests/unittest/obs/test_observations.py +++ b/tests/unittest/obs/test_observations.py @@ -69,23 +69,23 @@ def test_obs_cache(self): [ ( np.array([1]), - spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64), + spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float32), ), ( np.array([1, 2]), - spaces.Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float64), + spaces.Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float32), ), ( {"a": 1, "b": {"c": 1}}, spaces.Dict( { "a": spaces.Box( - low=-1e16, high=1e16, shape=(1,), dtype=np.float64 + low=-1e16, high=1e16, shape=(1,), dtype=np.float32 ), "b": spaces.Dict( { "c": spaces.Box( - low=-1e16, high=1e16, shape=(1,), dtype=np.float64 + low=-1e16, high=1e16, shape=(1,), dtype=np.float32 ) } ), diff --git a/tests/unittest/test_gym_env.py b/tests/unittest/test_gym_env.py index 9174470..eb943bb 100644 --- a/tests/unittest/test_gym_env.py +++ b/tests/unittest/test_gym_env.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock, patch +import numpy as np import pytest from gymnasium import spaces @@ -132,7 +133,9 @@ def test_get_obs_retasking_only(self): satellites=[ MagicMock( get_obs=MagicMock(return_value=[i + 1]), - observation_space=spaces.Box(-1e9, 1e9, shape=(1,)), + observation_space=spaces.Box( + -1e9, 1e9, shape=(1,), dtype=np.float32 + ), requires_retasking=(i == 1), ) for i in range(3)