-
Notifications
You must be signed in to change notification settings - Fork 7k
[rllib] Add support for complex observations in SingleAgentEpisode
#57017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] Add support for complex observations in SingleAgentEpisode
#57017
Conversation
Signed-off-by: Mark Towers <mark@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to add support for complex observation structures in SingleAgentEpisode.concat. The approach of using tree.flatten is a good direction, but the current implementation of the equality check is flawed and will raise a ValueError for nested observations containing numpy arrays. I've provided a critical fix for this issue. Additionally, the new test case to verify this functionality is not as robust as it could be, as it compares an object to itself rather than checking for value equality. I've added a comment with a suggestion to improve the test's reliability.
| assert len(episode_1) == 4 | ||
|
|
||
| # cut episode 1 to create episode 2 | ||
| episode_2 = episode_1.cut() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test for concatenating episodes with complex observations is not as robust as it could be. By creating episode_2 using episode_1.cut(), the overlapping observation (episode_2.observations[0]) is a reference to episode_1.observations[-1], not a deep copy. This means the assertion in concat_episode is testing for object identity rather than value equality.
To make this test stronger and ensure it correctly validates the logic for complex observation structures, consider constructing episode_2 in a way that it holds a deep copy of the overlapping observation. This will properly test the value comparison logic.
Signed-off-by: Mark Towers <mark@anyscale.com>
simonsays1980
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for these improvements and the test @pseudo-rnd-thoughts
|
|
||
| # Make sure, end matches other episode chunk's beginning. | ||
| assert np.all(other.observations[0] == self.observations[-1]) | ||
| tree.assert_same_structure(other.observations[0], self.observations[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sweet!
| tree.assert_same_structure(other.observations[0], self.observations[-1]) | ||
| # Use tree.map_structure with np.array_equal to check every leaf node are equivalent | ||
| # then np.all on flatten to validate all are tree | ||
| assert np.all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
| # id(episode_1.observations[105])) | ||
|
|
||
| def test_concat_episode_with_complex_obs(self): | ||
| """Tests if concatenation of two `SingleAgentEpisode`s works with complex observations (e.g. dict).""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! The Dict space cases are not fully covered, yet. Thanks for the initiative!
…ay-project#57017) ## Why are these changes needed? `SingleAgentEpisode.concat` would only support numpy array based observations due to `np.all(old_episode.observations[-1] == new_episode.observations[0])`. I've changed the implementation to use `tree.assert_same_structure` and `np.all` on the flatten structures to verify that observations are equivalent even for complex observation structures. In addition, I've added a test using a dict obs environment to verify this works. ## Related issue number Closes ray-project#54659 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [x] This PR is not tested :( <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Use structure-aware equality for observations during episode concatenation and add a test with dict observations; minor docstring tweaks. > > - **rllib/env**: > - **`single_agent_episode.py`**: > - `concat_episode`: Replace `np.all(a == b)` with `tree.assert_same_structure` and per-leaf `np.array_equal` to compare complex/nested observations. > - Add `tree` import. > - Minor docstring wording tweaks for `len_lookback_buffer`. > - **Tests**: > - **`rllib/env/tests/test_single_agent_episode.py`**: > - Add `DictTestEnv` and `test_concat_episode_with_complex_obs` to validate concatenation with dict observations. > - Fix test class name typo. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit dc4856f. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Signed-off-by: Mark Towers <mark@anyscale.com> Co-authored-by: Mark Towers <mark@anyscale.com> Co-authored-by: Kamil Kaczmarek <kaczmarek.poczta@gmail.com> Signed-off-by: xgui <xgui@anyscale.com>
…57017) ## Why are these changes needed? `SingleAgentEpisode.concat` would only support numpy array based observations due to `np.all(old_episode.observations[-1] == new_episode.observations[0])`. I've changed the implementation to use `tree.assert_same_structure` and `np.all` on the flatten structures to verify that observations are equivalent even for complex observation structures. In addition, I've added a test using a dict obs environment to verify this works. ## Related issue number Closes #54659 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [x] This PR is not tested :( <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Use structure-aware equality for observations during episode concatenation and add a test with dict observations; minor docstring tweaks. > > - **rllib/env**: > - **`single_agent_episode.py`**: > - `concat_episode`: Replace `np.all(a == b)` with `tree.assert_same_structure` and per-leaf `np.array_equal` to compare complex/nested observations. > - Add `tree` import. > - Minor docstring wording tweaks for `len_lookback_buffer`. > - **Tests**: > - **`rllib/env/tests/test_single_agent_episode.py`**: > - Add `DictTestEnv` and `test_concat_episode_with_complex_obs` to validate concatenation with dict observations. > - Fix test class name typo. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit dc4856f. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Signed-off-by: Mark Towers <mark@anyscale.com> Co-authored-by: Mark Towers <mark@anyscale.com> Co-authored-by: Kamil Kaczmarek <kaczmarek.poczta@gmail.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
…ay-project#57017) ## Why are these changes needed? `SingleAgentEpisode.concat` would only support numpy array based observations due to `np.all(old_episode.observations[-1] == new_episode.observations[0])`. I've changed the implementation to use `tree.assert_same_structure` and `np.all` on the flatten structures to verify that observations are equivalent even for complex observation structures. In addition, I've added a test using a dict obs environment to verify this works. ## Related issue number Closes ray-project#54659 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [x] This PR is not tested :( <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Use structure-aware equality for observations during episode concatenation and add a test with dict observations; minor docstring tweaks. > > - **rllib/env**: > - **`single_agent_episode.py`**: > - `concat_episode`: Replace `np.all(a == b)` with `tree.assert_same_structure` and per-leaf `np.array_equal` to compare complex/nested observations. > - Add `tree` import. > - Minor docstring wording tweaks for `len_lookback_buffer`. > - **Tests**: > - **`rllib/env/tests/test_single_agent_episode.py`**: > - Add `DictTestEnv` and `test_concat_episode_with_complex_obs` to validate concatenation with dict observations. > - Fix test class name typo. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit dc4856f. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Signed-off-by: Mark Towers <mark@anyscale.com> Co-authored-by: Mark Towers <mark@anyscale.com> Co-authored-by: Kamil Kaczmarek <kaczmarek.poczta@gmail.com>
…ay-project#57017) ## Why are these changes needed? `SingleAgentEpisode.concat` would only support numpy array based observations due to `np.all(old_episode.observations[-1] == new_episode.observations[0])`. I've changed the implementation to use `tree.assert_same_structure` and `np.all` on the flatten structures to verify that observations are equivalent even for complex observation structures. In addition, I've added a test using a dict obs environment to verify this works. ## Related issue number Closes ray-project#54659 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [x] This PR is not tested :( <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Use structure-aware equality for observations during episode concatenation and add a test with dict observations; minor docstring tweaks. > > - **rllib/env**: > - **`single_agent_episode.py`**: > - `concat_episode`: Replace `np.all(a == b)` with `tree.assert_same_structure` and per-leaf `np.array_equal` to compare complex/nested observations. > - Add `tree` import. > - Minor docstring wording tweaks for `len_lookback_buffer`. > - **Tests**: > - **`rllib/env/tests/test_single_agent_episode.py`**: > - Add `DictTestEnv` and `test_concat_episode_with_complex_obs` to validate concatenation with dict observations. > - Fix test class name typo. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit dc4856f. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Signed-off-by: Mark Towers <mark@anyscale.com> Co-authored-by: Mark Towers <mark@anyscale.com> Co-authored-by: Kamil Kaczmarek <kaczmarek.poczta@gmail.com> Signed-off-by: Aydin Abiar <aydin@anyscale.com>
Why are these changes needed?
SingleAgentEpisode.concatwould only support numpy array based observations due tonp.all(old_episode.observations[-1] == new_episode.observations[0]).I've changed the implementation to use
tree.assert_same_structureandnp.allon the flatten structures to verify that observations are equivalent even for complex observation structures.In addition, I've added a test using a dict obs environment to verify this works.
Related issue number
Closes #54659
Checks
git commit -s) in this PR.scripts/format.shto lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.Note
Use structure-aware equality for observations during episode concatenation and add a test with dict observations; minor docstring tweaks.
single_agent_episode.py:concat_episode: Replacenp.all(a == b)withtree.assert_same_structureand per-leafnp.array_equalto compare complex/nested observations.treeimport.len_lookback_buffer.rllib/env/tests/test_single_agent_episode.py:DictTestEnvandtest_concat_episode_with_complex_obsto validate concatenation with dict observations.Written by Cursor Bugbot for commit dc4856f. This will update automatically on new commits. Configure here.