Skip to content

Commit cfa752f

Browse files
authored
fix: pass tool_prompt_format to chat_formatter (#1198)
Summary: Need this to format the completion message with tool_calls correctly. See added unittest. Test Plan: python -m unittest llama_stack.providers.tests.inference.test_prompt_adapter
1 parent 33a64eb commit cfa752f

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

llama_stack/providers/tests/inference/test_prompt_adapter.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from llama_stack.apis.inference import (
1010
ChatCompletionRequest,
11+
CompletionMessage,
12+
StopReason,
1113
SystemMessage,
14+
ToolCall,
1215
ToolConfig,
1316
UserMessage,
1417
)
@@ -20,6 +23,7 @@
2023
)
2124
from llama_stack.providers.utils.inference.prompt_adapter import (
2225
chat_completion_request_to_messages,
26+
chat_completion_request_to_prompt,
2327
)
2428

2529
MODEL = "Llama3.1-8B-Instruct"
@@ -119,6 +123,46 @@ async def test_system_custom_and_builtin(self):
119123
self.assertTrue("Return function calls in JSON format" in messages[1].content)
120124
self.assertEqual(messages[-1].content, content)
121125

126+
async def test_completion_message_encoding(self):
127+
request = ChatCompletionRequest(
128+
model=MODEL3_2,
129+
messages=[
130+
UserMessage(content="hello"),
131+
CompletionMessage(
132+
content="",
133+
stop_reason=StopReason.end_of_turn,
134+
tool_calls=[
135+
ToolCall(
136+
tool_name="custom1",
137+
arguments={"param1": "value1"},
138+
call_id="123",
139+
)
140+
],
141+
),
142+
],
143+
tools=[
144+
ToolDefinition(
145+
tool_name="custom1",
146+
description="custom1 tool",
147+
parameters={
148+
"param1": ToolParamDefinition(
149+
param_type="str",
150+
description="param1 description",
151+
required=True,
152+
),
153+
},
154+
),
155+
],
156+
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
157+
)
158+
prompt = await chat_completion_request_to_prompt(request, request.model)
159+
self.assertIn('[custom1(param1="value1")]', prompt)
160+
161+
request.model = MODEL
162+
request.tool_config.tool_prompt_format = ToolPromptFormat.json
163+
prompt = await chat_completion_request_to_prompt(request, request.model)
164+
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
165+
122166
async def test_user_provided_system_message(self):
123167
content = "Hello !"
124168
system_prompt = "You are a pirate"

llama_stack/providers/utils/inference/prompt_adapter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,9 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
252252
request = await convert_request_to_raw(request)
253253

254254
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
255-
model_input = formatter.encode_dialog_prompt(request.messages)
255+
model_input = formatter.encode_dialog_prompt(
256+
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
257+
)
256258
return formatter.tokenizer.decode(model_input.tokens)
257259

258260

@@ -264,7 +266,9 @@ async def chat_completion_request_to_model_input_info(
264266
request = await convert_request_to_raw(request)
265267

266268
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
267-
model_input = formatter.encode_dialog_prompt(request.messages)
269+
model_input = formatter.encode_dialog_prompt(
270+
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
271+
)
268272
return (
269273
formatter.tokenizer.decode(model_input.tokens),
270274
len(model_input.tokens),

0 commit comments

Comments
 (0)