From aac2ee6cb650e6969a6d8b9f7c966f69b9e2df04 Mon Sep 17 00:00:00 2001 From: Vincent-Pierre BERGES Date: Wed, 31 Mar 2021 13:36:27 -0700 Subject: [PATCH] =?UTF-8?q?[=F0=9F=90=9B=20=F0=9F=94=A8=20]=20set=5Faction?= =?UTF-8?q?=5Ffor=5Fagent=20expects=20a=20ActionTuple=20with=20batch=20siz?= =?UTF-8?q?e=201.=20(#5208)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Bug Fix] set_action_for_agent expects a ActionTuple with batch size 1. * moving a line around --- ml-agents-envs/mlagents_envs/base_env.py | 14 +---- ml-agents-envs/mlagents_envs/environment.py | 4 +- .../mlagents_envs/tests/test_set_action.py | 55 +++++++++++++++++++ 3 files changed, 60 insertions(+), 13 deletions(-) create mode 100644 ml-agents-envs/mlagents_envs/tests/test_set_action.py diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index 000cef5709..2302491f7f 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -410,28 +410,20 @@ def random_action(self, n_agents: int) -> ActionTuple: return ActionTuple(continuous=_continuous, discrete=_discrete) def _validate_action( - self, actions: ActionTuple, n_agents: Optional[int], name: str + self, actions: ActionTuple, n_agents: int, name: str ) -> ActionTuple: """ Validates that action has the correct action dim for the correct number of agents and ensures the type. """ - _expected_shape = ( - (n_agents, self.continuous_size) - if n_agents is not None - else (self.continuous_size,) - ) + _expected_shape = (n_agents, self.continuous_size) if actions.continuous.shape != _expected_shape: raise UnityActionException( f"The behavior {name} needs a continuous input of dimension " f"{_expected_shape} for (, ) but " f"received input of dimension {actions.continuous.shape}" ) - _expected_shape = ( - (n_agents, self.discrete_size) - if n_agents is not None - else (self.discrete_size,) - ) + _expected_shape = (n_agents, self.discrete_size) if actions.discrete.shape != _expected_shape: raise UnityActionException( f"The behavior {name} needs a discrete input of dimension " diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 633f298bb4..ef70eac987 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -365,9 +365,9 @@ def set_action_for_agent( if behavior_name not in self._env_state: return action_spec = self._env_specs[behavior_name].action_spec - num_agents = len(self._env_state[behavior_name][0]) - action = action_spec._validate_action(action, None, behavior_name) + action = action_spec._validate_action(action, 1, behavior_name) if behavior_name not in self._env_actions: + num_agents = len(self._env_state[behavior_name][0]) self._env_actions[behavior_name] = action_spec.empty_action(num_agents) try: index = np.where(self._env_state[behavior_name][0].agent_id == agent_id)[0][ diff --git a/ml-agents-envs/mlagents_envs/tests/test_set_action.py b/ml-agents-envs/mlagents_envs/tests/test_set_action.py new file mode 100644 index 0000000000..227e074c72 --- /dev/null +++ b/ml-agents-envs/mlagents_envs/tests/test_set_action.py @@ -0,0 +1,55 @@ +from mlagents_envs.registry import default_registry +from mlagents_envs.side_channel.engine_configuration_channel import ( + EngineConfigurationChannel, +) +from mlagents_envs.base_env import ActionTuple +import numpy as np + +BALL_ID = "3DBall" + + +def test_set_action_single_agent(): + engine_config_channel = EngineConfigurationChannel() + env = default_registry[BALL_ID].make( + base_port=6000, + worker_id=0, + no_graphics=True, + side_channels=[engine_config_channel], + ) + engine_config_channel.set_configuration_parameters(time_scale=100) + for _ in range(3): + env.reset() + behavior_name = list(env.behavior_specs.keys())[0] + d, t = env.get_steps(behavior_name) + for _ in range(50): + for agent_id in d.agent_id: + action = np.ones((1, 2)) + action_tuple = ActionTuple() + action_tuple.add_continuous(action) + env.set_action_for_agent(behavior_name, agent_id, action_tuple) + env.step() + d, t = env.get_steps(behavior_name) + env.close() + + +def test_set_action_multi_agent(): + engine_config_channel = EngineConfigurationChannel() + env = default_registry[BALL_ID].make( + base_port=6001, + worker_id=0, + no_graphics=True, + side_channels=[engine_config_channel], + ) + engine_config_channel.set_configuration_parameters(time_scale=100) + for _ in range(3): + env.reset() + behavior_name = list(env.behavior_specs.keys())[0] + d, t = env.get_steps(behavior_name) + for _ in range(50): + action = np.ones((len(d), 2)) + action_tuple = ActionTuple() + action_tuple.add_continuous(action) + env.set_actions(behavior_name, action_tuple) + env.step() + d, t = env.get_steps(behavior_name) + env.close()