From 259f0b01a1a321b399f7d461f562058d83473c08 Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Sun, 24 Aug 2025 20:03:53 +0000 Subject: [PATCH 01/10] [Generator][Env] Add stop str, remove need for postprocessed action for search and txt2sql --- skyrl-gym/skyrl_gym/envs/search/env.py | 13 +++++-------- skyrl-gym/skyrl_gym/envs/sql/env.py | 15 +++++---------- skyrl-train/docs/examples/multi_turn_text2sql.rst | 4 ++++ skyrl-train/docs/examples/search.rst | 6 ++++++ skyrl-train/examples/search/run_search.sh | 2 ++ skyrl-train/examples/text_to_sql/run_skyrl_sql.sh | 2 ++ .../examples/text_to_sql/run_sql_deepspeed.sh | 2 ++ skyrl-train/examples/text_to_sql/run_sql_fsdp.sh | 2 ++ .../examples/text_to_sql/run_sql_fsdp_2node.sh | 2 ++ .../skyrl_train/config/ppo_base_config.yaml | 2 ++ .../skyrl_train/inference_engines/utils.py | 6 +++++- skyrl-train/skyrl_train/utils/utils.py | 15 +++++++++++++++ 12 files changed, 52 insertions(+), 19 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/search/env.py b/skyrl-gym/skyrl_gym/envs/search/env.py index 49a48b1eb7..9a932c5241 100644 --- a/skyrl-gym/skyrl_gym/envs/search/env.py +++ b/skyrl-gym/skyrl_gym/envs/search/env.py @@ -56,13 +56,11 @@ def _is_done(self, action: str) -> bool: return True return "" in action and "" in action - def _postprocess_action(self, action: str) -> str: + def _validate_action(self, action: str): if "" in action: - return action.split("")[0] + "" + assert action.split("")[1] == "", " detected in the response but it is not the last string generated. Use \"\" and \"\" as stop strings in the configuration." elif "" in action: - return action.split("")[0] + "" - else: - return action + assert action.split("")[1] == "", " detected in the response but it is not the last string generated. Use \"\" and \"\" as stop strings in the configuration." def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -> str: tool_output = super()._execute_tool(tool_group_name, tool_name, tool_input) @@ -71,7 +69,7 @@ def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) - def step(self, action: str) -> BaseTextEnvStepOutput: self.turns += 1 - action = self._postprocess_action(action) + self._validate_action(action) self.chat_history.append({"role": "assistant", "content": action}) error = None @@ -80,7 +78,7 @@ def step(self, action: str) -> BaseTextEnvStepOutput: if done: return BaseTextEnvStepOutput( - observations=[], reward=reward, done=done, metadata={}, postprocessed_action=action + observations=[], reward=reward, done=done, metadata={} ) try: @@ -114,5 +112,4 @@ def step(self, action: str) -> BaseTextEnvStepOutput: reward=reward, done=done, metadata=info, - postprocessed_action=action, ) diff --git a/skyrl-gym/skyrl_gym/envs/sql/env.py b/skyrl-gym/skyrl_gym/envs/sql/env.py index 33151fda73..cc58ae35e9 100644 --- a/skyrl-gym/skyrl_gym/envs/sql/env.py +++ b/skyrl-gym/skyrl_gym/envs/sql/env.py @@ -89,17 +89,15 @@ def _is_done(self, action: str) -> bool: return True return "" in action and "" in action - def _postprocess_action(self, action: str) -> str: + def _validate_action(self, action: str): if "" in action: - return action.split("")[0] + "" + assert action.split("")[1] == "", " detected in the response but it is not the last string generated. Use \"\" and \"\" as stop strings in the configuration." elif "" in action: - return action.split("")[0] + "" - else: - return action + assert action.split("")[1] == "", " detected in the response but it is not the last string generated. Use \"\" and \"\" as stop strings in the configuration." def step(self, action: str) -> BaseTextEnvStepOutput: self.turns += 1 - action = self._postprocess_action(action) + self._validate_action(action) self.chat_history.append({"role": "assistant", "content": action}) error = None @@ -107,9 +105,7 @@ def step(self, action: str) -> BaseTextEnvStepOutput: reward = self._get_reward(action, done) if done: - return BaseTextEnvStepOutput( - observations=[], reward=reward, done=done, metadata={}, postprocessed_action=action - ) + return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={}) try: tool_group_name, tool_name, tool_input = self._parse_action(action) @@ -140,5 +136,4 @@ def step(self, action: str) -> BaseTextEnvStepOutput: reward=reward, done=done, metadata=info, - postprocessed_action=action, ) diff --git a/skyrl-train/docs/examples/multi_turn_text2sql.rst b/skyrl-train/docs/examples/multi_turn_text2sql.rst index f459f86797..4fa3d26fdf 100644 --- a/skyrl-train/docs/examples/multi_turn_text2sql.rst +++ b/skyrl-train/docs/examples/multi_turn_text2sql.rst @@ -128,6 +128,8 @@ Now that we have our dataset and database files, let's walk through the some of #### generation sampling params (relevant to algorithm correctness) generator.sampling_params.temperature=0.6 \ generator.sampling_params.top_p=0.95 \ + generator.sampling_params.stop='["", ""]' \ + generator.eval_sampling_params.stop='["", ""]' \ #### training configuration trainer.policy.optimizer_config.lr=1.0e-6 \ @@ -146,6 +148,8 @@ Now that we have our dataset and database files, let's walk through the some of - Chat templating and loss masking for multi-turn conversations are handled by the ``SkyRLGymGenerator`` class. - In the above example, we set ``use_conversation_multi_turn=false`` to enforce that the multi-turn conversation is formatted as a single assistant response. + - We also set ``stop='["", ""]'`` for both ``sampling_params`` and ``eval_sampling_params`` as a part + of the training recipe. If you are using ``generator.use_conversation_multi_turn=true``, you might want to manually append an EOS token ID to the end of the response after these stop strings. - If you want to use a conversation-based format, you can set ``use_conversation_multi_turn=true`` and the model will generate a separate assistant response for each turn. This is supported only with ``backend="vllm"`` as of now. - See :code_link:`skyrl_train/generators/skyrl_gym_generator.py` for more details on both options! diff --git a/skyrl-train/docs/examples/search.rst b/skyrl-train/docs/examples/search.rst index a7043d50ba..ba717cb481 100644 --- a/skyrl-train/docs/examples/search.rst +++ b/skyrl-train/docs/examples/search.rst @@ -100,6 +100,7 @@ Let's walk through configuration for running GRPO to train a 4-turn search agent generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=1.0 \ generator.sampling_params.top_p=1.0 \ + generator.sampling_params.stop='["", ""]' \ # - Environment: environment class, max env workers, search env settings environment.env_class="search" \ @@ -112,12 +113,17 @@ Let's walk through configuration for running GRPO to train a 4-turn search agent trainer.eval_batch_size=256 \ trainer.eval_before_train=false \ generator.eval_sampling_params.temperature=0 \ + generator.eval_sampling_params.stop='["", ""]' \ trainer.eval_interval=50 \ ... # logging + checkpointing configuration (see `examples/search/run_search.sh` for the full script) To change the number of turns, you can simply change the ``generator.max_turns`` setting. For more details on environment implementation, see :skyrl_gym_link:`skyrl_gym/envs/search/env.py`. +Note we add ``stop='["", ""]'`` for both generation and evaluation sampling parameters +to adhere to the Search-R1 recipe. If you are using ``generator.use_conversation_multi_turn=true``, +you might want to manually append an EOS token ID to the end of the response after these stop strings. + Launching Your Training Run --------------------------- diff --git a/skyrl-train/examples/search/run_search.sh b/skyrl-train/examples/search/run_search.sh index dc5ce3d77c..bcaa7a82f0 100755 --- a/skyrl-train/examples/search/run_search.sh +++ b/skyrl-train/examples/search/run_search.sh @@ -48,6 +48,7 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=1.0 \ generator.sampling_params.top_p=1.0 \ + generator.sampling_params.stop='["", ""]' \ environment.env_class="search" \ environment.skyrl_gym.max_env_workers=16 \ environment.skyrl_gym.search.log_requests=false \ @@ -64,6 +65,7 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ trainer.eval_batch_size=256 \ trainer.eval_before_train=false \ generator.eval_sampling_params.temperature=0 \ + generator.eval_sampling_params.stop='["", ""]' \ trainer.export_path="$HOME/skyrl-search_4turns_maxgeneratelen_500/exports" \ trainer.eval_interval=50 \ $@ diff --git a/skyrl-train/examples/text_to_sql/run_skyrl_sql.sh b/skyrl-train/examples/text_to_sql/run_skyrl_sql.sh index a70489f1f8..bf29bdf919 100644 --- a/skyrl-train/examples/text_to_sql/run_skyrl_sql.sh +++ b/skyrl-train/examples/text_to_sql/run_skyrl_sql.sh @@ -59,6 +59,8 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=0.6 \ generator.sampling_params.top_p=0.95 \ + generator.sampling_params.stop='["", ""]' \ + generator.eval_sampling_params.stop='["", ""]' \ environment.skyrl_gym.text2sql.db_path=$DB_PATH \ trainer.logger="wandb" \ trainer.project_name="skyrlsql" \ diff --git a/skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh b/skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh index 2a9575b95a..226e772e4c 100755 --- a/skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh +++ b/skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh @@ -49,6 +49,8 @@ uv run --isolated --frozen --extra vllm --extra deepspeed -m skyrl_train.entrypo generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=0.6 \ generator.sampling_params.top_p=0.95 \ + generator.sampling_params.stop='["", ""]' \ + generator.eval_sampling_params.stop='["", ""]' \ environment.skyrl_gym.text2sql.db_path=$DB_PATH \ trainer.logger="wandb" \ trainer.project_name="skyrlsql" \ diff --git a/skyrl-train/examples/text_to_sql/run_sql_fsdp.sh b/skyrl-train/examples/text_to_sql/run_sql_fsdp.sh index 16d5ab8ec7..a7a867a9d3 100755 --- a/skyrl-train/examples/text_to_sql/run_sql_fsdp.sh +++ b/skyrl-train/examples/text_to_sql/run_sql_fsdp.sh @@ -60,6 +60,8 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=0.6 \ generator.sampling_params.top_p=0.95 \ + generator.sampling_params.stop='["", ""]' \ + generator.eval_sampling_params.stop='["", ""]' \ trainer.seed=1234 \ environment.skyrl_gym.text2sql.db_path=$DB_PATH \ trainer.logger="wandb" \ diff --git a/skyrl-train/examples/text_to_sql/run_sql_fsdp_2node.sh b/skyrl-train/examples/text_to_sql/run_sql_fsdp_2node.sh index 07f90bc925..20d172e268 100644 --- a/skyrl-train/examples/text_to_sql/run_sql_fsdp_2node.sh +++ b/skyrl-train/examples/text_to_sql/run_sql_fsdp_2node.sh @@ -53,6 +53,8 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=0.6 \ generator.sampling_params.top_p=0.95 \ + generator.sampling_params.stop='["", ""]' \ + generator.eval_sampling_params.stop='["", ""]' \ environment.skyrl_gym.text2sql.db_path=$DB_PATH \ trainer.logger="wandb" \ trainer.project_name="skyrlsql" \ diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index fdddf84757..e1427d05da 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -186,6 +186,7 @@ generator: min_p: 0.0 top_k: -1 logprobs: null + stop: null # whether to use a conversation based format for multi-turn generations # if false, append multi-turn model responses and env observations to the original assistant response @@ -200,6 +201,7 @@ generator: min_p: 0.0 top_k: -1 logprobs: null + stop: null # number of samples per prompt for evaluation eval_n_samples_per_prompt: 1 diff --git a/skyrl-train/skyrl_train/inference_engines/utils.py b/skyrl-train/skyrl_train/inference_engines/utils.py index 4bf9ee0dcb..5d705ca31d 100644 --- a/skyrl-train/skyrl_train/inference_engines/utils.py +++ b/skyrl-train/skyrl_train/inference_engines/utils.py @@ -13,6 +13,7 @@ def get_vllm_sampling_params(sampling_params: DictConfig) -> Dict[str, Any]: "top_k": sampling_params.top_k, "min_p": sampling_params.min_p, "logprobs": sampling_params.logprobs, + "stop": list(sampling_params.stop) if sampling_params.stop is not None else None, } exclude_keys = ["max_generate_length"] for key, value in sampling_params.items(): @@ -25,9 +26,12 @@ def get_vllm_sampling_params(sampling_params: DictConfig) -> Dict[str, Any]: def get_sglang_sampling_params(sampling_params: DictConfig) -> Dict[str, Any]: - # min_tokens, include_stop_str_in_output are not used in sglang + # `min_tokens` in vllm is equivalent to `min_new_tokens` in sglang. However `min_new_tokens` and + # `stop` are not supported when `skip_tokenizer_init` is True, which we need for token-in-token-out. + # See this issue for more: https://github.com/sgl-project/sglang/issues/9039#issuecomment-3218331087 sglang_sampling_params = { "skip_special_tokens": True, + "no_stop_trim": True, # equivalent to include_stop_str_in_output=True "max_new_tokens": sampling_params.max_generate_length, "temperature": sampling_params.temperature, "top_p": sampling_params.top_p, diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 57b74c062f..2c585a7f9e 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -268,6 +268,21 @@ def validate_cfg(cfg: DictConfig): if not cfg.generator.run_engines_locally: raise NotImplementedError("Remote inference mode doesn't support `sampling_params.logprobs`") + if cfg.generator.backend == "sglang": + # Some sampling parameters are not supported in SGLang when `skip_tokenizer_init` is True. + if cfg.generator.sampling_params.stop is not None or cfg.generator.eval_sampling_params.stop is not None: + raise ValueError( + "`sampling_params.stop` and `eval_sampling_params.stop` are not supported for SGLang backend " + "since we always set `skip_tokenizer_init` to True. If you have to use these parameters, you can switch to vLLM. " + "See this issue for more: https://github.com/sgl-project/sglang/issues/9039#issuecomment-3218331087" + ) + if "min_new_tokens" in cfg.generator.sampling_params or "min_new_tokens" in cfg.generator.eval_sampling_params: + raise ValueError( + "`sampling_params.min_new_tokens` and `eval_sampling_params.min_new_tokens` are not " + "supported for SGLang backend since we always set `skip_tokenizer_init` to True. " + "If you have to use these parameters, you can switch to vLLM. " + "See this issue for more: https://github.com/sgl-project/sglang/issues/9039#issuecomment-3218331087" + ) @ray.remote def get_all_env_variables(): From 33fb74072c7fe1661c9a5bc45e11edf46bb99f9d Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Sun, 24 Aug 2025 20:16:26 +0000 Subject: [PATCH 02/10] fix lint, add warning --- skyrl-gym/skyrl_gym/envs/search/env.py | 14 +++++++++----- skyrl-gym/skyrl_gym/envs/sql/env.py | 10 ++++++++-- skyrl-train/skyrl_train/utils/utils.py | 9 +++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/search/env.py b/skyrl-gym/skyrl_gym/envs/search/env.py index 9a932c5241..5db3261832 100644 --- a/skyrl-gym/skyrl_gym/envs/search/env.py +++ b/skyrl-gym/skyrl_gym/envs/search/env.py @@ -58,9 +58,15 @@ def _is_done(self, action: str) -> bool: def _validate_action(self, action: str): if "" in action: - assert action.split("")[1] == "", " detected in the response but it is not the last string generated. Use \"\" and \"\" as stop strings in the configuration." + assert action.split("")[1] == "", ( + " detected in the response but it is not the last string generated. " + 'Use "" and "" as stop strings in the configuration.' + ) elif "" in action: - assert action.split("")[1] == "", " detected in the response but it is not the last string generated. Use \"\" and \"\" as stop strings in the configuration." + assert action.split("")[1] == "", ( + " detected in the response but it is not the last string generated. " + 'Use "" and "" as stop strings in the configuration.' + ) def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -> str: tool_output = super()._execute_tool(tool_group_name, tool_name, tool_input) @@ -77,9 +83,7 @@ def step(self, action: str) -> BaseTextEnvStepOutput: reward = self._get_reward(action, done) if done: - return BaseTextEnvStepOutput( - observations=[], reward=reward, done=done, metadata={} - ) + return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={}) try: query = self._parse_action(action) diff --git a/skyrl-gym/skyrl_gym/envs/sql/env.py b/skyrl-gym/skyrl_gym/envs/sql/env.py index cc58ae35e9..e018817659 100644 --- a/skyrl-gym/skyrl_gym/envs/sql/env.py +++ b/skyrl-gym/skyrl_gym/envs/sql/env.py @@ -91,9 +91,15 @@ def _is_done(self, action: str) -> bool: def _validate_action(self, action: str): if "" in action: - assert action.split("")[1] == "", " detected in the response but it is not the last string generated. Use \"\" and \"\" as stop strings in the configuration." + assert action.split("")[1] == "", ( + " detected in the response but it is not the last string generated. " + 'Use "" and "" as stop strings in the configuration.' + ) elif "" in action: - assert action.split("")[1] == "", " detected in the response but it is not the last string generated. Use \"\" and \"\" as stop strings in the configuration." + assert action.split("")[1] == "", ( + " detected in the response but it is not the last string generated. " + 'Use "" and "" as stop strings in the configuration.' + ) def step(self, action: str) -> BaseTextEnvStepOutput: self.turns += 1 diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 2c585a7f9e..da037599e2 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -284,6 +284,15 @@ def validate_cfg(cfg: DictConfig): "See this issue for more: https://github.com/sgl-project/sglang/issues/9039#issuecomment-3218331087" ) + if cfg.generator.use_conversation_multi_turn: + if cfg.generator.sampling_params.stop is not None or cfg.generator.eval_sampling_params.stop is not None: + print( + "WARNING: `sampling_params.stop` and `eval_sampling_params.stop` are specified but we " + "are using multi-turn generation. You might want to manually append tokenizer.eos_token_id " + "to the assistant-generated response to match the chat template." + ) + + @ray.remote def get_all_env_variables(): import os From 62d38fd79bebf0f41b720df3cfb69e63c3bf46e4 Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Sun, 24 Aug 2025 20:17:59 +0000 Subject: [PATCH 03/10] address gemini comments --- skyrl-gym/skyrl_gym/envs/search/env.py | 4 ++-- skyrl-gym/skyrl_gym/envs/sql/env.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/search/env.py b/skyrl-gym/skyrl_gym/envs/search/env.py index 5db3261832..4393224c3f 100644 --- a/skyrl-gym/skyrl_gym/envs/search/env.py +++ b/skyrl-gym/skyrl_gym/envs/search/env.py @@ -58,12 +58,12 @@ def _is_done(self, action: str) -> bool: def _validate_action(self, action: str): if "" in action: - assert action.split("")[1] == "", ( + assert action.split("", 1)[1] == "", ( " detected in the response but it is not the last string generated. " 'Use "" and "" as stop strings in the configuration.' ) elif "" in action: - assert action.split("")[1] == "", ( + assert action.split("", 1)[1] == "", ( " detected in the response but it is not the last string generated. " 'Use "" and "" as stop strings in the configuration.' ) diff --git a/skyrl-gym/skyrl_gym/envs/sql/env.py b/skyrl-gym/skyrl_gym/envs/sql/env.py index e018817659..207352fdc4 100644 --- a/skyrl-gym/skyrl_gym/envs/sql/env.py +++ b/skyrl-gym/skyrl_gym/envs/sql/env.py @@ -91,12 +91,12 @@ def _is_done(self, action: str) -> bool: def _validate_action(self, action: str): if "" in action: - assert action.split("")[1] == "", ( + assert action.split("", 1)[1] == "", ( " detected in the response but it is not the last string generated. " 'Use "" and "" as stop strings in the configuration.' ) elif "" in action: - assert action.split("")[1] == "", ( + assert action.split("", 1)[1] == "", ( " detected in the response but it is not the last string generated. " 'Use "" and "" as stop strings in the configuration.' ) From cf9a63a152ded714542917d3cd5f317d9a22a821 Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Sun, 24 Aug 2025 20:26:10 +0000 Subject: [PATCH 04/10] remove postprocess action tests from skyrlgym --- skyrl-gym/tests/test_search.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/skyrl-gym/tests/test_search.py b/skyrl-gym/tests/test_search.py index ce43b843c1..841383f6a4 100644 --- a/skyrl-gym/tests/test_search.py +++ b/skyrl-gym/tests/test_search.py @@ -132,36 +132,6 @@ def test_parse_action(search_env, action, expected_input): assert query == expected_input -# ============================================================================= -# ACTION POSTPROCESSING FUNCTIONALITY TESTS -# ============================================================================= - - -@pytest.mark.parametrize( - "action, expected", - [ - # Search with extra content after closing tag - ("Query extra content", "Query"), - # Answer with extra content after closing tag - ("Answer extra content", "Answer"), - # Both search and answer tags - ("Query Answer extra", "Query"), - # Only search tag (no extra content) - ("Query", "Query"), - # Only answer tag (no extra content) - ("Answer", "Answer"), - # No special tags - ("Just plain text", "Just plain text"), - # end tag before start tag - ("Query", ""), - ], -) -def test_postprocess_action(search_env, action, expected): - """Test action postprocessing.""" - result = search_env._postprocess_action(action) - assert result == expected - - # ============================================================================= # EPISODE TERMINATION CONDITIONS TESTS # ============================================================================= @@ -313,8 +283,6 @@ def test_invalid_search_parsing(search_env, mock_search_api): ("emmanuel macron", {"target": "Emmanuel Macron"}, 1.0, True), # Answer without articles ("Emmanuel Macron", {"target": "The Emmanuel Macron"}, 1.0, True), - # Multiple answer tags (should use first one) - ("Wrong Emmanuel Macron", {"target": "Emmanuel Macron"}, 0.0, True), # No answer tag ("Just text without answer tag", {"target": "Emmanuel Macron"}, 0.0, False), ], From cdf24a578758ca85c1d273933f6c1f207d538d95 Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Tue, 26 Aug 2025 19:02:07 +0000 Subject: [PATCH 05/10] append eos to response after stop str in multiturn --- .../skyrl_train/config/ppo_base_config.yaml | 5 + .../generators/skyrl_gym_generator.py | 14 +++ ...est_skyrl_gym_generator_chat_templating.py | 91 ++++++++++++++++++- 3 files changed, 109 insertions(+), 1 deletion(-) diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index e1427d05da..f39741f7be 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -193,6 +193,11 @@ generator: # if true, each multi-turn model response and env observations is stored in a separate assistant/user message respectively use_conversation_multi_turn: true + # Used when use_conversation_multi_turn is true, and sampling_params.stop is not null. + # If true, append tokenizer.eos_token_id to the end of the generation if the generation ends + # with stop_reason "stop" and matched a stop string in sampling_params.stop. + append_eos_token_after_stop_str_in_multi_turn: true + # sampling params for evaluation eval_sampling_params: max_generate_length: ${generator.sampling_params.max_generate_length} diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index 14419f9865..94d9dfdbdb 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -177,6 +177,20 @@ async def agent_loop( output_ids = engine_output["response_ids"][0] stop_reason = engine_output["stop_reasons"][0] + # append eos token if needed. Only applicable when sampling_params.stop is not None. + stop_strs = sampling_params.get("stop", None) if sampling_params is not None else None + if ( + stop_strs is not None + and self.generator_cfg.append_eos_token_after_stop_str_in_multi_turn + and (retokenize_chat_history or self.use_conversation_multi_turn) + ): + for stop_str in stop_strs: + if output.endswith(stop_str) and output_ids[-1] != self.tokenizer.eos_token_id: + # Append EOS token to output to match chat template termination. + # Do not mutate loss_mask here; it will be updated downstream together with output_ids. + output_ids.append(self.tokenizer.eos_token_id) + break + # 2. Environment step if self.env_executor is not None: loop = asyncio.get_running_loop() diff --git a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py index bfc1262f1f..d325662189 100644 --- a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py +++ b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py @@ -61,7 +61,8 @@ def mock_generate(input_batch): num_prompts = len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"]) mock_llm_output_text = "b" + tokenizer.eos_token return { - "responses": [mock_llm_output_text] * num_prompts, + # no tokenizer.eos_token for responses because `skip_special_tokens` is True in sampling params + "responses": ["b"] * num_prompts, "stop_reasons": ["stop"] * num_prompts, "response_logprobs": None, # add_special_tokens needs to be False, otherwise for instance Llama will always @@ -80,6 +81,7 @@ def mock_generate(input_batch): "zero_reward_on_non_stop": False, "apply_overlong_filtering": False, "use_conversation_multi_turn": True, + "append_eos_token_after_stop_str_in_multi_turn": True, } ) env_cfg = DictConfig( @@ -183,3 +185,90 @@ def mock_generate(input_batch): expected_loss_masks = expected_loss_masks[:-1] # remove the extra 0 for \n assert len(expected_loss_masks) == len(generator_output["loss_masks"][0]) assert generator_output["loss_masks"][0] == expected_loss_masks + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", ["Qwen/Qwen2.5-0.5B-Instruct", "unsloth/Llama-3.2-1B-Instruct", "Qwen/Qwen3-0.6B"] +) +async def test_append_eos_after_stop_multi_turn(model_name): + _register_test_env_if_needed() + tokenizer = AutoTokenizer.from_pretrained(model_name) + + stop_tag = "" + + async def make_generator(append_flag: bool): + mock_llm = MagicMock() + + def mock_generate(input_batch): + num_prompts = len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"]) + mock_text = "b" + stop_tag + return { + "responses": [mock_text] * num_prompts, + "stop_reasons": ["stop"] * num_prompts, + "response_logprobs": None, + "response_ids": [tokenizer.encode(mock_text, add_special_tokens=False)] * num_prompts, + } + + mock_llm.generate = AsyncMock(side_effect=mock_generate) + + generator_cfg = DictConfig( + { + "sampling_params": {"max_generate_length": 200, "logprobs": None}, + "max_input_length": 200, + "batched": False, + "max_turns": 3, + "zero_reward_on_non_stop": False, + "apply_overlong_filtering": False, + "use_conversation_multi_turn": True, + "append_eos_token_after_stop_str_in_multi_turn": append_flag, + } + ) + env_cfg = DictConfig({"max_env_workers": 0, "env_class": "cpu_test_env"}) + gen = SkyRLGymGenerator( + generator_cfg=generator_cfg, + skyrl_gym_cfg=env_cfg, + inference_engine_client=mock_llm, + tokenizer=tokenizer, + model_name=model_name, + ) + return gen + + prompt = [[{"role": "user", "content": "a"}]] + extras = [{"answer": "4"}] + sp = {"stop": [stop_tag]} + + # Case 1: append flag = True + generator_true = await make_generator(True) + out_true: GeneratorOutput = await generator_true.generate( + {"prompts": prompt, "env_extras": extras, "env_classes": ["cpu_test_env"], "sampling_params": sp} + ) + + # Case 2: append flag = False + generator_false = await make_generator(False) + out_false: GeneratorOutput = await generator_false.generate( + {"prompts": prompt, "env_extras": extras, "env_classes": ["cpu_test_env"], "sampling_params": sp} + ) + + # Common assertions + assert out_true["stop_reasons"][0] == "stop" + assert out_false["stop_reasons"][0] == "stop" + assert len(out_true["response_ids"][0]) == len(out_true["loss_masks"][0]) + assert len(out_false["response_ids"][0]) == len(out_false["loss_masks"][0]) + + last_true = out_true["response_ids"][0][-1] + last_false = out_false["response_ids"][0][-1] + + if "Qwen3" in model_name: + # Retokenize path: custom chat template appends "<|im_end|>\n". + # Ensure EOS appears and allow a trailing newline token after it. + assert tokenizer.eos_token_id in out_true["response_ids"][0] + assert tokenizer.eos_token_id in out_false["response_ids"][0] + if last_true != tokenizer.eos_token_id: + assert tokenizer.decode(out_true["response_ids"][0]).endswith("\n") + if last_false != tokenizer.eos_token_id: + assert tokenizer.decode(out_false["response_ids"][0]).endswith("\n") + else: + # Non-retokenize path: last token is eos only when append flag is True + assert last_true == tokenizer.eos_token_id + assert last_false != tokenizer.eos_token_id From b3fc6ff7e08122736c071c133079f6e25d420776 Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Tue, 26 Aug 2025 19:29:46 +0000 Subject: [PATCH 06/10] fix scripts and doc according to append_eos --- .../docs/examples/multi_turn_text2sql.rst | 3 +- skyrl-train/docs/examples/search.rst | 10 ++- skyrl-train/examples/search/run_search.sh | 1 - .../examples/search/run_search_multiturn.sh | 79 ++++++++++++++++++ .../examples/text_to_sql/run_skyrl_sql.sh | 1 - .../text_to_sql/run_skyrl_sql_multiturn.sh | 80 +++++++++++++++++++ .../examples/text_to_sql/run_sql_deepspeed.sh | 1 - .../examples/text_to_sql/run_sql_fsdp.sh | 1 - .../text_to_sql/run_sql_fsdp_2node.sh | 1 - 9 files changed, 169 insertions(+), 8 deletions(-) create mode 100755 skyrl-train/examples/search/run_search_multiturn.sh create mode 100644 skyrl-train/examples/text_to_sql/run_skyrl_sql_multiturn.sh diff --git a/skyrl-train/docs/examples/multi_turn_text2sql.rst b/skyrl-train/docs/examples/multi_turn_text2sql.rst index 4fa3d26fdf..c10cf3ac71 100644 --- a/skyrl-train/docs/examples/multi_turn_text2sql.rst +++ b/skyrl-train/docs/examples/multi_turn_text2sql.rst @@ -149,7 +149,8 @@ Now that we have our dataset and database files, let's walk through the some of - In the above example, we set ``use_conversation_multi_turn=false`` to enforce that the multi-turn conversation is formatted as a single assistant response. - We also set ``stop='["", ""]'`` for both ``sampling_params`` and ``eval_sampling_params`` as a part - of the training recipe. If you are using ``generator.use_conversation_multi_turn=true``, you might want to manually append an EOS token ID to the end of the response after these stop strings. + of the training recipe. + - If you are using ``generator.use_conversation_multi_turn=true``, you might want to append an EOS token ID to the end of the response after these stop strings to adhere to the model's behavior (i.e. ending generation with an EOS token ID rather than say ````). This can be done by setting ``generator.append_eos_token_after_stop_str_in_multi_turn=true`` in the generator config. The full script is available in `examples/text_to_sql/run_skyrl_sql_multiturn.sh`. - If you want to use a conversation-based format, you can set ``use_conversation_multi_turn=true`` and the model will generate a separate assistant response for each turn. This is supported only with ``backend="vllm"`` as of now. - See :code_link:`skyrl_train/generators/skyrl_gym_generator.py` for more details on both options! diff --git a/skyrl-train/docs/examples/search.rst b/skyrl-train/docs/examples/search.rst index ba717cb481..fcb3c21d5b 100644 --- a/skyrl-train/docs/examples/search.rst +++ b/skyrl-train/docs/examples/search.rst @@ -121,8 +121,14 @@ To change the number of turns, you can simply change the ``generator.max_turns`` For more details on environment implementation, see :skyrl_gym_link:`skyrl_gym/envs/search/env.py`. Note we add ``stop='["", ""]'`` for both generation and evaluation sampling parameters -to adhere to the Search-R1 recipe. If you are using ``generator.use_conversation_multi_turn=true``, -you might want to manually append an EOS token ID to the end of the response after these stop strings. +to adhere to the Search-R1 recipe. + +If you are using ``generator.use_conversation_multi_turn=true``, +you might want to append an EOS token ID to the end of the response after these stop strings to adhere +to the model's behavior (i.e. ending generation with an EOS token ID rather than say ````). +This can be done by setting ``generator.append_eos_token_after_stop_str_in_multi_turn=true`` in the generator config. +The full script is available in `examples/search/run_search_multiturn.sh`. + Launching Your Training Run --------------------------- diff --git a/skyrl-train/examples/search/run_search.sh b/skyrl-train/examples/search/run_search.sh index bcaa7a82f0..a8e4d13340 100755 --- a/skyrl-train/examples/search/run_search.sh +++ b/skyrl-train/examples/search/run_search.sh @@ -45,7 +45,6 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ generator.use_conversation_multi_turn=false \ generator.n_samples_per_prompt=5 \ generator.max_turns=4 \ - generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=1.0 \ generator.sampling_params.top_p=1.0 \ generator.sampling_params.stop='["", ""]' \ diff --git a/skyrl-train/examples/search/run_search_multiturn.sh b/skyrl-train/examples/search/run_search_multiturn.sh new file mode 100755 index 0000000000..5dd9a777c6 --- /dev/null +++ b/skyrl-train/examples/search/run_search_multiturn.sh @@ -0,0 +1,79 @@ +set -x + +# The exact same script as `run_search.sh` but with `use_conversation_multi_turn=true` +# and hence `append_eos_token_after_stop_str_in_multi_turn=true` +# See https://skyrl.readthedocs.io/en/latest/tutorials/skyrl_gym_generator.html on the +# difference between the two options. You might want to change the data generation prompt +# to let the model know that we are doing multi-turn conversations (i.e. user will provide +# the search result for each turn). + +# Colocated GRPO training+generation for Qwen2.5-Coder-3B-Instruct on SearchR1 data. +# follow the instructions in examples/search/README.md for setting up the dataset +# and for starting the local search server +# export WANDB_API_KEY= +# bash examples/search/run_search.sh + +# path for dataset (.parquet files) containing the prompts and metadata for each question +DATA_DIR="$HOME/data/searchR1" + +uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ + data.train_data="['${DATA_DIR}/train.parquet']" \ + data.val_data="['${DATA_DIR}/validation.parquet']" \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.policy.optimizer_config.max_grad_norm=0.5 \ + trainer.policy.optimizer_config.num_warmup_steps=94 \ + trainer.algorithm.use_kl_loss=true \ + trainer.algorithm.kl_loss_coef=0.001 \ + trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ + trainer.placement.colocate_all=true \ + trainer.strategy=fsdp2 \ + trainer.policy.fsdp_config.cpu_offload=false \ + trainer.ref.fsdp_config.cpu_offload=true \ + trainer.placement.policy_num_gpus_per_node=8 \ + trainer.placement.ref_num_gpus_per_node=8 \ + generator.num_inference_engines=4 \ + generator.inference_engine_tensor_parallel_size=2 \ + generator.backend=vllm \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.gpu_memory_utilization=0.5 \ + trainer.epochs=1 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=512 \ + trainer.policy_mini_batch_size=256 \ + trainer.micro_forward_batch_size_per_gpu=4 \ + trainer.micro_train_batch_size_per_gpu=4 \ + trainer.max_prompt_length=2048 \ + generator.max_input_length=4096 \ + generator.sampling_params.max_generate_length=500 \ + generator.async_engine=true \ + generator.batched=false \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=5 \ + generator.max_turns=4 \ + generator.sampling_params.temperature=1.0 \ + generator.sampling_params.top_p=1.0 \ + generator.sampling_params.stop='["", ""]' \ + generator.append_eos_token_after_stop_str_in_multi_turn=true \ + environment.env_class="search" \ + environment.skyrl_gym.max_env_workers=16 \ + environment.skyrl_gym.search.log_requests=false \ + environment.skyrl_gym.search.search_url="http://127.0.0.1:8000/retrieve" \ + environment.skyrl_gym.search.topk=3 \ + trainer.logger="wandb" \ + trainer.project_name="skyrl-search" \ + trainer.run_name="skyrl-search_4turns_maxgeneratelen_500" \ + trainer.ckpt_interval=20 \ + trainer.hf_save_interval=100 \ + trainer.max_ckpts_to_keep=5 \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$HOME/skyrl-search_4turns_maxgeneratelen_500" \ + trainer.eval_batch_size=256 \ + trainer.eval_before_train=false \ + generator.eval_sampling_params.temperature=0 \ + generator.eval_sampling_params.stop='["", ""]' \ + trainer.export_path="$HOME/skyrl-search_4turns_maxgeneratelen_500/exports" \ + trainer.eval_interval=50 \ + $@ + \ No newline at end of file diff --git a/skyrl-train/examples/text_to_sql/run_skyrl_sql.sh b/skyrl-train/examples/text_to_sql/run_skyrl_sql.sh index bf29bdf919..344e2aa132 100644 --- a/skyrl-train/examples/text_to_sql/run_skyrl_sql.sh +++ b/skyrl-train/examples/text_to_sql/run_skyrl_sql.sh @@ -56,7 +56,6 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ generator.n_samples_per_prompt=5 \ generator.gpu_memory_utilization=0.7 \ generator.max_turns=6 \ - generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=0.6 \ generator.sampling_params.top_p=0.95 \ generator.sampling_params.stop='["", ""]' \ diff --git a/skyrl-train/examples/text_to_sql/run_skyrl_sql_multiturn.sh b/skyrl-train/examples/text_to_sql/run_skyrl_sql_multiturn.sh new file mode 100644 index 0000000000..e77f951dad --- /dev/null +++ b/skyrl-train/examples/text_to_sql/run_skyrl_sql_multiturn.sh @@ -0,0 +1,80 @@ +set -x + +# The exact same script as `run_search.sh` but with `use_conversation_multi_turn=true` +# and hence `append_eos_token_after_stop_str_in_multi_turn=true` +# See https://skyrl.readthedocs.io/en/latest/tutorials/skyrl_gym_generator.html on what behavior +# use_conversation_multi_turn corresponds to. You might want to change the data generation prompt +# to let the model know that we are doing multi-turn conversations (i.e. user will provide +# the search result for each turn). + +# Colocated GRPO training+generation for Qwen2.5-Coder-7B-Instruct on SkyRL-SQL-653 data. +# Uses 1 node with 8 GPUs. +# huggingface-cli download NovaSky-AI/SkyRL-SQL-653-data-newfmt --local-dir $HOME/data/sql --repo-type dataset +# export WANDB_API_KEY= +# bash examples/text_to_sql/run_skyrl_sql.sh + +# change these paths to your own +DATA_DIR="$HOME/data/sql" +DB_PATH="$HOME/data/sql/db_files/data" +CKPT_PATH="$HOME/ckpts/skyrl_sql_7B_ckpt" + +NUM_GPUS=8 +NUM_INFERENCE_ENGINES=2 +TP_SIZE=4 +MAX_INPUT_LENGTH=29000 +MAX_GENERATE_LENGTH=3000 +TRAIN_BATCH_SIZE=256 + +uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ + trainer.algorithm.advantage_estimator="grpo" \ + data.train_data="['$DATA_DIR/train.parquet']" \ + data.val_data="['$DATA_DIR/validation.parquet']" \ + trainer.policy.model.path="Qwen/Qwen2.5-Coder-7B-Instruct" \ + trainer.epochs=30 \ + trainer.placement.colocate_all=true \ + trainer.strategy=fsdp2 \ + trainer.policy.fsdp_config.cpu_offload=false \ + trainer.ref.fsdp_config.cpu_offload=true \ + trainer.policy.optimizer_config.max_grad_norm=0.5 \ + trainer.policy.sequence_parallel_size=1 \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ + trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ + generator.num_inference_engines=$NUM_INFERENCE_ENGINES \ + generator.inference_engine_tensor_parallel_size=$TP_SIZE \ + trainer.train_batch_size=$TRAIN_BATCH_SIZE \ + trainer.micro_forward_batch_size_per_gpu=8 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.max_prompt_length=6000 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.policy_mini_batch_size=256 \ + trainer.algorithm.use_kl_loss=false \ + trainer.ckpt_interval=60 \ + trainer.hf_save_interval=30 \ + trainer.dump_data_batch=true \ + generator.backend=vllm \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + environment.env_class=text2sql \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=5 \ + generator.gpu_memory_utilization=0.7 \ + generator.max_turns=6 \ + generator.sampling_params.temperature=0.6 \ + generator.sampling_params.top_p=0.95 \ + generator.sampling_params.stop='["", ""]' \ + generator.append_eos_token_after_stop_str_in_multi_turn=true \ + generator.eval_sampling_params.stop='["", ""]' \ + environment.skyrl_gym.text2sql.db_path=$DB_PATH \ + trainer.logger="wandb" \ + trainer.project_name="skyrlsql" \ + trainer.run_name="skyrlsql_repro" \ + trainer.resume_mode=latest \ + trainer.ckpt_path=$CKPT_PATH \ + trainer.eval_batch_size=1024 \ + trainer.eval_before_train=true \ + trainer.eval_interval=5 \ + $@ \ No newline at end of file diff --git a/skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh b/skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh index 226e772e4c..2ee4b5e4e1 100755 --- a/skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh +++ b/skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh @@ -46,7 +46,6 @@ uv run --isolated --frozen --extra vllm --extra deepspeed -m skyrl_train.entrypo generator.n_samples_per_prompt=5 \ generator.gpu_memory_utilization=0.7 \ generator.max_turns=5 \ - generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=0.6 \ generator.sampling_params.top_p=0.95 \ generator.sampling_params.stop='["", ""]' \ diff --git a/skyrl-train/examples/text_to_sql/run_sql_fsdp.sh b/skyrl-train/examples/text_to_sql/run_sql_fsdp.sh index a7a867a9d3..07e5d749f3 100755 --- a/skyrl-train/examples/text_to_sql/run_sql_fsdp.sh +++ b/skyrl-train/examples/text_to_sql/run_sql_fsdp.sh @@ -57,7 +57,6 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ generator.n_samples_per_prompt=5 \ generator.gpu_memory_utilization=0.7 \ generator.max_turns=6 \ - generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=0.6 \ generator.sampling_params.top_p=0.95 \ generator.sampling_params.stop='["", ""]' \ diff --git a/skyrl-train/examples/text_to_sql/run_sql_fsdp_2node.sh b/skyrl-train/examples/text_to_sql/run_sql_fsdp_2node.sh index 20d172e268..1546174fe3 100644 --- a/skyrl-train/examples/text_to_sql/run_sql_fsdp_2node.sh +++ b/skyrl-train/examples/text_to_sql/run_sql_fsdp_2node.sh @@ -50,7 +50,6 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ generator.n_samples_per_prompt=5 \ generator.gpu_memory_utilization=0.7 \ generator.max_turns=5 \ - generator.use_conversation_multi_turn=false \ generator.sampling_params.temperature=0.6 \ generator.sampling_params.top_p=0.95 \ generator.sampling_params.stop='["", ""]' \ From cb67035ef55526b744a7430bf880e65493c31f87 Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Tue, 26 Aug 2025 19:52:24 +0000 Subject: [PATCH 07/10] fix unit tests --- .../generators/skyrl_gym_generator.py | 1 - ...est_skyrl_gym_generator_chat_templating.py | 43 ++++++++++++------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index 94d9dfdbdb..4390a3725e 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -187,7 +187,6 @@ async def agent_loop( for stop_str in stop_strs: if output.endswith(stop_str) and output_ids[-1] != self.tokenizer.eos_token_id: # Append EOS token to output to match chat template termination. - # Do not mutate loss_mask here; it will be updated downstream together with output_ids. output_ids.append(self.tokenizer.eos_token_id) break diff --git a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py index d325662189..5a66bf42e9 100644 --- a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py +++ b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py @@ -192,17 +192,26 @@ def mock_generate(input_batch): "model_name", ["Qwen/Qwen2.5-0.5B-Instruct", "unsloth/Llama-3.2-1B-Instruct", "Qwen/Qwen3-0.6B"] ) async def test_append_eos_after_stop_multi_turn(model_name): + """ + Test the behavior of `append_eos_token_after_stop_str_in_multi_turn`, which is applicable + when `sampling_params.stop` is not `null` and `use_conversation_multi_turn` is `true` in + the ``agent_loop()`` function. + It is used in scripts `examples/search/run_search_multiturn.sh` and `examples/text_to_sql/run_skyrl_sql_multiturn.sh`. + """ _register_test_env_if_needed() tokenizer = AutoTokenizer.from_pretrained(model_name) stop_tag = "" + mock_text = "b" + stop_tag async def make_generator(append_flag: bool): mock_llm = MagicMock() + # The LLM engine will generate and return the stop tag, but no EOS token ID. def mock_generate(input_batch): - num_prompts = len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"]) - mock_text = "b" + stop_tag + num_prompts = ( + len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"]) + ) return { "responses": [mock_text] * num_prompts, "stop_reasons": ["stop"] * num_prompts, @@ -214,7 +223,7 @@ def mock_generate(input_batch): generator_cfg = DictConfig( { - "sampling_params": {"max_generate_length": 200, "logprobs": None}, + "sampling_params": {"max_generate_length": 200, "logprobs": None, "stop": [stop_tag]}, "max_input_length": 200, "batched": False, "max_turns": 3, @@ -256,19 +265,21 @@ def mock_generate(input_batch): assert len(out_true["response_ids"][0]) == len(out_true["loss_masks"][0]) assert len(out_false["response_ids"][0]) == len(out_false["loss_masks"][0]) - last_true = out_true["response_ids"][0][-1] - last_false = out_false["response_ids"][0][-1] - if "Qwen3" in model_name: - # Retokenize path: custom chat template appends "<|im_end|>\n". - # Ensure EOS appears and allow a trailing newline token after it. - assert tokenizer.eos_token_id in out_true["response_ids"][0] - assert tokenizer.eos_token_id in out_false["response_ids"][0] - if last_true != tokenizer.eos_token_id: - assert tokenizer.decode(out_true["response_ids"][0]).endswith("\n") - if last_false != tokenizer.eos_token_id: - assert tokenizer.decode(out_false["response_ids"][0]).endswith("\n") + # Retokenize path is not affected by append_eos_token_after_stop_str_in_multi_turn + # The chat template does things like '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' + # So regardless of append_eos_token_after_stop_str_in_multi_turn, the last tokens are: + # stop_tag, eos_token_id and \n + last_token_ids = tokenizer.encode(stop_tag + tokenizer.eos_token + "\n") + num_last_tokens = len(last_token_ids) + response_ids_true = out_true["response_ids"][0] + response_ids_false = out_false["response_ids"][0] + assert response_ids_true[-num_last_tokens:] == last_token_ids + assert response_ids_false[-num_last_tokens:] == last_token_ids + assert response_ids_true == response_ids_false else: # Non-retokenize path: last token is eos only when append flag is True - assert last_true == tokenizer.eos_token_id - assert last_false != tokenizer.eos_token_id + last_token_id_true = out_true["response_ids"][0][-1] + last_token_id_false = out_false["response_ids"][0][-1] + assert last_token_id_true == tokenizer.eos_token_id + assert last_token_id_false == tokenizer.encode(mock_text, add_special_tokens=False)[-1] From f8971edefd5f5427faccdc894779747b8ba9971d Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Tue, 26 Aug 2025 20:12:37 +0000 Subject: [PATCH 08/10] change logs --- .../skyrl_train/generators/skyrl_gym_generator.py | 1 + skyrl-train/skyrl_train/utils/utils.py | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index 4390a3725e..53f8cc1ec4 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -178,6 +178,7 @@ async def agent_loop( stop_reason = engine_output["stop_reasons"][0] # append eos token if needed. Only applicable when sampling_params.stop is not None. + # Note this does not affect 3.a because the chat template adds eos_token to the end. stop_strs = sampling_params.get("stop", None) if sampling_params is not None else None if ( stop_strs is not None diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index da037599e2..b6b90be127 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -285,11 +285,13 @@ def validate_cfg(cfg: DictConfig): ) if cfg.generator.use_conversation_multi_turn: - if cfg.generator.sampling_params.stop is not None or cfg.generator.eval_sampling_params.stop is not None: - print( - "WARNING: `sampling_params.stop` and `eval_sampling_params.stop` are specified but we " - "are using multi-turn generation. You might want to manually append tokenizer.eos_token_id " - "to the assistant-generated response to match the chat template." + if ( + cfg.generator.sampling_params.stop is not None or cfg.generator.eval_sampling_params.stop is not None + ) and not cfg.generator.append_eos_token_after_stop_str_in_multi_turn: + logger.warning( + "WARNING: `sampling_params.stop` and `eval_sampling_params.stop` are specified and we " + "are using multi-turn generation. You might want to set `append_eos_token_after_stop_str_in_multi_turn` " + "to `True` to append tokenizer.eos_token_id to the assistant-generated response to match the chat template." ) From c05f2e9646d969a2fe22322250fe6626a60b6868 Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Wed, 27 Aug 2025 00:49:20 +0000 Subject: [PATCH 09/10] Address how sampling_params is None for rollout --- skyrl-train/skyrl_train/generators/skyrl_gym_generator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index 53f8cc1ec4..5d65dc5706 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -179,7 +179,11 @@ async def agent_loop( # append eos token if needed. Only applicable when sampling_params.stop is not None. # Note this does not affect 3.a because the chat template adds eos_token to the end. - stop_strs = sampling_params.get("stop", None) if sampling_params is not None else None + stop_strs = None + if sampling_params is None: + stop_strs = self.generator_cfg.sampling_params.get("stop", None) + else: + stop_strs = sampling_params.get("stop", None) if ( stop_strs is not None and self.generator_cfg.append_eos_token_after_stop_str_in_multi_turn From 0d2effd3b64972abf3b6b6e815b6a955ab73f046 Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Wed, 27 Aug 2025 21:44:20 +0000 Subject: [PATCH 10/10] Address comments --- skyrl-gym/skyrl_gym/envs/search/env.py | 17 ++++++-------- skyrl-gym/skyrl_gym/envs/sql/env.py | 17 ++++++-------- .../docs/examples/multi_turn_text2sql.rst | 2 +- skyrl-train/docs/examples/search.rst | 2 +- ...n.sh => run_search_conversation_format.sh} | 0 ...h => run_skyrl_sql_conversation_format.sh} | 0 .../generators/skyrl_gym_generator.py | 22 ++++++++----------- ...est_skyrl_gym_generator_chat_templating.py | 3 ++- 8 files changed, 27 insertions(+), 36 deletions(-) rename skyrl-train/examples/search/{run_search_multiturn.sh => run_search_conversation_format.sh} (100%) rename skyrl-train/examples/text_to_sql/{run_skyrl_sql_multiturn.sh => run_skyrl_sql_conversation_format.sh} (100%) diff --git a/skyrl-gym/skyrl_gym/envs/search/env.py b/skyrl-gym/skyrl_gym/envs/search/env.py index 4393224c3f..32fd752097 100644 --- a/skyrl-gym/skyrl_gym/envs/search/env.py +++ b/skyrl-gym/skyrl_gym/envs/search/env.py @@ -57,16 +57,13 @@ def _is_done(self, action: str) -> bool: return "" in action and "" in action def _validate_action(self, action: str): - if "" in action: - assert action.split("", 1)[1] == "", ( - " detected in the response but it is not the last string generated. " - 'Use "" and "" as stop strings in the configuration.' - ) - elif "" in action: - assert action.split("", 1)[1] == "", ( - " detected in the response but it is not the last string generated. " - 'Use "" and "" as stop strings in the configuration.' - ) + stop_tags = ["", ""] + for tag in stop_tags: + if tag in action: + assert action.split(tag, 1)[1] == "", ( + f"{tag} detected in the response but it is not the last string generated. " + f"Use {stop_tags} as stop strings in the configuration." + ) def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -> str: tool_output = super()._execute_tool(tool_group_name, tool_name, tool_input) diff --git a/skyrl-gym/skyrl_gym/envs/sql/env.py b/skyrl-gym/skyrl_gym/envs/sql/env.py index 207352fdc4..8a5e5d67cf 100644 --- a/skyrl-gym/skyrl_gym/envs/sql/env.py +++ b/skyrl-gym/skyrl_gym/envs/sql/env.py @@ -90,16 +90,13 @@ def _is_done(self, action: str) -> bool: return "" in action and "" in action def _validate_action(self, action: str): - if "" in action: - assert action.split("", 1)[1] == "", ( - " detected in the response but it is not the last string generated. " - 'Use "" and "" as stop strings in the configuration.' - ) - elif "" in action: - assert action.split("", 1)[1] == "", ( - " detected in the response but it is not the last string generated. " - 'Use "" and "" as stop strings in the configuration.' - ) + stop_tags = ["", ""] + for tag in stop_tags: + if tag in action: + assert action.split(tag, 1)[1] == "", ( + f"{tag} detected in the response but it is not the last string generated. " + f"Use {stop_tags} as stop strings in the configuration." + ) def step(self, action: str) -> BaseTextEnvStepOutput: self.turns += 1 diff --git a/skyrl-train/docs/examples/multi_turn_text2sql.rst b/skyrl-train/docs/examples/multi_turn_text2sql.rst index c10cf3ac71..d699d4609e 100644 --- a/skyrl-train/docs/examples/multi_turn_text2sql.rst +++ b/skyrl-train/docs/examples/multi_turn_text2sql.rst @@ -150,7 +150,7 @@ Now that we have our dataset and database files, let's walk through the some of - In the above example, we set ``use_conversation_multi_turn=false`` to enforce that the multi-turn conversation is formatted as a single assistant response. - We also set ``stop='["", ""]'`` for both ``sampling_params`` and ``eval_sampling_params`` as a part of the training recipe. - - If you are using ``generator.use_conversation_multi_turn=true``, you might want to append an EOS token ID to the end of the response after these stop strings to adhere to the model's behavior (i.e. ending generation with an EOS token ID rather than say ````). This can be done by setting ``generator.append_eos_token_after_stop_str_in_multi_turn=true`` in the generator config. The full script is available in `examples/text_to_sql/run_skyrl_sql_multiturn.sh`. + - If you are using ``generator.use_conversation_multi_turn=true``, you might want to append an EOS token ID to the end of the response after these stop strings to adhere to the model's behavior (i.e. ending generation with an EOS token ID rather than say ````). This can be done by setting ``generator.append_eos_token_after_stop_str_in_multi_turn=true`` in the generator config. The full script is available in `examples/text_to_sql/run_skyrl_sql_conversation_format.sh`. - If you want to use a conversation-based format, you can set ``use_conversation_multi_turn=true`` and the model will generate a separate assistant response for each turn. This is supported only with ``backend="vllm"`` as of now. - See :code_link:`skyrl_train/generators/skyrl_gym_generator.py` for more details on both options! diff --git a/skyrl-train/docs/examples/search.rst b/skyrl-train/docs/examples/search.rst index fcb3c21d5b..0e4346880d 100644 --- a/skyrl-train/docs/examples/search.rst +++ b/skyrl-train/docs/examples/search.rst @@ -127,7 +127,7 @@ If you are using ``generator.use_conversation_multi_turn=true``, you might want to append an EOS token ID to the end of the response after these stop strings to adhere to the model's behavior (i.e. ending generation with an EOS token ID rather than say ````). This can be done by setting ``generator.append_eos_token_after_stop_str_in_multi_turn=true`` in the generator config. -The full script is available in `examples/search/run_search_multiturn.sh`. +The full script is available in `examples/search/run_search_conversation_format.sh`. Launching Your Training Run diff --git a/skyrl-train/examples/search/run_search_multiturn.sh b/skyrl-train/examples/search/run_search_conversation_format.sh similarity index 100% rename from skyrl-train/examples/search/run_search_multiturn.sh rename to skyrl-train/examples/search/run_search_conversation_format.sh diff --git a/skyrl-train/examples/text_to_sql/run_skyrl_sql_multiturn.sh b/skyrl-train/examples/text_to_sql/run_skyrl_sql_conversation_format.sh similarity index 100% rename from skyrl-train/examples/text_to_sql/run_skyrl_sql_multiturn.sh rename to skyrl-train/examples/text_to_sql/run_skyrl_sql_conversation_format.sh diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index 5d65dc5706..d849dfe18c 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -177,23 +177,19 @@ async def agent_loop( output_ids = engine_output["response_ids"][0] stop_reason = engine_output["stop_reasons"][0] - # append eos token if needed. Only applicable when sampling_params.stop is not None. - # Note this does not affect 3.a because the chat template adds eos_token to the end. - stop_strs = None - if sampling_params is None: - stop_strs = self.generator_cfg.sampling_params.get("stop", None) - else: - stop_strs = sampling_params.get("stop", None) + # Append eos when sampling_params.stop is not None. Does not affect 3.a as chat templates add eos_token. + # sampling_params is not None for eval, but None for training (which uses engine.sampling_params which are from cfg) + current_sampling_params = ( + sampling_params if sampling_params is not None else self.generator_cfg.sampling_params + ) + stop_strs = current_sampling_params.get("stop", None) if ( stop_strs is not None and self.generator_cfg.append_eos_token_after_stop_str_in_multi_turn - and (retokenize_chat_history or self.use_conversation_multi_turn) + and self.use_conversation_multi_turn ): - for stop_str in stop_strs: - if output.endswith(stop_str) and output_ids[-1] != self.tokenizer.eos_token_id: - # Append EOS token to output to match chat template termination. - output_ids.append(self.tokenizer.eos_token_id) - break + if output.endswith(tuple(stop_strs)) and output_ids[-1] != self.tokenizer.eos_token_id: + output_ids.append(self.tokenizer.eos_token_id) # 2. Environment step if self.env_executor is not None: diff --git a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py index 5a66bf42e9..eead5d5ccb 100644 --- a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py +++ b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py @@ -196,7 +196,8 @@ async def test_append_eos_after_stop_multi_turn(model_name): Test the behavior of `append_eos_token_after_stop_str_in_multi_turn`, which is applicable when `sampling_params.stop` is not `null` and `use_conversation_multi_turn` is `true` in the ``agent_loop()`` function. - It is used in scripts `examples/search/run_search_multiturn.sh` and `examples/text_to_sql/run_skyrl_sql_multiturn.sh`. + It is used in scripts `examples/search/run_search_conversation_format.sh` and + `examples/text_to_sql/run_skyrl_sql_conversation_format.sh`. """ _register_test_env_if_needed() tokenizer = AutoTokenizer.from_pretrained(model_name)