Skip to content

Commit 0f02fd1

Browse files
authored
Merge pull request #779 from NVIDIA/feat/passthrough-colang-2
Feat: support passthrough mode in colang 2
2 parents c038b08 + 5e5ea8f commit 0f02fd1

File tree

3 files changed

+187
-0
lines changed

3 files changed

+187
-0
lines changed

nemoguardrails/actions/v2_x/generation.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,19 @@
4848
get_element_from_head,
4949
get_event_from_element,
5050
)
51+
from nemoguardrails.context import (
52+
generation_options_var,
53+
llm_call_info_var,
54+
raw_llm_request,
55+
streaming_handler_var,
56+
)
5157
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
5258
from nemoguardrails.llm.filters import colang
5359
from nemoguardrails.llm.params import llm_params
5460
from nemoguardrails.llm.types import Task
5561
from nemoguardrails.logging import verbose
62+
from nemoguardrails.logging.explain import LLMCallInfo
63+
from nemoguardrails.rails.llm.options import GenerationOptions
5664
from nemoguardrails.utils import console, new_uuid
5765

5866
log = logging.getLogger(__name__)
@@ -390,6 +398,54 @@ async def generate_user_intent_and_bot_action(
390398
"bot_action": bot_action,
391399
}
392400

401+
@action(name="PassthroughLLMAction", is_system_action=True, execute_async=True)
402+
async def passthrough_llm_action(
403+
self,
404+
user_message: str,
405+
state: State,
406+
events: List[dict],
407+
llm: Optional[BaseLLM] = None,
408+
):
409+
event = get_last_user_utterance_event_v2_x(events)
410+
411+
# We check if we have a raw request. If the guardrails API is using
412+
# the `generate_events` API, this will not be set.
413+
raw_prompt = raw_llm_request.get()
414+
415+
if raw_prompt is None:
416+
prompt = event["final_transcript"]
417+
else:
418+
if isinstance(raw_prompt, str):
419+
# If we're in completion mode, we use directly the last $user_message
420+
# as it may have been altered by the input rails.
421+
prompt = event["final_transcript"]
422+
elif isinstance(raw_prompt, list):
423+
prompt = raw_prompt.copy()
424+
425+
# In this case, if the last message is from the user, we replace the text
426+
# just in case the input rails may have altered it.
427+
if prompt[-1]["role"] == "user":
428+
raw_prompt[-1]["content"] = event["final_transcript"]
429+
else:
430+
raise ValueError(f"Unsupported type for raw prompt: {type(raw_prompt)}")
431+
432+
# Initialize the LLMCallInfo object
433+
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
434+
435+
generation_options: GenerationOptions = generation_options_var.get()
436+
437+
with llm_params(
438+
llm,
439+
**((generation_options and generation_options.llm_params) or {}),
440+
):
441+
text = await llm_call(
442+
llm,
443+
user_message,
444+
custom_callback_handlers=[streaming_handler_var.get()],
445+
)
446+
447+
return text
448+
393449
@action(name="CheckValidFlowExistsAction", is_system_action=True)
394450
async def check_if_flow_exists(self, state: "State", flow_id: str) -> bool:
395451
"""Return True if a flow with the provided flow_id exists."""
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
2+
import llm
3+
4+
flow context free bot response generation on unhandled user intent
5+
"""Just make a call to LLM in passthrough mode"""
6+
7+
activate polling llm request response
8+
await _user_said_something_unexpected as $user_said
9+
$event = $user_said.event
10+
11+
# we need to wait for the automatic intent detection
12+
await unhandled user intent as $flow
13+
log 'unexpected user utterance: "{$event.final_transcript}"'
14+
15+
$user_message = $event.final_transcript
16+
17+
18+
log 'start generating bot response in passthrough mode...'
19+
$bot_message = await PassthroughLLMAction(user_message=$user_message)
20+
bot say $bot_message
21+
22+
@override
23+
flow llm continuation
24+
activate automating intent detection
25+
activate generating user intent for unhandled user utterance
26+
activate context free bot response generation on unhandled user intent

tests/v2_x/test_passthroug_mode.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import logging
16+
import unittest
17+
18+
from nemoguardrails import RailsConfig
19+
from tests.utils import TestChat
20+
21+
colang_content = '''
22+
import core
23+
import passthrough
24+
25+
flow main
26+
activate llm continuation
27+
activate greeting
28+
activate other reactions
29+
30+
flow greeting
31+
user expressed greeting
32+
bot say "Hello world!"
33+
34+
flow other reactions
35+
user expressed to be bored
36+
bot say "No problem!"
37+
38+
flow user expressed greeting
39+
""""User expressed greeting in any way or form."""
40+
user said "hi"
41+
42+
flow user expressed to be bored
43+
""""User expressed to be bored."""
44+
user said "This is boring"
45+
'''
46+
47+
yaml_content = """
48+
colang_version: "2.x"
49+
models:
50+
- type: main
51+
engine: openai
52+
model: gpt-3.5-turbo-instruct
53+
54+
"""
55+
56+
57+
config = RailsConfig.from_content(colang_content, yaml_content)
58+
59+
60+
class TestPassthroughLLMActionLogging(unittest.IsolatedAsyncioTestCase):
61+
def test_passthrough_llm_action_not_invoked_via_logs(self):
62+
chat = TestChat(
63+
config,
64+
llm_completions=["user expressed greeting"],
65+
)
66+
rails = chat.app
67+
68+
logger = logging.getLogger("nemoguardrails.colang.v2_x.runtime.statemachine")
69+
70+
with self.assertLogs(logger, level="INFO") as log:
71+
messages = [{"role": "user", "content": "hi"}]
72+
response = rails.generate(messages=messages)
73+
# Check that 'StartPassthroughLLMAction' is not in the logs
74+
passthrough_invoked = any(
75+
"PassthroughLLMActionFinished" in message for message in log.output
76+
)
77+
self.assertFalse(
78+
passthrough_invoked, "PassthroughLLMAction was invoked unexpectedly."
79+
)
80+
81+
self.assertIn("content", response)
82+
self.assertIsInstance(response["content"], str)
83+
84+
def test_passthrough_llm_action_invoked_via_logs(self):
85+
chat = TestChat(
86+
config,
87+
llm_completions=["user asked about capabilites", "a random text from llm"],
88+
)
89+
rails = chat.app
90+
91+
logger = logging.getLogger("nemoguardrails.colang.v2_x.runtime.statemachine")
92+
93+
with self.assertLogs(logger, level="INFO") as log:
94+
messages = [{"role": "user", "content": "What can you do?"}]
95+
response = rails.generate(messages=messages)
96+
# Check that 'StartPassthroughLLMAction' is in the logs
97+
passthrough_invoked = any(
98+
"StartPassthroughLLMAction" in message for message in log.output
99+
)
100+
self.assertTrue(
101+
passthrough_invoked, "PassthroughLLMAction was not invoked."
102+
)
103+
104+
self.assertIn("content", response)
105+
self.assertIsInstance(response["content"], str)

0 commit comments

Comments
 (0)