Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions examples/multi_agent/agent_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,21 @@ async def generate_response(args, prompt, key):

url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"

prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False)
if args.apply_chat_template:
assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True"
prompt_text = tokenizer.apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=True, # Add generation prompt for the assistant
**(args.apply_chat_template_kwargs or {}),
)
sample.prompt = prompt_text
else:
assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False"
sample.prompt = prompt
prompt_token_ids = tokenizer(sample.prompt, add_special_tokens=False)["input_ids"]
sample.tokens = prompt_token_ids
sample.prompt = prompt
input_token_ids = prompt_token_ids
prompt_length = len(input_token_ids)
prompt_length = len(prompt_token_ids)
current_sampling_params = deepcopy(sampling_params)
current_sampling_params["max_new_tokens"] = min(
sampling_params["max_new_tokens"], max_context_length - prompt_length
Expand All @@ -33,7 +43,7 @@ async def generate_response(args, prompt, key):
if current_sampling_params["max_new_tokens"] <= 0:
return None

payload = {"input_ids": input_token_ids, "sampling_params": current_sampling_params, "return_logprob": True}
payload = {"input_ids": prompt_token_ids, "sampling_params": current_sampling_params, "return_logprob": True}

output = await post(url, payload)

Expand Down
15 changes: 13 additions & 2 deletions examples/search-r1/generate_with_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,26 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:

# Handle partial rollout samples: continue generation from existing response
prompt = sample.prompt
prompt_tokens_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"]
if args.apply_chat_template:
assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True"
prompt_text = state.tokenizer.apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=True, # Add generation prompt for the assistant
**(args.apply_chat_template_kwargs or {}),
)
else:
assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False"
prompt_text = prompt
prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
response = ""
response_token_ids = []
loss_mask = []
rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None

for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]):
payload = {
"text": prompt + response,
"text": prompt_text + response,
"sampling_params": sampling_params,
}
# Add log probability collection if enabled
Expand Down
1 change: 1 addition & 0 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
state.tokenizer,
state.processor,
sample.metadata,
args.apply_chat_template,
args.apply_chat_template_kwargs,
)

Expand Down
32 changes: 22 additions & 10 deletions slime/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,32 @@ def _parse_generalized_path(s: str):
return s, None


def _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs):
def _should_skip_prompt(
prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs
):
if max_length is None:
return False

from slime.utils.processing_utils import prepare_model_inputs

input_ids, _ = prepare_model_inputs(prompt, tokenizer, processor, None, apply_chat_template_kwargs)
input_ids, _ = prepare_model_inputs(
prompt, tokenizer, processor, metadata, apply_chat_template, apply_chat_template_kwargs
)
return len(input_ids) > max_length


def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None):
messages = data.get(prompt_key)
def _build_messages(data: dict, prompt_key: str, as_conversation: bool, multimodal_keys: dict = None):
prompt = data.get(prompt_key)

if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
if isinstance(prompt, str):
# If prompt is a string and we don't apply chat template, return the prompt as is.
if not as_conversation:
return prompt
else:
prompt = [{"role": "user", "content": prompt}]

if multimodal_keys:
assert as_conversation, "as_conversation must be True when multimodal_keys is not None"
# Build mapping: placeholder -> (MultimodalType, content_list)
multimodals = {}
for type_name, data_key in multimodal_keys.items():
Expand All @@ -75,7 +84,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None):

pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")"

for message in messages:
for message in prompt:
if isinstance(message["content"], str):
content_list = []
for segment in re.split(pattern, message["content"]):
Expand Down Expand Up @@ -105,7 +114,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None):
f"Unsupported content type: {type(message['content'])}, expected str or list of dicts"
)

return messages
return prompt


class Dataset:
Expand All @@ -127,7 +136,8 @@ def __init__(
):
self.origin_samples = []
for data in read_file(path):
prompt = _build_messages(data, prompt_key, multimodal_keys)
as_conversation = apply_chat_template
prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys)

metadata = data.get(metadata_key) or {}
if tool_key is not None and tool_key in data:
Expand All @@ -140,7 +150,9 @@ def __init__(
metadata["tools"] = tools

# TODO: this is slow.
if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs):
if _should_skip_prompt(
prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs
):
continue

self.origin_samples.append(
Expand Down
30 changes: 22 additions & 8 deletions slime/utils/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import logging

import numpy as np
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin

logger = logging.getLogger(__name__)
Expand All @@ -25,7 +26,9 @@ def load_processor(name_or_path: str, **kwargs):
return proc


def prepare_model_inputs(prompt, tokenizer, processor=None, metadata=None, apply_chat_template_kwargs=None):
def prepare_model_inputs(
prompt, tokenizer, processor=None, metadata=None, apply_chat_template=False, apply_chat_template_kwargs=None
):
"""Prepare all inputs for model inference.

Returns:
Expand All @@ -34,13 +37,24 @@ def prepare_model_inputs(prompt, tokenizer, processor=None, metadata=None, apply
- extra_info: Dict with 'images', 'videos', 'multimodal_inputs' (or empty dict)
"""
tools = metadata.get("tools") if metadata else None
text_prompt = tokenizer.apply_chat_template(
prompt,
tools=tools,
tokenize=False,
add_generation_prompt=True,
**(apply_chat_template_kwargs or {}),
)
if isinstance(prompt, (list, np.ndarray)):
assert (
apply_chat_template
), f"apply_chat_template must be True when prompt is a list or numpy array, current prompt is {prompt}"
text_prompt = tokenizer.apply_chat_template(
prompt,
tools=tools,
tokenize=False,
add_generation_prompt=True,
**(apply_chat_template_kwargs or {}),
)
elif isinstance(prompt, str):
assert (
not apply_chat_template
), f"apply_chat_template must be False when prompt is a string, current prompt is {prompt}"
text_prompt = prompt
else:
raise ValueError(f"Invalid prompt type: {type(prompt)}, current prompt is {prompt}")

if not processor:
input_ids = tokenizer.encode(text_prompt, add_special_tokens=False)
Expand Down