|
1 | 1 | import unittest |
| 2 | +import numpy as np |
2 | 3 |
|
3 | 4 | import ray |
4 | 5 | from ray.rllib.algorithms.callbacks import DefaultCallbacks |
5 | 6 | from ray.rllib.algorithms.ppo import PPO, PPOConfig |
6 | 7 | from ray.rllib.connectors.connector import ActionConnector, ConnectorContext |
7 | 8 | from ray.rllib.evaluation.metrics import RolloutMetrics |
8 | 9 | from ray.rllib.examples.env.debug_counter_env import DebugCounterEnv |
9 | | -from ray.rllib.examples.env.multi_agent import BasicMultiAgent |
| 10 | +from ray.rllib.examples.env.multi_agent import BasicMultiAgent, GuessTheNumberGame |
10 | 11 | from ray.rllib.examples.policy.random_policy import RandomPolicy |
11 | 12 | from ray.rllib.policy.policy import PolicySpec |
12 | 13 | from ray.tune import register_env |
13 | 14 | from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch |
14 | 15 |
|
| 16 | +from ray.rllib.utils.test_utils import check |
| 17 | + |
15 | 18 |
|
16 | 19 | register_env("basic_multiagent", lambda _: BasicMultiAgent(2)) |
17 | 20 |
|
@@ -92,6 +95,85 @@ def test_sample_batch_rollout_multi_agent_env(self): |
92 | 95 | self.assertEqual(sample_batch.env_steps(), 200) |
93 | 96 | self.assertEqual(sample_batch.agent_steps(), 400) |
94 | 97 |
|
| 98 | + def test_guess_the_number_multi_agent(self): |
| 99 | + """This test will test env runner in the game of GuessTheNumberGame. |
| 100 | +
|
| 101 | + The policies are chosen to be deterministic, so that we can test for an |
| 102 | + expected reward. Agent 1 will always pick 1, and agent 2 will always guess that |
| 103 | + the picked number is higher than 1. The game will end when the picked number is |
| 104 | + 1, and agent 1 will win. The reward will be 100 for winning, and 1 for each |
| 105 | + step that the game is dragged on for. So the expected reward for agent 1 is 100 |
| 106 | + + 19 = 119. 19 is the number of steps that the game will last for agent 1 |
| 107 | + before it wins or loses. |
| 108 | + """ |
| 109 | + |
| 110 | + register_env("env_under_test", lambda config: GuessTheNumberGame(config)) |
| 111 | + |
| 112 | + def mapping_fn(agent_id, *args, **kwargs): |
| 113 | + return "pol1" if agent_id == 0 else "pol2" |
| 114 | + |
| 115 | + class PickOne(RandomPolicy): |
| 116 | + """This policy will always pick 1.""" |
| 117 | + |
| 118 | + def compute_actions( |
| 119 | + self, |
| 120 | + obs_batch, |
| 121 | + state_batches=None, |
| 122 | + prev_action_batch=None, |
| 123 | + prev_reward_batch=None, |
| 124 | + **kwargs |
| 125 | + ): |
| 126 | + return [np.array([2, 1])] * len(obs_batch), [], {} |
| 127 | + |
| 128 | + class GuessHigherThanOne(RandomPolicy): |
| 129 | + """This policy will guess that the picked number is higher than 1.""" |
| 130 | + |
| 131 | + def compute_actions( |
| 132 | + self, |
| 133 | + obs_batch, |
| 134 | + state_batches=None, |
| 135 | + prev_action_batch=None, |
| 136 | + prev_reward_batch=None, |
| 137 | + **kwargs |
| 138 | + ): |
| 139 | + return [np.array([1, 1])] * len(obs_batch), [], {} |
| 140 | + |
| 141 | + config = ( |
| 142 | + PPOConfig() |
| 143 | + .framework("torch") |
| 144 | + .environment(disable_env_checking=True, env="env_under_test") |
| 145 | + .rollouts( |
| 146 | + num_envs_per_worker=1, |
| 147 | + num_rollout_workers=0, |
| 148 | + # Enable EnvRunnerV2. |
| 149 | + enable_connectors=True, |
| 150 | + rollout_fragment_length=100, |
| 151 | + ) |
| 152 | + .multi_agent( |
| 153 | + # this makes it independent of neural networks |
| 154 | + policies={ |
| 155 | + "pol1": PolicySpec(policy_class=PickOne), |
| 156 | + "pol2": PolicySpec(policy_class=GuessHigherThanOne), |
| 157 | + }, |
| 158 | + policy_mapping_fn=mapping_fn, |
| 159 | + ) |
| 160 | + .debugging(seed=42) |
| 161 | + ) |
| 162 | + |
| 163 | + algo = PPO(config, env="env_under_test") |
| 164 | + |
| 165 | + rollout_worker = algo.workers.local_worker() |
| 166 | + sample_batch = rollout_worker.sample() |
| 167 | + pol1_batch = sample_batch.policy_batches["pol1"] |
| 168 | + |
| 169 | + # reward should be 100 (for winning) + 19 (for dragging the game for 19 steps) |
| 170 | + check(pol1_batch["rewards"], 119 * np.ones_like(pol1_batch["rewards"])) |
| 171 | + # check if pol1 only has one timestep of transition informatio per each episode |
| 172 | + check(len(set(pol1_batch["eps_id"])), len(pol1_batch["eps_id"])) |
| 173 | + # check if pol2 has 19 timesteps of transition information per each episode |
| 174 | + pol2_batch = sample_batch.policy_batches["pol2"] |
| 175 | + check(len(set(pol2_batch["eps_id"])) * 19, len(pol2_batch["eps_id"])) |
| 176 | + |
95 | 177 | def test_inference_batches_are_grouped_by_policy(self): |
96 | 178 | # Create 2 policies that have different inference batch shapes. |
97 | 179 | class RandomPolicyOne(RandomPolicy): |
|
0 commit comments