-
Notifications
You must be signed in to change notification settings - Fork 7k
[RLlib] Fix access to self._minibatch_size #58595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix access to self._minibatch_size #58595
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes an issue where an incorrect attribute self._minibatch_size was accessed in the IMPALA algorithm's configuration validation. The correct attribute self.minibatch_size is now used. Additionally, a new test is introduced to cover this validation logic, which was previously untested. The fix is correct and the addition of the test is a great improvement. I've added one suggestion to enhance the new test by parameterizing it and adding another validation scenario to make it more comprehensive and maintainable.
| def test_impala_minibatch_size_check(self): | ||
| config = ( | ||
| impala.IMPALAConfig() | ||
| .environment("CartPole-v1") | ||
| .training(minibatch_size=100) | ||
| .env_runners(rollout_fragment_length=30) | ||
| ) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=r"`minibatch_size` \(100\) must either be None or a multiple of `rollout_fragment_length` \(30\)", | ||
| ): | ||
| config.validate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job adding a test for this validation logic! To make it even more robust and easier to maintain, I suggest parameterizing the test using pytest.mark.parametrize. This allows us to easily cover more scenarios. I've included a second test case to also validate the check against total_train_batch_size, which is part of the same validation logic.
@pytest.mark.parametrize(
"minibatch_size, rollout_fragment_length, match",
[
(100, 30, r"`minibatch_size` \(100\) must either be None or a multiple of `rollout_fragment_length` \(30\)"),
(600, 30, r"smaller than or equal to `total_train_batch_size` \(500\)"),
],
)
def test_impala_minibatch_size_check(
self, minibatch_size, rollout_fragment_length, match
):
config = (
impala.IMPALAConfig()
.environment("CartPole-v1")
.training(minibatch_size=minibatch_size)
.env_runners(rollout_fragment_length=rollout_fragment_length)
)
with pytest.raises(ValueError, match=match):
config.validate()## Description In IMPALA, we access an attribute `self._minibatch_size` which does not exist anymore. It should be `self._minibatch_size`. While this check is nice, it's effectively untested code. This PR introduces a test that adds a small test that triggers the relevant code path.
## Description In IMPALA, we access an attribute `self._minibatch_size` which does not exist anymore. It should be `self._minibatch_size`. While this check is nice, it's effectively untested code. This PR introduces a test that adds a small test that triggers the relevant code path. Signed-off-by: Aydin Abiar <aydin@anyscale.com>
## Description In IMPALA, we access an attribute `self._minibatch_size` which does not exist anymore. It should be `self._minibatch_size`. While this check is nice, it's effectively untested code. This PR introduces a test that adds a small test that triggers the relevant code path. Signed-off-by: YK <1811651+ykdojo@users.noreply.github.com>
## Description In IMPALA, we access an attribute `self._minibatch_size` which does not exist anymore. It should be `self._minibatch_size`. While this check is nice, it's effectively untested code. This PR introduces a test that adds a small test that triggers the relevant code path.
Description
In IMPALA, we access an attribute
self._minibatch_sizewhich does not exist anymore.It should be
self._minibatch_size. While this check is nice, it's effectively untested code.This PR introduces a test that adds a small test that triggers the relevant code path.