[Generator] Support token-in-token-out rollout#152
[Generator] Support token-in-token-out rollout#152CharlieFRuan merged 14 commits intoNovaSky-AI:mainfrom
Conversation
| if self.tokenizer.eos_token_id in self.base_conversation_token_ids: | ||
| last_eos_token_index = ( | ||
| len(self.base_conversation_token_ids) | ||
| - 1 | ||
| - self.base_conversation_token_ids[::-1].index(self.tokenizer.eos_token_id) | ||
| ) | ||
| self.base_conversation_token_ids = self.base_conversation_token_ids[: last_eos_token_index + 1] |
There was a problem hiding this comment.
Want to highlight this change, slightly unsure about this
|
/gemini summary |
|
/gemini review |
Summary of ChangesThis pull request implements a comprehensive "token-in-token-out" rollout across the SkyRL framework, ensuring that language model inference engines and the SkyRLGymGenerator consistently operate using token IDs rather than raw text strings for improved efficiency and correctness in multi-turn interactions. This change standardizes the data flow, particularly for managing conversational context and observations, while addressing specific backend behaviors for vLLM and SGLang. Highlights
Changelog
Activity
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant and valuable changes to support token-in-token-out generation, which is crucial for consistency in multi-turn scenarios. The refactoring, particularly in SkyRLGymGenerator, greatly improves the clarity and structure of the code by separating the different conversation management strategies. My review has identified a key performance issue in the remote vLLM engine client where batch requests are sent sequentially, as well as several areas in the SkyRLGymGenerator where docstrings are either missing or inconsistent with the implementation, which could impact future maintainability. The proposed suggestions aim to fix the performance bottleneck and improve the documentation's accuracy.
skyrl-train/skyrl_train/inference_engines/remote_inference_engine.py
Outdated
Show resolved
Hide resolved
| @@ -14,9 +14,15 @@ class InferenceEngineInput(TypedDict): | |||
|
|
|||
|
|
|||
| class InferenceEngineOutput(TypedDict): | |||
There was a problem hiding this comment.
I have small nits on this paragraph just for brevity / clarity, but will draft something in next round of updates :)
skyrl-train/skyrl_train/inference_engines/remote_inference_engine.py
Outdated
Show resolved
Hide resolved
| # By cutting `\n` out in `base_conversation_token_ids`, `observation_ids` in | ||
| # `_get_next_input_ids_with_multiturn_chat_template()` will be `\n<|im_start|>user\nObservation here<|im_end|>\n`. | ||
| # Note the `\n` at the final assistant turn will still be missing, but this is fine. | ||
| # See tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py for more details. |
There was a problem hiding this comment.
These explanations are important and useful, but I actually think they should be in a doc with a single comment linking to them. They do add quite a lot of blot to the fast-growing generator file :D I think the skyrl gym generator has reached sufficient complexity to add a doc detailing it's behavior. At the least, we could start with a briefer doc on "multi-turn chat tokenization" and discuss the three possible methods you have here. What do you think?
There was a problem hiding this comment.
Yes, a doc should definitely be needed. I will add that to a follow-up PR
There was a problem hiding this comment.
Got it, thanks! Once added, can you please remove some of the comments here and just link to the doc?
ff7f9f2 to
efd544a
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces significant and valuable changes to support token-in-token-out generation throughout SkyRL. The refactoring of SkyRLGymGenerator greatly improves clarity by separating the different context management strategies. The updates to the remote inference engines to support batching are also a major enhancement. My review includes a couple of suggestions to further improve code quality by using standard library features for warnings and removing leftover development artifacts.
tyler-griggs
left a comment
There was a problem hiding this comment.
Mostly nits. Let's test e2e and merge?
| # By cutting `\n` out in `base_conversation_token_ids`, `observation_ids` in | ||
| # `_get_next_input_ids_with_multiturn_chat_template()` will be `\n<|im_start|>user\nObservation here<|im_end|>\n`. | ||
| # Note the `\n` at the final assistant turn will still be missing, but this is fine. | ||
| # See tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py for more details. |
There was a problem hiding this comment.
Got it, thanks! Once added, can you please remove some of the comments here and just link to the doc?
|
|
||
| env.close() # does nothing for now | ||
|
|
||
| # TODO(Charlie): this makes the prompt_ids include the generation prompt, is this what we want? |
There was a problem hiding this comment.
Yes I believe it is. Primarily because maintaining the generation prompt is "on policy". If we remove those tokens, it will not actually reflect what the model observed when making its generation.
There was a problem hiding this comment.
Before this PR the generation prompt is in response_ids (with proper maskings). After this PR it is in prompt_ids, hence the comment here. I am guessing it is equivalent?
…zation (#186) This PR adds a documentation page detailing the behavior of `SkyRLGymGenerator`, including the 3 codepaths described in #152, how multi-turn and single-turn rollouts/tokenizations work, and how token-in-token-out is enforced. As a result, we can remove some bulky comments in `skyrl_gym_generator.py` --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…rocessing action in search and txt2sql environments (#190) Motivated by token-in-token-out in #152, we want to avoid post-processing the action in `Env.step()`, which happens in the string space and causes re-tokenization. Currently, such post processing happens in our search and txt2sql environemnt. To fix it for these two environments, we add `cfg.generator.sampling_params.stop` and `cfg.generator.eval_sampling_params.stop`. This will be equivalent to the post-processing done in the search and txt2sql environment, which simply strips after tags like `"</search>"`. Furthermore, for these two environments, if `generator.use_conversation_multi_turn=true`, we need to append an EOS token ID after these stop strings to adhere to the chat template, e.g. ending in `</answer><|im_end|>`. This is only required in codepath 1 (out of the 3): https://github.com/NovaSky-AI/SkyRL/pull/186/files. Codepath 3 (always retokenize) is not needed because the chat template always applies eos to the end of message content. We do such append when `cfg.generator.append_eos_token_after_stop_str_in_multi_turn` is `True` and add `run_search_multiturn.sh` and `run_skyrl_sql_multiturn.sh` correspondingly. In addition, this PR: - Add `"no_stop_trim": True,` to SGLang sampling params, which is equivalent to vLLM's `include_stop_str_in_output`. - Throw error when `min_new_tokens` or `stop` is used with SGLang backend, since SGLang will throw an error when these sampling parameters are used when the engine is initialized with `skip_tokenizer_init=True`. - See this issue for more: sgl-project/sglang#9039 (comment) Tested by: - `tests/gpu/test_policy_local_engines_e2e.py::test_policy_local_engines_e2e` - E2E run of search r1 in #152 is run with these changes - `run_gsm8k.sh` with sglang and vllm -- also made sure those unsupported fields in sglang will raise error as expected Note: It is worth noting that, if say the stop string is `</search>`, and vLLM generated a token equivalent to `>\n`, the string output of vLLM will truncate to only have `</search>` while the token ID output will still be `>\n`. This can be observed in this script: https://gist.github.com/CharlieFRuan/ca3a8fee388263f7e2f96bc89e2ee7f5
…zation (#186) This PR adds a documentation page detailing the behavior of `SkyRLGymGenerator`, including the 3 codepaths described in NovaSky-AI/SkyRL#152, how multi-turn and single-turn rollouts/tokenizations work, and how token-in-token-out is enforced. As a result, we can remove some bulky comments in `skyrl_gym_generator.py` --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…rocessing action in search and txt2sql environments (#190) Motivated by token-in-token-out in NovaSky-AI/SkyRL#152, we want to avoid post-processing the action in `Env.step()`, which happens in the string space and causes re-tokenization. Currently, such post processing happens in our search and txt2sql environemnt. To fix it for these two environments, we add `cfg.generator.sampling_params.stop` and `cfg.generator.eval_sampling_params.stop`. This will be equivalent to the post-processing done in the search and txt2sql environment, which simply strips after tags like `"</search>"`. Furthermore, for these two environments, if `generator.use_conversation_multi_turn=true`, we need to append an EOS token ID after these stop strings to adhere to the chat template, e.g. ending in `</answer><|im_end|>`. This is only required in codepath 1 (out of the 3): https://github.com/NovaSky-AI/SkyRL/pull/186/files. Codepath 3 (always retokenize) is not needed because the chat template always applies eos to the end of message content. We do such append when `cfg.generator.append_eos_token_after_stop_str_in_multi_turn` is `True` and add `run_search_multiturn.sh` and `run_skyrl_sql_multiturn.sh` correspondingly. In addition, this PR: - Add `"no_stop_trim": True,` to SGLang sampling params, which is equivalent to vLLM's `include_stop_str_in_output`. - Throw error when `min_new_tokens` or `stop` is used with SGLang backend, since SGLang will throw an error when these sampling parameters are used when the engine is initialized with `skip_tokenizer_init=True`. - See this issue for more: sgl-project/sglang#9039 (comment) Tested by: - `tests/gpu/test_policy_local_engines_e2e.py::test_policy_local_engines_e2e` - E2E run of search r1 in NovaSky-AI/SkyRL#152 is run with these changes - `run_gsm8k.sh` with sglang and vllm -- also made sure those unsupported fields in sglang will raise error as expected Note: It is worth noting that, if say the stop string is `</search>`, and vLLM generated a token equivalent to `>\n`, the string output of vLLM will truncate to only have `</search>` while the token ID output will still be `>\n`. This can be observed in this script: https://gist.github.com/CharlieFRuan/ca3a8fee388263f7e2f96bc89e2ee7f5
…zation (#186) This PR adds a documentation page detailing the behavior of `SkyRLGymGenerator`, including the 3 codepaths described in NovaSky-AI/SkyRL#152, how multi-turn and single-turn rollouts/tokenizations work, and how token-in-token-out is enforced. As a result, we can remove some bulky comments in `skyrl_gym_generator.py` --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…rocessing action in search and txt2sql environments (#190) Motivated by token-in-token-out in NovaSky-AI/SkyRL#152, we want to avoid post-processing the action in `Env.step()`, which happens in the string space and causes re-tokenization. Currently, such post processing happens in our search and txt2sql environemnt. To fix it for these two environments, we add `cfg.generator.sampling_params.stop` and `cfg.generator.eval_sampling_params.stop`. This will be equivalent to the post-processing done in the search and txt2sql environment, which simply strips after tags like `"</search>"`. Furthermore, for these two environments, if `generator.use_conversation_multi_turn=true`, we need to append an EOS token ID after these stop strings to adhere to the chat template, e.g. ending in `</answer><|im_end|>`. This is only required in codepath 1 (out of the 3): https://github.com/NovaSky-AI/SkyRL/pull/186/files. Codepath 3 (always retokenize) is not needed because the chat template always applies eos to the end of message content. We do such append when `cfg.generator.append_eos_token_after_stop_str_in_multi_turn` is `True` and add `run_search_multiturn.sh` and `run_skyrl_sql_multiturn.sh` correspondingly. In addition, this PR: - Add `"no_stop_trim": True,` to SGLang sampling params, which is equivalent to vLLM's `include_stop_str_in_output`. - Throw error when `min_new_tokens` or `stop` is used with SGLang backend, since SGLang will throw an error when these sampling parameters are used when the engine is initialized with `skip_tokenizer_init=True`. - See this issue for more: sgl-project/sglang#9039 (comment) Tested by: - `tests/gpu/test_policy_local_engines_e2e.py::test_policy_local_engines_e2e` - E2E run of search r1 in NovaSky-AI/SkyRL#152 is run with these changes - `run_gsm8k.sh` with sglang and vllm -- also made sure those unsupported fields in sglang will raise error as expected Note: It is worth noting that, if say the stop string is `</search>`, and vLLM generated a token equivalent to `>\n`, the string output of vLLM will truncate to only have `</search>` while the token ID output will still be `>\n`. This can be observed in this script: https://gist.github.com/CharlieFRuan/ca3a8fee388263f7e2f96bc89e2ee7f5
This PR ensures token-in-token-out rollout in most codepaths in SkyRL. We make various codepaths of inference engines return token IDs, and make the 3 codepaths in SkyRLGymGenerator clearer while enforcing token-in-token-out.
…zation (#186) This PR adds a documentation page detailing the behavior of `SkyRLGymGenerator`, including the 3 codepaths described in NovaSky-AI#152, how multi-turn and single-turn rollouts/tokenizations work, and how token-in-token-out is enforced. As a result, we can remove some bulky comments in `skyrl_gym_generator.py` --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…rocessing action in search and txt2sql environments (#190) Motivated by token-in-token-out in NovaSky-AI#152, we want to avoid post-processing the action in `Env.step()`, which happens in the string space and causes re-tokenization. Currently, such post processing happens in our search and txt2sql environemnt. To fix it for these two environments, we add `cfg.generator.sampling_params.stop` and `cfg.generator.eval_sampling_params.stop`. This will be equivalent to the post-processing done in the search and txt2sql environment, which simply strips after tags like `"</search>"`. Furthermore, for these two environments, if `generator.use_conversation_multi_turn=true`, we need to append an EOS token ID after these stop strings to adhere to the chat template, e.g. ending in `</answer><|im_end|>`. This is only required in codepath 1 (out of the 3): https://github.com/NovaSky-AI/SkyRL/pull/186/files. Codepath 3 (always retokenize) is not needed because the chat template always applies eos to the end of message content. We do such append when `cfg.generator.append_eos_token_after_stop_str_in_multi_turn` is `True` and add `run_search_multiturn.sh` and `run_skyrl_sql_multiturn.sh` correspondingly. In addition, this PR: - Add `"no_stop_trim": True,` to SGLang sampling params, which is equivalent to vLLM's `include_stop_str_in_output`. - Throw error when `min_new_tokens` or `stop` is used with SGLang backend, since SGLang will throw an error when these sampling parameters are used when the engine is initialized with `skip_tokenizer_init=True`. - See this issue for more: sgl-project/sglang#9039 (comment) Tested by: - `tests/gpu/test_policy_local_engines_e2e.py::test_policy_local_engines_e2e` - E2E run of search r1 in NovaSky-AI#152 is run with these changes - `run_gsm8k.sh` with sglang and vllm -- also made sure those unsupported fields in sglang will raise error as expected Note: It is worth noting that, if say the stop string is `</search>`, and vLLM generated a token equivalent to `>\n`, the string output of vLLM will truncate to only have `</search>` while the token ID output will still be `>\n`. This can be observed in this script: https://gist.github.com/CharlieFRuan/ca3a8fee388263f7e2f96bc89e2ee7f5


This PR ensures token-in-token-out rollout in most codepaths in SkyRL. The motivation of having token-in/out is described in this issue #123
Changes
1. Making sure the inference engines are doing token-in-token-out (returning
InferenceEngineOutput.response_ids)response_idswere already added for local vLLM engine in [FlashRL 1/N] Add support for truncated importance sampling #145, this PR supports this field for {vllm, sglang} x {remote engine, local engine}.--skip-tokenizer-initpassed to SGLang engine creation, we can use/generateendpoint, which can take in batch input via HTTP, performing token-in-token-out (almost exactly the same as the python.generate())/completionscan take token-in, it does not do token-out. vLLM's/generateis demo-only according to its doc. So vLLM server endpoint does not support token-in-token-out yetresponse_idsby re-encoding the text output from vLLM for a fake token-out/generateand vLLM uses/completions(neither use/chat/completions)response_idsandresponsesinInferenceEngineOutput, but users when writing their own generator should useresponse_ids(either to prepare next turn's input orGeneratorOutput) to respect token-in-token-out.responsesshould only be used to parse for tool calls (like a "read-only" thing)2. Making sure SkyRLGymGenerator respects token-in-token-out
After making sure the InferenceEngineClient is token-in-token-out and returns token IDs to the generator, we need to make sure the generator itself respects token-in-token-out while conforming to (multi-turn) chat templates.
generate_batched()for single-turn, andagent_loop()for both single-turn and multi-turngenerate_batched()now supports token-in-token-out by simply passingInferenceEngineOutput.response_idstoGeneratorOutputagent_loop()is more tricky, elaborated belowSkyRLGymGeneratorhas 3 ways of managing context inagent_loop(), configured bycfg.generator.use_conversation_multi_turnand whetherget_custom_chat_template()returns None (currently it means that the model is Qwen3).use_conversation_multi_turn==Trueandget_custom_chat_template() is not Noneuse_conversation_multi_turn == True, but has no custom chat templateuse_conversation_multi_turn == FalseThis PR ensures token-in-token-out for codepaths 2 and 3 (but not 1) and make the 3 codepaths more distinct and clearer in the code.
2.1. Always re-tokenize chat history (does not support token-in-token-out)
["assistant_masks"]and["input_ids"]from the final tokenized chat history with the help of{% generation %}and{% endgeneration %}tags in the jinja template2.2
use_conversation_multi_turn == Truewithout re-tokenizing chat history each turnIn this codepath, the agent loop does the following
input_idsinput_idsto LLM engine, getoutput_idsout (thanks to change 1)input_ids += output_ids(a.k.a. token-in-token-out) -- the next turn's input IDs are precisely what the LLM generatedenv.step()), and append toinput_ids-- this is a bit tricky, explained later belowenv.step()marks doneTo ensure that the observation is correctly tokenized, we follow the delta-based method proposed here
self.base_conversation_token_idsin__init__()and use that to getobservation_ids<|assistant|>something<|im_end|>\n<|user|>something...<|im_end|>, the chat template adds a\n\nby making sureself.base_conversation_token_idsends with the EOS token IDtests/cpu/generators/test_skyrl_gym_generator_chat_templating.py2.3
use_conversation_multi_turn == Falseself.tokenizer.encode(output)withoutput_idsfromInferenceEngineClient)._get_next_input_ids_with_single_turn_chat_template()on what this codepath doesTests
tests/gpu/test_skyrl_gym_generator.py(besides search and text2sql)test_engine_generation.pyworks (both remote and local)tests/cpu.generator/test_skyrl_gym_generator.py: this PR changes this test a lot because we no longer re-tokenize the LLM output (i.e. the mock tokenizer would be used much less). We should re-write this CPU test and only use a mock LLM, but use an actual tokenizertests/cpu/generators/test_skyrl_gym_generator_chat_templating.pyTODOs
TODOs till ready for a review
SkyRLGymGenerator.generate_batched()SkyRLGymGenerator.agent_loop()Future PRs
Tracked by #179
skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py, use an actual tokenizer. This should make the test cleaner. Can also parametrize tokenizer and test with different models.tests/cpu/generators/test_skyrl_gym_generator_chat_templating.pychat_historyinagent_loop()even when token-in-token-out, with the goal of doing a sanity check of equivalence between the maintainedinput_idsandchat_history