diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index d337793eac..5ea73e22a1 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -39,7 +39,7 @@ def __init__( self.use_conversation_multi_turn = generator_cfg.use_conversation_multi_turn # optionally use custom chat template to get loss masks (i.e. for Qwen3) - self.custom_chat_template = get_custom_chat_template(model_name) + self.custom_chat_template = get_custom_chat_template(model_name, generator_cfg.get("train_on_thinking_tokens", False)) # get generation prompt ids for the tokenizer if needed self.generation_prompt_ids = get_generation_prompt_ids(tokenizer) if self.use_conversation_multi_turn else None if self.skyrl_gym_cfg.max_env_workers > 0: diff --git a/skyrl-train/skyrl_train/generators/utils.py b/skyrl-train/skyrl_train/generators/utils.py index 22b6648258..a31dfa8c5c 100644 --- a/skyrl-train/skyrl_train/generators/utils.py +++ b/skyrl-train/skyrl_train/generators/utils.py @@ -28,8 +28,8 @@ } -def get_custom_chat_template(model_name: str) -> str: - if "Qwen3" in model_name: +def get_custom_chat_template(model_name: str, train_on_thinking_tokens: bool = False) -> Optional[str]: + if "Qwen3" in model_name and not train_on_thinking_tokens: return CUSTOM_CHAT_TEMPLATES["qwen3_thinking"] else: return None