Skip to content

Commit 26c8264

Browse files
Pouyanpitgasser-nv
authored andcommitted
fix(runtime): ensure stop flag is set for policy violations in parallel rails (#1467)
* fix(runtime): ensure stop flag is set for policy violations in parallel rails * Update nemoguardrails/colang/v1_0/runtime/runtime.py
1 parent f87618f commit 26c8264

File tree

2 files changed

+153
-9
lines changed

2 files changed

+153
-9
lines changed

nemoguardrails/colang/v1_0/runtime/runtime.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,13 @@ async def _run_flows_in_parallel(
311311
# Wrapper function to help reverse map the task result to the flow ID
312312
async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
313313
result = await func(*args, **kwargs)
314-
if post_event:
314+
315+
has_stop = any(
316+
event["type"] == "BotIntent" and event["intent"] == "stop"
317+
for event in result
318+
)
319+
320+
if post_event and not has_stop:
315321
result.append(post_event)
316322
args[1].append(
317323
{"type": "event", "timestamp": time(), "data": post_event}
@@ -361,6 +367,7 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
361367
unique_flow_ids[flow_uid] = task
362368

363369
stopped_task_results: List[dict] = []
370+
stopped_task_processing_logs: List[dict] = []
364371

365372
# Process tasks as they complete using as_completed
366373
try:
@@ -377,6 +384,9 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
377384
# If this flow had a stop event
378385
if has_stop:
379386
stopped_task_results = task_results[flow_id] + result
387+
stopped_task_processing_logs = task_processing_logs[
388+
flow_id
389+
].copy()
380390

381391
# Cancel all remaining tasks
382392
for pending_task in tasks:
@@ -433,14 +443,19 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
433443
finished_task_processing_logs.extend(task_processing_logs[flow_id])
434444

435445
if processing_log:
436-
for plog in finished_task_processing_logs:
437-
# Filter out "Listen" and "start_flow" events from task processing log
438-
if plog["type"] == "event" and (
439-
plog["data"]["type"] == "Listen"
440-
or plog["data"]["type"] == "start_flow"
441-
):
442-
continue
443-
processing_log.append(plog)
446+
447+
def filter_and_append(logs, target_log):
448+
for plog in logs:
449+
# Filter out "Listen" and "start_flow" events from task processing log
450+
if plog["type"] == "event" and (
451+
plog["data"]["type"] == "Listen"
452+
or plog["data"]["type"] == "start_flow"
453+
):
454+
continue
455+
target_log.append(plog)
456+
457+
filter_and_append(stopped_task_processing_logs, processing_log)
458+
filter_and_append(finished_task_processing_logs, processing_log)
444459

445460
# We pack all events into a single event to add it to the event history.
446461
history_events = new_event_dict(

tests/test_parallel_rails.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,132 @@ async def test_parallel_rails_output_fail_2():
152152
and result.response[0]["content"]
153153
== "I cannot express a term in the bot answer."
154154
)
155+
156+
157+
@pytest.mark.asyncio
158+
async def test_parallel_rails_input_stop_flag():
159+
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
160+
chat = TestChat(
161+
config,
162+
llm_completions=[
163+
"No",
164+
"Hi there! How can I assist you with questions about the ABC Company today?",
165+
"No",
166+
],
167+
)
168+
169+
chat >> "hi, this is a blocked term."
170+
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
171+
172+
stopped_rails = [rail for rail in result.log.activated_rails if rail.stop]
173+
assert len(stopped_rails) == 1, "Expected exactly one stopped rail"
174+
assert (
175+
"check blocked input terms" in stopped_rails[0].name
176+
), f"Expected 'check blocked input terms' rail to be stopped, got {stopped_rails[0].name}"
177+
178+
179+
@pytest.mark.asyncio
180+
async def test_parallel_rails_output_stop_flag():
181+
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
182+
chat = TestChat(
183+
config,
184+
llm_completions=[
185+
"No",
186+
"Hi there! This is a blocked term!",
187+
"No",
188+
],
189+
)
190+
191+
chat >> "hi!"
192+
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
193+
194+
stopped_rails = [rail for rail in result.log.activated_rails if rail.stop]
195+
assert len(stopped_rails) == 1, "Expected exactly one stopped rail"
196+
assert (
197+
"check blocked output terms" in stopped_rails[0].name
198+
), f"Expected 'check blocked output terms' rail to be stopped, got {stopped_rails[0].name}"
199+
200+
201+
@pytest.mark.asyncio
202+
async def test_parallel_rails_client_code_pattern():
203+
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
204+
chat = TestChat(
205+
config,
206+
llm_completions=[
207+
"No",
208+
"Hi there! This is a blocked term!",
209+
"No",
210+
],
211+
)
212+
213+
chat >> "hi!"
214+
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
215+
216+
activated_rails = result.log.activated_rails if result.log else None
217+
assert activated_rails is not None, "Expected activated_rails to be present"
218+
219+
rails_to_check = [
220+
"self check output",
221+
"check blocked output terms $duration=1.0",
222+
]
223+
rails_set = set(rails_to_check)
224+
225+
stopping_rails = [rail for rail in activated_rails if rail.stop]
226+
227+
assert len(stopping_rails) > 0, "Expected at least one stopping rail"
228+
229+
blocked_rails = []
230+
for rail in stopping_rails:
231+
if rail.name in rails_set:
232+
blocked_rails.append(rail.name)
233+
234+
assert (
235+
len(blocked_rails) == 1
236+
), f"Expected exactly one blocked rail from our check list, got {len(blocked_rails)}: {blocked_rails}"
237+
assert (
238+
"check blocked output terms $duration=1.0" in blocked_rails
239+
), f"Expected 'check blocked output terms $duration=1.0' to be blocked, got {blocked_rails}"
240+
241+
for rail in activated_rails:
242+
if (
243+
rail.name in rails_set
244+
and rail.name != "check blocked output terms $duration=1.0"
245+
):
246+
assert (
247+
not rail.stop
248+
), f"Non-blocked rail {rail.name} should not have stop=True"
249+
250+
251+
@pytest.mark.asyncio
252+
async def test_parallel_rails_multiple_activated_rails():
253+
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
254+
chat = TestChat(
255+
config,
256+
llm_completions=[
257+
"No",
258+
"Hi there! This is a blocked term!",
259+
"No",
260+
],
261+
)
262+
263+
chat >> "hi!"
264+
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
265+
266+
activated_rails = result.log.activated_rails if result.log else None
267+
assert activated_rails is not None, "Expected activated_rails to be present"
268+
assert len(activated_rails) > 1, (
269+
f"Expected multiple activated_rails, got {len(activated_rails)}: "
270+
f"{[rail.name for rail in activated_rails]}"
271+
)
272+
273+
stopped_rails = [rail for rail in activated_rails if rail.stop]
274+
assert len(stopped_rails) == 1, (
275+
f"Expected exactly one stopped rail, got {len(stopped_rails)}: "
276+
f"{[rail.name for rail in stopped_rails]}"
277+
)
278+
279+
rails_with_stop_true = [rail for rail in activated_rails if rail.stop is True]
280+
assert len(rails_with_stop_true) == 1, (
281+
f"Expected exactly one rail with stop=True, got {len(rails_with_stop_true)}: "
282+
f"{[rail.name for rail in rails_with_stop_true]}"
283+
)

0 commit comments

Comments
 (0)