Skip to content

[Generator] Support token-in-token-out rollout#152

Merged
CharlieFRuan merged 14 commits intoNovaSky-AI:mainfrom
CharlieFRuan:pr-0815-token-in-out
Aug 22, 2025
Merged

[Generator] Support token-in-token-out rollout#152
CharlieFRuan merged 14 commits intoNovaSky-AI:mainfrom
CharlieFRuan:pr-0815-token-in-out

Conversation

@CharlieFRuan
Copy link
Collaborator

@CharlieFRuan CharlieFRuan commented Aug 16, 2025

This PR ensures token-in-token-out rollout in most codepaths in SkyRL. The motivation of having token-in/out is described in this issue #123

Changes

1. Making sure the inference engines are doing token-in-token-out (returning InferenceEngineOutput.response_ids)

  • While response_ids were already added for local vLLM engine in [FlashRL 1/N] Add support for truncated importance sampling #145, this PR supports this field for {vllm, sglang} x {remote engine, local engine}.
    • SGLang remote engine:
      • with --skip-tokenizer-init passed to SGLang engine creation, we can use /generate endpoint, which can take in batch input via HTTP, performing token-in-token-out (almost exactly the same as the python .generate())
    • SGLang local engine: same thing
    • vLLM remote engine:
      • While /completions can take token-in, it does not do token-out. vLLM's /generate is demo-only according to its doc. So vLLM server endpoint does not support token-in-token-out yet
      • Instead, we do token-in, and prepare response_ids by re-encoding the text output from vLLM for a fake token-out
      • This will be fixed after this PR lands in vLLM: Add return_token_ids parameter to OpenAI API endpoints vllm-project/vllm#22587
    • That is, SGLang uses /generate and vLLM uses /completions (neither use /chat/completions)
  • We always return both response_ids and responses in InferenceEngineOutput, but users when writing their own generator should use response_ids (either to prepare next turn's input or GeneratorOutput) to respect token-in-token-out. responses should only be used to parse for tool calls (like a "read-only" thing)

2. Making sure SkyRLGymGenerator respects token-in-token-out

After making sure the InferenceEngineClient is token-in-token-out and returns token IDs to the generator, we need to make sure the generator itself respects token-in-token-out while conforming to (multi-turn) chat templates.

  • There are 2 codepaths in SkyRLGymGenerator, generate_batched() for single-turn, and agent_loop() for both single-turn and multi-turn
    • generate_batched() now supports token-in-token-out by simply passing InferenceEngineOutput.response_ids to GeneratorOutput
    • agent_loop() is more tricky, elaborated below

SkyRLGymGenerator has 3 ways of managing context in agent_loop(), configured by cfg.generator.use_conversation_multi_turn and whether get_custom_chat_template() returns None (currently it means that the model is Qwen3).

  1. Always re-tokenize chat history, when use_conversation_multi_turn==True and get_custom_chat_template() is not None
  2. use_conversation_multi_turn == True, but has no custom chat template
  3. use_conversation_multi_turn == False

This PR ensures token-in-token-out for codepaths 2 and 3 (but not 1) and make the 3 codepaths more distinct and clearer in the code.

2.1. Always re-tokenize chat history (does not support token-in-token-out)

  • This codepath currently serves Qwen3 models, where we always re-tokenize the chat history for each turn
  • There are two reasons for this:
    • Qwen3 removes non-last turn thinking tokens
    • Able to get ["assistant_masks"] and ["input_ids"] from the final tokenized chat history with the help of {% generation %} and {% endgeneration %} tags in the jinja template
  • That is, this codepath rollouts Qwen3 by following the inference chat template, and returns only the last-turn thinking tokens to Generator for the training pipeline
  • Comment:
    • It is debatable whether this is the best method to train Qwen3 -- will revisit this codepath in a future PR
    • TODO: Users should be able to train Qwen3 following 2.2 by default, and follow 2.1 by setting a config flag

2.2 use_conversation_multi_turn == True without re-tokenizing chat history each turn

  • In this codepath, the agent loop does the following

    1. Tokenize dataset's prompt to initialize input_ids
    2. Feed input_ids to LLM engine, get output_ids out (thanks to change 1)
    3. input_ids += output_ids (a.k.a. token-in-token-out) -- the next turn's input IDs are precisely what the LLM generated
    4. Tokenize observations got from SkyRL-Gym's environment output (i.e. env.step()), and append to input_ids -- this is a bit tricky, explained later below
    5. Repeat 2-4 until env.step() marks done
  • To ensure that the observation is correctly tokenized, we follow the delta-based method proposed here

    • We instantiate self.base_conversation_token_ids in __init__() and use that to get observation_ids
    • One tricky thing is how Qwen models will do: <|assistant|>something<|im_end|>\n<|user|>something...
    • That is, after the assistant generates the EOS token <|im_end|>, the chat template adds a \n
    • We need to manually add this \n by making sure self.base_conversation_token_ids ends with the EOS token ID
    • For more, see changes in tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py

2.3 use_conversation_multi_turn == False

  • This PR makes this codepath support token-in-token-out without much change (replace self.tokenizer.encode(output) with output_ids from InferenceEngineClient).
  • See _get_next_input_ids_with_single_turn_chat_template() on what this codepath does

Tests

  • GPU tests
    • tests/gpu/test_skyrl_gym_generator.py (besides search and text2sql)
    • test_engine_generation.py works (both remote and local)
  • CPU tests all passed
    • tests/cpu.generator/test_skyrl_gym_generator.py: this PR changes this test a lot because we no longer re-tokenize the LLM output (i.e. the mock tokenizer would be used much less). We should re-write this CPU test and only use a mock LLM, but use an actual tokenizer
    • tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py
    • We should consider testing more than Qwen and Llama3 tokenizers

TODOs

TODOs till ready for a review

  • Support token-in-token-out for SkyRLGymGenerator.generate_batched()
  • Support token-in-token-out for SkyRLGymGenerator.agent_loop()
  • Make remote engines return tokens
  • Thoroughly test each codepath

Future PRs
Tracked by #179

  • CPU testing
    • Do not mock tokenizer in skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py, use an actual tokenizer. This should make the test cleaner. Can also parametrize tokenizer and test with different models.
    • Consider adding tokenizers beyond Qwen and Llama
    • Add single-turn code path to tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py
  • Deprecate post-processed action for SkyRL Gym to ensure token-in-token-out
  • Re-consider Qwen3 training, should it follow the "always re-tokenize chat history" codepath? Maybe by default make it follow 2.2 and only follow 2.1 with a config flag
  • Add a documentation page detailing the behavior of SkyRLGymGenerator, where the content is mainly this PR description and the current chunks of comments. With the doc, we can remove those chunks of comments in the file and keep it more lightweight
  • Consider maintaining string-based chat_history in agent_loop() even when token-in-token-out, with the goal of doing a sanity check of equivalence between the maintained input_ids and chat_history

gemini-code-assist[bot]

This comment was marked as outdated.

Comment on lines +81 to +86
if self.tokenizer.eos_token_id in self.base_conversation_token_ids:
last_eos_token_index = (
len(self.base_conversation_token_ids)
- 1
- self.base_conversation_token_ids[::-1].index(self.tokenizer.eos_token_id)
)
self.base_conversation_token_ids = self.base_conversation_token_ids[: last_eos_token_index + 1]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to highlight this change, slightly unsure about this

@CharlieFRuan
Copy link
Collaborator Author

/gemini summary

@CharlieFRuan
Copy link
Collaborator Author

/gemini review

@gemini-code-assist
Copy link
Contributor

Summary of Changes

This pull request implements a comprehensive "token-in-token-out" rollout across the SkyRL framework, ensuring that language model inference engines and the SkyRLGymGenerator consistently operate using token IDs rather than raw text strings for improved efficiency and correctness in multi-turn interactions. This change standardizes the data flow, particularly for managing conversational context and observations, while addressing specific backend behaviors for vLLM and SGLang.

Highlights

  • Standardized Token-in-Token-out: Inference engines now consistently return response_ids (token IDs) alongside text responses, ensuring a token-based data flow for subsequent operations.
  • Backend-Specific Token Handling: SGLang remote and local engines are configured to use their /generate endpoint for native token-in-token-out. vLLM remote engine, due to current API limitations, re-encodes its text output into token IDs to simulate token-out, with a future fix noted.
  • Enhanced SkyRLGymGenerator: The generator has been updated to fully leverage token IDs for managing conversational context, especially within multi-turn agent_loop() scenarios.
  • Refactored Context Management: The agent_loop() in SkyRLGymGenerator now clearly distinguishes and handles three distinct codepaths for managing chat history and input IDs, including a delta-based approach for observations in multi-turn conversations.
  • Improved Test Coverage: CPU and GPU tests have been updated to reflect and validate the new token-in-token-out behavior, including adjustments to mocks and assertions.
Changelog
  • skyrl-train/skyrl_train/entrypoints/main_base.py
    • The create_remote_inference_engines_from_config function now requires a tokenizer argument to facilitate tokenization within remote inference engines.
  • skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
    • The init method now initializes self.base_conversation_token_ids to support delta-based tokenization for observations in multi-turn conversations.
    • The agent_loop method has been significantly refactored to consume response_ids directly from inference engines and to manage conversational context through three distinct, token-in-token-out compliant codepaths.
    • The generate_batched method has been updated to directly use response_ids from the inference engine output.
    • New helper methods _get_next_input_ids_by_retokenizing_chat_history, _get_next_input_ids_with_multiturn_chat_template, and _get_next_input_ids_with_single_turn_chat_template were introduced to clarify context management logic.
  • skyrl-train/skyrl_train/inference_engines/base.py
    • The InferenceEngineOutput TypedDict now explicitly requires response_ids to always be present, standardizing the output format.
  • skyrl-train/skyrl_train/inference_engines/inference_engine_client.py
    • The _generate_with_trajectory_routing and _generate_batched methods were simplified, removing conditional checks for response_ids as they are now guaranteed.
  • skyrl-train/skyrl_train/inference_engines/remote_inference_engine.py
    • The RemoteInferenceEngine constructor now accepts a tokenizer to enable internal tokenization.
    • The generate method now handles prompt tokenization if text prompts are provided and implements backend-specific logic for retrieving response_ids from vLLM (via re-encoding text output) and SGLang (direct token IDs).
    • The create_remote_inference_engines function now passes the tokenizer to the engine instances.
  • skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py
    • The SGLang engine is now initialized with skip_tokenizer_init=True to enforce token-in-token-out behavior, and its _postprocess_outputs method explicitly returns response_ids.
  • skyrl-train/skyrl_train/inference_engines/sglang/sglang_server.py
    • The SGLang server startup script now automatically appends --skip-tokenizer-init to ensure the server operates in token-in-token-out mode.
  • skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py
    • Mock objects for tokenizer and LLM were updated to return and expect token IDs, and test assertions were adjusted to validate the new token-based outputs and loss mask lengths.
  • skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py
    • The mock LLM now returns response_ids, and assertions were refined to account for specific chat template behaviors (e.g., Qwen's newline character after EOS token) when comparing tokenized and string outputs.
  • skyrl-train/tests/gpu/test_engine_generation.py
    • The init_remote_inference_servers function now accepts and passes a tokenizer, sampling parameters were updated to include logprobs: None, and the vLLM distributed executor backend was changed from ray to mp.
  • skyrl-train/tests/gpu/test_skyrl_gym_generator.py
    • Sampling parameters were updated to include logprobs: None, and test parametrization was adjusted to focus on single-turn generate_batched and agent_loop scenarios.
Activity
  • CharlieFRuan requested a summary of the pull request.
  • CharlieFRuan requested a review of the pull request.
  • gemini-code-assist[bot] provided a review comment suggesting the use of the logging module instead of print for warnings.
  • CharlieFRuan added a review comment highlighting a specific change and expressing slight uncertainty about it.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant and valuable changes to support token-in-token-out generation, which is crucial for consistency in multi-turn scenarios. The refactoring, particularly in SkyRLGymGenerator, greatly improves the clarity and structure of the code by separating the different conversation management strategies. My review has identified a key performance issue in the remote vLLM engine client where batch requests are sent sequentially, as well as several areas in the SkyRLGymGenerator where docstrings are either missing or inconsistent with the implementation, which could impact future maintainability. The proposed suggestions aim to fix the performance bottleneck and improve the documentation's accuracy.

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking pretty good!

@@ -14,9 +14,15 @@ class InferenceEngineInput(TypedDict):


class InferenceEngineOutput(TypedDict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have small nits on this paragraph just for brevity / clarity, but will draft something in next round of updates :)

# By cutting `\n` out in `base_conversation_token_ids`, `observation_ids` in
# `_get_next_input_ids_with_multiturn_chat_template()` will be `\n<|im_start|>user\nObservation here<|im_end|>\n`.
# Note the `\n` at the final assistant turn will still be missing, but this is fine.
# See tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py for more details.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These explanations are important and useful, but I actually think they should be in a doc with a single comment linking to them. They do add quite a lot of blot to the fast-growing generator file :D I think the skyrl gym generator has reached sufficient complexity to add a doc detailing it's behavior. At the least, we could start with a briefer doc on "multi-turn chat tokenization" and discuss the three possible methods you have here. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a doc should definitely be needed. I will add that to a follow-up PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks! Once added, can you please remove some of the comments here and just link to the doc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep will do!

@CharlieFRuan CharlieFRuan marked this pull request as ready for review August 21, 2025 19:54
@tyler-griggs
Copy link
Member

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant and valuable changes to support token-in-token-out generation throughout SkyRL. The refactoring of SkyRLGymGenerator greatly improves clarity by separating the different context management strategies. The updates to the remote inference engines to support batching are also a major enhancement. My review includes a couple of suggestions to further improve code quality by using standard library features for warnings and removing leftover development artifacts.

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nits. Let's test e2e and merge?

# By cutting `\n` out in `base_conversation_token_ids`, `observation_ids` in
# `_get_next_input_ids_with_multiturn_chat_template()` will be `\n<|im_start|>user\nObservation here<|im_end|>\n`.
# Note the `\n` at the final assistant turn will still be missing, but this is fine.
# See tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py for more details.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks! Once added, can you please remove some of the comments here and just link to the doc?


env.close() # does nothing for now

# TODO(Charlie): this makes the prompt_ids include the generation prompt, is this what we want?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I believe it is. Primarily because maintaining the generation prompt is "on policy". If we remove those tokens, it will not actually reflect what the model observed when making its generation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR the generation prompt is in response_ids (with proper maskings). After this PR it is in prompt_ids, hence the comment here. I am guessing it is equivalent?

@CharlieFRuan
Copy link
Collaborator Author

CharlieFRuan commented Aug 22, 2025

(Will in-place edit this comment / PR description to add full results and details when runs finish.)

E2E run with search-r1 with 4 turns.

use_conversation_multi_turn == False (codepath 2.3):
image

use_conversation_multi_turn == True (codepath 2.2):
image

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amazing

@CharlieFRuan CharlieFRuan merged commit 7246726 into NovaSky-AI:main Aug 22, 2025
3 checks passed
@CharlieFRuan CharlieFRuan deleted the pr-0815-token-in-out branch August 22, 2025 23:01
tyler-griggs pushed a commit that referenced this pull request Aug 27, 2025
…zation (#186)

This PR adds a documentation page detailing the behavior of
`SkyRLGymGenerator`, including the 3 codepaths described in
#152, how multi-turn and
single-turn rollouts/tokenizations work, and how token-in-token-out is
enforced.

As a result, we can remove some bulky comments in
`skyrl_gym_generator.py`

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
SumanthRH pushed a commit that referenced this pull request Aug 28, 2025
…rocessing action in search and txt2sql environments (#190)

Motivated by token-in-token-out in
#152, we want to avoid
post-processing the action in `Env.step()`, which happens in the string
space and causes re-tokenization.

Currently, such post processing happens in our search and txt2sql
environemnt. To fix it for these two environments, we add
`cfg.generator.sampling_params.stop` and
`cfg.generator.eval_sampling_params.stop`. This will be equivalent to
the post-processing done in the search and txt2sql environment, which
simply strips after tags like `"</search>"`.

Furthermore, for these two environments, if
`generator.use_conversation_multi_turn=true`, we need to append an EOS
token ID after these stop strings to adhere to the chat template, e.g.
ending in `</answer><|im_end|>`. This is only required in codepath 1
(out of the 3): https://github.com/NovaSky-AI/SkyRL/pull/186/files.
Codepath 3 (always retokenize) is not needed because the chat template
always applies eos to the end of message content. We do such append when
`cfg.generator.append_eos_token_after_stop_str_in_multi_turn` is `True`
and add `run_search_multiturn.sh` and `run_skyrl_sql_multiturn.sh`
correspondingly.

In addition, this PR:
- Add `"no_stop_trim": True,` to SGLang sampling params, which is
equivalent to vLLM's `include_stop_str_in_output`.
- Throw error when `min_new_tokens` or `stop` is used with SGLang
backend, since SGLang will throw an error when these sampling parameters
are used when the engine is initialized with `skip_tokenizer_init=True`.
- See this issue for more:
sgl-project/sglang#9039 (comment)

Tested by:
-
`tests/gpu/test_policy_local_engines_e2e.py::test_policy_local_engines_e2e`
- E2E run of search r1 in #152
is run with these changes
- `run_gsm8k.sh` with sglang and vllm -- also made sure those
unsupported fields in sglang will raise error as expected

Note:
It is worth noting that, if say the stop string is `</search>`, and vLLM
generated a token equivalent to `>\n`, the string output of vLLM will
truncate to only have `</search>` while the token ID output will still
be `>\n`. This can be observed in this script:
https://gist.github.com/CharlieFRuan/ca3a8fee388263f7e2f96bc89e2ee7f5
ztcanddota added a commit to ztcanddota/skyagent that referenced this pull request Sep 28, 2025
…zation (#186)

This PR adds a documentation page detailing the behavior of
`SkyRLGymGenerator`, including the 3 codepaths described in
NovaSky-AI/SkyRL#152, how multi-turn and
single-turn rollouts/tokenizations work, and how token-in-token-out is
enforced.

As a result, we can remove some bulky comments in
`skyrl_gym_generator.py`

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
ztcanddota added a commit to ztcanddota/skyagent that referenced this pull request Sep 28, 2025
…rocessing action in search and txt2sql environments (#190)

Motivated by token-in-token-out in
NovaSky-AI/SkyRL#152, we want to avoid
post-processing the action in `Env.step()`, which happens in the string
space and causes re-tokenization.

Currently, such post processing happens in our search and txt2sql
environemnt. To fix it for these two environments, we add
`cfg.generator.sampling_params.stop` and
`cfg.generator.eval_sampling_params.stop`. This will be equivalent to
the post-processing done in the search and txt2sql environment, which
simply strips after tags like `"</search>"`.

Furthermore, for these two environments, if
`generator.use_conversation_multi_turn=true`, we need to append an EOS
token ID after these stop strings to adhere to the chat template, e.g.
ending in `</answer><|im_end|>`. This is only required in codepath 1
(out of the 3): https://github.com/NovaSky-AI/SkyRL/pull/186/files.
Codepath 3 (always retokenize) is not needed because the chat template
always applies eos to the end of message content. We do such append when
`cfg.generator.append_eos_token_after_stop_str_in_multi_turn` is `True`
and add `run_search_multiturn.sh` and `run_skyrl_sql_multiturn.sh`
correspondingly.

In addition, this PR:
- Add `"no_stop_trim": True,` to SGLang sampling params, which is
equivalent to vLLM's `include_stop_str_in_output`.
- Throw error when `min_new_tokens` or `stop` is used with SGLang
backend, since SGLang will throw an error when these sampling parameters
are used when the engine is initialized with `skip_tokenizer_init=True`.
- See this issue for more:
sgl-project/sglang#9039 (comment)

Tested by:
-
`tests/gpu/test_policy_local_engines_e2e.py::test_policy_local_engines_e2e`
- E2E run of search r1 in NovaSky-AI/SkyRL#152
is run with these changes
- `run_gsm8k.sh` with sglang and vllm -- also made sure those
unsupported fields in sglang will raise error as expected

Note:
It is worth noting that, if say the stop string is `</search>`, and vLLM
generated a token equivalent to `>\n`, the string output of vLLM will
truncate to only have `</search>` while the token ID output will still
be `>\n`. This can be observed in this script:
https://gist.github.com/CharlieFRuan/ca3a8fee388263f7e2f96bc89e2ee7f5
SungjunlaLee added a commit to SungjunlaLee/SkyRL that referenced this pull request Jan 3, 2026
…zation (#186)

This PR adds a documentation page detailing the behavior of
`SkyRLGymGenerator`, including the 3 codepaths described in
NovaSky-AI/SkyRL#152, how multi-turn and
single-turn rollouts/tokenizations work, and how token-in-token-out is
enforced.

As a result, we can remove some bulky comments in
`skyrl_gym_generator.py`

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
SungjunlaLee added a commit to SungjunlaLee/SkyRL that referenced this pull request Jan 3, 2026
…rocessing action in search and txt2sql environments (#190)

Motivated by token-in-token-out in
NovaSky-AI/SkyRL#152, we want to avoid
post-processing the action in `Env.step()`, which happens in the string
space and causes re-tokenization.

Currently, such post processing happens in our search and txt2sql
environemnt. To fix it for these two environments, we add
`cfg.generator.sampling_params.stop` and
`cfg.generator.eval_sampling_params.stop`. This will be equivalent to
the post-processing done in the search and txt2sql environment, which
simply strips after tags like `"</search>"`.

Furthermore, for these two environments, if
`generator.use_conversation_multi_turn=true`, we need to append an EOS
token ID after these stop strings to adhere to the chat template, e.g.
ending in `</answer><|im_end|>`. This is only required in codepath 1
(out of the 3): https://github.com/NovaSky-AI/SkyRL/pull/186/files.
Codepath 3 (always retokenize) is not needed because the chat template
always applies eos to the end of message content. We do such append when
`cfg.generator.append_eos_token_after_stop_str_in_multi_turn` is `True`
and add `run_search_multiturn.sh` and `run_skyrl_sql_multiturn.sh`
correspondingly.

In addition, this PR:
- Add `"no_stop_trim": True,` to SGLang sampling params, which is
equivalent to vLLM's `include_stop_str_in_output`.
- Throw error when `min_new_tokens` or `stop` is used with SGLang
backend, since SGLang will throw an error when these sampling parameters
are used when the engine is initialized with `skip_tokenizer_init=True`.
- See this issue for more:
sgl-project/sglang#9039 (comment)

Tested by:
-
`tests/gpu/test_policy_local_engines_e2e.py::test_policy_local_engines_e2e`
- E2E run of search r1 in NovaSky-AI/SkyRL#152
is run with these changes
- `run_gsm8k.sh` with sglang and vllm -- also made sure those
unsupported fields in sglang will raise error as expected

Note:
It is worth noting that, if say the stop string is `</search>`, and vLLM
generated a token equivalent to `>\n`, the string output of vLLM will
truncate to only have `</search>` while the token ID output will still
be `>\n`. This can be observed in this script:
https://gist.github.com/CharlieFRuan/ca3a8fee388263f7e2f96bc89e2ee7f5
dzorlu referenced this pull request in fleet-ai/SkyRL Feb 4, 2026
This PR ensures token-in-token-out rollout in most codepaths in SkyRL.

We make various codepaths of inference engines return token IDs, and make the 3 codepaths in SkyRLGymGenerator clearer while enforcing token-in-token-out.
dzorlu pushed a commit to fleet-ai/SkyRL that referenced this pull request Feb 4, 2026
…zation (#186)

This PR adds a documentation page detailing the behavior of
`SkyRLGymGenerator`, including the 3 codepaths described in
NovaSky-AI#152, how multi-turn and
single-turn rollouts/tokenizations work, and how token-in-token-out is
enforced.

As a result, we can remove some bulky comments in
`skyrl_gym_generator.py`

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
dzorlu pushed a commit to fleet-ai/SkyRL that referenced this pull request Feb 4, 2026
…rocessing action in search and txt2sql environments (#190)

Motivated by token-in-token-out in
NovaSky-AI#152, we want to avoid
post-processing the action in `Env.step()`, which happens in the string
space and causes re-tokenization.

Currently, such post processing happens in our search and txt2sql
environemnt. To fix it for these two environments, we add
`cfg.generator.sampling_params.stop` and
`cfg.generator.eval_sampling_params.stop`. This will be equivalent to
the post-processing done in the search and txt2sql environment, which
simply strips after tags like `"</search>"`.

Furthermore, for these two environments, if
`generator.use_conversation_multi_turn=true`, we need to append an EOS
token ID after these stop strings to adhere to the chat template, e.g.
ending in `</answer><|im_end|>`. This is only required in codepath 1
(out of the 3): https://github.com/NovaSky-AI/SkyRL/pull/186/files.
Codepath 3 (always retokenize) is not needed because the chat template
always applies eos to the end of message content. We do such append when
`cfg.generator.append_eos_token_after_stop_str_in_multi_turn` is `True`
and add `run_search_multiturn.sh` and `run_skyrl_sql_multiturn.sh`
correspondingly.

In addition, this PR:
- Add `"no_stop_trim": True,` to SGLang sampling params, which is
equivalent to vLLM's `include_stop_str_in_output`.
- Throw error when `min_new_tokens` or `stop` is used with SGLang
backend, since SGLang will throw an error when these sampling parameters
are used when the engine is initialized with `skip_tokenizer_init=True`.
- See this issue for more:
sgl-project/sglang#9039 (comment)

Tested by:
-
`tests/gpu/test_policy_local_engines_e2e.py::test_policy_local_engines_e2e`
- E2E run of search r1 in NovaSky-AI#152
is run with these changes
- `run_gsm8k.sh` with sglang and vllm -- also made sure those
unsupported fields in sglang will raise error as expected

Note:
It is worth noting that, if say the stop string is `</search>`, and vLLM
generated a token equivalent to `>\n`, the string output of vLLM will
truncate to only have `</search>` while the token ID output will still
be `>\n`. This can be observed in this script:
https://gist.github.com/CharlieFRuan/ca3a8fee388263f7e2f96bc89e2ee7f5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants