Skip to content

Commit 5b43c7e

Browse files
committed
fix(logging): correct message type formatting in logs
add test fix
1 parent 08629b5 commit 5b43c7e

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

nemoguardrails/logging/callbacks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ async def on_chat_model_start(
112112
else "Tool"
113113
if msg.type == "tool"
114114
else "System"
115+
if msg.type in {"system", "developer"}
116+
else msg.type.title()
115117
)
116118
+ "[/]"
117119
+ "\n"

tests/test_callbacks.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818

1919
import pytest
2020
from langchain.schema import Generation, LLMResult
21-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
21+
from langchain_core.messages import (
22+
AIMessage,
23+
BaseMessage,
24+
HumanMessage,
25+
SystemMessage,
26+
ToolMessage,
27+
)
2228
from langchain_core.outputs import ChatGeneration
2329

2430
from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var
@@ -212,3 +218,46 @@ async def test_tool_message_labeling_in_logging():
212218
assert "[cyan]Bot[/]" in logged_prompt
213219
assert "[cyan]System[/]" in logged_prompt
214220
assert "[cyan]Tool[/]" in logged_prompt
221+
222+
223+
@pytest.mark.asyncio
224+
async def test_unknown_message_type_labeling():
225+
"""Test that unknown message types display their actual type name."""
226+
llm_call_info = LLMCallInfo()
227+
llm_call_info_var.set(llm_call_info)
228+
229+
llm_stats = LLMStats()
230+
llm_stats_var.set(llm_stats)
231+
232+
explain_info = ExplainInfo()
233+
explain_info_var.set(explain_info)
234+
235+
handler = LoggingCallbackHandler()
236+
237+
class CustomMessage(BaseMessage):
238+
def __init__(self, content, msg_type):
239+
super().__init__(content=content, type=msg_type)
240+
241+
messages: list[BaseMessage] = [
242+
CustomMessage("Custom message", "custom"),
243+
CustomMessage("Function message", "function"),
244+
]
245+
246+
with patch("nemoguardrails.logging.callbacks.log") as mock_log:
247+
await handler.on_chat_model_start(
248+
serialized={},
249+
messages=[messages],
250+
run_id=uuid4(),
251+
)
252+
253+
mock_log.info.assert_called()
254+
255+
logged_prompt = None
256+
for call in mock_log.info.call_args_list:
257+
if "Prompt Messages" in str(call):
258+
logged_prompt = call[0][1]
259+
break
260+
261+
assert logged_prompt is not None
262+
assert "[cyan]Custom[/]" in logged_prompt
263+
assert "[cyan]Function[/]" in logged_prompt

0 commit comments

Comments
 (0)