Skip to content

Commit

Permalink
[πŸ› πŸ”¨ ] set_action_for_agent expects a ActionTuple with batch size 1. (#…
Browse files Browse the repository at this point in the history
…5208)

* [Bug Fix] set_action_for_agent expects a ActionTuple with batch size 1.

* moving a line around
  • Loading branch information
vincentpierre authored Mar 31, 2021
1 parent 21548e0 commit aac2ee6
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 13 deletions.
14 changes: 3 additions & 11 deletions ml-agents-envs/mlagents_envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,28 +410,20 @@ def random_action(self, n_agents: int) -> ActionTuple:
return ActionTuple(continuous=_continuous, discrete=_discrete)

def _validate_action(
self, actions: ActionTuple, n_agents: Optional[int], name: str
self, actions: ActionTuple, n_agents: int, name: str
) -> ActionTuple:
"""
Validates that action has the correct action dim
for the correct number of agents and ensures the type.
"""
_expected_shape = (
(n_agents, self.continuous_size)
if n_agents is not None
else (self.continuous_size,)
)
_expected_shape = (n_agents, self.continuous_size)
if actions.continuous.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs a continuous input of dimension "
f"{_expected_shape} for (<number of agents>, <action size>) but "
f"received input of dimension {actions.continuous.shape}"
)
_expected_shape = (
(n_agents, self.discrete_size)
if n_agents is not None
else (self.discrete_size,)
)
_expected_shape = (n_agents, self.discrete_size)
if actions.discrete.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs a discrete input of dimension "
Expand Down
4 changes: 2 additions & 2 deletions ml-agents-envs/mlagents_envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,9 @@ def set_action_for_agent(
if behavior_name not in self._env_state:
return
action_spec = self._env_specs[behavior_name].action_spec
num_agents = len(self._env_state[behavior_name][0])
action = action_spec._validate_action(action, None, behavior_name)
action = action_spec._validate_action(action, 1, behavior_name)
if behavior_name not in self._env_actions:
num_agents = len(self._env_state[behavior_name][0])
self._env_actions[behavior_name] = action_spec.empty_action(num_agents)
try:
index = np.where(self._env_state[behavior_name][0].agent_id == agent_id)[0][
Expand Down
55 changes: 55 additions & 0 deletions ml-agents-envs/mlagents_envs/tests/test_set_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from mlagents_envs.registry import default_registry
from mlagents_envs.side_channel.engine_configuration_channel import (
EngineConfigurationChannel,
)
from mlagents_envs.base_env import ActionTuple
import numpy as np

BALL_ID = "3DBall"


def test_set_action_single_agent():
engine_config_channel = EngineConfigurationChannel()
env = default_registry[BALL_ID].make(
base_port=6000,
worker_id=0,
no_graphics=True,
side_channels=[engine_config_channel],
)
engine_config_channel.set_configuration_parameters(time_scale=100)
for _ in range(3):
env.reset()
behavior_name = list(env.behavior_specs.keys())[0]
d, t = env.get_steps(behavior_name)
for _ in range(50):
for agent_id in d.agent_id:
action = np.ones((1, 2))
action_tuple = ActionTuple()
action_tuple.add_continuous(action)
env.set_action_for_agent(behavior_name, agent_id, action_tuple)
env.step()
d, t = env.get_steps(behavior_name)
env.close()


def test_set_action_multi_agent():
engine_config_channel = EngineConfigurationChannel()
env = default_registry[BALL_ID].make(
base_port=6001,
worker_id=0,
no_graphics=True,
side_channels=[engine_config_channel],
)
engine_config_channel.set_configuration_parameters(time_scale=100)
for _ in range(3):
env.reset()
behavior_name = list(env.behavior_specs.keys())[0]
d, t = env.get_steps(behavior_name)
for _ in range(50):
action = np.ones((len(d), 2))
action_tuple = ActionTuple()
action_tuple.add_continuous(action)
env.set_actions(behavior_name, action_tuple)
env.step()
d, t = env.get_steps(behavior_name)
env.close()

0 comments on commit aac2ee6

Please sign in to comment.