|
18 | 18 |
|
19 | 19 | import pytest |
20 | 20 | from langchain.schema import Generation, LLMResult |
21 | | -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage |
| 21 | +from langchain_core.messages import ( |
| 22 | + AIMessage, |
| 23 | + BaseMessage, |
| 24 | + HumanMessage, |
| 25 | + SystemMessage, |
| 26 | + ToolMessage, |
| 27 | +) |
22 | 28 | from langchain_core.outputs import ChatGeneration |
23 | 29 |
|
24 | 30 | from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var |
@@ -212,3 +218,46 @@ async def test_tool_message_labeling_in_logging(): |
212 | 218 | assert "[cyan]Bot[/]" in logged_prompt |
213 | 219 | assert "[cyan]System[/]" in logged_prompt |
214 | 220 | assert "[cyan]Tool[/]" in logged_prompt |
| 221 | + |
| 222 | + |
| 223 | +@pytest.mark.asyncio |
| 224 | +async def test_unknown_message_type_labeling(): |
| 225 | + """Test that unknown message types display their actual type name.""" |
| 226 | + llm_call_info = LLMCallInfo() |
| 227 | + llm_call_info_var.set(llm_call_info) |
| 228 | + |
| 229 | + llm_stats = LLMStats() |
| 230 | + llm_stats_var.set(llm_stats) |
| 231 | + |
| 232 | + explain_info = ExplainInfo() |
| 233 | + explain_info_var.set(explain_info) |
| 234 | + |
| 235 | + handler = LoggingCallbackHandler() |
| 236 | + |
| 237 | + class CustomMessage(BaseMessage): |
| 238 | + def __init__(self, content, msg_type): |
| 239 | + super().__init__(content=content, type=msg_type) |
| 240 | + |
| 241 | + messages: list[BaseMessage] = [ |
| 242 | + CustomMessage("Custom message", "custom"), |
| 243 | + CustomMessage("Function message", "function"), |
| 244 | + ] |
| 245 | + |
| 246 | + with patch("nemoguardrails.logging.callbacks.log") as mock_log: |
| 247 | + await handler.on_chat_model_start( |
| 248 | + serialized={}, |
| 249 | + messages=[messages], |
| 250 | + run_id=uuid4(), |
| 251 | + ) |
| 252 | + |
| 253 | + mock_log.info.assert_called() |
| 254 | + |
| 255 | + logged_prompt = None |
| 256 | + for call in mock_log.info.call_args_list: |
| 257 | + if "Prompt Messages" in str(call): |
| 258 | + logged_prompt = call[0][1] |
| 259 | + break |
| 260 | + |
| 261 | + assert logged_prompt is not None |
| 262 | + assert "[cyan]Custom[/]" in logged_prompt |
| 263 | + assert "[cyan]Function[/]" in logged_prompt |
0 commit comments