Skip to content

Commit 9a8c910

Browse files
committed
fix(logging): correct message type formatting in logs (#1416)
* refactor(logging): replace if-else with dict for type mapping
1 parent 9817123 commit 9a8c910

File tree

3 files changed

+104
-8
lines changed

3 files changed

+104
-8
lines changed

nemoguardrails/logging/callbacks.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,17 @@ async def on_chat_model_start(
101101
if explain_info:
102102
explain_info.llm_calls.append(llm_call_info)
103103

104+
type_map = {
105+
"human": "User",
106+
"ai": "Bot",
107+
"tool": "Tool",
108+
"system": "System",
109+
"developer": "Developer",
110+
}
104111
prompt = "\n" + "\n".join(
105112
[
106113
"[cyan]"
107-
+ (
108-
"User"
109-
if msg.type == "human"
110-
else "Bot"
111-
if msg.type == "ai"
112-
else "System"
113-
)
114+
+ type_map.get(msg.type, msg.type.title())
114115
+ "[/]"
115116
+ "\n"
116117
+ (msg.content if isinstance(msg.content, str) else "")

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ yara-python = "^4.5.1"
152152
opentelemetry-api = "^1.34.1"
153153
opentelemetry-sdk = "^1.34.1"
154154

155+
# Directories in which to run Pyright type-checking
156+
[tool.pyright]
157+
include = ["nemoguardrails/rails/**", "tests/test_callbacks.py"]
155158

156159
[tool.poetry.group.docs]
157160
optional = true

tests/test_callbacks.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717

1818
import pytest
1919
from langchain.schema import Generation, LLMResult
20-
from langchain_core.messages import AIMessage
20+
from langchain_core.messages import (
21+
AIMessage,
22+
BaseMessage,
23+
HumanMessage,
24+
SystemMessage,
25+
ToolMessage,
26+
)
2127
from langchain_core.outputs import ChatGeneration
2228

2329
from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var
@@ -168,3 +174,89 @@ async def test_multiple_generations_token_accumulation():
168174
assert llm_stats.get_stat("total_tokens") == 19
169175
assert llm_stats.get_stat("total_prompt_tokens") == 12
170176
assert llm_stats.get_stat("total_completion_tokens") == 7
177+
178+
179+
@pytest.mark.asyncio
180+
async def test_tool_message_labeling_in_logging():
181+
"""Test that tool messages are labeled as 'Tool' in logging output."""
182+
llm_call_info = LLMCallInfo()
183+
llm_call_info_var.set(llm_call_info)
184+
185+
llm_stats = LLMStats()
186+
llm_stats_var.set(llm_stats)
187+
188+
explain_info = ExplainInfo()
189+
explain_info_var.set(explain_info)
190+
191+
handler = LoggingCallbackHandler()
192+
193+
messages = [
194+
HumanMessage(content="Hello"),
195+
AIMessage(content="Hi there"),
196+
SystemMessage(content="System message"),
197+
ToolMessage(content="Tool result", tool_call_id="test_tool_call"),
198+
]
199+
200+
with patch("nemoguardrails.logging.callbacks.log") as mock_log:
201+
await handler.on_chat_model_start(
202+
serialized={},
203+
messages=[messages],
204+
run_id=uuid4(),
205+
)
206+
207+
mock_log.info.assert_called()
208+
209+
logged_prompt = None
210+
for call in mock_log.info.call_args_list:
211+
if "Prompt Messages" in str(call):
212+
logged_prompt = call[0][1]
213+
break
214+
215+
assert logged_prompt is not None
216+
assert "[cyan]User[/]" in logged_prompt
217+
assert "[cyan]Bot[/]" in logged_prompt
218+
assert "[cyan]System[/]" in logged_prompt
219+
assert "[cyan]Tool[/]" in logged_prompt
220+
221+
222+
@pytest.mark.asyncio
223+
async def test_unknown_message_type_labeling():
224+
"""Test that unknown message types display their actual type name."""
225+
llm_call_info = LLMCallInfo()
226+
llm_call_info_var.set(llm_call_info)
227+
228+
llm_stats = LLMStats()
229+
llm_stats_var.set(llm_stats)
230+
231+
explain_info = ExplainInfo()
232+
explain_info_var.set(explain_info)
233+
234+
handler = LoggingCallbackHandler()
235+
236+
class CustomMessage(BaseMessage):
237+
def __init__(self, content, msg_type):
238+
super().__init__(content=content, type=msg_type)
239+
240+
messages: list[BaseMessage] = [
241+
CustomMessage("Custom message", "custom"),
242+
CustomMessage("Function message", "function"),
243+
]
244+
245+
with patch("nemoguardrails.logging.callbacks.log") as mock_log:
246+
await handler.on_chat_model_start(
247+
serialized={},
248+
messages=[messages],
249+
run_id=uuid4(),
250+
)
251+
252+
mock_log.info.assert_called()
253+
254+
logged_prompt = None
255+
for call in mock_log.info.call_args_list:
256+
if "Prompt Messages" in str(call):
257+
logged_prompt = call[0][1]
258+
break
259+
260+
assert logged_prompt is not None
261+
assert "[cyan]Custom[/]" in logged_prompt
262+
assert "[cyan]Function[/]" in logged_prompt

0 commit comments

Comments
 (0)