From 7e5ec326c72667f2dc8d6fe73c4c9aeeb7c6207e Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Fri, 28 Jun 2024 12:11:43 +0200 Subject: [PATCH 1/4] Modified sampling for 'batch_mode=complete_episodes' as this was a) not reducing workload when scaled and b) was using 'train_batch_size' neglecting 'train_batch_size_per_learner'. Signed-off-by: simonsays1980 --- rllib/env/single_agent_env_runner.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index f9b20b498291..76a07e39a171 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -197,19 +197,14 @@ def sample( explore=explore, random_actions=random_actions, ) - # For complete episodes mode, sample as long as the number of timesteps - # done is smaller than the `train_batch_size`. + # For complete episodes mode, sample a single episode and + # leave coordination of sampling to `synchronous_parallel_sample`. else: - total = 0 - samples = [] - while total < self.config.train_batch_size: - episodes = self._sample_episodes( - num_episodes=self.num_envs, - explore=explore, - random_actions=random_actions, - ) - total += sum(len(e) for e in episodes) - samples.extend(episodes) + samples = self._sample_episodes( + num_episodes=1, + explore=explore, + random_actions=random_actions, + ) # Make the `on_sample_end` callback. self._callbacks.on_sample_end( From 7a9184bf6f528dac9e786d3a8376004c9436253e Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Fri, 28 Jun 2024 13:31:56 +0200 Subject: [PATCH 2/4] Added a comment in regard to @sven1977's review. Signed-off-by: simonsays1980 --- rllib/env/single_agent_env_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index 76a07e39a171..78e36143ea38 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -199,6 +199,9 @@ def sample( ) # For complete episodes mode, sample a single episode and # leave coordination of sampling to `synchronous_parallel_sample`. + # TODO (simon, sven): The coordination will eventually move + # to `EnvRunnerGroup` in the future. So from the algorithm one + # would do `EnvRunnerGroup.sample()`. else: samples = self._sample_episodes( num_episodes=1, From af85b4ac870eb1491b73539d35d5e354d61d6fa3 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Wed, 3 Jul 2024 14:36:49 +0200 Subject: [PATCH 3/4] Removed an assertion from a callback test that was not passing b/c with 'complete_episodes' sampling happens multiple times until the number of timesteps for the 'train_batch_size' is reached. Signed-off-by: simonsays1980 --- rllib/algorithms/tests/test_callbacks_on_env_runner.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/rllib/algorithms/tests/test_callbacks_on_env_runner.py b/rllib/algorithms/tests/test_callbacks_on_env_runner.py index 898717c6c2b2..8351c22fd07a 100644 --- a/rllib/algorithms/tests/test_callbacks_on_env_runner.py +++ b/rllib/algorithms/tests/test_callbacks_on_env_runner.py @@ -85,7 +85,7 @@ class TestCallbacks(unittest.TestCase): @classmethod def setUpClass(cls): tune.register_env("multi_cart", lambda _: MultiAgentCartPole({"num_agents": 2})) - ray.init() + ray.init(local_mode=True) @classmethod def tearDownClass(cls): @@ -179,9 +179,6 @@ def test_episode_and_sample_callbacks_batch_mode_complete_episodes(self): # Train one iteration. algo.train() - # We must have had exactly one `sample()` call on our EnvRunner. - if not multi_agent: - self.assertEqual(callback_obj.counts["sample"], 1) # We should have had at least one episode start. self.assertGreater(callback_obj.counts["start"], 0) # Episode starts must be exact same as episode ends (b/c we always complete From 8169d7b92a1ed632c965df92d7286befff9c42cf Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Wed, 3 Jul 2024 14:37:31 +0200 Subject: [PATCH 4/4] Remvoed local mode from debugging the test: Signed-off-by: simonsays1980 --- rllib/algorithms/tests/test_callbacks_on_env_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/algorithms/tests/test_callbacks_on_env_runner.py b/rllib/algorithms/tests/test_callbacks_on_env_runner.py index 8351c22fd07a..6afa874509e0 100644 --- a/rllib/algorithms/tests/test_callbacks_on_env_runner.py +++ b/rllib/algorithms/tests/test_callbacks_on_env_runner.py @@ -85,7 +85,7 @@ class TestCallbacks(unittest.TestCase): @classmethod def setUpClass(cls): tune.register_env("multi_cart", lambda _: MultiAgentCartPole({"num_agents": 2})) - ray.init(local_mode=True) + ray.init() @classmethod def tearDownClass(cls):