Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion skyrl-train/examples/mini_swe_agent/mini_swe_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from skyrl_train.inference_engines.utils import get_sampling_params_for_backend
from skyrl_train.generators.utils import (
get_rollout_metrics,
encode_messages_subset,
)


Expand Down Expand Up @@ -128,6 +129,9 @@ def __init__(
self.model_name = model_name
self.litellm_model_name = "openai/" + self.model_name

if self.generator_cfg.chat_template.name_or_path is not None:
raise NotImplementedError("MiniSWEAgentGenerator doesn't support custom chat template")

async def minisweagent_agent_loop(
self,
prompt: ConversationType,
Expand Down Expand Up @@ -182,7 +186,7 @@ async def minisweagent_agent_loop(

for message in response_messages:
# Apply chat template and tokenize each message
msg_encoding = self.tokenizer.apply_chat_template([message], add_generation_prompt=False, tokenize=True)
msg_encoding = encode_messages_subset([message], self.tokenizer)

# Extend response_ids with the tokens
response_ids.extend(msg_encoding)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an assertion of

initial_input_ids + respons_ids == self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=True)

And same for terminal bench?

At least a warning perhaps

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm such an assertion or warning can be misleading or incorrect because applying the chat template message by message can be pretty different from full conversation.

For models like Qwen 3 - the thinking tokens for previous messages in the history are discarded by default. Now, if we call encode_messages_subset on each message , we end up preserving hte think tokens for each message (even with base convo present it is okay).

But then with the RHS - the think tokens for previous messages are removed.

Now, I don't think either is the correct behaviour we want for on policy trainig, but in any case we shouldn't have this assertion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As such, for Qwen3 8B, I re-ran the mini swe agent example and it is okay - actually the previous expression was also correct, because for qwen 3 8B there is no default system prompt added;

print(self.tokenizer.apply_chat_template([{"role": "assistant", "content": "What is 1+1?"}], tokenize=False))
# '<|im_start|>assistant\nWhat is 1+1?<|im_end|>\n'

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the tests should be sufficient

Copy link
Collaborator

@CharlieFRuan CharlieFRuan Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For models like Qwen 3 - the thinking tokens for previous messages in the history are discarded by default. Now, if we call encode_messages_subset on each message , we end up preserving hte think tokens for each message (even with base convo present it is okay).

Good point... so the current behavior becomes, during inference we discard thinking tokens, and for training, we keep all thinking tokens.

Made an issue for this: #410

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import List
from skyrl_train.generators.base import GeneratorInterface, GeneratorInput, GeneratorOutput
from skyrl_train.generators.utils import get_rollout_metrics
from skyrl_train.generators.utils import get_rollout_metrics, encode_messages_subset
from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
from skyrl_train.inference_engines.base import ConversationType
from omegaconf import DictConfig
Expand Down Expand Up @@ -48,6 +48,9 @@ def __init__(
self.sandboxes_dir = terminal_bench_cfg.sandboxes_dir
self.max_episodes = terminal_bench_cfg.max_episodes

if self.generator_cfg.chat_template.name_or_path is not None:
raise NotImplementedError("TerminalBenchGenerator doesn't support custom chat template")

async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
# TODO(tgriggs): Plumb the sandboxes task list here instead of using (and ignoring) empty prompts
prompts = input_batch["prompts"]
Expand Down Expand Up @@ -134,7 +137,7 @@ async def terminal_bench_agent_loop(

for message in response_messages:
# Apply chat template and tokenize each message
msg_encoding = self.tokenizer.apply_chat_template([message], add_generation_prompt=False, tokenize=True)
msg_encoding = encode_messages_subset([message], self.tokenizer)

# Extend response_ids with the tokens
response_ids.extend(msg_encoding)
Expand Down
51 changes: 51 additions & 0 deletions skyrl-train/skyrl_train/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
import numpy as np
from skyrl_train.generators.base import GeneratorOutput, GeneratorInput, TrajectoryID, BatchMetadata, TrainingPhase
from skyrl_train.inference_engines.base import ConversationType
from omegaconf import DictConfig

CUSTOM_CHAT_TEMPLATES = {
Expand Down Expand Up @@ -267,3 +268,53 @@ def prepare_generator_input(
}

return generator_input, uids


def encode_messages_subset(messages: ConversationType, tokenizer):
"""Encodes a subset of messages from a multi-turn conversation using the fixed base approach.

This function tokenizes messages as if they are part of a larger conversation, ensuring
no additional default system messages are prepended by the tokenizer's chat template

The "fixed base approach" works by:
- Creating a dummy base conversation to establish context
- Appending the target messages to this base
- Tokenizing the full conversation and extracting only the tokens for the target messages

For simple chat templates without complex token splitting behavior, this produces the same
result as directly tokenizing the messages. For templates like Qwen's ChatML format where
a default system prompt can be appended, this ensures correct tokenization

Reference: https://jybsuper.github.io/posts/multiturn_tokenization/#the-breakthrough-fixed-base-approach

Args:
messages: List of message dicts with 'role' and 'content' keys. Must contain at least
one message. These are assumed to be a subset from a larger conversation.
tokenizer: HuggingFace tokenizer with chat_template support and eos_token_id defined.

Returns:
List[int]: Token IDs for the given messages, with proper multi-turn context handling.
"""
assert len(messages), "messages list cannot be empty"
# Follows https://jybsuper.github.io/posts/multiturn_tokenization/#the-breakthrough-fixed-base-approach
base_conversation = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "I am a user."},
]
if messages[0]["role"] != "assistant":
# add an assistant message as well if the first role is user/tool
base_conversation.append({"role": "assistant", "content": "I am an assistant."})
base_conversation_token_ids = tokenizer.apply_chat_template(
base_conversation,
add_generation_prompt=False,
tokenize=True,
)

full_conversation = base_conversation + messages
full_conversation_token_ids = tokenizer.apply_chat_template(
full_conversation,
add_generation_prompt=False,
tokenize=True,
)
conversation_token_ids = full_conversation_token_ids[len(base_conversation_token_ids) :]
return conversation_token_ids
93 changes: 92 additions & 1 deletion skyrl-train/tests/cpu/generators/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

import pytest
from skyrl_train.generators.utils import apply_overlong_filtering
from skyrl_train.generators.utils import apply_overlong_filtering, encode_messages_subset
from transformers import AutoTokenizer


@pytest.mark.parametrize(
Expand Down Expand Up @@ -116,3 +117,93 @@ def test_apply_overlong_filtering_length_mismatch_assertion(loss_masks, response
eos_token_id = 4
with pytest.raises(AssertionError, match="loss_masks and response_ids must have the same length"):
apply_overlong_filtering(loss_masks, response_ids, eos_token_id)


dummy_chat_template = (
"{%- for message in messages %}"
"{%- if message['role'] == 'user' %}"
"<USER>{{ message['content'] }}</s>\n"
"{%- elif message['role'] == 'assistant' %}"
"<ASSISTANT>{{ message['content'] }}</s>\n"
"{%- elif message['role'] == 'system' %}"
"<SYSTEM>{{ message['content'] }}</s>\n"
"{%- endif %}"
"{%- endfor %}"
"{%- if add_generation_prompt %}"
"<ASSISTANT>"
"{%- endif %}"
)


@pytest.fixture
def tokenizer_w_dummy_template():
tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b")
tokenizer.chat_template = dummy_chat_template
return tokenizer


@pytest.mark.parametrize(
"messages",
[
# Test case 1: Single assistant message
[{"role": "assistant", "content": "Hello, I can help you."}],
# Test case 2: Single user message
[{"role": "user", "content": "What is the weather today?"}],
# Test case 3: Multiple messages (user-assistant exchange)
[{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "The answer is 4."}],
# Test case 4: Multiple messages starting with assistant
[
{"role": "assistant", "content": "I'm here to help."},
{"role": "user", "content": "Can you explain Python?"},
{"role": "assistant", "content": "Python is a programming language."},
],
],
)
def test_encode_messages(messages, tokenizer_w_dummy_template):
# For a simple chat template, the fixed base approach is expected to behave the same
# as `apply_chat_template`
expected_token_ids = tokenizer_w_dummy_template.apply_chat_template(messages)
actual_token_ids = encode_messages_subset(messages, tokenizer_w_dummy_template)
assert expected_token_ids == actual_token_ids


@pytest.fixture
def qwen_tokenizer():
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")


@pytest.mark.parametrize(
"messages, expected_str",
[
# Test case 1: Single assistant message
(
[{"role": "assistant", "content": "Hello, I can help you."}],
"<|im_start|>assistant\nHello, I can help you.<|im_end|>\n",
),
# Test case 2: Single user message - additional \n because the expectation is that there is a previous assistant turn
(
[{"role": "user", "content": "What is the weather today?"}],
"<|im_start|>user\nWhat is the weather today?<|im_end|>\n",
),
# Test case 3: Multiple messages (user-assistant exchange)
(
[{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "The answer is 4."}],
# NOTE: Additional \n because the expectation is that there is a previous assistant turn.
# All tokens after EOS in the previous turn get pushed into the next user/tool message.
"<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n",
),
# Test case 4: Multiple messages starting with assistant
(
[
{"role": "assistant", "content": "I'm here to help."},
{"role": "user", "content": "Can you explain Python?"},
{"role": "assistant", "content": "Python is a programming language."},
],
"<|im_start|>assistant\nI'm here to help.<|im_end|>\n<|im_start|>user\nCan you explain Python?<|im_end|>\n<|im_start|>assistant\nPython is a programming language.<|im_end|>\n",
),
],
)
def test_encode_messages_qwen(messages, expected_str, qwen_tokenizer):
expected_token_ids = qwen_tokenizer.encode(expected_str, add_special_tokens=False)
actual_token_ids = encode_messages_subset(messages, qwen_tokenizer)
assert expected_token_ids == actual_token_ids, f"Got actual tokens: {qwen_tokenizer.decode(actual_token_ids)}"