-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add "shuffle batch per epoch" option. #47458
[RLlib] Add "shuffle batch per epoch" option. #47458
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. For the future we might need to inspect critically if we can reuse code we have written elsewhere and if we can replace a lot of iterating by ray.data
dataset iteration. This reduces code further and places the logic where it belongs, i.e. iterating through data to ray.data
.
@@ -2103,6 +2113,15 @@ def training( | |||
stack, this setting should no longer be used. Instead, use | |||
`train_batch_size_per_learner` (in combination with | |||
`num_learners`). | |||
num_epochs: The number of complete passes over the entire train batch (per |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! For Offline RL we might want to add here that an epoch might loop over the entire dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add!
rllib/algorithms/appo/appo.py
Outdated
@@ -185,7 +188,7 @@ def training( | |||
target_network_update_freq: The frequency to update the target policy and | |||
tune the kl loss coefficients that are used during training. After | |||
setting this parameter, the algorithm waits for at least | |||
`target_network_update_freq * minibatch_size * num_sgd_iter` number of | |||
`target_network_update_freq * minibatch_size * num_epochs` number of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might not completely understand this, but isn't minibatch_size
the size of a minibatch and not necessarily the number of minibatches per epoch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, maybe this means the number of samples that have been trained on until we update the target networks, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wait, great catch. I think this comment here is incorrect.
When we update e.g. PPO with a batch of 4000, the num_env_steps_trained_lifetime counter only(!) gets increased by that 4000, and NOT by: num_epochs * 4000. So for APPO here, this is also wrong. Will fix the comment and clarify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -734,6 +705,9 @@ def training_step(self) -> ResultDict: | |||
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 | |||
), | |||
}, | |||
num_epochs=self.config.num_epochs, | |||
minibatch_size=self.config.minibatch_size, | |||
shuffle_batch_per_epoch=self.config.shuffle_batch_per_epoch, | |||
) | |||
else: | |||
learner_results = self.learner_group.update_from_episodes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder: isn't it possible to just turn over a ray.data.DataIterator
ti the learner via update_from_iterator
and then iterate over the train batch (as a materialized dataset) in minibatch_size batches?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could run all of this (in the new stack) through the PreLearner
to prefetch and make the learner connector run.
rllib/algorithms/marwil/marwil.py
Outdated
@@ -398,7 +398,7 @@ class (multi-/single-learner setup) and evaluation on | |||
learner_results = self.learner_group.update_from_batch( | |||
batch, | |||
minibatch_size=self.config.train_batch_size_per_learner, | |||
num_iters=self.config.dataset_num_iters_per_learner, | |||
num_epochs=self.config.dataset_num_iters_per_learner, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here for example this is a bit confusing: the dataset_num_iters_per_learner
is here irrelevant because we pass over a batch that is trained on as a whole if this is a single learner. In the multi learner setup we pass an iterator and dataset_num_iters_per_learner
defines many batches should be pulled from it in a single RLlib training iteration (this is set to None
by default which would mean it runs over the entire dataset once - so only a single epoch - during a single RLlib training iteration).
I know this is somehow still messy, but due to the different entries of the learner API not really aligned with offline RL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch. Clarified the arg names and added this protection to Learner.update_from_iterator
for now:
if "num_epochs" in kwargs:
raise ValueError(
"`num_epochs` arg NOT supported by Learner.update_from_iterator! Use "
"`num_iters` instead."
)
such that it cannot be confused with num_epochs
passed in by accident.
MiniBatchCyclicIterator, | ||
uses_new_env_runners=True, | ||
num_total_mini_batches=num_total_mini_batches, | ||
MiniBatchCyclicIterator, _uses_new_env_runners=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this here still? I thought we deprecate the hybrid stack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Officially, not yet. PR is still pending ...
self._mini_batch_count = 0 | ||
self._num_total_mini_batches = num_total_mini_batches | ||
self._minibatch_count = 0 | ||
self._num_total_minibatches = num_total_minibatches | ||
|
||
def __iter__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the long run we might want to override DataIterator
from ray.data
to build batches from MultiAgentEpisode
s. Less code.
@@ -140,6 +159,11 @@ def get_len(b): | |||
n_steps -= len_sample | |||
s = 0 | |||
self._num_covered_epochs[module_id] += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually the same logic like independent
sampling mode in our MultiAgentEpisodeBuffer
s. For the future we should reduce code again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, great catch! We'll get rid of this iterator anyways b/c it relies on MultiAgentBatch
. We should rewrite it to rely only on a plain dict.
{"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (56, 55)}, | ||
{"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (400, 400)}, | ||
{"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (64, 64)}, | ||
{"minibatch_size": 256, "num_epochs": 30, "agent_steps": (1652, 1463)}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice test!
…shuffle_batch_option_to_cyclic_iterator
…shuffle_batch_option_to_cyclic_iterator
…shuffle_batch_option_to_cyclic_iterator
…shuffle_batch_option_to_cyclic_iterator
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…shuffle_batch_option_to_cyclic_iterator
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Add "shuffle batch per epoch" option.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.