From 6f746ab16236e7edfebdee583c7b1d103c9b474f Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Mon, 15 Sep 2025 10:40:54 +0200 Subject: [PATCH] feat(tool-calling): add tool call passthrough support in LLMRails Implements tool call extraction and passthrough functionality in LLMRails: - Add tool_calls_var context variable for storing LLM tool calls - Refactor llm_call utils to extract and store tool calls from responses - Support tool calls in both GenerationResponse and dict message formats - Add ToolMessage support for langchain message conversion - Comprehensive test coverage for tool calling integration feat(tool-calling): add tool call passthrough support in LLMRails Implements tool call extraction and passthrough functionality in LLMRails: - Add tool_calls_var context variable for storing LLM tool calls - Refactor llm_call utils to extract and store tool calls from responses - Support tool calls in both GenerationResponse and dict message formats - Add ToolMessage support for langchain message conversion - Comprehensive test coverage for tool calling integration --- nemoguardrails/actions/llm/utils.py | 149 +++++--- nemoguardrails/context.py | 5 + nemoguardrails/rails/llm/llmrails.py | 8 + nemoguardrails/rails/llm/options.py | 4 + tests/rails/llm/test_options.py | 85 ++++- ...st_tool_calling_passthrough_integration.py | 360 ++++++++++++++++++ tests/test_tool_calling_utils.py | 252 ++++++++++++ tests/test_tool_calls_context.py | 71 ++++ 8 files changed, 890 insertions(+), 44 deletions(-) create mode 100644 tests/test_tool_calling_passthrough_integration.py create mode 100644 tests/test_tool_calling_utils.py create mode 100644 tests/test_tool_calls_context.py diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index 7b80d9d37..7b4f8dce1 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -20,11 +20,15 @@ from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager from langchain.prompts.base import StringPromptValue from langchain.prompts.chat import ChatPromptValue -from langchain.schema import AIMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from nemoguardrails.colang.v2_x.lang.colang_ast import Flow from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents -from nemoguardrails.context import llm_call_info_var, reasoning_trace_var +from nemoguardrails.context import ( + llm_call_info_var, + reasoning_trace_var, + tool_calls_var, +) from nemoguardrails.logging.callbacks import logging_callbacks from nemoguardrails.logging.explain import LLMCallInfo @@ -72,7 +76,22 @@ async def llm_call( custom_callback_handlers: Optional[List[AsyncCallbackHandler]] = None, ) -> str: """Calls the LLM with a prompt and returns the generated text.""" - # We initialize a new LLM call if we don't have one already + _setup_llm_call_info(llm, model_name, model_provider) + all_callbacks = _prepare_callbacks(custom_callback_handlers) + + if isinstance(prompt, str): + response = await _invoke_with_string_prompt(llm, prompt, all_callbacks, stop) + else: + response = await _invoke_with_message_list(llm, prompt, all_callbacks, stop) + + _store_tool_calls(response) + return _extract_content(response) + + +def _setup_llm_call_info( + llm: BaseLanguageModel, model_name: Optional[str], model_provider: Optional[str] +) -> None: + """Initialize or update LLM call info in context.""" llm_call_info = llm_call_info_var.get() if llm_call_info is None: llm_call_info = LLMCallInfo() @@ -81,52 +100,84 @@ async def llm_call( llm_call_info.llm_model_name = model_name or _infer_model_name(llm) llm_call_info.llm_provider_name = model_provider + +def _prepare_callbacks( + custom_callback_handlers: Optional[List[AsyncCallbackHandler]], +) -> BaseCallbackManager: + """Prepare callback manager with custom handlers if provided.""" if custom_callback_handlers and custom_callback_handlers != [None]: - all_callbacks = BaseCallbackManager( + return BaseCallbackManager( handlers=logging_callbacks.handlers + custom_callback_handlers, inheritable_handlers=logging_callbacks.handlers + custom_callback_handlers, ) - else: - all_callbacks = logging_callbacks + return logging_callbacks - if isinstance(prompt, str): - # stop sinks here - try: - result = await llm.agenerate_prompt( - [StringPromptValue(text=prompt)], callbacks=all_callbacks, stop=stop + +async def _invoke_with_string_prompt( + llm: BaseLanguageModel, + prompt: str, + callbacks: BaseCallbackManager, + stop: Optional[List[str]], +): + """Invoke LLM with string prompt.""" + try: + return await llm.ainvoke(prompt, config={"callbacks": callbacks, "stop": stop}) + except Exception as e: + raise LLMCallException(e) + + +async def _invoke_with_message_list( + llm: BaseLanguageModel, + prompt: List[dict], + callbacks: BaseCallbackManager, + stop: Optional[List[str]], +): + """Invoke LLM with message list after converting to LangChain format.""" + messages = _convert_messages_to_langchain_format(prompt) + try: + return await llm.ainvoke( + messages, config={"callbacks": callbacks, "stop": stop} + ) + except Exception as e: + raise LLMCallException(e) + + +def _convert_messages_to_langchain_format(prompt: List[dict]) -> List: + """Convert message list to LangChain message format.""" + messages = [] + for msg in prompt: + msg_type = msg["type"] if "type" in msg else msg["role"] + + if msg_type == "user": + messages.append(HumanMessage(content=msg["content"])) + elif msg_type in ["bot", "assistant"]: + messages.append(AIMessage(content=msg["content"])) + elif msg_type == "system": + messages.append(SystemMessage(content=msg["content"])) + elif msg_type == "tool": + messages.append( + ToolMessage( + content=msg["content"], + tool_call_id=msg.get("tool_call_id", ""), + ) ) - except Exception as e: - raise LLMCallException(e) - llm_call_info.raw_response = result.llm_output + else: + raise ValueError(f"Unknown message type {msg_type}") - # TODO: error handling - return result.generations[0][0].text - else: - # We first need to translate the array of messages into LangChain message format - messages = [] - for _msg in prompt: - msg_type = _msg["type"] if "type" in _msg else _msg["role"] - if msg_type == "user": - messages.append(HumanMessage(content=_msg["content"])) - elif msg_type in ["bot", "assistant"]: - messages.append(AIMessage(content=_msg["content"])) - elif msg_type == "system": - messages.append(SystemMessage(content=_msg["content"])) - else: - # TODO: add support for tool-related messages - raise ValueError(f"Unknown message type {msg_type}") + return messages - try: - result = await llm.agenerate_prompt( - [ChatPromptValue(messages=messages)], callbacks=all_callbacks, stop=stop - ) - except Exception as e: - raise LLMCallException(e) +def _store_tool_calls(response) -> None: + """Extract and store tool calls from response in context.""" + tool_calls = getattr(response, "tool_calls", None) + tool_calls_var.set(tool_calls) - llm_call_info.raw_response = result.llm_output - return result.generations[0][0].text +def _extract_content(response) -> str: + """Extract text content from response.""" + if hasattr(response, "content"): + return response.content + return str(response) def get_colang_history( @@ -175,15 +226,15 @@ def get_colang_history( history += f'user "{event["text"]}"\n' elif event["type"] == "UserIntent": if include_texts: - history += f' {event["intent"]}\n' + history += f" {event['intent']}\n" else: - history += f'user {event["intent"]}\n' + history += f"user {event['intent']}\n" elif event["type"] == "BotIntent": # If we have instructions, we add them before the bot message. # But we only do that for the last bot message. if "instructions" in event and idx == last_bot_intent_idx: history += f"# {event['instructions']}\n" - history += f'bot {event["intent"]}\n' + history += f"bot {event['intent']}\n" elif event["type"] == "StartUtteranceBotAction" and include_texts: history += f' "{event["script"]}"\n' # We skip system actions from this log @@ -352,9 +403,9 @@ def flow_to_colang(flow: Union[dict, Flow]) -> str: if "_type" not in element: raise Exception("bla") if element["_type"] == "UserIntent": - colang_flow += f'user {element["intent_name"]}\n' + colang_flow += f"user {element['intent_name']}\n" elif element["_type"] == "run_action" and element["action_name"] == "utter": - colang_flow += f'bot {element["action_params"]["value"]}\n' + colang_flow += f"bot {element['action_params']['value']}\n" return colang_flow @@ -592,3 +643,15 @@ def get_and_clear_reasoning_trace_contextvar() -> Optional[str]: reasoning_trace_var.set(None) return reasoning_trace return None + + +def get_and_clear_tool_calls_contextvar() -> Optional[list]: + """Get the current tool calls and clear them from the context. + + Returns: + Optional[list]: The tool calls if they exist, None otherwise. + """ + if tool_calls := tool_calls_var.get(): + tool_calls_var.set(None) + return tool_calls + return None diff --git a/nemoguardrails/context.py b/nemoguardrails/context.py index e66f1a0d5..ff6a3a2a5 100644 --- a/nemoguardrails/context.py +++ b/nemoguardrails/context.py @@ -37,3 +37,8 @@ reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( "reasoning_trace", default=None ) + +# The tool calls from the current LLM response. +tool_calls_var: contextvars.ContextVar[Optional[list]] = contextvars.ContextVar( + "tool_calls", default=None +) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 0027b7fc5..c05f3ddf9 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -33,6 +33,7 @@ from nemoguardrails.actions.llm.generation import LLMGenerationActions from nemoguardrails.actions.llm.utils import ( get_and_clear_reasoning_trace_contextvar, + get_and_clear_tool_calls_contextvar, get_colang_history, ) from nemoguardrails.actions.output_mapping import is_output_blocked @@ -1084,6 +1085,8 @@ async def generate_async( options.log.llm_calls = True options.log.internal_events = True + tool_calls = get_and_clear_tool_calls_contextvar() + # If we have generation options, we prepare a GenerationResponse instance. if options: # If a prompt was used, we only need to return the content of the message. @@ -1100,6 +1103,9 @@ async def generate_async( reasoning_trace + res.response[0]["content"] ) + if tool_calls: + res.tool_calls = tool_calls + if self.config.colang_version == "1.0": # If output variables are specified, we extract their values if options.output_vars: @@ -1228,6 +1234,8 @@ async def generate_async( if prompt: return new_message["content"] else: + if tool_calls: + new_message["tool_calls"] = tool_calls return new_message def stream_async( diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index 51c712f03..40decabbd 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -408,6 +408,10 @@ class GenerationResponse(BaseModel): default=None, description="A state object which can be used in subsequent calls to continue the interaction.", ) + tool_calls: Optional[list] = Field( + default=None, + description="Tool calls extracted from the LLM response, if any.", + ) if __name__ == "__main__": diff --git a/tests/rails/llm/test_options.py b/tests/rails/llm/test_options.py index a2c99742d..d7f575acd 100644 --- a/tests/rails/llm/test_options.py +++ b/tests/rails/llm/test_options.py @@ -15,7 +15,11 @@ import pytest -from nemoguardrails.rails.llm.options import GenerationOptions, GenerationRailsOptions +from nemoguardrails.rails.llm.options import ( + GenerationOptions, + GenerationRailsOptions, + GenerationResponse, +) def test_generation_options_initialization(): @@ -110,3 +114,82 @@ def test_generation_options_serialization(): assert '"output":false' in options_json assert '"activated_rails":true' in options_json assert '"llm_calls":true' in options_json + + +def test_generation_response_initialization(): + """Test GenerationResponse initialization.""" + response = GenerationResponse(response="Hello, world!") + assert response.response == "Hello, world!" + assert response.tool_calls is None + assert response.llm_output is None + assert response.state is None + + +def test_generation_response_with_tool_calls(): + test_tool_calls = [ + { + "name": "get_weather", + "args": {"location": "NYC"}, + "id": "call_123", + "type": "tool_call", + }, + { + "name": "calculate", + "args": {"expression": "2+2"}, + "id": "call_456", + "type": "tool_call", + }, + ] + + response = GenerationResponse( + response=[{"role": "assistant", "content": "I'll help you with that."}], + tool_calls=test_tool_calls, + ) + + assert response.tool_calls == test_tool_calls + assert len(response.tool_calls) == 2 + assert response.tool_calls[0]["id"] == "call_123" + assert response.tool_calls[1]["name"] == "calculate" + + +def test_generation_response_empty_tool_calls(): + """Test GenerationResponse with empty tool calls list.""" + response = GenerationResponse(response="No tools needed", tool_calls=[]) + + assert response.tool_calls == [] + assert len(response.tool_calls) == 0 + + +def test_generation_response_serialization_with_tool_calls(): + test_tool_calls = [ + {"name": "test_func", "args": {}, "id": "call_test", "type": "tool_call"} + ] + + response = GenerationResponse(response="Response text", tool_calls=test_tool_calls) + + response_dict = response.dict() + assert "tool_calls" in response_dict + assert response_dict["tool_calls"] == test_tool_calls + + response_json = response.json() + assert "tool_calls" in response_json + assert "test_func" in response_json + + +def test_generation_response_model_validation(): + """Test GenerationResponse model validation.""" + test_tool_calls = [ + {"name": "valid_function", "args": {}, "id": "call_123", "type": "tool_call"}, + {"name": "another_function", "args": {}, "id": "call_456", "type": "tool_call"}, + ] + + response = GenerationResponse( + response=[{"role": "assistant", "content": "Test response"}], + tool_calls=test_tool_calls, + llm_output={"token_usage": {"total_tokens": 50}}, + ) + + assert response.tool_calls is not None + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 2 + assert response.llm_output["token_usage"]["total_tokens"] == 50 diff --git a/tests/test_tool_calling_passthrough_integration.py b/tests/test_tool_calling_passthrough_integration.py new file mode 100644 index 000000000..ca1689b97 --- /dev/null +++ b/tests/test_tool_calling_passthrough_integration.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from nemoguardrails import RailsConfig +from nemoguardrails.context import tool_calls_var +from nemoguardrails.rails.llm.llmrails import GenerationOptions, GenerationResponse +from tests.utils import TestChat + + +class TestToolCallingPassthroughIntegration: + def setup_method(self): + self.passthrough_config = RailsConfig.from_content( + colang_content="", + yaml_content=""" + models: [] + passthrough: true + """, + ) + + @pytest.mark.asyncio + async def test_tool_calls_work_in_passthrough_mode_with_options(self): + test_tool_calls = [ + { + "name": "get_weather", + "args": {"location": "NYC"}, + "id": "call_123", + "type": "tool_call", + }, + { + "name": "calculate", + "args": {"a": 2, "b": 2}, + "id": "call_456", + "type": "tool_call", + }, + ] + + with patch( + "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat = TestChat( + self.passthrough_config, + llm_completions=["I'll help you with the weather and calculation."], + ) + + result = await chat.app.generate_async( + messages=[ + { + "role": "user", + "content": "What's the weather in NYC and what's 2+2?", + } + ], + options=GenerationOptions(), + ) + + assert isinstance(result, GenerationResponse) + assert result.tool_calls == test_tool_calls + assert len(result.tool_calls) == 2 + assert isinstance(result.response, list) + assert result.response[0]["role"] == "assistant" + assert "help you" in result.response[0]["content"] + + @pytest.mark.asyncio + async def test_tool_calls_work_in_passthrough_mode_dict_response(self): + test_tool_calls = [ + { + "name": "get_weather", + "args": {"location": "London"}, + "id": "call_weather", + "type": "tool_call", + } + ] + + with patch( + "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat = TestChat( + self.passthrough_config, + llm_completions=["I'll check the weather for you."], + ) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "What's the weather like?"}] + ) + + assert isinstance(result, dict) + assert "tool_calls" in result + assert result["tool_calls"] == test_tool_calls + assert result["role"] == "assistant" + assert "check the weather" in result["content"] + + @pytest.mark.asyncio + async def test_no_tool_calls_in_passthrough_mode(self): + with patch( + "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = None + + chat = TestChat( + self.passthrough_config, + llm_completions=["Hello! How can I help you today?"], + ) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Hello"}], + options=GenerationOptions(), + ) + + assert isinstance(result, GenerationResponse) + assert result.tool_calls is None + assert "Hello! How can I help" in result.response[0]["content"] + + @pytest.mark.asyncio + async def test_empty_tool_calls_in_passthrough_mode(self): + with patch( + "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = [] + + chat = TestChat( + self.passthrough_config, llm_completions=["I understand your request."] + ) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Tell me a joke"}] + ) + + assert isinstance(result, dict) + assert "tool_calls" not in result + assert "understand your request" in result["content"] + + @pytest.mark.asyncio + async def test_tool_calls_with_prompt_mode_passthrough(self): + test_tool_calls = [ + { + "name": "search", + "args": {"query": "latest news"}, + "id": "call_prompt", + "type": "tool_call", + } + ] + + with patch( + "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat = TestChat( + self.passthrough_config, + llm_completions=["I'll search for that information."], + ) + + result = await chat.app.generate_async( + prompt="Search for the latest news", options=GenerationOptions() + ) + + assert isinstance(result, GenerationResponse) + assert result.tool_calls == test_tool_calls + assert isinstance(result.response, str) + assert "search for that information" in result.response + + @pytest.mark.asyncio + async def test_complex_tool_calls_passthrough_integration(self): + complex_tool_calls = [ + { + "name": "get_current_weather", + "args": {"location": "San Francisco", "unit": "fahrenheit"}, + "id": "call_weather_001", + "type": "tool_call", + }, + { + "name": "calculate_tip", + "args": {"bill_amount": 85.50, "tip_percentage": 18}, + "id": "call_calc_002", + "type": "tool_call", + }, + { + "name": "web_search", + "args": {"query": "best restaurants near me", "limit": 5}, + "id": "call_search_003", + "type": "tool_call", + }, + ] + + with patch( + "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = complex_tool_calls + + chat = TestChat( + self.passthrough_config, + llm_completions=[ + "I'll help you with the weather, calculate the tip, and find restaurants." + ], + ) + + result = await chat.app.generate_async( + messages=[ + { + "role": "user", + "content": "I need weather, tip calculation, and restaurant search", + } + ], + options=GenerationOptions(), + ) + + assert isinstance(result, GenerationResponse) + assert result.tool_calls == complex_tool_calls + assert len(result.tool_calls) == 3 + + weather_call = result.tool_calls[0] + assert weather_call["name"] == "get_current_weather" + assert weather_call["args"]["location"] == "San Francisco" + assert weather_call["args"]["unit"] == "fahrenheit" + assert weather_call["id"] == "call_weather_001" + assert weather_call["type"] == "tool_call" + + tip_call = result.tool_calls[1] + assert tip_call["name"] == "calculate_tip" + assert tip_call["args"]["bill_amount"] == 85.50 + assert tip_call["args"]["tip_percentage"] == 18 + assert tip_call["id"] == "call_calc_002" + + search_call = result.tool_calls[2] + assert search_call["name"] == "web_search" + assert search_call["args"]["query"] == "best restaurants near me" + assert search_call["args"]["limit"] == 5 + assert search_call["id"] == "call_search_003" + + def test_get_and_clear_tool_calls_called_correctly(self): + test_tool_calls = [ + { + "name": "test_func", + "args": {"param": "value"}, + "id": "call_test", + "type": "tool_call", + } + ] + + tool_calls_var.set(test_tool_calls) + + from nemoguardrails.actions.llm.utils import get_and_clear_tool_calls_contextvar + + result = get_and_clear_tool_calls_contextvar() + + assert result == test_tool_calls + assert tool_calls_var.get() is None + + @pytest.mark.asyncio + async def test_tool_calls_integration_preserves_other_response_data(self): + test_tool_calls = [ + { + "name": "preserve_test", + "args": {"data": "preserved"}, + "id": "call_preserve", + "type": "tool_call", + } + ] + + with patch( + "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat = TestChat( + self.passthrough_config, + llm_completions=["Response with preserved data."], + ) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Test message"}], + options=GenerationOptions(), + ) + + assert isinstance(result, GenerationResponse) + assert result.tool_calls == test_tool_calls + assert result.response is not None + assert result.llm_output is None + assert result.state is None + assert isinstance(result.response, list) + assert len(result.response) == 1 + assert result.response[0]["role"] == "assistant" + assert result.response[0]["content"] == "Response with preserved data." + + @pytest.mark.asyncio + async def test_tool_calls_with_real_world_examples(self): + realistic_tool_calls = [ + { + "name": "get_weather", + "args": {"location": "London"}, + "id": "call_JMTxzsfy21izMf248MHZvj3G", + "type": "tool_call", + }, + { + "name": "add", + "args": {"a": 15, "b": 27}, + "id": "call_INoaqHesFOrZdjHynU78qjX4", + "type": "tool_call", + }, + ] + + with patch( + "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = realistic_tool_calls + + chat = TestChat( + self.passthrough_config, + llm_completions=[ + "I'll get the weather in London and add 15 + 27 for you." + ], + ) + + result = await chat.app.generate_async( + messages=[ + { + "role": "user", + "content": "What's the weather in London and what's 15 + 27?", + } + ], + options=GenerationOptions(), + ) + + assert isinstance(result, GenerationResponse) + assert result.tool_calls == realistic_tool_calls + + weather_call = result.tool_calls[0] + assert weather_call["name"] == "get_weather" + assert weather_call["args"] == {"location": "London"} + assert weather_call["id"] == "call_JMTxzsfy21izMf248MHZvj3G" + assert weather_call["type"] == "tool_call" + + add_call = result.tool_calls[1] + assert add_call["name"] == "add" + assert add_call["args"] == {"a": 15, "b": 27} + assert add_call["id"] == "call_INoaqHesFOrZdjHynU78qjX4" + assert add_call["type"] == "tool_call" + + @pytest.mark.asyncio + async def test_passthrough_config_requirement(self): + assert self.passthrough_config.passthrough is True diff --git a/tests/test_tool_calling_utils.py b/tests/test_tool_calling_utils.py new file mode 100644 index 000000000..5312b0b6f --- /dev/null +++ b/tests/test_tool_calling_utils.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage + +from nemoguardrails.actions.llm.utils import ( + _convert_messages_to_langchain_format, + _extract_content, + _store_tool_calls, + get_and_clear_tool_calls_contextvar, + llm_call, +) +from nemoguardrails.context import tool_calls_var +from nemoguardrails.rails.llm.llmrails import GenerationResponse + + +def test_get_and_clear_tool_calls_contextvar(): + test_tool_calls = [ + {"name": "test_func", "args": {}, "id": "call_123", "type": "tool_call"} + ] + tool_calls_var.set(test_tool_calls) + + result = get_and_clear_tool_calls_contextvar() + + assert result == test_tool_calls + assert tool_calls_var.get() is None + + +def test_get_and_clear_tool_calls_contextvar_empty(): + """Test that it returns None when no tool calls exist.""" + tool_calls_var.set(None) + + result = get_and_clear_tool_calls_contextvar() + + assert result is None + + +def test_convert_messages_to_langchain_format_user(): + """Test converting user messages to LangChain format.""" + messages = [{"role": "user", "content": "Hello"}] + + result = _convert_messages_to_langchain_format(messages) + + assert len(result) == 1 + assert isinstance(result[0], HumanMessage) + assert result[0].content == "Hello" + + +def test_convert_messages_to_langchain_format_assistant(): + """Test converting assistant messages to LangChain format.""" + messages = [{"role": "assistant", "content": "Hi there"}] + + result = _convert_messages_to_langchain_format(messages) + + assert len(result) == 1 + assert isinstance(result[0], AIMessage) + assert result[0].content == "Hi there" + + +def test_convert_messages_to_langchain_format_bot(): + """Test converting bot messages to LangChain format.""" + messages = [{"type": "bot", "content": "Hello from bot"}] + + result = _convert_messages_to_langchain_format(messages) + + assert len(result) == 1 + assert isinstance(result[0], AIMessage) + assert result[0].content == "Hello from bot" + + +def test_convert_messages_to_langchain_format_system(): + """Test converting system messages to LangChain format.""" + messages = [{"role": "system", "content": "You are a helpful assistant"}] + + result = _convert_messages_to_langchain_format(messages) + + assert len(result) == 1 + assert isinstance(result[0], SystemMessage) + assert result[0].content == "You are a helpful assistant" + + +def test_convert_messages_to_langchain_format_tool(): + """Test converting tool messages to LangChain format.""" + messages = [{"role": "tool", "content": "Tool result", "tool_call_id": "call_123"}] + + result = _convert_messages_to_langchain_format(messages) + + assert len(result) == 1 + assert isinstance(result[0], ToolMessage) + assert result[0].content == "Tool result" + assert result[0].tool_call_id == "call_123" + + +def test_convert_messages_to_langchain_format_tool_no_id(): + """Test converting tool messages without tool_call_id.""" + messages = [{"role": "tool", "content": "Tool result"}] + + result = _convert_messages_to_langchain_format(messages) + + assert len(result) == 1 + assert isinstance(result[0], ToolMessage) + assert result[0].content == "Tool result" + assert result[0].tool_call_id == "" + + +def test_convert_messages_to_langchain_format_mixed(): + """Test converting mixed message types.""" + messages = [ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "User message"}, + {"type": "bot", "content": "Bot response"}, + {"role": "tool", "content": "Tool output", "tool_call_id": "call_456"}, + ] + + result = _convert_messages_to_langchain_format(messages) + + assert len(result) == 4 + assert isinstance(result[0], SystemMessage) + assert isinstance(result[1], HumanMessage) + assert isinstance(result[2], AIMessage) + assert isinstance(result[3], ToolMessage) + assert result[3].tool_call_id == "call_456" + + +def test_convert_messages_to_langchain_format_unknown_type(): + """Test that unknown message types raise ValueError.""" + messages = [{"role": "unknown", "content": "Unknown message"}] + + with pytest.raises(ValueError, match="Unknown message type unknown"): + _convert_messages_to_langchain_format(messages) + + +def test_store_tool_calls(): + """Test storing tool calls from response.""" + mock_response = MagicMock() + test_tool_calls = [ + {"name": "another_func", "args": {}, "id": "call_789", "type": "tool_call"} + ] + mock_response.tool_calls = test_tool_calls + + _store_tool_calls(mock_response) + + assert tool_calls_var.get() == test_tool_calls + + +def test_store_tool_calls_no_tool_calls(): + """Test storing tool calls when response has no tool_calls attribute.""" + mock_response = MagicMock() + del mock_response.tool_calls + + _store_tool_calls(mock_response) + + assert tool_calls_var.get() is None + + +def test_extract_content_with_content_attr(): + """Test extracting content from response with content attribute.""" + mock_response = MagicMock() + mock_response.content = "Response content" + + result = _extract_content(mock_response) + + assert result == "Response content" + + +def test_extract_content_without_content_attr(): + """Test extracting content from response without content attribute.""" + mock_response = "Plain string response" + + result = _extract_content(mock_response) + + assert result == "Plain string response" + + +@pytest.mark.asyncio +async def test_llm_call_with_string_prompt(): + """Test llm_call with string prompt.""" + mock_llm = AsyncMock() + mock_response = MagicMock() + mock_response.content = "LLM response" + mock_llm.ainvoke.return_value = mock_response + + result = await llm_call(mock_llm, "Test prompt") + + assert result == "LLM response" + mock_llm.ainvoke.assert_called_once() + call_args = mock_llm.ainvoke.call_args + assert call_args[0][0] == "Test prompt" + + +@pytest.mark.asyncio +async def test_llm_call_with_message_list(): + """Test llm_call with message list.""" + mock_llm = AsyncMock() + mock_response = MagicMock() + mock_response.content = "LLM response" + mock_llm.ainvoke.return_value = mock_response + + messages = [{"role": "user", "content": "Hello"}] + result = await llm_call(mock_llm, messages) + + assert result == "LLM response" + mock_llm.ainvoke.assert_called_once() + call_args = mock_llm.ainvoke.call_args + assert len(call_args[0][0]) == 1 + assert isinstance(call_args[0][0][0], HumanMessage) + + +@pytest.mark.asyncio +async def test_llm_call_stores_tool_calls(): + """Test that llm_call stores tool calls from response.""" + mock_llm = AsyncMock() + mock_response = MagicMock() + mock_response.content = "Response with tools" + test_tool_calls = [ + {"name": "test", "args": {}, "id": "call_test", "type": "tool_call"} + ] + mock_response.tool_calls = test_tool_calls + mock_llm.ainvoke.return_value = mock_response + + result = await llm_call(mock_llm, "Test prompt") + + assert result == "Response with tools" + assert tool_calls_var.get() == test_tool_calls + + +def test_generation_response_tool_calls_field(): + """Test that GenerationResponse can store tool calls.""" + test_tool_calls = [ + {"name": "test_function", "args": {}, "id": "call_test", "type": "tool_call"} + ] + + response = GenerationResponse( + response=[{"role": "assistant", "content": "Hello"}], tool_calls=test_tool_calls + ) + + assert response.tool_calls == test_tool_calls diff --git a/tests/test_tool_calls_context.py b/tests/test_tool_calls_context.py new file mode 100644 index 000000000..e155946f4 --- /dev/null +++ b/tests/test_tool_calls_context.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemoguardrails.context import tool_calls_var + + +def test_tool_calls_var_default(): + """Test that tool_calls_var defaults to None.""" + assert tool_calls_var.get() is None + + +def test_tool_calls_var_set_and_get(): + """Test setting and getting tool calls from context.""" + test_tool_calls = [ + { + "name": "get_weather", + "args": {"location": "New York"}, + "id": "call_123", + "type": "tool_call", + }, + { + "name": "calculate", + "args": {"expression": "2+2"}, + "id": "call_456", + "type": "tool_call", + }, + ] + + tool_calls_var.set(test_tool_calls) + + result = tool_calls_var.get() + assert result == test_tool_calls + assert len(result) == 2 + assert result[0]["id"] == "call_123" + assert result[1]["name"] == "calculate" + + +def test_tool_calls_var_clear(): + """Test clearing tool calls from context.""" + test_tool_calls = [ + {"name": "test", "args": {}, "id": "call_test", "type": "tool_call"} + ] + + tool_calls_var.set(test_tool_calls) + assert tool_calls_var.get() == test_tool_calls + + tool_calls_var.set(None) + assert tool_calls_var.get() is None + + +def test_tool_calls_var_empty_list(): + """Test setting empty list of tool calls.""" + tool_calls_var.set([]) + + result = tool_calls_var.get() + assert result == [] + assert len(result) == 0