|
8 | 8 |
|
9 | 9 | from llama_stack.apis.inference import ( |
10 | 10 | ChatCompletionRequest, |
| 11 | + CompletionMessage, |
| 12 | + StopReason, |
11 | 13 | SystemMessage, |
| 14 | + ToolCall, |
12 | 15 | ToolConfig, |
13 | 16 | UserMessage, |
14 | 17 | ) |
|
20 | 23 | ) |
21 | 24 | from llama_stack.providers.utils.inference.prompt_adapter import ( |
22 | 25 | chat_completion_request_to_messages, |
| 26 | + chat_completion_request_to_prompt, |
23 | 27 | ) |
24 | 28 |
|
25 | 29 | MODEL = "Llama3.1-8B-Instruct" |
@@ -119,6 +123,46 @@ async def test_system_custom_and_builtin(self): |
119 | 123 | self.assertTrue("Return function calls in JSON format" in messages[1].content) |
120 | 124 | self.assertEqual(messages[-1].content, content) |
121 | 125 |
|
| 126 | + async def test_completion_message_encoding(self): |
| 127 | + request = ChatCompletionRequest( |
| 128 | + model=MODEL3_2, |
| 129 | + messages=[ |
| 130 | + UserMessage(content="hello"), |
| 131 | + CompletionMessage( |
| 132 | + content="", |
| 133 | + stop_reason=StopReason.end_of_turn, |
| 134 | + tool_calls=[ |
| 135 | + ToolCall( |
| 136 | + tool_name="custom1", |
| 137 | + arguments={"param1": "value1"}, |
| 138 | + call_id="123", |
| 139 | + ) |
| 140 | + ], |
| 141 | + ), |
| 142 | + ], |
| 143 | + tools=[ |
| 144 | + ToolDefinition( |
| 145 | + tool_name="custom1", |
| 146 | + description="custom1 tool", |
| 147 | + parameters={ |
| 148 | + "param1": ToolParamDefinition( |
| 149 | + param_type="str", |
| 150 | + description="param1 description", |
| 151 | + required=True, |
| 152 | + ), |
| 153 | + }, |
| 154 | + ), |
| 155 | + ], |
| 156 | + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), |
| 157 | + ) |
| 158 | + prompt = await chat_completion_request_to_prompt(request, request.model) |
| 159 | + self.assertIn('[custom1(param1="value1")]', prompt) |
| 160 | + |
| 161 | + request.model = MODEL |
| 162 | + request.tool_config.tool_prompt_format = ToolPromptFormat.json |
| 163 | + prompt = await chat_completion_request_to_prompt(request, request.model) |
| 164 | + self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt) |
| 165 | + |
122 | 166 | async def test_user_provided_system_message(self): |
123 | 167 | content = "Hello !" |
124 | 168 | system_prompt = "You are a pirate" |
|
0 commit comments