Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ Generator Configuration

zero_reward_on_non_stop: false

apply_overlong_filtering: false


Inference Engine Placement Configuration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -457,4 +459,5 @@ Generation Parameters
Misc Configuration
~~~~~~~~~~~~~~~~~~

- ``generator.zero_reward_on_non_stop``: Whether to set the reward to 0 if the `stop_reason` is not `stop`. Cases where this is useful: Often, we have format rewards for the LLM to follow, but in cases where the LLM didn't finish the response, we typically don't want to reward it. This is a general setting for all environments.
- ``generator.zero_reward_on_non_stop``: Whether to set the reward to 0 if the `stop_reason` is not `stop`. Cases where this is useful: Often, we have format rewards for the LLM to follow, but in cases where the LLM didn't finish the response, we typically don't want to reward it. This is a general setting for all environments.
- ``generator.apply_overlong_filtering``: Whether to apply DAPO Overlong Filtering to the loss masks. For each trajectory that exceeds the max length (i.e., truncated and does not end with an EOS token), this masks out every token in the loss mask.
5 changes: 5 additions & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ generator:
# TODO (erictang000): Show clear ablations for benefits of this on GSM8K or SQL.
zero_reward_on_non_stop: false

# Whether to apply DAPO Overlong Filtering to the loss masks.
# For each trajectory that exceeds the max length (i.e., truncated and does not end with an
# EOS token), this masks out every token in the loss mask.
apply_overlong_filtering: false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call


environment:
env_class: "gsm8k"
# NOTE: environment specific defaults for environment.skyrl_gym are set at the following path:
Expand Down
9 changes: 8 additions & 1 deletion skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from skyrl_train.inference_engines.base import InferenceEngineInput, ConversationType
from omegaconf import DictConfig
from skyrl_gym.envs.base_text_env import BaseTextEnvStepOutput
from skyrl_train.generators.utils import get_custom_chat_template, get_generation_prompt_ids
from skyrl_train.generators.utils import get_custom_chat_template, get_generation_prompt_ids, apply_overlong_filtering


class SkyRLGymGenerator(GeneratorInterface):
Expand Down Expand Up @@ -230,6 +230,9 @@ async def generate_batched(
responses = truncated_responses
rollout_metrics = self._rollout_metrics(responses, rewards)

if self.generator_cfg.apply_overlong_filtering:
loss_masks = apply_overlong_filtering(loss_masks, responses, self.tokenizer.eos_token_id)

generator_output: GeneratorOutput = {
"prompt_token_ids": prompt_token_ids,
"response_ids": responses,
Expand Down Expand Up @@ -293,10 +296,14 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
prompt_token_ids = sum([[output[4]] for output in all_outputs], [])

rollout_metrics = self._rollout_metrics(responses, rewards)

if self.generator_cfg.zero_reward_on_non_stop:
# set reward to 0 if the stop reason is not "stop"
rewards = self._zero_reward_if_not_stop(rewards, stop_reasons)

if self.generator_cfg.apply_overlong_filtering:
loss_masks = apply_overlong_filtering(loss_masks, responses, self.tokenizer.eos_token_id)

generator_output: GeneratorOutput = {
"prompt_token_ids": prompt_token_ids,
"response_ids": responses,
Expand Down
19 changes: 19 additions & 0 deletions skyrl-train/skyrl_train/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,22 @@ def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> G
result["stop_reasons"] = sum([output["stop_reasons"] for output in generator_outputs], [])

return result


def apply_overlong_filtering(
loss_masks: List[List[int]],
response_ids: List[List[int]],
eos_token_id: int,
) -> List[List[int]]:
"""
Implements DAPO Overlong Filtering: zero-out every token's mask whenever
the response does not end with the eos token id (i.e. truncated).

Returns:
- The loss masks with tokens zeroed out for truncated responses
"""
assert len(loss_masks) == len(response_ids), "loss_masks and response_ids must have the same length"
return [
[0] * len(mask) if not response or response[-1] != eos_token_id else mask
for mask, response in zip(loss_masks, response_ids)
]
161 changes: 161 additions & 0 deletions skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,3 +616,164 @@ def mock_encode(text, **kwargs):
assert reward == 1.0
assert stop_reason == "stop"
assert len(response_ids) == len(loss_mask)


@pytest.mark.asyncio
@patch("skyrl_gym.make")
async def test_apply_overlong_filtering_non_batched(
mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg
):
"""
Test that apply_overlong_filtering correctly zeroes out loss masks for truncated trajectories
in non-batched mode (using agent_loop).

Tests both truncated and non-truncated responses to verify that:
- Trajectories with responses not ending with eos token have their loss masks zeroed out
- Trajectories with responses ending with eos token keep their original loss masks
"""
mock_make.return_value = mock_env
mock_generator_cfg.apply_overlong_filtering = True # Enable filtering
mock_generator_cfg.batched = False
mock_generator_cfg.max_turns = 1
mock_generator_cfg.use_conversation_multi_turn = False
mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {})

# Mock out the environment and inference engine generation.
mock_env.step.side_effect = lambda x: BaseTextEnvStepOutput(observations=[], reward=1.0, done=True, metadata={})

def mock_apply_chat_template(messages, **kwargs):
if kwargs.get("tokenize", True):
return [1, 2, 3, 4, 5] # 5 tokens for prompt
else:
return "".join([msg.get("content", "") for msg in messages])

def mock_encode_or_tokenize(text, **kwargs):
# Return different token patterns for different responses
if "truncated" in str(text):
# Simulate a long response that will get truncated by max_response_tokens
return [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] # 10 tokens, will be truncated
else:
return [20, 21, 4] # 3 tokens, ends with eos_token_id=4

mock_tokenizer.apply_chat_template.side_effect = mock_apply_chat_template
mock_tokenizer.encode.side_effect = mock_encode_or_tokenize
mock_tokenizer.eos_token_id = 4 # Set EOS token ID

generator = SkyRLGymGenerator(
generator_cfg=mock_generator_cfg,
skyrl_gym_cfg=mock_env_cfg,
inference_engine_client=mock_llm,
tokenizer=mock_tokenizer,
model_name="test_model",
)

# First test: response that doesn't end with eos token (should be filtered)
mock_llm.generate = AsyncMock(return_value={"responses": ["truncated response"], "stop_reasons": ["length"]})

input_batch_truncated: GeneratorInput = {
"prompts": [[{"role": "user", "content": "Test prompt"}]],
"env_extras": [{"test": "value"}],
"env_classes": [mock_env_cfg.env_class],
}

output_truncated = await generator.generate(input_batch_truncated)

# Verify truncated response has zeroed loss mask
assert len(output_truncated["loss_masks"]) == 1
assert len(output_truncated["loss_masks"][0]) == 5 # Truncated to max_generate_length=5
assert output_truncated["loss_masks"][0] == [
0,
0,
0,
0,
0,
], "Loss mask should be all zeros for response not ending with eos token"
# Note: The long response gets truncated by max_response_tokens, so it doesn't end with eos token

# Second test: response that ends with eos token (should not be filtered)
# Reset the environment init to ensure clean state
mock_env.init.return_value = ([{"role": "user", "content": "Fresh input"}], {})
mock_llm.generate = AsyncMock(return_value={"responses": ["normal response"], "stop_reasons": ["stop"]})

input_batch_normal: GeneratorInput = {
"prompts": [[{"role": "user", "content": "Another test prompt"}]],
"env_extras": [{"test": "value"}],
"env_classes": [mock_env_cfg.env_class],
}

output_normal = await generator.generate(input_batch_normal)

# Verify normal response keeps original loss mask (all 1s)
assert len(output_normal["loss_masks"]) == 1
assert len(output_normal["loss_masks"][0]) == 3 # 3 response tokens (already includes EOS token)
assert output_normal["loss_masks"][0] == [
1,
1,
1,
], "Loss mask should remain as 1s for response ending with eos token"


@pytest.mark.asyncio
@patch("skyrl_gym.make")
async def test_apply_overlong_filtering_batched(
mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg
):
"""
Test that apply_overlong_filtering correctly zeroes out loss masks for truncated trajectories
in batched mode.

Tests a response that doesn't end with eos token to verify that it gets filtered.
"""
mock_make.return_value = mock_env
mock_generator_cfg.apply_overlong_filtering = True # Enable filtering
mock_generator_cfg.batched = True
mock_generator_cfg.max_turns = 1
mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {})

# Mock out environment and inference engine generation.
mock_env.step.side_effect = lambda x: BaseTextEnvStepOutput(observations=[], reward=1.0, done=True, metadata={})
mock_llm.generate = AsyncMock(return_value={"responses": ["truncated response"], "stop_reasons": ["length"]})

def mock_apply_chat_template(messages, **kwargs):
if kwargs.get("tokenize", True):
return [[1, 2, 3, 4, 5] for _ in messages] # 5 tokens for each prompt
else:
return "".join([msg.get("content", "") for msg in messages])

def mock_encode_or_tokenize(text):
return [10, 11, 12, 13] # 4 tokens, doesn't end with eos_token_id=4

mock_tokenizer.apply_chat_template.side_effect = mock_apply_chat_template
mock_tokenizer.side_effect = lambda text: {"input_ids": mock_encode_or_tokenize(text)}
mock_tokenizer.eos_token_id = 4 # Set EOS token ID

generator = SkyRLGymGenerator(
generator_cfg=mock_generator_cfg,
skyrl_gym_cfg=mock_env_cfg,
inference_engine_client=mock_llm,
tokenizer=mock_tokenizer,
model_name="test_model",
)

# Test batched mode with response that doesn't end with eos token
prompts = [[{"role": "user", "content": "Test prompt"}]]
env_extras = [{"test": "value"}]
env_classes = [mock_env_cfg.env_class]

input_batch: GeneratorInput = {
"prompts": prompts,
"env_extras": env_extras,
"env_classes": env_classes,
}

generator_output = await generator.generate(input_batch)

# Verify that the loss mask is zeroed out for the response not ending with eos token
assert len(generator_output["loss_masks"]) == 1
assert len(generator_output["loss_masks"][0]) == 4 # Should match response length
assert generator_output["loss_masks"][0] == [
0,
0,
0,
0,
], "Loss mask should be all zeros for response not ending with eos token"
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def mock_generate(input_batch):
"batched": False,
"max_turns": 3,
"zero_reward_on_non_stop": False,
"apply_overlong_filtering": False,
"use_conversation_multi_turn": True,
}
)
Expand Down
118 changes: 118 additions & 0 deletions skyrl-train/tests/cpu/generators/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
uv run --extra dev --isolated pytest tests/cpu/generators/test_utils.py
"""

import pytest
from skyrl_train.generators.utils import apply_overlong_filtering


@pytest.mark.parametrize(
"loss_masks,response_ids,eos_token_id,expected_masks",
[
# Test case 1: All responses end with eos token - masks should remain unchanged
(
[[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]],
[[1, 2, 3, 4], [5, 6, 7, 4], [8, 9, 4]], # All end with eos_token_id=4
4,
[[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]],
),
# Test case 2: No responses end with eos token - all masks should be zeroed
(
[[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]],
[[1, 2, 3, 5], [5, 6, 7, 8], [8, 9, 10]], # None end with eos_token_id=4
4,
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0]],
),
# Test case 3: Mixed responses - only non-eos ending masks should be zeroed
(
[[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1, 0, 1]],
[[1, 2, 3, 4], [5, 6, 7, 8], [8, 9, 10, 11, 4]], # First and third end with eos_token_id=4
4,
[[1, 1, 0, 1], [0, 0, 0, 0], [1, 0, 1, 0, 1]],
),
# Test case 4: Empty responses should be zeroed
(
[[1, 1], [1, 0, 1], [0, 1, 1, 1]],
[[], [1, 2, 3], [4, 5, 6, 7]], # Empty, no eos, no eos (eos_token_id=4)
4,
[[0, 0], [0, 0, 0], [0, 0, 0, 0]],
),
# Test case 5: Empty lists
([], [], 4, []),
# Test case 6: Different eos token id
(
[[1, 1], [1, 0, 1], [0, 1, 1, 1]],
[[1, 2], [3, 4, 99], [5, 6, 7, 99]], # Second and third end with eos_token_id=99
99,
[[0, 0], [1, 0, 1], [0, 1, 1, 1]],
),
],
)
def test_apply_overlong_filtering(loss_masks, response_ids, eos_token_id, expected_masks):
"""
Test the apply_overlong_filtering function which implements DAPO Overlong Filtering.

This function should zero-out every token's mask whenever the response does not end
with the eos token id (i.e. truncated), while leaving other masks unchanged.
"""
result = apply_overlong_filtering(loss_masks, response_ids, eos_token_id)

assert result == expected_masks, f"Expected {expected_masks}, but got {result}"

# Verify that the original inputs are not modified (immutability check)
assert len(result) == len(loss_masks), "Result should have same length as input"

# Check that each individual mask is processed correctly
for i, (original_mask, response, expected_mask) in enumerate(zip(loss_masks, response_ids, expected_masks)):
if len(response) == 0 or response[-1] != eos_token_id:
# Should be all zeros with same length as original
assert result[i] == [0] * len(original_mask), f"Mask {i} should be all zeros for truncated response"
else:
# Should be unchanged
assert result[i] == original_mask, f"Mask {i} should be unchanged for response ending with eos token"


def test_apply_overlong_filtering_immutability():
"""
Test that apply_overlong_filtering doesn't modify the original input lists.
"""
original_loss_masks = [[1, 1, 0, 1], [0, 1, 1]]
original_response_ids = [[1, 2, 3, 4], [5, 6, 7]] # First ends with eos=4, second doesn't
eos_token_id = 4

# Create copies to compare against later
loss_masks_copy = [mask[:] for mask in original_loss_masks] # Deep copy of lists
response_ids_copy = [response[:] for response in original_response_ids] # Deep copy of lists

result = apply_overlong_filtering(original_loss_masks, original_response_ids, eos_token_id)

# Verify original inputs are unchanged
assert original_loss_masks == loss_masks_copy, "Original loss_masks should not be modified"
assert original_response_ids == response_ids_copy, "Original response_ids should not be modified"

# Verify result is correct
expected = [[1, 1, 0, 1], [0, 0, 0]] # Second mask zeroed due to not ending with eos
assert result == expected, f"Expected {expected}, got {result}"


@pytest.mark.parametrize(
"loss_masks,response_ids",
[
# Test case 1: More loss_masks than response_ids
([[1, 1], [0, 1]], [[1, 2]]),
# Test case 2: More response_ids than loss_masks
([[1, 1]], [[1, 2], [3, 4]]),
# Test case 3: Empty loss_masks but non-empty response_ids
([], [[1, 2]]),
# Test case 4: Non-empty loss_masks but empty response_ids
([[1, 0]], []),
],
)
def test_apply_overlong_filtering_length_mismatch_assertion(loss_masks, response_ids):
"""
Test that apply_overlong_filtering raises AssertionError when loss_masks and response_ids
have different lengths.
"""
eos_token_id = 4
with pytest.raises(AssertionError, match="loss_masks and response_ids must have the same length"):
apply_overlong_filtering(loss_masks, response_ids, eos_token_id)