Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 295 additions & 0 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def get_model_path() -> str:
return path


DEBUG = False


def print_debug(*args):
if DEBUG:
print(*args)


CHAT_TEMPLATE = r"""
{%- if tools %}
{{- '<|im_start|>system\n' }}
Expand Down Expand Up @@ -209,6 +217,7 @@ def setUp(self):
self.config.explorer.rollout_model.use_v1 = True
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(
Expand Down Expand Up @@ -299,3 +308,289 @@ def test_assistant_token_mask(self):
self.assertTrue(torch.equal(token_ids, token_ids_hf))
self.assertTrue(torch.equal(action_mask, action_mask_hf))
self.assertEqual(prompt_length, prompt_length_hf)


@parameterized_class(
("enable_thinking", "reasoning_parser"),
[
(True, "deepseek_r1"),
(False, None),
],
)
class TestAPIServerToolCall(RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
self.config.explorer.rollout_model.engine_num = 1
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.use_v1 = True
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True
# added for toolcalls
self.config.explorer.rollout_model.enable_auto_tool_choice = True
self.config.explorer.rollout_model.tool_call_parser = "hermes"
self.config.explorer.rollout_model.enable_thinking = self.enable_thinking
self.config.explorer.rollout_model.reasoning_parser = self.reasoning_parser

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(
self.engines[0], model_type="vllm_async", enable_history=True
)
self.model_wrapper_no_history = ModelWrapper(
self.engines[0], model_type="vllm_async", enable_history=False
)

def test_api_tool_calls(self):
"""Tests the full conversation flow of a tool call via the OpenAI API."""
import json
import time

tokenizer = AutoTokenizer.from_pretrained(get_model_path())
print_debug("\n\n" + "=" * 30 + " Running test_api_tool_calls " + "=" * 30)
start_time = time.time()

# --- Step 0: Get OpenAI Client ---
print_debug(f"[{time.time() - start_time:.2f}s] Getting OpenAI client...")
openai_client = self.model_wrapper.get_openai_client()
model_id = openai_client.models.list().data[0].id
print_debug(
f"[{time.time() - start_time:.2f}s] Successfully got client. Model ID: {model_id}"
)

# --- Step 1: Define Tools and Messages ---
print_debug(f"[{time.time() - start_time:.2f}s] Defining tools and initial message...")
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston?"}]
print_debug(
f"[{time.time() - start_time:.2f}s] Initial user message: {messages[0]['content']}"
)
print_debug("-" * 80)

# --- Step 2: First API Call (Expecting a tool call) ---
print_debug(f"[{time.time() - start_time:.2f}s] Making first API call to the model...")
response = openai_client.chat.completions.create(
model=model_id,
messages=messages,
tools=tools,
tool_choice="auto",
extra_body={
"repetition_penalty": 1.05,
"chat_template_kwargs": {
"enable_thinking": self.enable_thinking
}, # default to True
},
)
print_debug(f"[{time.time() - start_time:.2f}s] First API call completed.")

# --- Step 3: Assert and Print the Tool Call Response ---
print_debug(f"[{time.time() - start_time:.2f}s] Asserting response is a tool call...")
self.assertEqual(len(response.choices), 1)
choice = response.choices[0]
print_debug(f" > Finish Reason: {choice.finish_reason}")
self.assertEqual(choice.finish_reason, "tool_calls")
if self.enable_thinking:
self.assertIsNotNone(choice.message.reasoning_content)
self.assertIsNotNone(choice.message.tool_calls)
self.assertEqual(len(choice.message.tool_calls), 1)

tool_call = choice.message.tool_calls[0]
print_debug(f" > Tool Call ID: {tool_call.id}")
print_debug(f" > Function Name: {tool_call.function.name}")
print_debug(f" > Function Arguments: {tool_call.function.arguments}")
self.assertEqual(tool_call.type, "function")
self.assertEqual(tool_call.function.name, "get_current_weather")
self.assertIn("Boston", tool_call.function.arguments)
print_debug(f"[{time.time() - start_time:.2f}s] Assertions for tool call passed.")
print_debug("-" * 80)

# --- Step 4: Check Experience History ---
print_debug(f"[{time.time() - start_time:.2f}s] Checking experience history...")
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
# The response text in the experience should contain the tool call info
print_debug(f" > Recorded experience response_text: {exps[0].response_text}")
print_debug(f" > Recorded experience: {exps[0]}")
print_debug(f" > message: {choice.message}")

exp = exps[0]
print_debug("\n" + "-" * 15 + " Decoding Experience Tokens " + "-" * 15)

full_decoded_text = tokenizer.decode(exp.tokens, skip_special_tokens=False)
print_debug(
f" > Full Decoded Text ({len(exp.tokens)} tokens):\n---\n{full_decoded_text}\n---"
)

prompt_length = exp.prompt_length
prompt_tokens = exp.tokens[:prompt_length]
response_tokens = exp.tokens[prompt_length:]

prompt_decoded_text = tokenizer.decode(prompt_tokens, skip_special_tokens=False)
response_decoded_text = tokenizer.decode(response_tokens, skip_special_tokens=False)

print_debug(
f" > Decoded Prompt Part ({len(prompt_tokens)} tokens):\n---\n{prompt_decoded_text}\n---"
)
print_debug(
f" > Decoded Response Part ({len(response_tokens)} tokens):\n---\n{response_decoded_text}\n---"
)

action_mask = getattr(exp, "action_mask", None)
if action_mask is not None:
print_debug(f"\n > Action Mask (Length: {len(action_mask)}):")
masked_tokens_info = []
for i, token_id in enumerate(response_tokens):
token_text = tokenizer.decode([token_id])
mask_value = action_mask[i] if i < len(action_mask) else "N/A"
masked_tokens_info.append(f"({repr(token_text)}, Mask: {mask_value})")

print_debug(" " + " ".join(masked_tokens_info))

self.assertTrue(
abs(len(action_mask) - len(response_tokens)) <= 1,
f"Length of action_mask ({len(action_mask)}) does not match "
f"length of response_tokens ({len(response_tokens)})",
)
else:
print_debug(" > Action Mask: Not found in experience.")

print_debug("-" * 52 + "\n")

# pass this part
# self.assertIn("get_current_weather", exps[0].response_text)

self.assertEqual(
len(self.model_wrapper.extract_experience_from_history()), 0
) # Verify cleared
print_debug(f"[{time.time() - start_time:.2f}s] Experience history check passed.")
print_debug("-" * 80)

# --- Step 5: Second API Call (Providing tool result) ---
print_debug(
f"[{time.time() - start_time:.2f}s] Preparing for the second API call with tool result..."
)
messages.append(response.choices[0].message) # Add assistant's tool call message

# Mock the result of our tool
tool_response_content = json.dumps(
{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}
)

messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_response_content,
}
)
print_debug(f"[{time.time() - start_time:.2f}s] Full message list for second call:")
for msg in messages:
print_debug(f" - {msg}")

print_debug(f"[{time.time() - start_time:.2f}s] Making second API call...")
second_response = openai_client.chat.completions.create(
model=model_id,
messages=messages,
tools=tools,
extra_body={
"repetition_penalty": 1.05,
"chat_template_kwargs": {
"enable_thinking": self.enable_thinking
}, # default to True
},
)
print_debug(f"[{time.time() - start_time:.2f}s] Second API call completed.")

# --- Step 6: Assert and Print the Final Response ---
print_debug(
f"[{time.time() - start_time:.2f}s] Asserting final natural language response..."
)
self.assertEqual(len(second_response.choices), 1)
final_choice = second_response.choices[0]
print_debug(f" > Final Finish Reason: {final_choice.finish_reason}")
print_debug(f" > Final Message Content: {final_choice.message.content}")
print_debug(f" > Final Message: {final_choice.message}")
self.assertEqual(final_choice.finish_reason, "stop")
# self.assertIsNone(final_choice.message.tool_calls)
self.assertEqual(final_choice.message.tool_calls, [])
self.assertIsNotNone(final_choice.message.content)
# Check if the model used the information from the tool response
self.assertIn("72", final_choice.message.content)
self.assertIn("Boston", final_choice.message.content)
print_debug(f"[{time.time() - start_time:.2f}s] Assertions for final response passed.")
print_debug("-" * 80)

# --- Step 7: Check Final Experience History ---
print_debug(f"[{time.time() - start_time:.2f}s] Checking final experience history...")
final_exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(final_exps), 1)
print_debug(f" > Final recorded experience response_text: {final_exps[0].response_text}")
self.assertEqual(final_exps[0].response_text, final_choice.message.content)
print_debug(f"[{time.time() - start_time:.2f}s] Final experience history check passed.")

exp = final_exps[0]
print_debug("\n" + "-" * 15 + " Decoding Experience Tokens " + "-" * 15)

full_decoded_text = tokenizer.decode(exp.tokens, skip_special_tokens=False)
print_debug(
f" > Full Decoded Text ({len(exp.tokens)} tokens):\n---\n{full_decoded_text}\n---"
)

prompt_length = exp.prompt_length
prompt_tokens = exp.tokens[:prompt_length]
response_tokens = exp.tokens[prompt_length:]

prompt_decoded_text = tokenizer.decode(prompt_tokens, skip_special_tokens=False)
response_decoded_text = tokenizer.decode(response_tokens, skip_special_tokens=False)

print_debug(
f" > Decoded Prompt Part ({len(prompt_tokens)} tokens):\n---\n{prompt_decoded_text}\n---"
)
print_debug(
f" > Decoded Response Part ({len(response_tokens)} tokens):\n---\n{response_decoded_text}\n---"
)

action_mask = getattr(exp, "action_mask", None)
if action_mask is not None:
print_debug(f"\n > Action Mask (Length: {len(action_mask)}):")
masked_tokens_info = []
for i, token_id in enumerate(response_tokens):
token_text = tokenizer.decode([token_id])
mask_value = action_mask[i] if i < len(action_mask) else "N/A"
masked_tokens_info.append(f"({repr(token_text)}, Mask: {mask_value})")

print_debug(" " + " ".join(masked_tokens_info))

self.assertTrue(
abs(len(action_mask) - len(response_tokens)) <= 1,
f"Length of action_mask ({len(action_mask)}) does not match "
f"length of response_tokens ({len(response_tokens)})",
)
else:
print_debug(" > Action Mask: Not found in experience.")

total_time = time.time() - start_time
print_debug(
"\n" + "=" * 28 + f" test_api_tool_calls PASSED in {total_time:.2f}s " + "=" * 28 + "\n"
)
7 changes: 7 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ class InferenceModelConfig:
# For OpenAI API
enable_openai_api: bool = False

# For tool calls in OpenAI API
enable_auto_tool_choice: bool = False

tool_call_parser: Optional[str] = None

reasoning_parser: Optional[str] = None

# ! DO NOT SET
bundle_indices: str = ""

Expand Down
26 changes: 24 additions & 2 deletions trinity/common/models/api/vllm_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,15 @@ def get_vllm_version():
return vllm_version


async def run_api_server_in_ray_actor(async_llm, host: str, port: int, model_path: str):
async def run_api_server_in_ray_actor(
async_llm,
host: str,
port: int,
model_path: str,
enable_auto_tool_choice: bool = False,
tool_call_parser: Optional[str] = None,
reasoning_parser: Optional[str] = None,
):
vllm_version = get_vllm_version()
if vllm_version < parse_version("0.8.5") or vllm_version >= parse_version("0.10.0"):
raise ValueError(
Expand All @@ -347,6 +355,20 @@ async def run_api_server_in_ray_actor(async_llm, host: str, port: int, model_pat

parser = FlexibleArgumentParser(description="Run the OpenAI API server.")
args = make_arg_parser(parser)
args = parser.parse_args(["--host", str(host), "--port", str(port), "--model", model_path])
cli_args = [
"--host",
str(host),
"--port",
str(port),
"--model",
model_path,
]
if enable_auto_tool_choice:
cli_args.append("--enable-auto-tool-choice")
if tool_call_parser:
cli_args.extend(["--tool-call-parser", tool_call_parser])
if reasoning_parser:
cli_args.extend(["--reasoning-parser", reasoning_parser])
args = parser.parse_args(cli_args)
print(args)
await run_server_in_ray(args, async_llm)
8 changes: 7 additions & 1 deletion trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,13 @@ async def run_api_server(self):

self.api_server_host, self.api_server_port = self.get_available_address()
await run_api_server_in_ray_actor(
self.async_llm, self.api_server_host, self.api_server_port, self.config.model_path
self.async_llm,
self.api_server_host,
self.api_server_port,
self.config.model_path,
self.config.enable_auto_tool_choice,
self.config.tool_call_parser,
self.config.reasoning_parser,
)

async def has_api_server(self) -> bool:
Expand Down
Loading