Skip to content

Commit 02e989e

Browse files
Pouyanpitgasser-nv
authored andcommitted
fix(logging): correct message type formatting in logs (#1416)
* refactor(logging): replace if-else with dict for type mapping
1 parent ea15382 commit 02e989e

File tree

3 files changed

+59
-11
lines changed

3 files changed

+59
-11
lines changed

nemoguardrails/logging/callbacks.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,17 @@ async def on_chat_model_start(
101101
if explain_info:
102102
explain_info.llm_calls.append(llm_call_info)
103103

104+
type_map = {
105+
"human": "User",
106+
"ai": "Bot",
107+
"tool": "Tool",
108+
"system": "System",
109+
"developer": "Developer",
110+
}
104111
prompt = "\n" + "\n".join(
105112
[
106113
"[cyan]"
107-
+ (
108-
"User"
109-
if msg.type == "human"
110-
else "Bot"
111-
if msg.type == "ai"
112-
else "Tool"
113-
if msg.type == "tool"
114-
else "System"
115-
)
114+
+ type_map.get(msg.type, msg.type.title())
116115
+ "[/]"
117116
+ "\n"
118117
+ (msg.content if isinstance(msg.content, str) else "")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ pyright = "^1.1.405"
155155

156156
# Directories in which to run Pyright type-checking
157157
[tool.pyright]
158-
include = ["nemoguardrails/rails/**"]
158+
include = ["nemoguardrails/rails/**", "tests/test_callbacks.py"]
159159

160160
[tool.poetry.group.docs]
161161
optional = true

tests/test_callbacks.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818

1919
import pytest
2020
from langchain.schema import Generation, LLMResult
21-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
21+
from langchain_core.messages import (
22+
AIMessage,
23+
BaseMessage,
24+
HumanMessage,
25+
SystemMessage,
26+
ToolMessage,
27+
)
2228
from langchain_core.outputs import ChatGeneration
2329

2430
from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var
@@ -212,3 +218,46 @@ async def test_tool_message_labeling_in_logging():
212218
assert "[cyan]Bot[/]" in logged_prompt
213219
assert "[cyan]System[/]" in logged_prompt
214220
assert "[cyan]Tool[/]" in logged_prompt
221+
222+
223+
@pytest.mark.asyncio
224+
async def test_unknown_message_type_labeling():
225+
"""Test that unknown message types display their actual type name."""
226+
llm_call_info = LLMCallInfo()
227+
llm_call_info_var.set(llm_call_info)
228+
229+
llm_stats = LLMStats()
230+
llm_stats_var.set(llm_stats)
231+
232+
explain_info = ExplainInfo()
233+
explain_info_var.set(explain_info)
234+
235+
handler = LoggingCallbackHandler()
236+
237+
class CustomMessage(BaseMessage):
238+
def __init__(self, content, msg_type):
239+
super().__init__(content=content, type=msg_type)
240+
241+
messages: list[BaseMessage] = [
242+
CustomMessage("Custom message", "custom"),
243+
CustomMessage("Function message", "function"),
244+
]
245+
246+
with patch("nemoguardrails.logging.callbacks.log") as mock_log:
247+
await handler.on_chat_model_start(
248+
serialized={},
249+
messages=[messages],
250+
run_id=uuid4(),
251+
)
252+
253+
mock_log.info.assert_called()
254+
255+
logged_prompt = None
256+
for call in mock_log.info.call_args_list:
257+
if "Prompt Messages" in str(call):
258+
logged_prompt = call[0][1]
259+
break
260+
261+
assert logged_prompt is not None
262+
assert "[cyan]Custom[/]" in logged_prompt
263+
assert "[cyan]Function[/]" in logged_prompt

0 commit comments

Comments
 (0)