Skip to content

Commit 4f5ca01

Browse files
Pouyanpitgasser-nv
authored andcommitted
fix(logging): add "Tool" type to message sender labeling (#1412)
Previously, messages of type "tool" were not distinctly labeled in the LoggingCallbackHandler output, causing them to be grouped under "System". This change adds explicit handling for "tool" messages, labeling them as "Tool" in the logs for improved clarity and debugging.
1 parent 66958e3 commit 4f5ca01

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

nemoguardrails/logging/callbacks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ async def on_chat_model_start(
109109
if msg.type == "human"
110110
else "Bot"
111111
if msg.type == "ai"
112+
else "Tool"
113+
if msg.type == "tool"
112114
else "System"
113115
)
114116
+ "[/]"

tests/test_callbacks.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from unittest.mock import patch
1617
from uuid import uuid4
1718

1819
import pytest
1920
from langchain.schema import Generation, LLMResult
20-
from langchain_core.messages import AIMessage
21+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
2122
from langchain_core.outputs import ChatGeneration
2223

2324
from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var
@@ -168,3 +169,46 @@ async def test_multiple_generations_token_accumulation():
168169
assert llm_stats.get_stat("total_tokens") == 19
169170
assert llm_stats.get_stat("total_prompt_tokens") == 12
170171
assert llm_stats.get_stat("total_completion_tokens") == 7
172+
173+
174+
@pytest.mark.asyncio
175+
async def test_tool_message_labeling_in_logging():
176+
"""Test that tool messages are labeled as 'Tool' in logging output."""
177+
llm_call_info = LLMCallInfo()
178+
llm_call_info_var.set(llm_call_info)
179+
180+
llm_stats = LLMStats()
181+
llm_stats_var.set(llm_stats)
182+
183+
explain_info = ExplainInfo()
184+
explain_info_var.set(explain_info)
185+
186+
handler = LoggingCallbackHandler()
187+
188+
messages = [
189+
HumanMessage(content="Hello"),
190+
AIMessage(content="Hi there"),
191+
SystemMessage(content="System message"),
192+
ToolMessage(content="Tool result", tool_call_id="test_tool_call"),
193+
]
194+
195+
with patch("nemoguardrails.logging.callbacks.log") as mock_log:
196+
await handler.on_chat_model_start(
197+
serialized={},
198+
messages=[messages],
199+
run_id=uuid4(),
200+
)
201+
202+
mock_log.info.assert_called()
203+
204+
logged_prompt = None
205+
for call in mock_log.info.call_args_list:
206+
if "Prompt Messages" in str(call):
207+
logged_prompt = call[0][1]
208+
break
209+
210+
assert logged_prompt is not None
211+
assert "[cyan]User[/]" in logged_prompt
212+
assert "[cyan]Bot[/]" in logged_prompt
213+
assert "[cyan]System[/]" in logged_prompt
214+
assert "[cyan]Tool[/]" in logged_prompt

0 commit comments

Comments
 (0)