diff --git a/examples/multi_agent/agent_system.py b/examples/multi_agent/agent_system.py index b7e2b3a89..49b62ca6d 100644 --- a/examples/multi_agent/agent_system.py +++ b/examples/multi_agent/agent_system.py @@ -20,11 +20,21 @@ async def generate_response(args, prompt, key): url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False) + if args.apply_chat_template: + assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True" + prompt_text = tokenizer.apply_chat_template( + prompt, + tokenize=False, + add_generation_prompt=True, # Add generation prompt for the assistant + **(args.apply_chat_template_kwargs or {}), + ) + sample.prompt = prompt_text + else: + assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False" + sample.prompt = prompt + prompt_token_ids = tokenizer(sample.prompt, add_special_tokens=False)["input_ids"] sample.tokens = prompt_token_ids - sample.prompt = prompt - input_token_ids = prompt_token_ids - prompt_length = len(input_token_ids) + prompt_length = len(prompt_token_ids) current_sampling_params = deepcopy(sampling_params) current_sampling_params["max_new_tokens"] = min( sampling_params["max_new_tokens"], max_context_length - prompt_length @@ -33,7 +43,7 @@ async def generate_response(args, prompt, key): if current_sampling_params["max_new_tokens"] <= 0: return None - payload = {"input_ids": input_token_ids, "sampling_params": current_sampling_params, "return_logprob": True} + payload = {"input_ids": prompt_token_ids, "sampling_params": current_sampling_params, "return_logprob": True} output = await post(url, payload) diff --git a/examples/search-r1/generate_with_search.py b/examples/search-r1/generate_with_search.py index 968c3bebb..65ea9e399 100644 --- a/examples/search-r1/generate_with_search.py +++ b/examples/search-r1/generate_with_search.py @@ -151,7 +151,18 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: # Handle partial rollout samples: continue generation from existing response prompt = sample.prompt - prompt_tokens_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"] + if args.apply_chat_template: + assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True" + prompt_text = state.tokenizer.apply_chat_template( + prompt, + tokenize=False, + add_generation_prompt=True, # Add generation prompt for the assistant + **(args.apply_chat_template_kwargs or {}), + ) + else: + assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False" + prompt_text = prompt + prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"] response = "" response_token_ids = [] loss_mask = [] @@ -159,7 +170,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]): payload = { - "text": prompt + response, + "text": prompt_text + response, "sampling_params": sampling_params, } # Add log probability collection if enabled diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index d4a4ba660..bf7106563 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -101,6 +101,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A state.tokenizer, state.processor, sample.metadata, + args.apply_chat_template, args.apply_chat_template_kwargs, ) diff --git a/slime/utils/data.py b/slime/utils/data.py index 3b3e6f4b2..45ee103fa 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -49,23 +49,32 @@ def _parse_generalized_path(s: str): return s, None -def _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs): +def _should_skip_prompt( + prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs +): if max_length is None: return False from slime.utils.processing_utils import prepare_model_inputs - input_ids, _ = prepare_model_inputs(prompt, tokenizer, processor, None, apply_chat_template_kwargs) + input_ids, _ = prepare_model_inputs( + prompt, tokenizer, processor, metadata, apply_chat_template, apply_chat_template_kwargs + ) return len(input_ids) > max_length -def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None): - messages = data.get(prompt_key) +def _build_messages(data: dict, prompt_key: str, as_conversation: bool, multimodal_keys: dict = None): + prompt = data.get(prompt_key) - if isinstance(messages, str): - messages = [{"role": "user", "content": messages}] + if isinstance(prompt, str): + # If prompt is a string and we don't apply chat template, return the prompt as is. + if not as_conversation: + return prompt + else: + prompt = [{"role": "user", "content": prompt}] if multimodal_keys: + assert as_conversation, "as_conversation must be True when multimodal_keys is not None" # Build mapping: placeholder -> (MultimodalType, content_list) multimodals = {} for type_name, data_key in multimodal_keys.items(): @@ -75,7 +84,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None): pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")" - for message in messages: + for message in prompt: if isinstance(message["content"], str): content_list = [] for segment in re.split(pattern, message["content"]): @@ -105,7 +114,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None): f"Unsupported content type: {type(message['content'])}, expected str or list of dicts" ) - return messages + return prompt class Dataset: @@ -127,7 +136,8 @@ def __init__( ): self.origin_samples = [] for data in read_file(path): - prompt = _build_messages(data, prompt_key, multimodal_keys) + as_conversation = apply_chat_template + prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys) metadata = data.get(metadata_key) or {} if tool_key is not None and tool_key in data: @@ -140,7 +150,9 @@ def __init__( metadata["tools"] = tools # TODO: this is slow. - if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs): + if _should_skip_prompt( + prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs + ): continue self.origin_samples.append( diff --git a/slime/utils/processing_utils.py b/slime/utils/processing_utils.py index fc837e613..f5952d5d0 100644 --- a/slime/utils/processing_utils.py +++ b/slime/utils/processing_utils.py @@ -2,6 +2,7 @@ import io import logging +import numpy as np from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin logger = logging.getLogger(__name__) @@ -25,7 +26,9 @@ def load_processor(name_or_path: str, **kwargs): return proc -def prepare_model_inputs(prompt, tokenizer, processor=None, metadata=None, apply_chat_template_kwargs=None): +def prepare_model_inputs( + prompt, tokenizer, processor=None, metadata=None, apply_chat_template=False, apply_chat_template_kwargs=None +): """Prepare all inputs for model inference. Returns: @@ -34,13 +37,24 @@ def prepare_model_inputs(prompt, tokenizer, processor=None, metadata=None, apply - extra_info: Dict with 'images', 'videos', 'multimodal_inputs' (or empty dict) """ tools = metadata.get("tools") if metadata else None - text_prompt = tokenizer.apply_chat_template( - prompt, - tools=tools, - tokenize=False, - add_generation_prompt=True, - **(apply_chat_template_kwargs or {}), - ) + if isinstance(prompt, (list, np.ndarray)): + assert ( + apply_chat_template + ), f"apply_chat_template must be True when prompt is a list or numpy array, current prompt is {prompt}" + text_prompt = tokenizer.apply_chat_template( + prompt, + tools=tools, + tokenize=False, + add_generation_prompt=True, + **(apply_chat_template_kwargs or {}), + ) + elif isinstance(prompt, str): + assert ( + not apply_chat_template + ), f"apply_chat_template must be False when prompt is a string, current prompt is {prompt}" + text_prompt = prompt + else: + raise ValueError(f"Invalid prompt type: {type(prompt)}, current prompt is {prompt}") if not processor: input_ids = tokenizer.encode(text_prompt, add_special_tokens=False)