From 719e1ecf34b8dc0550909b66b08a3cd3cf6307b9 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Thu, 9 Oct 2025 14:18:11 +0100 Subject: [PATCH 1/4] [rllib] Fix segment_tree.py edge case Signed-off-by: Mark Towers --- rllib/execution/segment_tree.py | 7 +++- .../test_segment_tree_replay_buffer_api.py | 41 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/rllib/execution/segment_tree.py b/rllib/execution/segment_tree.py index 5e7a5fd102f6..255c99bca82f 100644 --- a/rllib/execution/segment_tree.py +++ b/rllib/execution/segment_tree.py @@ -136,7 +136,7 @@ def __setitem__(self, idx: int, val: float) -> None: Inserts/overwrites a value in/into the tree. Args: - idx: The index to insert to. Must be in [0, `self.capacity`[ + idx: The index to insert to. Must be in [0, `self.capacity`) val: The value to insert. """ assert 0 <= idx < self.capacity, f"idx={idx} capacity={self.capacity}" @@ -192,6 +192,11 @@ def find_prefixsum_idx(self, prefixsum: float) -> int: # Global sum node. idx = 1 + # Edge case when prefixsum can clip into the invalid regions + # https://github.com/ray-project/ray/issues/54284 + if prefixsum >= self.value[idx * 2]: + prefixsum -= 0.0001 + # While non-leaf (first half of tree). while idx < self.capacity: update_idx = 2 * idx diff --git a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py index 17b64bd5b57b..d351e6f1204b 100644 --- a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py +++ b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py @@ -1,7 +1,9 @@ import numpy as np import unittest +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree +from ray.rllib.utils.replay_buffers import PrioritizedEpisodeReplayBuffer class TestSegmentTree(unittest.TestCase): @@ -94,6 +96,45 @@ def test_max_interval_tree(self): assert np.isclose(tree.min(2, -1), 4.0) assert np.isclose(tree.min(3, 4), 3.0) + @staticmethod + def _get_episode(episode_len=None, id_=None, with_extra_model_outs=False): + eps = SingleAgentEpisode(id_=id_, observations=[0.0], infos=[{}]) + ts = np.random.randint(1, 200) if episode_len is None else episode_len + for t in range(ts): + eps.add_env_step( + observation=float(t + 1), + action=int(t), + reward=0.1 * (t + 1), + infos={}, + extra_model_outputs=( + {k: k for k in range(2)} if with_extra_model_outs else None + ), + ) + eps.is_terminated = np.random.random() > 0.5 + eps.is_truncated = False if eps.is_terminated else np.random.random() > 0.8 + return eps + + def test_find_prefixsum_idx(self, buffer_size=80): + """Fix edge case related to https://github.com/ray-project/ray/issues/54284""" + replay_buffer = PrioritizedEpisodeReplayBuffer(capacity=buffer_size) + sum_segment = replay_buffer._sum_segment + + for i in range(10): + replay_buffer.add(self._get_episode(id_=str(i), episode_len=10)) + + assert sum_segment.capacity >= buffer_size + + for sample in np.linspace(0, sum_segment.sum(), 50): + prefixsum_idx = sum_segment.find_prefixsum_idx(sample) + assert prefixsum_idx in replay_buffer._tree_idx_to_sample_idx + + prefixsum_idx = sum_segment.find_prefixsum_idx(sum_segment.sum() - 0.00001) + assert prefixsum_idx in replay_buffer._tree_idx_to_sample_idx + prefixsum_idx = sum_segment.find_prefixsum_idx(sum_segment.sum()) + assert prefixsum_idx in replay_buffer._tree_idx_to_sample_idx + prefixsum_idx = sum_segment.find_prefixsum_idx(sum_segment.sum() + 0.00001) + assert prefixsum_idx in replay_buffer._tree_idx_to_sample_idx + if __name__ == "__main__": import pytest From b2d75de0ec63b21c3e53769644cb3bdeb24069fc Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Thu, 9 Oct 2025 14:43:38 +0100 Subject: [PATCH 2/4] [rllib] Address cursors and gemini's problems Signed-off-by: Mark Towers --- rllib/execution/segment_tree.py | 4 ++-- .../test_segment_tree_replay_buffer_api.py | 23 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/rllib/execution/segment_tree.py b/rllib/execution/segment_tree.py index 255c99bca82f..5316fe34eecf 100644 --- a/rllib/execution/segment_tree.py +++ b/rllib/execution/segment_tree.py @@ -194,8 +194,8 @@ def find_prefixsum_idx(self, prefixsum: float) -> int: # Edge case when prefixsum can clip into the invalid regions # https://github.com/ray-project/ray/issues/54284 - if prefixsum >= self.value[idx * 2]: - prefixsum -= 0.0001 + if prefixsum >= self.value[idx]: + prefixsum = self.value[idx] - 1e-5 # While non-leaf (first half of tree). while idx < self.capacity: diff --git a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py index d351e6f1204b..b75ee98bba58 100644 --- a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py +++ b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py @@ -124,16 +124,23 @@ def test_find_prefixsum_idx(self, buffer_size=80): assert sum_segment.capacity >= buffer_size + # standard cases for sample in np.linspace(0, sum_segment.sum(), 50): prefixsum_idx = sum_segment.find_prefixsum_idx(sample) - assert prefixsum_idx in replay_buffer._tree_idx_to_sample_idx - - prefixsum_idx = sum_segment.find_prefixsum_idx(sum_segment.sum() - 0.00001) - assert prefixsum_idx in replay_buffer._tree_idx_to_sample_idx - prefixsum_idx = sum_segment.find_prefixsum_idx(sum_segment.sum()) - assert prefixsum_idx in replay_buffer._tree_idx_to_sample_idx - prefixsum_idx = sum_segment.find_prefixsum_idx(sum_segment.sum() + 0.00001) - assert prefixsum_idx in replay_buffer._tree_idx_to_sample_idx + assert ( + prefixsum_idx in replay_buffer._tree_idx_to_sample_idx + ), f"{sum_segment.sum()=}, {sample=}, {prefixsum_idx=}" + + # edge cases + for sample in [ + sum_segment.sum() - 0.00001, + sum_segment.sum(), + sum_segment.sum() + 0.00001, + ]: + prefixsum_idx = sum_segment.find_prefixsum_idx(sample) + assert ( + prefixsum_idx in replay_buffer._tree_idx_to_sample_idx + ), f"{sum_segment.sum()=}, {sample=}, {prefixsum_idx=}" if __name__ == "__main__": From 62166e9cd4ed4086d7be738f7e4b4d45e1334722 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 14 Oct 2025 17:03:59 +0100 Subject: [PATCH 3/4] Code review by Simon Signed-off-by: Mark Towers --- .../test_segment_tree_replay_buffer_api.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py index b75ee98bba58..f65fd6d05adc 100644 --- a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py +++ b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py @@ -122,25 +122,28 @@ def test_find_prefixsum_idx(self, buffer_size=80): for i in range(10): replay_buffer.add(self._get_episode(id_=str(i), episode_len=10)) - assert sum_segment.capacity >= buffer_size + self.assertTrue(sum_segment.capacity >= buffer_size) # standard cases for sample in np.linspace(0, sum_segment.sum(), 50): prefixsum_idx = sum_segment.find_prefixsum_idx(sample) - assert ( - prefixsum_idx in replay_buffer._tree_idx_to_sample_idx - ), f"{sum_segment.sum()=}, {sample=}, {prefixsum_idx=}" + self.assertTrue( + prefixsum_idx in replay_buffer._tree_idx_to_sample_idx, + f"{sum_segment.sum()=}, {sample=}, {prefixsum_idx=}", + ) - # edge cases + # Edge cases (at the boundary then the binary tree can "clip" into invalid regions) + # Therefore, testing using values close to or above the max valid number for sample in [ sum_segment.sum() - 0.00001, sum_segment.sum(), sum_segment.sum() + 0.00001, ]: prefixsum_idx = sum_segment.find_prefixsum_idx(sample) - assert ( - prefixsum_idx in replay_buffer._tree_idx_to_sample_idx - ), f"{sum_segment.sum()=}, {sample=}, {prefixsum_idx=}" + self.assertTrue( + prefixsum_idx in replay_buffer._tree_idx_to_sample_idx, + f"{sum_segment.sum()=}, {sample=}, {prefixsum_idx=}", + ) if __name__ == "__main__": From 31f55006e0e1496b7bcd932b779e2dc0686372b0 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 21 Oct 2025 12:51:32 +0100 Subject: [PATCH 4/4] pre-commit Signed-off-by: Mark Towers --- .../replay_buffers/tests/test_segment_tree_replay_buffer_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py index 824604dc07a2..9deb9e7f1387 100644 --- a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py +++ b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py @@ -1,4 +1,5 @@ import unittest +import numpy as np from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree