-
Notifications
You must be signed in to change notification settings - Fork 256
[DAPO] Add support for overlong filtering #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
9966861
Add trajectory dumping functionality and GSM8K example updates
tyler-griggs b9e1131
x
tyler-griggs 3e02948
Add tests for apply_overlong_filtering functionality
tyler-griggs 0a30cbe
Implement DAPO Overlong Filtering feature
tyler-griggs 84681c7
x
tyler-griggs 2c1c28b
switch from stop reasons to eos token IDs
tyler-griggs d2e3699
x
tyler-griggs 102c8e4
minor edits
tyler-griggs 6e31837
add doc
tyler-griggs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| """ | ||
| uv run --extra dev --isolated pytest tests/cpu/generators/test_utils.py | ||
| """ | ||
|
|
||
| import pytest | ||
| from skyrl_train.generators.utils import apply_overlong_filtering | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "loss_masks,response_ids,eos_token_id,expected_masks", | ||
| [ | ||
| # Test case 1: All responses end with eos token - masks should remain unchanged | ||
| ( | ||
| [[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]], | ||
| [[1, 2, 3, 4], [5, 6, 7, 4], [8, 9, 4]], # All end with eos_token_id=4 | ||
| 4, | ||
| [[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]], | ||
| ), | ||
| # Test case 2: No responses end with eos token - all masks should be zeroed | ||
| ( | ||
| [[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]], | ||
| [[1, 2, 3, 5], [5, 6, 7, 8], [8, 9, 10]], # None end with eos_token_id=4 | ||
| 4, | ||
| [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0]], | ||
| ), | ||
| # Test case 3: Mixed responses - only non-eos ending masks should be zeroed | ||
| ( | ||
| [[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1, 0, 1]], | ||
| [[1, 2, 3, 4], [5, 6, 7, 8], [8, 9, 10, 11, 4]], # First and third end with eos_token_id=4 | ||
| 4, | ||
| [[1, 1, 0, 1], [0, 0, 0, 0], [1, 0, 1, 0, 1]], | ||
| ), | ||
| # Test case 4: Empty responses should be zeroed | ||
| ( | ||
| [[1, 1], [1, 0, 1], [0, 1, 1, 1]], | ||
| [[], [1, 2, 3], [4, 5, 6, 7]], # Empty, no eos, no eos (eos_token_id=4) | ||
| 4, | ||
| [[0, 0], [0, 0, 0], [0, 0, 0, 0]], | ||
| ), | ||
| # Test case 5: Empty lists | ||
| ([], [], 4, []), | ||
| # Test case 6: Different eos token id | ||
| ( | ||
| [[1, 1], [1, 0, 1], [0, 1, 1, 1]], | ||
| [[1, 2], [3, 4, 99], [5, 6, 7, 99]], # Second and third end with eos_token_id=99 | ||
| 99, | ||
| [[0, 0], [1, 0, 1], [0, 1, 1, 1]], | ||
| ), | ||
| ], | ||
| ) | ||
| def test_apply_overlong_filtering(loss_masks, response_ids, eos_token_id, expected_masks): | ||
| """ | ||
| Test the apply_overlong_filtering function which implements DAPO Overlong Filtering. | ||
|
|
||
| This function should zero-out every token's mask whenever the response does not end | ||
| with the eos token id (i.e. truncated), while leaving other masks unchanged. | ||
| """ | ||
| result = apply_overlong_filtering(loss_masks, response_ids, eos_token_id) | ||
|
|
||
| assert result == expected_masks, f"Expected {expected_masks}, but got {result}" | ||
|
|
||
| # Verify that the original inputs are not modified (immutability check) | ||
| assert len(result) == len(loss_masks), "Result should have same length as input" | ||
|
|
||
| # Check that each individual mask is processed correctly | ||
| for i, (original_mask, response, expected_mask) in enumerate(zip(loss_masks, response_ids, expected_masks)): | ||
| if len(response) == 0 or response[-1] != eos_token_id: | ||
| # Should be all zeros with same length as original | ||
| assert result[i] == [0] * len(original_mask), f"Mask {i} should be all zeros for truncated response" | ||
| else: | ||
| # Should be unchanged | ||
| assert result[i] == original_mask, f"Mask {i} should be unchanged for response ending with eos token" | ||
|
|
||
|
|
||
| def test_apply_overlong_filtering_immutability(): | ||
| """ | ||
| Test that apply_overlong_filtering doesn't modify the original input lists. | ||
| """ | ||
| original_loss_masks = [[1, 1, 0, 1], [0, 1, 1]] | ||
| original_response_ids = [[1, 2, 3, 4], [5, 6, 7]] # First ends with eos=4, second doesn't | ||
| eos_token_id = 4 | ||
|
|
||
| # Create copies to compare against later | ||
| loss_masks_copy = [mask[:] for mask in original_loss_masks] # Deep copy of lists | ||
| response_ids_copy = [response[:] for response in original_response_ids] # Deep copy of lists | ||
|
|
||
| result = apply_overlong_filtering(original_loss_masks, original_response_ids, eos_token_id) | ||
|
|
||
| # Verify original inputs are unchanged | ||
| assert original_loss_masks == loss_masks_copy, "Original loss_masks should not be modified" | ||
| assert original_response_ids == response_ids_copy, "Original response_ids should not be modified" | ||
|
|
||
| # Verify result is correct | ||
| expected = [[1, 1, 0, 1], [0, 0, 0]] # Second mask zeroed due to not ending with eos | ||
| assert result == expected, f"Expected {expected}, got {result}" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "loss_masks,response_ids", | ||
| [ | ||
| # Test case 1: More loss_masks than response_ids | ||
| ([[1, 1], [0, 1]], [[1, 2]]), | ||
| # Test case 2: More response_ids than loss_masks | ||
| ([[1, 1]], [[1, 2], [3, 4]]), | ||
| # Test case 3: Empty loss_masks but non-empty response_ids | ||
| ([], [[1, 2]]), | ||
| # Test case 4: Non-empty loss_masks but empty response_ids | ||
| ([[1, 0]], []), | ||
| ], | ||
| ) | ||
| def test_apply_overlong_filtering_length_mismatch_assertion(loss_masks, response_ids): | ||
| """ | ||
| Test that apply_overlong_filtering raises AssertionError when loss_masks and response_ids | ||
| have different lengths. | ||
| """ | ||
| eos_token_id = 4 | ||
| with pytest.raises(AssertionError, match="loss_masks and response_ids must have the same length"): | ||
| apply_overlong_filtering(loss_masks, response_ids, eos_token_id) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually can you make sure this is updated in the docs as well: https://github.com/NovaSky-AI/SkyRL/blob/main/skyrl-train/docs/configuration/config.rst
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call