Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower committed Nov 15, 2023
1 parent 48fbb7d commit a5c7dbc
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def add_agent(self, type):
self.rewards[agent] = 0
self._cumulative_rewards[agent] = 0
num_actions = self._act_spaces[type].n
self.infos[agent] = {"action_mask": np.eye(num_actions)[self.np_random.choice(num_actions)].astype(np.int8)}
self.infos[agent] = {
"action_mask": np.eye(num_actions)[
self.np_random.choice(num_actions)
].astype(np.int8)
}
return agent

def reset(self, seed=None, options=None):
Expand Down Expand Up @@ -152,7 +156,11 @@ def step(self, action):
# Sample info action mask randomly
type = self.agent_selection.split("_")[0]
num_actions = self._act_spaces[type].n
self.infos[self.agent_selection] = {"action_mask": np.eye(num_actions)[self.np_random.choice(num_actions)].astype(np.int8)}
self.infos[self.agent_selection] = {
"action_mask": np.eye(num_actions)[
self.np_random.choice(num_actions)
].astype(np.int8)
}

# Cycle agents
self.agent_selection = self._agent_selector.next()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def add_type(self):
obs_space = gymnasium.spaces.Dict(
{
"observation": gymnasium.spaces.Box(low=0, high=1, shape=(obs_size,)),
"action_mask": gymnasium.spaces.Box(low=0, high=1, shape=(num_actions,), dtype=np.int8),
"action_mask": gymnasium.spaces.Box(
low=0, high=1, shape=(num_actions,), dtype=np.int8
),
}
)
act_space = gymnasium.spaces.Discrete(num_actions)
Expand Down
79 changes: 43 additions & 36 deletions test/action_mask_test.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,43 @@
from typing import Type

import pytest
from pettingzoo.utils.env import AECEnv

from pettingzoo.test import seed_test, api_test

from pettingzoo.test.example_envs import generated_agents_env_action_mask_obs_v0, generated_agents_env_action_mask_info_v0

@pytest.mark.parametrize("env_constructor", [generated_agents_env_action_mask_info_v0.env, generated_agents_env_action_mask_obs_v0.env])
def test_action_mask(env_constructor: Type[AECEnv]):
"""Test that environments function deterministically in cases where action mask is in observation, or in info."""
seed_test(env_constructor)
api_test(env_constructor())

# Step through the environment according to example code given in AEC documentation (following action mask)
env = env_constructor()
env.reset(seed=42)
for agent in env.agent_iter():
observation, reward, termination, truncation, info = env.last()

if termination or truncation:
action = None
else:
# invalid action masking is optional and environment-dependent
if "action_mask" in info:
mask = info["action_mask"]
elif isinstance(observation, dict) and "action_mask" in observation:
mask = observation["action_mask"]
else:
mask = None
action = env.action_space(agent).sample(mask)
env.step(action)
env.close()


from typing import Type

import pytest

from pettingzoo.test import api_test, seed_test
from pettingzoo.test.example_envs import (
generated_agents_env_action_mask_info_v0,
generated_agents_env_action_mask_obs_v0,
)
from pettingzoo.utils.env import AECEnv


@pytest.mark.parametrize(
"env_constructor",
[
generated_agents_env_action_mask_info_v0.env,
generated_agents_env_action_mask_obs_v0.env,
],
)
def test_action_mask(env_constructor: Type[AECEnv]):
"""Test that environments function deterministically in cases where action mask is in observation, or in info."""
seed_test(env_constructor)
api_test(env_constructor())

# Step through the environment according to example code given in AEC documentation (following action mask)
env = env_constructor()
env.reset(seed=42)
for agent in env.agent_iter():
observation, reward, termination, truncation, info = env.last()

if termination or truncation:
action = None
else:
# invalid action masking is optional and environment-dependent
if "action_mask" in info:
mask = info["action_mask"]
elif isinstance(observation, dict) and "action_mask" in observation:
mask = observation["action_mask"]
else:
mask = None
action = env.action_space(agent).sample(mask)
env.step(action)
env.close()

0 comments on commit a5c7dbc

Please sign in to comment.