-
Notifications
You must be signed in to change notification settings - Fork 7k
[rllib] Fix segment_tree.py edge case #57599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] Fix segment_tree.py edge case #57599
Conversation
Signed-off-by: Mark Towers <mark@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an edge case in SegmentTree.find_prefixsum_idx that could lead to returning an invalid index, particularly when the prefixsum value aligns with the sum of a subtree. The fix involves a small adjustment to the prefixsum value. The addition of a unit test to reproduce and verify the fix is a great step. My feedback focuses on making the logical fix more robust and less dependent on a fixed magic number, which could be brittle under certain conditions.
Signed-off-by: Mark Towers <mark@anyscale.com>
simonsays1980
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome PR @pseudo-rnd-thoughts!! Thanks for fixing this hard-to-find one.
| Args: | ||
| idx: The index to insert to. Must be in [0, `self.capacity`[ | ||
| idx: The index to insert to. Must be in [0, `self.capacity`) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
|
|
||
| # Edge case when prefixsum can clip into the invalid regions | ||
| # https://github.com/ray-project/ray/issues/54284 | ||
| if prefixsum >= self.value[idx]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Simple fix for a big problem :)
rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py
Outdated
Show resolved
Hide resolved
| sum_segment.sum() + 0.00001, | ||
| ]: | ||
| prefixsum_idx = sum_segment.find_prefixsum_idx(sample) | ||
| assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, like above. Could we use the TestCase checking methods or the one from us? Or is there a specific reason why you would prefer the assert here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm more familiar with pytest rather than unittest and the rest of the file uses assert rather than self.assertTrue.
I've updated to use TestCase implementation
| ), f"{sum_segment.sum()=}, {sample=}, {prefixsum_idx=}" | ||
|
|
||
| # edge cases | ||
| for sample in [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Could we add a comment of why this case could cause problems on the SumSegment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
Signed-off-by: Mark Towers <mark@anyscale.com>
simonsays1980
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
# Conflicts: # rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py
Signed-off-by: Mark Towers <mark@anyscale.com>
| # Edge case when prefixsum can clip into the invalid regions | ||
| # https://github.com/ray-project/ray/issues/54284 | ||
| if prefixsum >= self.value[idx]: | ||
| prefixsum = self.value[idx] - 1e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Prefix Sum Edge Case Handling Fails
The edge case handling in _sample_prefixsum incorrectly modifies the prefixsum. It clamps valid input values that are slightly above the total sum, and can also result in a negative prefixsum when the total sum is very small or zero. This leads to assertion failures and logically inconsistent states for tree traversal.
## Why are these changes needed? `SegmentTree` is a component of the rllib `PrioritizedEpisodeReplayBuffer` however for extreme edge case prefix sum values then `find_prefixsum_idx` will return invalid out of bounds value. I couldn't find a bug, rather if the `prefixsum_value` is equal to the `SegmentTree.sum()` then traversing down the tree could cause it to return invalid indexes. I've added unittests to reproduce the original error and check against it ## Related issue number Close ray-project#54284 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run pre-commit jobs to lint the changes in this PR. ([pre-commit setup](https://docs.ray.io/en/latest/ray-contribute/getting-involved.html#lint-and-formatting)) - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Mark Towers <mark@anyscale.com> Co-authored-by: Mark Towers <mark@anyscale.com> Co-authored-by: simonsays1980 <simon.zehnder@gmail.com> Signed-off-by: xgui <xgui@anyscale.com>
## Why are these changes needed? `SegmentTree` is a component of the rllib `PrioritizedEpisodeReplayBuffer` however for extreme edge case prefix sum values then `find_prefixsum_idx` will return invalid out of bounds value. I couldn't find a bug, rather if the `prefixsum_value` is equal to the `SegmentTree.sum()` then traversing down the tree could cause it to return invalid indexes. I've added unittests to reproduce the original error and check against it ## Related issue number Close #54284 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run pre-commit jobs to lint the changes in this PR. ([pre-commit setup](https://docs.ray.io/en/latest/ray-contribute/getting-involved.html#lint-and-formatting)) - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Mark Towers <mark@anyscale.com> Co-authored-by: Mark Towers <mark@anyscale.com> Co-authored-by: simonsays1980 <simon.zehnder@gmail.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
## Why are these changes needed? `SegmentTree` is a component of the rllib `PrioritizedEpisodeReplayBuffer` however for extreme edge case prefix sum values then `find_prefixsum_idx` will return invalid out of bounds value. I couldn't find a bug, rather if the `prefixsum_value` is equal to the `SegmentTree.sum()` then traversing down the tree could cause it to return invalid indexes. I've added unittests to reproduce the original error and check against it ## Related issue number Close ray-project#54284 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run pre-commit jobs to lint the changes in this PR. ([pre-commit setup](https://docs.ray.io/en/latest/ray-contribute/getting-involved.html#lint-and-formatting)) - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Mark Towers <mark@anyscale.com> Co-authored-by: Mark Towers <mark@anyscale.com> Co-authored-by: simonsays1980 <simon.zehnder@gmail.com>
## Why are these changes needed? `SegmentTree` is a component of the rllib `PrioritizedEpisodeReplayBuffer` however for extreme edge case prefix sum values then `find_prefixsum_idx` will return invalid out of bounds value. I couldn't find a bug, rather if the `prefixsum_value` is equal to the `SegmentTree.sum()` then traversing down the tree could cause it to return invalid indexes. I've added unittests to reproduce the original error and check against it ## Related issue number Close ray-project#54284 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run pre-commit jobs to lint the changes in this PR. ([pre-commit setup](https://docs.ray.io/en/latest/ray-contribute/getting-involved.html#lint-and-formatting)) - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Mark Towers <mark@anyscale.com> Co-authored-by: Mark Towers <mark@anyscale.com> Co-authored-by: simonsays1980 <simon.zehnder@gmail.com> Signed-off-by: Aydin Abiar <aydin@anyscale.com>
Why are these changes needed?
SegmentTreeis a component of the rllibPrioritizedEpisodeReplayBufferhowever for extreme edge case prefix sum values thenfind_prefixsum_idxwill return invalid out of bounds value.I couldn't find a bug, rather if the
prefixsum_valueis equal to theSegmentTree.sum()then traversing down the tree could cause it to return invalid indexes.I've added unittests to reproduce the original error and check against it
Related issue number
Close #54284
Checks
git commit -s) in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.