Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions nemoguardrails/colang/v1_0/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,13 @@ async def _run_flows_in_parallel(
# Wrapper function to help reverse map the task result to the flow ID
async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
result = await func(*args, **kwargs)
if post_event:

has_stop = any(
event["type"] == "BotIntent" and event["intent"] == "stop"
for event in result
)

if post_event and not has_stop:
result.append(post_event)
args[1].append(
{"type": "event", "timestamp": time(), "data": post_event}
Expand Down Expand Up @@ -361,6 +367,7 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
unique_flow_ids[flow_uid] = task

stopped_task_results: List[dict] = []
stopped_task_processing_logs: List[dict] = []

# Process tasks as they complete using as_completed
try:
Expand All @@ -377,6 +384,9 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
# If this flow had a stop event
if has_stop:
stopped_task_results = task_results[flow_id] + result
stopped_task_processing_logs = task_processing_logs[
flow_id
].copy()

# Cancel all remaining tasks
for pending_task in tasks:
Expand Down Expand Up @@ -433,14 +443,19 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
finished_task_processing_logs.extend(task_processing_logs[flow_id])

if processing_log:
for plog in finished_task_processing_logs:
# Filter out "Listen" and "start_flow" events from task processing log
if plog["type"] == "event" and (
plog["data"]["type"] == "Listen"
or plog["data"]["type"] == "start_flow"
):
continue
processing_log.append(plog)

def filter_and_append(logs, target_log):
for plog in logs:
# Filter out "Listen" and "start_flow" events from task processing log
if plog["type"] == "event" and (
plog["data"]["type"] == "Listen"
or plog["data"]["type"] == "start_flow"
):
continue
target_log.append(plog)

filter_and_append(stopped_task_processing_logs, processing_log)
filter_and_append(finished_task_processing_logs, processing_log)

# We pack all events into a single event to add it to the event history.
history_events = new_event_dict(
Expand Down
129 changes: 129 additions & 0 deletions tests/test_parallel_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,132 @@ async def test_parallel_rails_output_fail_2():
and result.response[0]["content"]
== "I cannot express a term in the bot answer."
)


@pytest.mark.asyncio
async def test_parallel_rails_input_stop_flag():
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
chat = TestChat(
config,
llm_completions=[
"No",
"Hi there! How can I assist you with questions about the ABC Company today?",
"No",
],
)

chat >> "hi, this is a blocked term."
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)

stopped_rails = [rail for rail in result.log.activated_rails if rail.stop]
assert len(stopped_rails) == 1, "Expected exactly one stopped rail"
assert (
"check blocked input terms" in stopped_rails[0].name
), f"Expected 'check blocked input terms' rail to be stopped, got {stopped_rails[0].name}"


@pytest.mark.asyncio
async def test_parallel_rails_output_stop_flag():
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
chat = TestChat(
config,
llm_completions=[
"No",
"Hi there! This is a blocked term!",
"No",
],
)

chat >> "hi!"
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)

stopped_rails = [rail for rail in result.log.activated_rails if rail.stop]
assert len(stopped_rails) == 1, "Expected exactly one stopped rail"
assert (
"check blocked output terms" in stopped_rails[0].name
), f"Expected 'check blocked output terms' rail to be stopped, got {stopped_rails[0].name}"


@pytest.mark.asyncio
async def test_parallel_rails_client_code_pattern():
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
chat = TestChat(
config,
llm_completions=[
"No",
"Hi there! This is a blocked term!",
"No",
],
)

chat >> "hi!"
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)

activated_rails = result.log.activated_rails if result.log else None
assert activated_rails is not None, "Expected activated_rails to be present"

rails_to_check = [
"self check output",
"check blocked output terms $duration=1.0",
]
rails_set = set(rails_to_check)

stopping_rails = [rail for rail in activated_rails if rail.stop]

assert len(stopping_rails) > 0, "Expected at least one stopping rail"

blocked_rails = []
for rail in stopping_rails:
if rail.name in rails_set:
blocked_rails.append(rail.name)

assert (
len(blocked_rails) == 1
), f"Expected exactly one blocked rail from our check list, got {len(blocked_rails)}: {blocked_rails}"
assert (
"check blocked output terms $duration=1.0" in blocked_rails
), f"Expected 'check blocked output terms $duration=1.0' to be blocked, got {blocked_rails}"

for rail in activated_rails:
if (
rail.name in rails_set
and rail.name != "check blocked output terms $duration=1.0"
):
assert (
not rail.stop
), f"Non-blocked rail {rail.name} should not have stop=True"


@pytest.mark.asyncio
async def test_parallel_rails_multiple_activated_rails():
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
chat = TestChat(
config,
llm_completions=[
"No",
"Hi there! This is a blocked term!",
"No",
],
)

chat >> "hi!"
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)

activated_rails = result.log.activated_rails if result.log else None
assert activated_rails is not None, "Expected activated_rails to be present"
assert len(activated_rails) > 1, (
f"Expected multiple activated_rails, got {len(activated_rails)}: "
f"{[rail.name for rail in activated_rails]}"
)

stopped_rails = [rail for rail in activated_rails if rail.stop]
assert len(stopped_rails) == 1, (
f"Expected exactly one stopped rail, got {len(stopped_rails)}: "
f"{[rail.name for rail in stopped_rails]}"
)

rails_with_stop_true = [rail for rail in activated_rails if rail.stop is True]
assert len(rails_with_stop_true) == 1, (
f"Expected exactly one rail with stop=True, got {len(rails_with_stop_true)}: "
f"{[rail.name for rail in rails_with_stop_true]}"
)