Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
generation_options_var,
llm_stats_var,
raw_llm_request,
reasoning_trace_var,
streaming_handler_var,
)
from nemoguardrails.embeddings.index import EmbeddingsIndex
Expand Down Expand Up @@ -576,6 +575,20 @@ def _get_events_for_messages(self, messages: List[dict], state: Any):

return events

@staticmethod
def _ensure_explain_info() -> ExplainInfo:
"""Ensure that the ExplainInfo variable is present in the current context

Returns:
A ExplainInfo class containing the llm calls' statistics
"""
explain_info = explain_info_var.get()
if explain_info is None:
explain_info = ExplainInfo()
explain_info_var.set(explain_info)

return explain_info

async def generate_async(
self,
prompt: Optional[str] = None,
Expand Down Expand Up @@ -634,14 +647,7 @@ async def generate_async(
# Initialize the object with additional explanation information.
# We allow this to also be set externally. This is useful when multiple parallel
# requests are made.
explain_info = explain_info_var.get()
if explain_info is None:
explain_info = ExplainInfo()
explain_info_var.set(explain_info)

# We also keep a general reference to this object
self.explain_info = explain_info
self.explain_info = explain_info
self.explain_info = self._ensure_explain_info()

if prompt is not None:
# Currently, we transform the prompt request into a single turn conversation
Expand Down Expand Up @@ -805,9 +811,11 @@ async def generate_async(

# If logging is enabled, we log the conversation
# TODO: add support for logging flag
explain_info.colang_history = get_colang_history(events)
self.explain_info.colang_history = get_colang_history(events)
if self.verbose:
log.info(f"Conversation history so far: \n{explain_info.colang_history}")
log.info(
f"Conversation history so far: \n{self.explain_info.colang_history}"
)

total_time = time.time() - t0
log.info(
Expand Down Expand Up @@ -960,6 +968,8 @@ def stream_async(
include_generation_metadata: Optional[bool] = False,
) -> AsyncIterator[str]:
"""Simplified interface for getting directly the streamed tokens from the LLM."""
self.explain_info = self._ensure_explain_info()

streaming_handler = StreamingHandler(
include_generation_metadata=include_generation_metadata
)
Expand Down Expand Up @@ -1278,13 +1288,6 @@ def _prepare_params(
**action_params,
}

def _update_explain_info():
explain_info = explain_info_var.get()
if explain_info is None:
explain_info = ExplainInfo()
explain_info_var.set(explain_info)
self.explain_info = explain_info

output_rails_streaming_config = self.config.rails.output.streaming
buffer_strategy = get_buffer_strategy(output_rails_streaming_config)
output_rails_flows_id = self.config.rails.output.flows
Expand Down Expand Up @@ -1329,7 +1332,7 @@ def _update_explain_info():
action_name, params
)
# Include explain info (whatever _update_explain_info does)
_update_explain_info()
self.explain_info = self._ensure_explain_info()

# Retrieve the action function from the dispatcher
action_func = self.runtime.action_dispatcher.get_action(action_name)
Expand Down
117 changes: 117 additions & 0 deletions tests/test_llm_rails_context_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio

import pytest

from nemoguardrails import RailsConfig
from tests.utils import TestChat


@pytest.mark.asyncio
async def test_1():
config = RailsConfig.from_content(
"""
define user express greeting
"hello"

define flow
user express greeting
bot express greeting
"""
)
chat = TestChat(
config,
llm_completions=[
"express greeting",
"Hello! I'm doing great, thank you. How can I assist you today?",
],
)

new_messages = await chat.app.generate_async(
messages=[{"role": "user", "content": "hi, how are you"}]
)

assert new_messages == {
"content": "Hello! I'm doing great, thank you. How can I assist you today?",
"role": "assistant",
}, "message content do not match"

# note that 2 llm call are expected as we matched the bot intent
assert (
len(chat.app.explain().llm_calls) == 2
), "number of llm call not as expected. Expected 2, found {}".format(
len(chat.app.explain().llm_calls)
)


@pytest.mark.asyncio
async def test_2():
config = RailsConfig.from_content(
config={
"models": [],
"rails": {
"output": {
# run the real self check output rails
"flows": {"self check output"},
"streaming": {
"enabled": True,
"chunk_size": 4,
"context_size": 2,
"stream_first": False,
},
}
},
"streaming": False,
"prompts": [{"task": "self_check_output", "content": "a test template"}],
},
colang_content="""
define user express greeting
"hi"

define flow
user express greeting
bot tell joke
""",
)

llm_completions = [
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
' "This is a joke that should be blocked."',
# add as many `no`` as chunks you want the output stream to check
"No",
"No",
"Yes",
]

chat = TestChat(
config,
llm_completions=llm_completions,
streaming=True,
)
chunks = []
async for chunk in chat.app.stream_async(
messages=[{"role": "user", "content": "Hi!"}],
):
chunks.append(chunk)

# note that 6 llm call are expected as we matched the bot intent
assert (
len(chat.app.explain().llm_calls) == 5
), "number of llm call not as expected. Expected 5, found {}".format(
len(chat.app.explain().llm_calls)
)

await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})