diff --git a/parl/env/continuous_wrappers.py b/parl/env/continuous_wrappers.py index c61181919..ee8808ad5 100644 --- a/parl/env/continuous_wrappers.py +++ b/parl/env/continuous_wrappers.py @@ -23,15 +23,17 @@ def __init__(self, env): [low_bound, high_bound]. """ gym.Wrapper.__init__(self, env) + assert hasattr( self.env.action_space, 'low'), 'action space should be instance of gym.spaces.Box' assert hasattr( self.env.action_space, 'high'), 'action space should be instance of gym.spaces.Box' - self.low_bound = self.env.action_space.low[0] - self.high_bound = self.env.action_space.high[0] - assert self.high_bound > self.low_bound + self.low_bound = self.env.action_space.low + self.high_bound = self.env.action_space.high + assert np.all(self.high_bound >= self.low_bound) + if hasattr(env, '_max_episode_steps'): self._max_episode_steps = int(self.env._max_episode_steps) diff --git a/parl/env/tests/continuous_wrappers_test.py b/parl/env/tests/continuous_wrappers_test.py index f7423f200..5c09f939b 100644 --- a/parl/env/tests/continuous_wrappers_test.py +++ b/parl/env/tests/continuous_wrappers_test.py @@ -20,7 +20,8 @@ class MockEnv(gym.Env): def __init__(self, low, high): - self.action_space = gym.spaces.Box(low=low, high=high, shape=(3, )) + self.action_space = gym.spaces.Box( + low=np.array(low), high=np.array(high)) self._max_episode_steps = 1000 def step(self, action): @@ -34,21 +35,33 @@ class TestActionMappingWrapper(unittest.TestCase): def test_action_mapping(self): origin_act = np.array([-1.0, 0.0, 1.0]) - env = MockEnv(0.0, 1.0) + env = MockEnv([0.0] * 3, [1.0] * 3) wrapper_env = ActionMappingWrapper(env) wrapper_env.step(origin_act) self.assertListEqual(list(env.action), [0.0, 0.5, 1.0]) - env = MockEnv(-2.0, 2.0) + env = MockEnv([-2.0] * 3, [2.0] * 3) wrapper_env = ActionMappingWrapper(env) wrapper_env.step(origin_act) self.assertListEqual(list(env.action), [-2.0, 0.0, 2.0]) - env = MockEnv(-5.0, 10.0) + env = MockEnv([-5.0] * 3, [10.0] * 3) wrapper_env = ActionMappingWrapper(env) wrapper_env.step(origin_act) self.assertListEqual(list(env.action), [-5.0, 2.5, 10.0]) + # test low bound or high bound is different in different dimensions. + env = MockEnv([0.0, -2.0, -5.0], [1.0, 2.0, 10.0]) + wrapper_env = ActionMappingWrapper(env) + wrapper_env.step(origin_act) + self.assertListEqual(list(env.action), [0.0, 0.0, 10.0]) + + origin_act = np.array([0.0, 0.0, 0.0]) + env = MockEnv([0.0, -2.0, -5.0], [1.0, 2.0, 10.0]) + wrapper_env = ActionMappingWrapper(env) + wrapper_env.step(origin_act) + self.assertListEqual(list(env.action), [0.5, 0.0, 2.5]) + if __name__ == '__main__': unittest.main()