Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 106 additions & 43 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage

from nemoguardrails.colang.v2_x.lang.colang_ast import Flow
from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents
from nemoguardrails.context import llm_call_info_var, reasoning_trace_var
from nemoguardrails.context import (
llm_call_info_var,
reasoning_trace_var,
tool_calls_var,
)
from nemoguardrails.logging.callbacks import logging_callbacks
from nemoguardrails.logging.explain import LLMCallInfo

Expand Down Expand Up @@ -72,7 +76,22 @@ async def llm_call(
custom_callback_handlers: Optional[List[AsyncCallbackHandler]] = None,
) -> str:
"""Calls the LLM with a prompt and returns the generated text."""
# We initialize a new LLM call if we don't have one already
_setup_llm_call_info(llm, model_name, model_provider)
all_callbacks = _prepare_callbacks(custom_callback_handlers)

if isinstance(prompt, str):
response = await _invoke_with_string_prompt(llm, prompt, all_callbacks, stop)
else:
response = await _invoke_with_message_list(llm, prompt, all_callbacks, stop)

_store_tool_calls(response)
return _extract_content(response)


def _setup_llm_call_info(
llm: BaseLanguageModel, model_name: Optional[str], model_provider: Optional[str]
) -> None:
"""Initialize or update LLM call info in context."""
llm_call_info = llm_call_info_var.get()
if llm_call_info is None:
llm_call_info = LLMCallInfo()
Expand All @@ -81,52 +100,84 @@ async def llm_call(
llm_call_info.llm_model_name = model_name or _infer_model_name(llm)
llm_call_info.llm_provider_name = model_provider


def _prepare_callbacks(
custom_callback_handlers: Optional[List[AsyncCallbackHandler]],
) -> BaseCallbackManager:
"""Prepare callback manager with custom handlers if provided."""
if custom_callback_handlers and custom_callback_handlers != [None]:
all_callbacks = BaseCallbackManager(
return BaseCallbackManager(
handlers=logging_callbacks.handlers + custom_callback_handlers,
inheritable_handlers=logging_callbacks.handlers + custom_callback_handlers,
)
else:
all_callbacks = logging_callbacks
return logging_callbacks

if isinstance(prompt, str):
# stop sinks here
try:
result = await llm.agenerate_prompt(
[StringPromptValue(text=prompt)], callbacks=all_callbacks, stop=stop

async def _invoke_with_string_prompt(
llm: BaseLanguageModel,
prompt: str,
callbacks: BaseCallbackManager,
stop: Optional[List[str]],
):
"""Invoke LLM with string prompt."""
try:
return await llm.ainvoke(prompt, config={"callbacks": callbacks, "stop": stop})
except Exception as e:
raise LLMCallException(e)


async def _invoke_with_message_list(
llm: BaseLanguageModel,
prompt: List[dict],
callbacks: BaseCallbackManager,
stop: Optional[List[str]],
):
"""Invoke LLM with message list after converting to LangChain format."""
messages = _convert_messages_to_langchain_format(prompt)
try:
return await llm.ainvoke(
messages, config={"callbacks": callbacks, "stop": stop}
)
except Exception as e:
raise LLMCallException(e)


def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
"""Convert message list to LangChain message format."""
messages = []
for msg in prompt:
msg_type = msg["type"] if "type" in msg else msg["role"]

if msg_type == "user":
messages.append(HumanMessage(content=msg["content"]))
elif msg_type in ["bot", "assistant"]:
messages.append(AIMessage(content=msg["content"]))
elif msg_type == "system":
messages.append(SystemMessage(content=msg["content"]))
elif msg_type == "tool":
messages.append(
ToolMessage(
content=msg["content"],
tool_call_id=msg.get("tool_call_id", ""),
)
)
except Exception as e:
raise LLMCallException(e)
llm_call_info.raw_response = result.llm_output
else:
raise ValueError(f"Unknown message type {msg_type}")

# TODO: error handling
return result.generations[0][0].text
else:
# We first need to translate the array of messages into LangChain message format
messages = []
for _msg in prompt:
msg_type = _msg["type"] if "type" in _msg else _msg["role"]
if msg_type == "user":
messages.append(HumanMessage(content=_msg["content"]))
elif msg_type in ["bot", "assistant"]:
messages.append(AIMessage(content=_msg["content"]))
elif msg_type == "system":
messages.append(SystemMessage(content=_msg["content"]))
else:
# TODO: add support for tool-related messages
raise ValueError(f"Unknown message type {msg_type}")
return messages

try:
result = await llm.agenerate_prompt(
[ChatPromptValue(messages=messages)], callbacks=all_callbacks, stop=stop
)

except Exception as e:
raise LLMCallException(e)
def _store_tool_calls(response) -> None:
"""Extract and store tool calls from response in context."""
tool_calls = getattr(response, "tool_calls", None)
tool_calls_var.set(tool_calls)

llm_call_info.raw_response = result.llm_output

return result.generations[0][0].text
def _extract_content(response) -> str:
"""Extract text content from response."""
if hasattr(response, "content"):
return response.content
return str(response)


def get_colang_history(
Expand Down Expand Up @@ -175,15 +226,15 @@ def get_colang_history(
history += f'user "{event["text"]}"\n'
elif event["type"] == "UserIntent":
if include_texts:
history += f' {event["intent"]}\n'
history += f" {event['intent']}\n"
else:
history += f'user {event["intent"]}\n'
history += f"user {event['intent']}\n"
elif event["type"] == "BotIntent":
# If we have instructions, we add them before the bot message.
# But we only do that for the last bot message.
if "instructions" in event and idx == last_bot_intent_idx:
history += f"# {event['instructions']}\n"
history += f'bot {event["intent"]}\n'
history += f"bot {event['intent']}\n"
elif event["type"] == "StartUtteranceBotAction" and include_texts:
history += f' "{event["script"]}"\n'
# We skip system actions from this log
Expand Down Expand Up @@ -352,9 +403,9 @@ def flow_to_colang(flow: Union[dict, Flow]) -> str:
if "_type" not in element:
raise Exception("bla")
if element["_type"] == "UserIntent":
colang_flow += f'user {element["intent_name"]}\n'
colang_flow += f"user {element['intent_name']}\n"
elif element["_type"] == "run_action" and element["action_name"] == "utter":
colang_flow += f'bot {element["action_params"]["value"]}\n'
colang_flow += f"bot {element['action_params']['value']}\n"

return colang_flow

Expand Down Expand Up @@ -592,3 +643,15 @@ def get_and_clear_reasoning_trace_contextvar() -> Optional[str]:
reasoning_trace_var.set(None)
return reasoning_trace
return None


def get_and_clear_tool_calls_contextvar() -> Optional[list]:
"""Get the current tool calls and clear them from the context.

Returns:
Optional[list]: The tool calls if they exist, None otherwise.
"""
if tool_calls := tool_calls_var.get():
tool_calls_var.set(None)
return tool_calls
return None
5 changes: 5 additions & 0 deletions nemoguardrails/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@
reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"reasoning_trace", default=None
)

# The tool calls from the current LLM response.
tool_calls_var: contextvars.ContextVar[Optional[list]] = contextvars.ContextVar(
"tool_calls", default=None
)
8 changes: 8 additions & 0 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from nemoguardrails.actions.llm.generation import LLMGenerationActions
from nemoguardrails.actions.llm.utils import (
get_and_clear_reasoning_trace_contextvar,
get_and_clear_tool_calls_contextvar,
get_colang_history,
)
from nemoguardrails.actions.output_mapping import is_output_blocked
Expand Down Expand Up @@ -1084,6 +1085,8 @@ async def generate_async(
options.log.llm_calls = True
options.log.internal_events = True

tool_calls = get_and_clear_tool_calls_contextvar()

# If we have generation options, we prepare a GenerationResponse instance.
if options:
# If a prompt was used, we only need to return the content of the message.
Expand All @@ -1100,6 +1103,9 @@ async def generate_async(
reasoning_trace + res.response[0]["content"]
)

if tool_calls:
res.tool_calls = tool_calls

if self.config.colang_version == "1.0":
# If output variables are specified, we extract their values
if options.output_vars:
Expand Down Expand Up @@ -1228,6 +1234,8 @@ async def generate_async(
if prompt:
return new_message["content"]
else:
if tool_calls:
new_message["tool_calls"] = tool_calls
return new_message

def stream_async(
Expand Down
4 changes: 4 additions & 0 deletions nemoguardrails/rails/llm/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,10 @@ class GenerationResponse(BaseModel):
default=None,
description="A state object which can be used in subsequent calls to continue the interaction.",
)
tool_calls: Optional[list] = Field(
default=None,
description="Tool calls extracted from the LLM response, if any.",
)


if __name__ == "__main__":
Expand Down
85 changes: 84 additions & 1 deletion tests/rails/llm/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

import pytest

from nemoguardrails.rails.llm.options import GenerationOptions, GenerationRailsOptions
from nemoguardrails.rails.llm.options import (
GenerationOptions,
GenerationRailsOptions,
GenerationResponse,
)


def test_generation_options_initialization():
Expand Down Expand Up @@ -110,3 +114,82 @@ def test_generation_options_serialization():
assert '"output":false' in options_json
assert '"activated_rails":true' in options_json
assert '"llm_calls":true' in options_json


def test_generation_response_initialization():
"""Test GenerationResponse initialization."""
response = GenerationResponse(response="Hello, world!")
assert response.response == "Hello, world!"
assert response.tool_calls is None
assert response.llm_output is None
assert response.state is None


def test_generation_response_with_tool_calls():
test_tool_calls = [
{
"name": "get_weather",
"args": {"location": "NYC"},
"id": "call_123",
"type": "tool_call",
},
{
"name": "calculate",
"args": {"expression": "2+2"},
"id": "call_456",
"type": "tool_call",
},
]

response = GenerationResponse(
response=[{"role": "assistant", "content": "I'll help you with that."}],
tool_calls=test_tool_calls,
)

assert response.tool_calls == test_tool_calls
assert len(response.tool_calls) == 2
assert response.tool_calls[0]["id"] == "call_123"
assert response.tool_calls[1]["name"] == "calculate"


def test_generation_response_empty_tool_calls():
"""Test GenerationResponse with empty tool calls list."""
response = GenerationResponse(response="No tools needed", tool_calls=[])

assert response.tool_calls == []
assert len(response.tool_calls) == 0


def test_generation_response_serialization_with_tool_calls():
test_tool_calls = [
{"name": "test_func", "args": {}, "id": "call_test", "type": "tool_call"}
]

response = GenerationResponse(response="Response text", tool_calls=test_tool_calls)

response_dict = response.dict()
assert "tool_calls" in response_dict
assert response_dict["tool_calls"] == test_tool_calls

response_json = response.json()
assert "tool_calls" in response_json
assert "test_func" in response_json


def test_generation_response_model_validation():
"""Test GenerationResponse model validation."""
test_tool_calls = [
{"name": "valid_function", "args": {}, "id": "call_123", "type": "tool_call"},
{"name": "another_function", "args": {}, "id": "call_456", "type": "tool_call"},
]

response = GenerationResponse(
response=[{"role": "assistant", "content": "Test response"}],
tool_calls=test_tool_calls,
llm_output={"token_usage": {"total_tokens": 50}},
)

assert response.tool_calls is not None
assert isinstance(response.tool_calls, list)
assert len(response.tool_calls) == 2
assert response.llm_output["token_usage"]["total_tokens"] == 50
Loading