|
17 | 17 |
|
18 | 18 | import pytest |
19 | 19 | from langchain.schema import Generation, LLMResult |
20 | | -from langchain_core.messages import AIMessage |
| 20 | +from langchain_core.messages import ( |
| 21 | + AIMessage, |
| 22 | + BaseMessage, |
| 23 | + HumanMessage, |
| 24 | + SystemMessage, |
| 25 | + ToolMessage, |
| 26 | +) |
21 | 27 | from langchain_core.outputs import ChatGeneration |
22 | 28 |
|
23 | 29 | from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var |
@@ -168,3 +174,89 @@ async def test_multiple_generations_token_accumulation(): |
168 | 174 | assert llm_stats.get_stat("total_tokens") == 19 |
169 | 175 | assert llm_stats.get_stat("total_prompt_tokens") == 12 |
170 | 176 | assert llm_stats.get_stat("total_completion_tokens") == 7 |
| 177 | + |
| 178 | + |
| 179 | +@pytest.mark.asyncio |
| 180 | +async def test_tool_message_labeling_in_logging(): |
| 181 | + """Test that tool messages are labeled as 'Tool' in logging output.""" |
| 182 | + llm_call_info = LLMCallInfo() |
| 183 | + llm_call_info_var.set(llm_call_info) |
| 184 | + |
| 185 | + llm_stats = LLMStats() |
| 186 | + llm_stats_var.set(llm_stats) |
| 187 | + |
| 188 | + explain_info = ExplainInfo() |
| 189 | + explain_info_var.set(explain_info) |
| 190 | + |
| 191 | + handler = LoggingCallbackHandler() |
| 192 | + |
| 193 | + messages = [ |
| 194 | + HumanMessage(content="Hello"), |
| 195 | + AIMessage(content="Hi there"), |
| 196 | + SystemMessage(content="System message"), |
| 197 | + ToolMessage(content="Tool result", tool_call_id="test_tool_call"), |
| 198 | + ] |
| 199 | + |
| 200 | + with patch("nemoguardrails.logging.callbacks.log") as mock_log: |
| 201 | + await handler.on_chat_model_start( |
| 202 | + serialized={}, |
| 203 | + messages=[messages], |
| 204 | + run_id=uuid4(), |
| 205 | + ) |
| 206 | + |
| 207 | + mock_log.info.assert_called() |
| 208 | + |
| 209 | + logged_prompt = None |
| 210 | + for call in mock_log.info.call_args_list: |
| 211 | + if "Prompt Messages" in str(call): |
| 212 | + logged_prompt = call[0][1] |
| 213 | + break |
| 214 | + |
| 215 | + assert logged_prompt is not None |
| 216 | + assert "[cyan]User[/]" in logged_prompt |
| 217 | + assert "[cyan]Bot[/]" in logged_prompt |
| 218 | + assert "[cyan]System[/]" in logged_prompt |
| 219 | + assert "[cyan]Tool[/]" in logged_prompt |
| 220 | + |
| 221 | + |
| 222 | +@pytest.mark.asyncio |
| 223 | +async def test_unknown_message_type_labeling(): |
| 224 | + """Test that unknown message types display their actual type name.""" |
| 225 | + llm_call_info = LLMCallInfo() |
| 226 | + llm_call_info_var.set(llm_call_info) |
| 227 | + |
| 228 | + llm_stats = LLMStats() |
| 229 | + llm_stats_var.set(llm_stats) |
| 230 | + |
| 231 | + explain_info = ExplainInfo() |
| 232 | + explain_info_var.set(explain_info) |
| 233 | + |
| 234 | + handler = LoggingCallbackHandler() |
| 235 | + |
| 236 | + class CustomMessage(BaseMessage): |
| 237 | + def __init__(self, content, msg_type): |
| 238 | + super().__init__(content=content, type=msg_type) |
| 239 | + |
| 240 | + messages: list[BaseMessage] = [ |
| 241 | + CustomMessage("Custom message", "custom"), |
| 242 | + CustomMessage("Function message", "function"), |
| 243 | + ] |
| 244 | + |
| 245 | + with patch("nemoguardrails.logging.callbacks.log") as mock_log: |
| 246 | + await handler.on_chat_model_start( |
| 247 | + serialized={}, |
| 248 | + messages=[messages], |
| 249 | + run_id=uuid4(), |
| 250 | + ) |
| 251 | + |
| 252 | + mock_log.info.assert_called() |
| 253 | + |
| 254 | + logged_prompt = None |
| 255 | + for call in mock_log.info.call_args_list: |
| 256 | + if "Prompt Messages" in str(call): |
| 257 | + logged_prompt = call[0][1] |
| 258 | + break |
| 259 | + |
| 260 | + assert logged_prompt is not None |
| 261 | + assert "[cyan]Custom[/]" in logged_prompt |
| 262 | + assert "[cyan]Function[/]" in logged_prompt |
0 commit comments