From f7f83c2b0d8110b8acd924d5fee5a17454d9722d Mon Sep 17 00:00:00 2001 From: zenghsh3 Date: Wed, 30 Jun 2021 14:48:56 +0800 Subject: [PATCH 1/2] make ActionWrapper support low/high bound is different in different dimensions --- parl/env/continuous_wrappers.py | 6 +++--- parl/env/tests/continuous_wrappers_test.py | 21 +++++++++++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/parl/env/continuous_wrappers.py b/parl/env/continuous_wrappers.py index a4a5944e3..b44a9815e 100644 --- a/parl/env/continuous_wrappers.py +++ b/parl/env/continuous_wrappers.py @@ -24,9 +24,9 @@ def __init__(self, env): """ gym.Wrapper.__init__(self, env) assert isinstance(self.env.action_space, gym.spaces.Box) - self.low_bound = self.env.action_space.low[0] - self.high_bound = self.env.action_space.high[0] - assert self.high_bound > self.low_bound + self.low_bound = self.env.action_space.low + self.high_bound = self.env.action_space.high + assert np.all(self.high_bound >= self.low_bound) if hasattr(env, '_max_episode_steps'): self._max_episode_steps = int(self.env._max_episode_steps) diff --git a/parl/env/tests/continuous_wrappers_test.py b/parl/env/tests/continuous_wrappers_test.py index f7423f200..d1fdeeec7 100644 --- a/parl/env/tests/continuous_wrappers_test.py +++ b/parl/env/tests/continuous_wrappers_test.py @@ -20,7 +20,8 @@ class MockEnv(gym.Env): def __init__(self, low, high): - self.action_space = gym.spaces.Box(low=low, high=high, shape=(3, )) + self.action_space = gym.spaces.Box( + low=np.array(low), high=np.array(high)) self._max_episode_steps = 1000 def step(self, action): @@ -34,21 +35,33 @@ class TestActionMappingWrapper(unittest.TestCase): def test_action_mapping(self): origin_act = np.array([-1.0, 0.0, 1.0]) - env = MockEnv(0.0, 1.0) + env = MockEnv([0.0] * 3, [1.0] * 3) wrapper_env = ActionMappingWrapper(env) wrapper_env.step(origin_act) self.assertListEqual(list(env.action), [0.0, 0.5, 1.0]) - env = MockEnv(-2.0, 2.0) + env = MockEnv([-2.0] * 3, [2.0] * 3) wrapper_env = ActionMappingWrapper(env) wrapper_env.step(origin_act) self.assertListEqual(list(env.action), [-2.0, 0.0, 2.0]) - env = MockEnv(-5.0, 10.0) + env = MockEnv([-5.0] * 3, [10.0] * 3) wrapper_env = ActionMappingWrapper(env) wrapper_env.step(origin_act) self.assertListEqual(list(env.action), [-5.0, 2.5, 10.0]) + # test low bound or high bound is different in different dimensions. + env = MockEnv([0.0, -2.0, -5.0], [1.0, 2.0, 10.0]) + wrapper_env = ActionMappingWrapper(env) + wrapper_env.step(origin_act) + self.assertListEqual(list(env.action), [0.0, 0.0, 10.0]) + + origin_act = np.array([0.0, 0.0, 0.0]) + env = MockEnv([0.0, -2.0, -5.0], [1.0, 2.0, 10.0]) + wrapper_env = ActionMappingWrapper(env) + wrapper_env.step(origin_act) + self.assertListEqual(list(env.action), [0.5, 0.0, 7.5]) + if __name__ == '__main__': unittest.main() From e52cb2e835ebb52e1e319c72c1316e6d872c807e Mon Sep 17 00:00:00 2001 From: zenghsh3 Date: Wed, 30 Jun 2021 15:02:19 +0800 Subject: [PATCH 2/2] fix bug --- parl/env/tests/continuous_wrappers_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parl/env/tests/continuous_wrappers_test.py b/parl/env/tests/continuous_wrappers_test.py index d1fdeeec7..5c09f939b 100644 --- a/parl/env/tests/continuous_wrappers_test.py +++ b/parl/env/tests/continuous_wrappers_test.py @@ -60,7 +60,7 @@ def test_action_mapping(self): env = MockEnv([0.0, -2.0, -5.0], [1.0, 2.0, 10.0]) wrapper_env = ActionMappingWrapper(env) wrapper_env.step(origin_act) - self.assertListEqual(list(env.action), [0.5, 0.0, 7.5]) + self.assertListEqual(list(env.action), [0.5, 0.0, 2.5]) if __name__ == '__main__':