Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
3a4b930
Added config for custom chat template (will make easier in subsequent…
devpatelio Aug 21, 2025
930cb19
Added test script for masking and demasking, seems to work well!
devpatelio Aug 21, 2025
700d7c0
Update run_gsm8k_thinking.sh
devpatelio Aug 21, 2025
64d6f7e
updated new
devpatelio Aug 21, 2025
c27930b
Added support for custom masking training with templates for batching…
devpatelio Aug 21, 2025
20e63b7
Fixed bug to append custom response for multi-turn
devpatelio Aug 22, 2025
66c23c9
stash changes
devpatelio Aug 23, 2025
7a22332
Updated config to support names in custom_chat_template or pass in ji…
devpatelio Aug 23, 2025
4233da9
fixed rebase
devpatelio Aug 23, 2025
14d8b92
nit: repetition
devpatelio Aug 23, 2025
b8cea88
done
devpatelio Aug 23, 2025
6a38f6d
fixed refactor issues with some misnamed variables from the previous …
devpatelio Aug 24, 2025
f7efc37
additional rebase fixes
devpatelio Aug 24, 2025
a077e8d
removed print statements and debug scripts
devpatelio Aug 24, 2025
7d41b60
deleted extra jinja template file
devpatelio Aug 24, 2025
5ab4dde
fixed bug where prompts were being processed by tokenizer as 1 elemen…
devpatelio Aug 24, 2025
58082a4
list comprehension
devpatelio Aug 24, 2025
fb7b2ea
PR comments applied
devpatelio Aug 24, 2025
dcffdaf
Remove assignment of enable_thinking_tokens variable
devpatelio Aug 26, 2025
9663547
removed uv.lock file
tyler-griggs Aug 26, 2025
b3398b9
removed redundant model_name
tyler-griggs Aug 27, 2025
0f80225
Update skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
devpatelio Aug 27, 2025
f2711a4
Update skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat…
devpatelio Aug 27, 2025
b13a647
done
tyler-griggs Aug 27, 2025
d2bbe58
fix by using without thinking so no parsing problem, like before
tyler-griggs Aug 27, 2025
c1d0679
fixed the test defaults to check without tokenization
tyler-griggs Aug 27, 2025
cee373c
allow model_name to pass to SkyRLGymGenerator, do not use in get_cust…
tyler-griggs Aug 27, 2025
99aeb3c
eos token for merge conflict
tyler-griggs Aug 29, 2025
d2ee9cd
Merge branch 'main' into devpatel/skyrl-issue-104
devpatelio Aug 29, 2025
85f946f
looks like a new merge conflict removed default logprobs truncation f…
tyler-griggs Aug 29, 2025
c0baade
removed redundant appends
tyler-griggs Aug 29, 2025
7dc2790
cleanup
tyler-griggs Aug 29, 2025
169ec5b
removed extra file + cleanup
tyler-griggs Aug 29, 2025
8a001b6
Update skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
devpatelio Aug 29, 2025
1922cb2
fixed issue of overriding log-probs
tyler-griggs Aug 29, 2025
f6e1be4
Update skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
devpatelio Aug 29, 2025
81cee5e
added
tyler-griggs Aug 29, 2025
6e75e09
reverted back to working base, no redundant code
devpatelio Sep 5, 2025
29d1446
condensed the if else logic
devpatelio Sep 5, 2025
12dbd9c
Merge branch 'main' into devpatel/skyrl-issue-104
devpatelio Sep 5, 2025
9aed08b
keep the last message thinking in a multi-turn conversation
devpatelio Sep 16, 2025
96d3ade
linter
devpatelio Sep 16, 2025
f7e5d28
fixed generator texts to support new format
devpatelio Sep 16, 2025
9e83a00
fixed comment
devpatelio Sep 16, 2025
2a06dbb
checked test, default qwen model has the newline in the unit test so …
devpatelio Sep 16, 2025
61e398e
added unit test for qwen3_without_thinking and default chat template …
devpatelio Sep 17, 2025
4f1eef9
fixed tests to support the config updates
devpatelio Sep 17, 2025
4bf821d
additional fix for get
devpatelio Sep 17, 2025
35459ca
removed unnecessary dictionary return style by not using custom_chat_…
devpatelio Sep 17, 2025
1ecc27e
Update skyrl-train/skyrl_train/generators/utils.py
devpatelio Sep 17, 2025
bccb846
addressed new review of comments
devpatelio Sep 17, 2025
64c9031
run_gsm8k update to defaults
devpatelio Sep 17, 2025
c9e92f5
added mock tokenizer test for custom chat template with jinja2 file
devpatelio Sep 17, 2025
63b4a23
applied nits and removed test redundancy for jinja2
devpatelio Sep 23, 2025
a57d7de
resolved ppo merge conflict
devpatelio Sep 23, 2025
162dcb5
resolved ppo merge conflict
devpatelio Sep 23, 2025
3fad0cb
Merge branch 'main' into devpatel/skyrl-issue-104
devpatelio Sep 23, 2025
8044b1d
linter
devpatelio Sep 23, 2025
5d873b3
FINAL CHAT UPDATES
devpatelio Oct 4, 2025
d3c735d
merge conflict
devpatelio Oct 4, 2025
ff09fd5
Merge branch 'main' into devpatel/skyrl-issue-104
devpatelio Oct 4, 2025
ecbf3e9
inter
devpatelio Oct 4, 2025
6f40559
last nigts
devpatelio Oct 4, 2025
bdea263
great
devpatelio Oct 4, 2025
e7f6be7
Update skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat…
devpatelio Oct 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ generator:
http_endpoint_port: 8000
max_turns: 1

# chat template configuration
chat_template:
source: "name" # "name" or "file"
name_or_path: null # e.g., "qwen3_with_thinking" or "/path/to/template.j2"
# Inference engine arguments. Arguments are passed directly to the vLLM or SGLang engine, so names must match
# the engine's args. To specify an engine arg in the CLI override, use the format: +generator.engine_init_kwargs.arg_name=value
engine_init_kwargs: {}
Expand Down
3 changes: 1 addition & 2 deletions skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def __init__(
self.max_turns = generator_cfg.max_turns
self.batched = generator_cfg.batched
self.use_conversation_multi_turn = generator_cfg.use_conversation_multi_turn

# optionally use custom chat template to get loss masks (i.e. for Qwen3)
self.custom_chat_template = get_custom_chat_template(model_name)
self.custom_chat_template = get_custom_chat_template(generator_cfg.chat_template)
# get generation prompt ids for the tokenizer if needed
self.generation_prompt_ids = get_generation_prompt_ids(tokenizer) if self.use_conversation_multi_turn else None
if self.skyrl_gym_cfg.max_env_workers > 0:
Expand Down
61 changes: 54 additions & 7 deletions skyrl-train/skyrl_train/generators/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
import torch
from typing import List, Tuple, Union, Dict, Any
from typing import List, Tuple, Union, Optional, Dict, Any
from collections import defaultdict
import numpy as np
from skyrl_train.generators.base import GeneratorOutput, GeneratorInput, TrajectoryID, BatchMetadata, TrainingPhase
from omegaconf import DictConfig

CUSTOM_CHAT_TEMPLATES = {
# chat template for qwen3 thinking mode to remove think tokens similar to generation phase
"qwen3_thinking": (
# chat template for qwen3 that preserves thinking tokens
"qwen3_with_thinking": (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
),
"qwen3_without_thinking": (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
Expand All @@ -28,12 +42,45 @@
}


def get_custom_chat_template(model_name: str) -> str:
if "Qwen3" in model_name:
return CUSTOM_CHAT_TEMPLATES["qwen3_thinking"]
else:
def get_custom_chat_template(chat_template_config: Optional[Union[dict, DictConfig]] = None) -> Optional[str]:
"""
Get custom chat template based on the new config structure.

Args:
chat_template_config: Config dict with 'source' and 'name_or_path' fields.

Returns:
Chat template string or None
"""
if chat_template_config is None:
return None

source = chat_template_config.get("source")
if not source:
raise ValueError("'source' is required in chat_template_config")

name_or_path = chat_template_config.get("name_or_path")
if not name_or_path:
return None # if name_or_path is not provided, use the default chat template from the tokenizer

if source == "name":
if name_or_path in CUSTOM_CHAT_TEMPLATES:
return CUSTOM_CHAT_TEMPLATES[name_or_path]
else:
raise ValueError(
f"Template name '{name_or_path}' not found. Available templates: {list(CUSTOM_CHAT_TEMPLATES.keys())}"
)
elif source == "file":
try:
with open(name_or_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError as e:
raise ValueError(f"Template file '{name_or_path}' not found") from e
except OSError as e:
raise ValueError(f"Error reading template file '{name_or_path}': {e}") from e
else:
raise ValueError(f"Invalid source '{source}'. Must be 'name' or 'file'")


def get_generation_prompt_ids(tokenizer) -> List[int]:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{% for message in messages %}{% if (message['role'] != 'assistant') %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% elif (message['role'] == 'assistant')%}{{'<|im_start|>' + message['role'] + '\n'}}{% generation %}{% set full_content = message['content'] %}{% set mycontent = message['content'] %}{% set is_last_message = loop.last and messages[-1]['role'] == 'assistant' %}{% if '</think>' in full_content and not is_last_message %}{% set mycontent = full_content.split('</think>')[-1].lstrip('\n') %}{% endif %}{{mycontent + '<|im_end|>'}}{% endgeneration %}{{'\n'}}{% endif %}{% endfor %}
20 changes: 20 additions & 0 deletions skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def mock_apply_chat_template(x, **kwargs):
if not kwargs.get("tokenize", True):
return "".join([str(i["content"]) for i in x])
else:
# Check if return_dict is requested
if kwargs.get("return_dict", False):
# Return dictionary format for retokenization path
return {
"input_ids": MOCK_LLM_OUTPUT_IDS.copy(),
"assistant_masks": [1] * len(MOCK_LLM_OUTPUT_IDS),
}
# Non-dict return
if isinstance(x, list) and len(x) > 0 and isinstance(x[0], list):
# Multiple prompts
Expand Down Expand Up @@ -95,6 +102,7 @@ def mock_generator_cfg():
cfg.max_input_length = 512
cfg.batched = True
cfg.max_turns = 1
cfg.chat_template = {"source": "name", "name_or_path": None}
return cfg


Expand Down Expand Up @@ -461,6 +469,7 @@ async def test_length_limit_exceeded_during_conversation(
mock_generator_cfg.batched = False # Use agent_loop mode
mock_generator_cfg.max_turns = 5 # Allow multiple turns
mock_generator_cfg.use_conversation_multi_turn = True
mock_generator_cfg.chat_template = {"source": "name", "name_or_path": None}
mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {})

# Configure environment to never set done=True naturally (we want to hit length limit)
Expand Down Expand Up @@ -545,6 +554,7 @@ async def test_multi_turn_response_truncation(
mock_generator_cfg.max_turns = 3 # Ensure multi-turn logic is triggered
mock_generator_cfg.batched = False # Test is for agent_loop
mock_generator_cfg.use_conversation_multi_turn = True
mock_generator_cfg.chat_template = {"source": "name", "name_or_path": None}
mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {})

# Configure environment to run for multiple turns to generate enough tokens for truncation
Expand Down Expand Up @@ -632,6 +642,8 @@ async def test_postprocessed_action_used(
mock_make.return_value = mock_env
mock_generator_cfg.max_turns = 1 # Single turn
mock_generator_cfg.batched = False
# Override to avoid retokenization path for this test
mock_generator_cfg.chat_template = {"source": "name", "name_or_path": None}
mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {})

postprocessed_response = "This is a clean response."
Expand Down Expand Up @@ -949,6 +961,7 @@ def close(self):
cfg.max_turns = 10
cfg.zero_reward_on_non_stop = False
cfg.use_conversation_multi_turn = False
cfg.chat_template = {"source": "name", "name_or_path": None}

generator = SkyRLGymGenerator(
generator_cfg=cfg,
Expand Down Expand Up @@ -983,6 +996,7 @@ async def test_agent_loop_token_level_rewards_multi_turn_conversation_format(
mock_tokenizer.eos_token_id = 4

# Tokenizer: initial prompt -> 2 tokens; observation template -> 2 tokens each call

def apply_chat_template_side_effect(messages, **kwargs):
if kwargs.get("tokenize", True):
# For observations path, generator passes [*base_conversation, *new_obs] with add_generation_prompt=True
Expand Down Expand Up @@ -1038,6 +1052,7 @@ def close(self):
cfg.max_turns = 10
cfg.zero_reward_on_non_stop = False
cfg.use_conversation_multi_turn = True
cfg.chat_template = {"source": "name", "name_or_path": None}

mock_env_cfg.env_class = "mt_env"

Expand Down Expand Up @@ -1130,6 +1145,10 @@ def close(self):
cfg.max_turns = 10
cfg.zero_reward_on_non_stop = False
cfg.use_conversation_multi_turn = True
cfg.chat_template = {
"source": "name",
"name_or_path": "qwen3_without_thinking", # TODO: revisit this test once we separate the retokenize config from the custom chat template config
}

generator = SkyRLGymGenerator(
generator_cfg=cfg,
Expand Down Expand Up @@ -1207,6 +1226,7 @@ def close(self):
cfg.max_turns = 1
cfg.zero_reward_on_non_stop = False
cfg.use_conversation_multi_turn = False
cfg.chat_template = {"source": "name", "name_or_path": None}

generator = SkyRLGymGenerator(
generator_cfg=cfg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from omegaconf import DictConfig
from transformers import AutoTokenizer
from skyrl_gym.envs import register
from skyrl_train.generators.utils import get_custom_chat_template
from skyrl_train.generators.utils import get_custom_chat_template, CUSTOM_CHAT_TEMPLATES
from pathlib import Path


# Setup for formatting tests
Expand Down Expand Up @@ -49,20 +50,26 @@ def _register_test_env_if_needed():

@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name", ["Qwen/Qwen2.5-0.5B-Instruct", "unsloth/Llama-3.2-1B-Instruct", "Qwen/Qwen3-0.6B"]
"model_name",
["Qwen/Qwen2.5-0.5B-Instruct", "unsloth/Llama-3.2-1B-Instruct", "Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B-FROM_PATH"],
)
async def test_skyrl_gym_generator_chat_templating_exact(model_name):
_register_test_env_if_needed() # Register only when needed
is_custom_jinja_from_file = model_name.endswith("-FROM_PATH")
model_name = model_name.replace("-FROM_PATH", "")
tokenizer = AutoTokenizer.from_pretrained(model_name)
mock_llm = MagicMock()

# Mock the new generate method
def mock_generate(input_batch):
num_prompts = len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"])

mock_llm_output_text = "b" + tokenizer.eos_token
mock_response_text = "b"

return {
# no tokenizer.eos_token for responses because `skip_special_tokens` is True in sampling params
"responses": ["b"] * num_prompts,
"responses": [mock_response_text] * num_prompts,
"stop_reasons": ["stop"] * num_prompts,
"response_logprobs": None,
# add_special_tokens needs to be False, otherwise for instance Llama will always
Expand All @@ -72,6 +79,16 @@ def mock_generate(input_batch):

mock_llm.generate = AsyncMock(side_effect=mock_generate)
# Create a mock generator config

chat_template_config = None
if is_custom_jinja_from_file:
template_path = Path(__file__).parent / "qwen3_acc_without_thinking.jinja2"
chat_template_config = {"source": "file", "name_or_path": str(template_path)}
elif "Qwen3" in model_name:
chat_template_config = {"source": "name", "name_or_path": "qwen3_without_thinking"}
else:
chat_template_config = {"source": "name", "name_or_path": None}

generator_cfg = DictConfig(
{
"sampling_params": {"max_generate_length": 200, "logprobs": None},
Expand All @@ -81,6 +98,7 @@ def mock_generate(input_batch):
"zero_reward_on_non_stop": False,
"apply_overlong_filtering": False,
"use_conversation_multi_turn": True,
"chat_template": chat_template_config,
"append_eos_token_after_stop_str_in_multi_turn": True,
}
)
Expand All @@ -106,16 +124,19 @@ def mock_generate(input_batch):
"env_extras": extras,
"env_classes": [env_cfg.env_class],
}

generator_output: GeneratorOutput = await generator.generate(input_batch)

# assume every actual message is 1 token for loss mask checking
expected_assistant_content = "b"

expected_chat_history = [
{"role": "user", "content": "a"},
{"role": "assistant", "content": "b"},
{"role": "assistant", "content": expected_assistant_content},
{"role": "user", "content": "1"},
{"role": "assistant", "content": "b"},
{"role": "assistant", "content": expected_assistant_content},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "b"},
{"role": "assistant", "content": expected_assistant_content},
]

# For Qwen2.5 generator_output_str, we have (note the missing \n after the eos token):
Expand All @@ -127,7 +148,7 @@ def mock_generate(input_batch):
# check that the full response is exactly string matching with applying the chat template on history
prompt_str = tokenizer.decode(generator_output["prompt_token_ids"][0])
resp_str = tokenizer.decode(generator_output["response_ids"][0])
custom_chat_template = get_custom_chat_template(model_name)
custom_chat_template = get_custom_chat_template(chat_template_config)
if custom_chat_template is not None:
assert prompt_str + resp_str == tokenizer.apply_chat_template(
expected_chat_history, chat_template=custom_chat_template, tokenize=False
Expand Down Expand Up @@ -187,6 +208,40 @@ def mock_generate(input_batch):
assert generator_output["loss_masks"][0] == expected_loss_masks


def test_qwen3_original_vs_without_thinking_chat_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

messages = [
{"content": "hi", "role": "system"},
{"content": "hi", "role": "user"},
{"content": "<think>thinking</think>hi", "role": "assistant"},
{"content": "hi", "role": "user"},
{"content": "<think>thinking</think>hi", "role": "assistant"},
{"content": "hi", "role": "user"},
]

# Apply custom chat template
qwen3_without_thinking_str = tokenizer.apply_chat_template(
messages, chat_template=CUSTOM_CHAT_TEMPLATES["qwen3_without_thinking"], tokenize=False
)

# Apply custom chat template from file
file_path = Path(__file__).parent / "qwen3_acc_without_thinking.jinja2"
with open(file_path, "r", encoding="utf-8") as f:
template_from_file = f.read()

qwen3_without_thinking_str_from_file = tokenizer.apply_chat_template(
messages, chat_template=template_from_file, tokenize=False
)

# Apply default chat template
default_template_str = tokenizer.apply_chat_template(messages, chat_template=None, tokenize=False)

# The original_chat_template should match the tokenizer exactly
assert default_template_str == qwen3_without_thinking_str
assert qwen3_without_thinking_str == qwen3_without_thinking_str_from_file


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name", ["Qwen/Qwen2.5-0.5B-Instruct", "unsloth/Llama-3.2-1B-Instruct", "Qwen/Qwen3-0.6B"]
Expand Down Expand Up @@ -222,6 +277,11 @@ def mock_generate(input_batch):

mock_llm.generate = AsyncMock(side_effect=mock_generate)

chat_template_config = None
if "Qwen3" in model_name:
chat_template_config = {"source": "name", "name_or_path": "qwen3_without_thinking"}
else:
chat_template_config = {"source": "name", "name_or_path": None}
generator_cfg = DictConfig(
{
"sampling_params": {"max_generate_length": 200, "logprobs": None, "stop": [stop_tag]},
Expand All @@ -231,6 +291,7 @@ def mock_generate(input_batch):
"zero_reward_on_non_stop": False,
"apply_overlong_filtering": False,
"use_conversation_multi_turn": True,
"chat_template": chat_template_config,
"append_eos_token_after_stop_str_in_multi_turn": append_flag,
}
)
Expand Down
2 changes: 1 addition & 1 deletion skyrl-train/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.