|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +from unittest.mock import patch |
16 | 17 | from uuid import uuid4 |
17 | 18 |
|
18 | 19 | import pytest |
19 | 20 | from langchain.schema import Generation, LLMResult |
20 | | -from langchain_core.messages import AIMessage |
| 21 | +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage |
21 | 22 | from langchain_core.outputs import ChatGeneration |
22 | 23 |
|
23 | 24 | from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var |
@@ -168,3 +169,46 @@ async def test_multiple_generations_token_accumulation(): |
168 | 169 | assert llm_stats.get_stat("total_tokens") == 19 |
169 | 170 | assert llm_stats.get_stat("total_prompt_tokens") == 12 |
170 | 171 | assert llm_stats.get_stat("total_completion_tokens") == 7 |
| 172 | + |
| 173 | + |
| 174 | +@pytest.mark.asyncio |
| 175 | +async def test_tool_message_labeling_in_logging(): |
| 176 | + """Test that tool messages are labeled as 'Tool' in logging output.""" |
| 177 | + llm_call_info = LLMCallInfo() |
| 178 | + llm_call_info_var.set(llm_call_info) |
| 179 | + |
| 180 | + llm_stats = LLMStats() |
| 181 | + llm_stats_var.set(llm_stats) |
| 182 | + |
| 183 | + explain_info = ExplainInfo() |
| 184 | + explain_info_var.set(explain_info) |
| 185 | + |
| 186 | + handler = LoggingCallbackHandler() |
| 187 | + |
| 188 | + messages = [ |
| 189 | + HumanMessage(content="Hello"), |
| 190 | + AIMessage(content="Hi there"), |
| 191 | + SystemMessage(content="System message"), |
| 192 | + ToolMessage(content="Tool result", tool_call_id="test_tool_call"), |
| 193 | + ] |
| 194 | + |
| 195 | + with patch("nemoguardrails.logging.callbacks.log") as mock_log: |
| 196 | + await handler.on_chat_model_start( |
| 197 | + serialized={}, |
| 198 | + messages=[messages], |
| 199 | + run_id=uuid4(), |
| 200 | + ) |
| 201 | + |
| 202 | + mock_log.info.assert_called() |
| 203 | + |
| 204 | + logged_prompt = None |
| 205 | + for call in mock_log.info.call_args_list: |
| 206 | + if "Prompt Messages" in str(call): |
| 207 | + logged_prompt = call[0][1] |
| 208 | + break |
| 209 | + |
| 210 | + assert logged_prompt is not None |
| 211 | + assert "[cyan]User[/]" in logged_prompt |
| 212 | + assert "[cyan]Bot[/]" in logged_prompt |
| 213 | + assert "[cyan]System[/]" in logged_prompt |
| 214 | + assert "[cyan]Tool[/]" in logged_prompt |
0 commit comments