-
Notifications
You must be signed in to change notification settings - Fork 7k
[rllib] Fix segment_tree.py edge case #57599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
719e1ec
b2d75de
62166e9
e7eb784
a076f58
3e14aa4
31f5500
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -136,7 +136,7 @@ def __setitem__(self, idx: int, val: float) -> None: | |
| Inserts/overwrites a value in/into the tree. | ||
|
|
||
| Args: | ||
| idx: The index to insert to. Must be in [0, `self.capacity`[ | ||
| idx: The index to insert to. Must be in [0, `self.capacity`) | ||
| val: The value to insert. | ||
| """ | ||
| assert 0 <= idx < self.capacity, f"idx={idx} capacity={self.capacity}" | ||
|
|
@@ -192,6 +192,11 @@ def find_prefixsum_idx(self, prefixsum: float) -> int: | |
| # Global sum node. | ||
| idx = 1 | ||
|
|
||
| # Edge case when prefixsum can clip into the invalid regions | ||
| # https://github.com/ray-project/ray/issues/54284 | ||
| if prefixsum >= self.value[idx]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! Simple fix for a big problem :) |
||
| prefixsum = self.value[idx] - 1e-5 | ||
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Prefix Sum Edge Case Handling FailsThe edge case handling in |
||
|
|
||
| # While non-leaf (first half of tree). | ||
| while idx < self.capacity: | ||
| update_idx = 2 * idx | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,9 @@ | ||
| import unittest | ||
|
|
||
| import numpy as np | ||
|
|
||
| from ray.rllib.execution.segment_tree import MinSegmentTree, SumSegmentTree | ||
| from ray.rllib.env.single_agent_episode import SingleAgentEpisode | ||
| from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree | ||
| from ray.rllib.utils.replay_buffers import PrioritizedEpisodeReplayBuffer | ||
|
|
||
|
|
||
| class TestSegmentTree(unittest.TestCase): | ||
|
|
@@ -95,6 +96,55 @@ def test_max_interval_tree(self): | |
| assert np.isclose(tree.min(2, -1), 4.0) | ||
| assert np.isclose(tree.min(3, 4), 3.0) | ||
|
|
||
| @staticmethod | ||
| def _get_episode(episode_len=None, id_=None, with_extra_model_outs=False): | ||
| eps = SingleAgentEpisode(id_=id_, observations=[0.0], infos=[{}]) | ||
| ts = np.random.randint(1, 200) if episode_len is None else episode_len | ||
| for t in range(ts): | ||
| eps.add_env_step( | ||
| observation=float(t + 1), | ||
| action=int(t), | ||
| reward=0.1 * (t + 1), | ||
| infos={}, | ||
| extra_model_outputs=( | ||
| {k: k for k in range(2)} if with_extra_model_outs else None | ||
| ), | ||
| ) | ||
| eps.is_terminated = np.random.random() > 0.5 | ||
| eps.is_truncated = False if eps.is_terminated else np.random.random() > 0.8 | ||
| return eps | ||
|
|
||
| def test_find_prefixsum_idx(self, buffer_size=80): | ||
| """Fix edge case related to https://github.com/ray-project/ray/issues/54284""" | ||
| replay_buffer = PrioritizedEpisodeReplayBuffer(capacity=buffer_size) | ||
| sum_segment = replay_buffer._sum_segment | ||
|
|
||
| for i in range(10): | ||
| replay_buffer.add(self._get_episode(id_=str(i), episode_len=10)) | ||
|
|
||
| self.assertTrue(sum_segment.capacity >= buffer_size) | ||
|
|
||
| # standard cases | ||
| for sample in np.linspace(0, sum_segment.sum(), 50): | ||
| prefixsum_idx = sum_segment.find_prefixsum_idx(sample) | ||
| self.assertTrue( | ||
| prefixsum_idx in replay_buffer._tree_idx_to_sample_idx, | ||
| f"{sum_segment.sum()=}, {sample=}, {prefixsum_idx=}", | ||
| ) | ||
|
|
||
| # Edge cases (at the boundary then the binary tree can "clip" into invalid regions) | ||
| # Therefore, testing using values close to or above the max valid number | ||
| for sample in [ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice! Could we add a comment of why this case could cause problems on the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| sum_segment.sum() - 0.00001, | ||
| sum_segment.sum(), | ||
| sum_segment.sum() + 0.00001, | ||
| ]: | ||
| prefixsum_idx = sum_segment.find_prefixsum_idx(sample) | ||
| self.assertTrue( | ||
| prefixsum_idx in replay_buffer._tree_idx_to_sample_idx, | ||
| f"{sum_segment.sum()=}, {sample=}, {prefixsum_idx=}", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import sys | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!