Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add "shuffle batch per epoch" option. #47458

Merged

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Sep 3, 2024

Add "shuffle batch per epoch" option.

  • For PPO and any other algo using minibatching AND more than 1 epochs AND per train batch.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Copy link
Collaborator

@simonsays1980 simonsays1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. For the future we might need to inspect critically if we can reuse code we have written elsewhere and if we can replace a lot of iterating by ray.data dataset iteration. This reduces code further and places the logic where it belongs, i.e. iterating through data to ray.data.

@@ -2103,6 +2113,15 @@ def training(
stack, this setting should no longer be used. Instead, use
`train_batch_size_per_learner` (in combination with
`num_learners`).
num_epochs: The number of complete passes over the entire train batch (per
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! For Offline RL we might want to add here that an epoch might loop over the entire dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add!

@@ -185,7 +188,7 @@ def training(
target_network_update_freq: The frequency to update the target policy and
tune the kl loss coefficients that are used during training. After
setting this parameter, the algorithm waits for at least
`target_network_update_freq * minibatch_size * num_sgd_iter` number of
`target_network_update_freq * minibatch_size * num_epochs` number of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might not completely understand this, but isn't minibatch_size the size of a minibatch and not necessarily the number of minibatches per epoch?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, maybe this means the number of samples that have been trained on until we update the target networks, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, great catch. I think this comment here is incorrect.
When we update e.g. PPO with a batch of 4000, the num_env_steps_trained_lifetime counter only(!) gets increased by that 4000, and NOT by: num_epochs * 4000. So for APPO here, this is also wrong. Will fix the comment and clarify.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -734,6 +705,9 @@ def training_step(self) -> ResultDict:
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0
),
},
num_epochs=self.config.num_epochs,
minibatch_size=self.config.minibatch_size,
shuffle_batch_per_epoch=self.config.shuffle_batch_per_epoch,
)
else:
learner_results = self.learner_group.update_from_episodes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder: isn't it possible to just turn over a ray.data.DataIterator ti the learner via update_from_iterator and then iterate over the train batch (as a materialized dataset) in minibatch_size batches?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could run all of this (in the new stack) through the PreLearner to prefetch and make the learner connector run.

@@ -398,7 +398,7 @@ class (multi-/single-learner setup) and evaluation on
learner_results = self.learner_group.update_from_batch(
batch,
minibatch_size=self.config.train_batch_size_per_learner,
num_iters=self.config.dataset_num_iters_per_learner,
num_epochs=self.config.dataset_num_iters_per_learner,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here for example this is a bit confusing: the dataset_num_iters_per_learner is here irrelevant because we pass over a batch that is trained on as a whole if this is a single learner. In the multi learner setup we pass an iterator and dataset_num_iters_per_learner defines many batches should be pulled from it in a single RLlib training iteration (this is set to None by default which would mean it runs over the entire dataset once - so only a single epoch - during a single RLlib training iteration).

I know this is somehow still messy, but due to the different entries of the learner API not really aligned with offline RL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch. Clarified the arg names and added this protection to Learner.update_from_iterator for now:

        if "num_epochs" in kwargs:
            raise ValueError(
                "`num_epochs` arg NOT supported by Learner.update_from_iterator! Use "
                "`num_iters` instead."
            )

such that it cannot be confused with num_epochs passed in by accident.

MiniBatchCyclicIterator,
uses_new_env_runners=True,
num_total_mini_batches=num_total_mini_batches,
MiniBatchCyclicIterator, _uses_new_env_runners=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this here still? I thought we deprecate the hybrid stack?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Officially, not yet. PR is still pending ...

self._mini_batch_count = 0
self._num_total_mini_batches = num_total_mini_batches
self._minibatch_count = 0
self._num_total_minibatches = num_total_minibatches

def __iter__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the long run we might want to override DataIterator from ray.data to build batches from MultiAgentEpisodes. Less code.

@@ -140,6 +159,11 @@ def get_len(b):
n_steps -= len_sample
s = 0
self._num_covered_epochs[module_id] += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually the same logic like independent sampling mode in our MultiAgentEpisodeBuffers. For the future we should reduce code again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, great catch! We'll get rid of this iterator anyways b/c it relies on MultiAgentBatch. We should rewrite it to rely only on a plain dict.

{"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (56, 55)},
{"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (400, 400)},
{"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (64, 64)},
{"minibatch_size": 256, "num_epochs": 30, "agent_steps": (1652, 1463)},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test!

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) September 4, 2024 10:51
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Sep 4, 2024
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) September 4, 2024 11:19
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) September 4, 2024 14:18
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 added rllib RLlib related issues rllib-newstack labels Sep 17, 2024
@sven1977 sven1977 enabled auto-merge (squash) September 17, 2024 10:13
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) September 17, 2024 11:17
@sven1977 sven1977 merged commit ed5b382 into ray-project:master Sep 17, 2024
6 checks passed
@sven1977 sven1977 deleted the add_shuffle_batch_option_to_cyclic_iterator branch September 18, 2024 07:57
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests rllib RLlib related issues rllib-newstack
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants