diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index 0ea688175a..1f925e762e 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -405,6 +405,8 @@ Generator Configuration zero_reward_on_non_stop: false + apply_overlong_filtering: false + Inference Engine Placement Configuration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -457,4 +459,5 @@ Generation Parameters Misc Configuration ~~~~~~~~~~~~~~~~~~ -- ``generator.zero_reward_on_non_stop``: Whether to set the reward to 0 if the `stop_reason` is not `stop`. Cases where this is useful: Often, we have format rewards for the LLM to follow, but in cases where the LLM didn't finish the response, we typically don't want to reward it. This is a general setting for all environments. \ No newline at end of file +- ``generator.zero_reward_on_non_stop``: Whether to set the reward to 0 if the `stop_reason` is not `stop`. Cases where this is useful: Often, we have format rewards for the LLM to follow, but in cases where the LLM didn't finish the response, we typically don't want to reward it. This is a general setting for all environments. +- ``generator.apply_overlong_filtering``: Whether to apply DAPO Overlong Filtering to the loss masks. For each trajectory that exceeds the max length (i.e., truncated and does not end with an EOS token), this masks out every token in the loss mask. \ No newline at end of file diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 949d61f13d..6a7f6351d6 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -194,6 +194,11 @@ generator: # TODO (erictang000): Show clear ablations for benefits of this on GSM8K or SQL. zero_reward_on_non_stop: false + # Whether to apply DAPO Overlong Filtering to the loss masks. + # For each trajectory that exceeds the max length (i.e., truncated and does not end with an + # EOS token), this masks out every token in the loss mask. + apply_overlong_filtering: false + environment: env_class: "gsm8k" # NOTE: environment specific defaults for environment.skyrl_gym are set at the following path: diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index d337793eac..c82fc2699a 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -12,7 +12,7 @@ from skyrl_train.inference_engines.base import InferenceEngineInput, ConversationType from omegaconf import DictConfig from skyrl_gym.envs.base_text_env import BaseTextEnvStepOutput -from skyrl_train.generators.utils import get_custom_chat_template, get_generation_prompt_ids +from skyrl_train.generators.utils import get_custom_chat_template, get_generation_prompt_ids, apply_overlong_filtering class SkyRLGymGenerator(GeneratorInterface): @@ -230,6 +230,9 @@ async def generate_batched( responses = truncated_responses rollout_metrics = self._rollout_metrics(responses, rewards) + if self.generator_cfg.apply_overlong_filtering: + loss_masks = apply_overlong_filtering(loss_masks, responses, self.tokenizer.eos_token_id) + generator_output: GeneratorOutput = { "prompt_token_ids": prompt_token_ids, "response_ids": responses, @@ -293,10 +296,14 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: prompt_token_ids = sum([[output[4]] for output in all_outputs], []) rollout_metrics = self._rollout_metrics(responses, rewards) + if self.generator_cfg.zero_reward_on_non_stop: # set reward to 0 if the stop reason is not "stop" rewards = self._zero_reward_if_not_stop(rewards, stop_reasons) + if self.generator_cfg.apply_overlong_filtering: + loss_masks = apply_overlong_filtering(loss_masks, responses, self.tokenizer.eos_token_id) + generator_output: GeneratorOutput = { "prompt_token_ids": prompt_token_ids, "response_ids": responses, diff --git a/skyrl-train/skyrl_train/generators/utils.py b/skyrl-train/skyrl_train/generators/utils.py index 22b6648258..de40b33f9f 100644 --- a/skyrl-train/skyrl_train/generators/utils.py +++ b/skyrl-train/skyrl_train/generators/utils.py @@ -94,3 +94,22 @@ def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> G result["stop_reasons"] = sum([output["stop_reasons"] for output in generator_outputs], []) return result + + +def apply_overlong_filtering( + loss_masks: List[List[int]], + response_ids: List[List[int]], + eos_token_id: int, +) -> List[List[int]]: + """ + Implements DAPO Overlong Filtering: zero-out every token's mask whenever + the response does not end with the eos token id (i.e. truncated). + + Returns: + - The loss masks with tokens zeroed out for truncated responses + """ + assert len(loss_masks) == len(response_ids), "loss_masks and response_ids must have the same length" + return [ + [0] * len(mask) if not response or response[-1] != eos_token_id else mask + for mask, response in zip(loss_masks, response_ids) + ] diff --git a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py index 2cff75a39d..699a7e7ba8 100644 --- a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py +++ b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py @@ -616,3 +616,164 @@ def mock_encode(text, **kwargs): assert reward == 1.0 assert stop_reason == "stop" assert len(response_ids) == len(loss_mask) + + +@pytest.mark.asyncio +@patch("skyrl_gym.make") +async def test_apply_overlong_filtering_non_batched( + mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg +): + """ + Test that apply_overlong_filtering correctly zeroes out loss masks for truncated trajectories + in non-batched mode (using agent_loop). + + Tests both truncated and non-truncated responses to verify that: + - Trajectories with responses not ending with eos token have their loss masks zeroed out + - Trajectories with responses ending with eos token keep their original loss masks + """ + mock_make.return_value = mock_env + mock_generator_cfg.apply_overlong_filtering = True # Enable filtering + mock_generator_cfg.batched = False + mock_generator_cfg.max_turns = 1 + mock_generator_cfg.use_conversation_multi_turn = False + mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {}) + + # Mock out the environment and inference engine generation. + mock_env.step.side_effect = lambda x: BaseTextEnvStepOutput(observations=[], reward=1.0, done=True, metadata={}) + + def mock_apply_chat_template(messages, **kwargs): + if kwargs.get("tokenize", True): + return [1, 2, 3, 4, 5] # 5 tokens for prompt + else: + return "".join([msg.get("content", "") for msg in messages]) + + def mock_encode_or_tokenize(text, **kwargs): + # Return different token patterns for different responses + if "truncated" in str(text): + # Simulate a long response that will get truncated by max_response_tokens + return [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] # 10 tokens, will be truncated + else: + return [20, 21, 4] # 3 tokens, ends with eos_token_id=4 + + mock_tokenizer.apply_chat_template.side_effect = mock_apply_chat_template + mock_tokenizer.encode.side_effect = mock_encode_or_tokenize + mock_tokenizer.eos_token_id = 4 # Set EOS token ID + + generator = SkyRLGymGenerator( + generator_cfg=mock_generator_cfg, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_llm, + tokenizer=mock_tokenizer, + model_name="test_model", + ) + + # First test: response that doesn't end with eos token (should be filtered) + mock_llm.generate = AsyncMock(return_value={"responses": ["truncated response"], "stop_reasons": ["length"]}) + + input_batch_truncated: GeneratorInput = { + "prompts": [[{"role": "user", "content": "Test prompt"}]], + "env_extras": [{"test": "value"}], + "env_classes": [mock_env_cfg.env_class], + } + + output_truncated = await generator.generate(input_batch_truncated) + + # Verify truncated response has zeroed loss mask + assert len(output_truncated["loss_masks"]) == 1 + assert len(output_truncated["loss_masks"][0]) == 5 # Truncated to max_generate_length=5 + assert output_truncated["loss_masks"][0] == [ + 0, + 0, + 0, + 0, + 0, + ], "Loss mask should be all zeros for response not ending with eos token" + # Note: The long response gets truncated by max_response_tokens, so it doesn't end with eos token + + # Second test: response that ends with eos token (should not be filtered) + # Reset the environment init to ensure clean state + mock_env.init.return_value = ([{"role": "user", "content": "Fresh input"}], {}) + mock_llm.generate = AsyncMock(return_value={"responses": ["normal response"], "stop_reasons": ["stop"]}) + + input_batch_normal: GeneratorInput = { + "prompts": [[{"role": "user", "content": "Another test prompt"}]], + "env_extras": [{"test": "value"}], + "env_classes": [mock_env_cfg.env_class], + } + + output_normal = await generator.generate(input_batch_normal) + + # Verify normal response keeps original loss mask (all 1s) + assert len(output_normal["loss_masks"]) == 1 + assert len(output_normal["loss_masks"][0]) == 3 # 3 response tokens (already includes EOS token) + assert output_normal["loss_masks"][0] == [ + 1, + 1, + 1, + ], "Loss mask should remain as 1s for response ending with eos token" + + +@pytest.mark.asyncio +@patch("skyrl_gym.make") +async def test_apply_overlong_filtering_batched( + mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg +): + """ + Test that apply_overlong_filtering correctly zeroes out loss masks for truncated trajectories + in batched mode. + + Tests a response that doesn't end with eos token to verify that it gets filtered. + """ + mock_make.return_value = mock_env + mock_generator_cfg.apply_overlong_filtering = True # Enable filtering + mock_generator_cfg.batched = True + mock_generator_cfg.max_turns = 1 + mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {}) + + # Mock out environment and inference engine generation. + mock_env.step.side_effect = lambda x: BaseTextEnvStepOutput(observations=[], reward=1.0, done=True, metadata={}) + mock_llm.generate = AsyncMock(return_value={"responses": ["truncated response"], "stop_reasons": ["length"]}) + + def mock_apply_chat_template(messages, **kwargs): + if kwargs.get("tokenize", True): + return [[1, 2, 3, 4, 5] for _ in messages] # 5 tokens for each prompt + else: + return "".join([msg.get("content", "") for msg in messages]) + + def mock_encode_or_tokenize(text): + return [10, 11, 12, 13] # 4 tokens, doesn't end with eos_token_id=4 + + mock_tokenizer.apply_chat_template.side_effect = mock_apply_chat_template + mock_tokenizer.side_effect = lambda text: {"input_ids": mock_encode_or_tokenize(text)} + mock_tokenizer.eos_token_id = 4 # Set EOS token ID + + generator = SkyRLGymGenerator( + generator_cfg=mock_generator_cfg, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_llm, + tokenizer=mock_tokenizer, + model_name="test_model", + ) + + # Test batched mode with response that doesn't end with eos token + prompts = [[{"role": "user", "content": "Test prompt"}]] + env_extras = [{"test": "value"}] + env_classes = [mock_env_cfg.env_class] + + input_batch: GeneratorInput = { + "prompts": prompts, + "env_extras": env_extras, + "env_classes": env_classes, + } + + generator_output = await generator.generate(input_batch) + + # Verify that the loss mask is zeroed out for the response not ending with eos token + assert len(generator_output["loss_masks"]) == 1 + assert len(generator_output["loss_masks"][0]) == 4 # Should match response length + assert generator_output["loss_masks"][0] == [ + 0, + 0, + 0, + 0, + ], "Loss mask should be all zeros for response not ending with eos token" diff --git a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py index d42c278df4..856bc9498e 100644 --- a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py +++ b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py @@ -70,6 +70,7 @@ def mock_generate(input_batch): "batched": False, "max_turns": 3, "zero_reward_on_non_stop": False, + "apply_overlong_filtering": False, "use_conversation_multi_turn": True, } ) diff --git a/skyrl-train/tests/cpu/generators/test_utils.py b/skyrl-train/tests/cpu/generators/test_utils.py new file mode 100644 index 0000000000..ab789bbdb5 --- /dev/null +++ b/skyrl-train/tests/cpu/generators/test_utils.py @@ -0,0 +1,118 @@ +""" +uv run --extra dev --isolated pytest tests/cpu/generators/test_utils.py +""" + +import pytest +from skyrl_train.generators.utils import apply_overlong_filtering + + +@pytest.mark.parametrize( + "loss_masks,response_ids,eos_token_id,expected_masks", + [ + # Test case 1: All responses end with eos token - masks should remain unchanged + ( + [[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]], + [[1, 2, 3, 4], [5, 6, 7, 4], [8, 9, 4]], # All end with eos_token_id=4 + 4, + [[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]], + ), + # Test case 2: No responses end with eos token - all masks should be zeroed + ( + [[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]], + [[1, 2, 3, 5], [5, 6, 7, 8], [8, 9, 10]], # None end with eos_token_id=4 + 4, + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0]], + ), + # Test case 3: Mixed responses - only non-eos ending masks should be zeroed + ( + [[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1, 0, 1]], + [[1, 2, 3, 4], [5, 6, 7, 8], [8, 9, 10, 11, 4]], # First and third end with eos_token_id=4 + 4, + [[1, 1, 0, 1], [0, 0, 0, 0], [1, 0, 1, 0, 1]], + ), + # Test case 4: Empty responses should be zeroed + ( + [[1, 1], [1, 0, 1], [0, 1, 1, 1]], + [[], [1, 2, 3], [4, 5, 6, 7]], # Empty, no eos, no eos (eos_token_id=4) + 4, + [[0, 0], [0, 0, 0], [0, 0, 0, 0]], + ), + # Test case 5: Empty lists + ([], [], 4, []), + # Test case 6: Different eos token id + ( + [[1, 1], [1, 0, 1], [0, 1, 1, 1]], + [[1, 2], [3, 4, 99], [5, 6, 7, 99]], # Second and third end with eos_token_id=99 + 99, + [[0, 0], [1, 0, 1], [0, 1, 1, 1]], + ), + ], +) +def test_apply_overlong_filtering(loss_masks, response_ids, eos_token_id, expected_masks): + """ + Test the apply_overlong_filtering function which implements DAPO Overlong Filtering. + + This function should zero-out every token's mask whenever the response does not end + with the eos token id (i.e. truncated), while leaving other masks unchanged. + """ + result = apply_overlong_filtering(loss_masks, response_ids, eos_token_id) + + assert result == expected_masks, f"Expected {expected_masks}, but got {result}" + + # Verify that the original inputs are not modified (immutability check) + assert len(result) == len(loss_masks), "Result should have same length as input" + + # Check that each individual mask is processed correctly + for i, (original_mask, response, expected_mask) in enumerate(zip(loss_masks, response_ids, expected_masks)): + if len(response) == 0 or response[-1] != eos_token_id: + # Should be all zeros with same length as original + assert result[i] == [0] * len(original_mask), f"Mask {i} should be all zeros for truncated response" + else: + # Should be unchanged + assert result[i] == original_mask, f"Mask {i} should be unchanged for response ending with eos token" + + +def test_apply_overlong_filtering_immutability(): + """ + Test that apply_overlong_filtering doesn't modify the original input lists. + """ + original_loss_masks = [[1, 1, 0, 1], [0, 1, 1]] + original_response_ids = [[1, 2, 3, 4], [5, 6, 7]] # First ends with eos=4, second doesn't + eos_token_id = 4 + + # Create copies to compare against later + loss_masks_copy = [mask[:] for mask in original_loss_masks] # Deep copy of lists + response_ids_copy = [response[:] for response in original_response_ids] # Deep copy of lists + + result = apply_overlong_filtering(original_loss_masks, original_response_ids, eos_token_id) + + # Verify original inputs are unchanged + assert original_loss_masks == loss_masks_copy, "Original loss_masks should not be modified" + assert original_response_ids == response_ids_copy, "Original response_ids should not be modified" + + # Verify result is correct + expected = [[1, 1, 0, 1], [0, 0, 0]] # Second mask zeroed due to not ending with eos + assert result == expected, f"Expected {expected}, got {result}" + + +@pytest.mark.parametrize( + "loss_masks,response_ids", + [ + # Test case 1: More loss_masks than response_ids + ([[1, 1], [0, 1]], [[1, 2]]), + # Test case 2: More response_ids than loss_masks + ([[1, 1]], [[1, 2], [3, 4]]), + # Test case 3: Empty loss_masks but non-empty response_ids + ([], [[1, 2]]), + # Test case 4: Non-empty loss_masks but empty response_ids + ([[1, 0]], []), + ], +) +def test_apply_overlong_filtering_length_mismatch_assertion(loss_masks, response_ids): + """ + Test that apply_overlong_filtering raises AssertionError when loss_masks and response_ids + have different lengths. + """ + eos_token_id = 4 + with pytest.raises(AssertionError, match="loss_masks and response_ids must have the same length"): + apply_overlong_filtering(loss_masks, response_ids, eos_token_id)