Skip to content

Commit 4267ac2

Browse files
Jun GongkouroshHakha
andauthored
[RLlib] Fix reward collection for OpenSpiel games (#31156)
Signed-off-by: Jun Gong <jungong@anyscale.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Co-authored-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
1 parent 80d0bc7 commit 4267ac2

File tree

3 files changed

+172
-3
lines changed

3 files changed

+172
-3
lines changed

rllib/evaluation/env_runner_v2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,20 @@ def _process_observations(
543543
# Create a fake observation by sampling the original env
544544
# observation space.
545545
obs_space = get_original_space(policy.observation_space)
546+
# Although there is no obs for this agent, there may be
547+
# good rewards and info dicts for it.
548+
# This is the case for e.g. OpenSpiel games, where a reward
549+
# is only earned with the last step, but the obs for that
550+
# step is {}.
551+
reward = rewards[env_id].get(agent_id, 0.0)
552+
info = infos[env_id].get(agent_id, {})
546553
values_dict = {
547554
SampleBatch.T: episode.length,
548555
SampleBatch.ENV_ID: env_id,
549556
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
550-
SampleBatch.REWARDS: 0.0,
557+
SampleBatch.REWARDS: reward,
551558
SampleBatch.DONES: True,
552-
SampleBatch.INFOS: {},
559+
SampleBatch.INFOS: info,
553560
SampleBatch.NEXT_OBS: obs_space.sample(),
554561
}
555562

rllib/evaluation/tests/test_env_runner_v2.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import unittest
2+
import numpy as np
23

34
import ray
45
from ray.rllib.algorithms.callbacks import DefaultCallbacks
56
from ray.rllib.algorithms.ppo import PPO, PPOConfig
67
from ray.rllib.connectors.connector import ActionConnector, ConnectorContext
78
from ray.rllib.evaluation.metrics import RolloutMetrics
89
from ray.rllib.examples.env.debug_counter_env import DebugCounterEnv
9-
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
10+
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, GuessTheNumberGame
1011
from ray.rllib.examples.policy.random_policy import RandomPolicy
1112
from ray.rllib.policy.policy import PolicySpec
1213
from ray.tune import register_env
1314
from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch
1415

16+
from ray.rllib.utils.test_utils import check
17+
1518

1619
register_env("basic_multiagent", lambda _: BasicMultiAgent(2))
1720

@@ -92,6 +95,85 @@ def test_sample_batch_rollout_multi_agent_env(self):
9295
self.assertEqual(sample_batch.env_steps(), 200)
9396
self.assertEqual(sample_batch.agent_steps(), 400)
9497

98+
def test_guess_the_number_multi_agent(self):
99+
"""This test will test env runner in the game of GuessTheNumberGame.
100+
101+
The policies are chosen to be deterministic, so that we can test for an
102+
expected reward. Agent 1 will always pick 1, and agent 2 will always guess that
103+
the picked number is higher than 1. The game will end when the picked number is
104+
1, and agent 1 will win. The reward will be 100 for winning, and 1 for each
105+
step that the game is dragged on for. So the expected reward for agent 1 is 100
106+
+ 19 = 119. 19 is the number of steps that the game will last for agent 1
107+
before it wins or loses.
108+
"""
109+
110+
register_env("env_under_test", lambda config: GuessTheNumberGame(config))
111+
112+
def mapping_fn(agent_id, *args, **kwargs):
113+
return "pol1" if agent_id == 0 else "pol2"
114+
115+
class PickOne(RandomPolicy):
116+
"""This policy will always pick 1."""
117+
118+
def compute_actions(
119+
self,
120+
obs_batch,
121+
state_batches=None,
122+
prev_action_batch=None,
123+
prev_reward_batch=None,
124+
**kwargs
125+
):
126+
return [np.array([2, 1])] * len(obs_batch), [], {}
127+
128+
class GuessHigherThanOne(RandomPolicy):
129+
"""This policy will guess that the picked number is higher than 1."""
130+
131+
def compute_actions(
132+
self,
133+
obs_batch,
134+
state_batches=None,
135+
prev_action_batch=None,
136+
prev_reward_batch=None,
137+
**kwargs
138+
):
139+
return [np.array([1, 1])] * len(obs_batch), [], {}
140+
141+
config = (
142+
PPOConfig()
143+
.framework("torch")
144+
.environment(disable_env_checking=True, env="env_under_test")
145+
.rollouts(
146+
num_envs_per_worker=1,
147+
num_rollout_workers=0,
148+
# Enable EnvRunnerV2.
149+
enable_connectors=True,
150+
rollout_fragment_length=100,
151+
)
152+
.multi_agent(
153+
# this makes it independent of neural networks
154+
policies={
155+
"pol1": PolicySpec(policy_class=PickOne),
156+
"pol2": PolicySpec(policy_class=GuessHigherThanOne),
157+
},
158+
policy_mapping_fn=mapping_fn,
159+
)
160+
.debugging(seed=42)
161+
)
162+
163+
algo = PPO(config, env="env_under_test")
164+
165+
rollout_worker = algo.workers.local_worker()
166+
sample_batch = rollout_worker.sample()
167+
pol1_batch = sample_batch.policy_batches["pol1"]
168+
169+
# reward should be 100 (for winning) + 19 (for dragging the game for 19 steps)
170+
check(pol1_batch["rewards"], 119 * np.ones_like(pol1_batch["rewards"]))
171+
# check if pol1 only has one timestep of transition informatio per each episode
172+
check(len(set(pol1_batch["eps_id"])), len(pol1_batch["eps_id"]))
173+
# check if pol2 has 19 timesteps of transition information per each episode
174+
pol2_batch = sample_batch.policy_batches["pol2"]
175+
check(len(set(pol2_batch["eps_id"])) * 19, len(pol2_batch["eps_id"]))
176+
95177
def test_inference_batches_are_grouped_by_policy(self):
96178
# Create 2 policies that have different inference batch shapes.
97179
class RandomPolicyOne(RandomPolicy):

rllib/examples/env/multi_agent.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,86 @@ def step(self, action_dict):
289289
return obs, rew, done, info
290290

291291

292+
class GuessTheNumberGame(MultiAgentEnv):
293+
"""
294+
We have two players, 0 and 1. Agent 0 has to pick a number between 0, MAX-1
295+
at reset. Agent 1 has to guess the number by asking N questions of whether
296+
of the form of "a <number> is higher|lower|equal to the picked number. The
297+
action space is MultiDiscrete [3, MAX]. For the first index 0 means lower,
298+
1 means higher and 2 means equal. The environment answers with yes (1) or
299+
no (0) on the reward function. Every time step that agent 1 wastes agent 0
300+
gets a reward of 1. After N steps the game is terminated. If agent 1
301+
guesses the number correctly, it gets a reward of 100 points, otherwise it
302+
gets a reward of 0. On the other hand if agent 0 wins they win 100 points.
303+
The optimal policy controlling agent 1 should converge to a binary search
304+
strategy.
305+
"""
306+
307+
MAX_NUMBER = 3
308+
MAX_STEPS = 20
309+
310+
def __init__(self, config):
311+
super().__init__()
312+
self._agent_ids = {0, 1}
313+
314+
self.max_number = config.get("max_number", self.MAX_NUMBER)
315+
self.max_steps = config.get("max_steps", self.MAX_STEPS)
316+
317+
self._number = None
318+
self.observation_space = gym.spaces.Discrete(2)
319+
self.action_space = gym.spaces.MultiDiscrete([3, self.max_number])
320+
321+
def reset(self):
322+
self._step = 0
323+
self._number = None
324+
# agent 0 has to pick a number. So the returned obs does not matter.
325+
return {0: 0}
326+
327+
def step(self, action_dict):
328+
# get agent 0's action
329+
agent_0_action = action_dict.get(0)
330+
331+
if agent_0_action is not None:
332+
# ignore the first part of the action and look at the number
333+
self._number = agent_0_action[1]
334+
# next obs should tell agent 1 to start guessing.
335+
# the returned reward and dones should be on agent 0 who picked a
336+
# number.
337+
return {1: 0}, {0: 0}, {0: False, "__all__": False}, {}
338+
339+
if self._number is None:
340+
raise ValueError(
341+
"No number is selected by agent 0. Have you restarted "
342+
"the environment?"
343+
)
344+
345+
# get agent 1's action
346+
direction, number = action_dict.get(1)
347+
info = {}
348+
# always the same, we don't need agent 0 to act ever again, agent 1 should keep
349+
# guessing.
350+
obs = {1: 0}
351+
guessed_correctly = False
352+
# everytime agent 1 does not guess correctly agent 0 gets a reward of 1.
353+
if direction == 0: # lower
354+
reward = {1: int(number > self._number), 0: 1}
355+
done = {1: False, "__all__": False}
356+
elif direction == 1: # higher
357+
reward = {1: int(number < self._number), 0: 1}
358+
done = {1: False, "__all__": False}
359+
else: # equal
360+
guessed_correctly = number == self._number
361+
reward = {1: guessed_correctly * 100, 0: guessed_correctly * -100}
362+
done = {1: guessed_correctly, "__all__": guessed_correctly}
363+
364+
self._step += 1
365+
if self._step >= self.max_steps: # max number of steps episode is over
366+
done["__all__"] = True
367+
if not guessed_correctly:
368+
reward[0] = 100 # agent 0 wins
369+
return obs, reward, done, info
370+
371+
292372
MultiAgentCartPole = make_multi_agent("CartPole-v1")
293373
MultiAgentMountainCar = make_multi_agent("MountainCarContinuous-v0")
294374
MultiAgentPendulum = make_multi_agent("Pendulum-v1")

0 commit comments

Comments
 (0)