@@ -152,3 +152,132 @@ async def test_parallel_rails_output_fail_2():
152152 and result .response [0 ]["content" ]
153153 == "I cannot express a term in the bot answer."
154154 )
155+
156+
157+ @pytest .mark .asyncio
158+ async def test_parallel_rails_input_stop_flag ():
159+ config = RailsConfig .from_path (os .path .join (CONFIGS_FOLDER , "parallel_rails" ))
160+ chat = TestChat (
161+ config ,
162+ llm_completions = [
163+ "No" ,
164+ "Hi there! How can I assist you with questions about the ABC Company today?" ,
165+ "No" ,
166+ ],
167+ )
168+
169+ chat >> "hi, this is a blocked term."
170+ result = await chat .app .generate_async (messages = chat .history , options = OPTIONS )
171+
172+ stopped_rails = [rail for rail in result .log .activated_rails if rail .stop ]
173+ assert len (stopped_rails ) == 1 , "Expected exactly one stopped rail"
174+ assert (
175+ "check blocked input terms" in stopped_rails [0 ].name
176+ ), f"Expected 'check blocked input terms' rail to be stopped, got { stopped_rails [0 ].name } "
177+
178+
179+ @pytest .mark .asyncio
180+ async def test_parallel_rails_output_stop_flag ():
181+ config = RailsConfig .from_path (os .path .join (CONFIGS_FOLDER , "parallel_rails" ))
182+ chat = TestChat (
183+ config ,
184+ llm_completions = [
185+ "No" ,
186+ "Hi there! This is a blocked term!" ,
187+ "No" ,
188+ ],
189+ )
190+
191+ chat >> "hi!"
192+ result = await chat .app .generate_async (messages = chat .history , options = OPTIONS )
193+
194+ stopped_rails = [rail for rail in result .log .activated_rails if rail .stop ]
195+ assert len (stopped_rails ) == 1 , "Expected exactly one stopped rail"
196+ assert (
197+ "check blocked output terms" in stopped_rails [0 ].name
198+ ), f"Expected 'check blocked output terms' rail to be stopped, got { stopped_rails [0 ].name } "
199+
200+
201+ @pytest .mark .asyncio
202+ async def test_parallel_rails_client_code_pattern ():
203+ config = RailsConfig .from_path (os .path .join (CONFIGS_FOLDER , "parallel_rails" ))
204+ chat = TestChat (
205+ config ,
206+ llm_completions = [
207+ "No" ,
208+ "Hi there! This is a blocked term!" ,
209+ "No" ,
210+ ],
211+ )
212+
213+ chat >> "hi!"
214+ result = await chat .app .generate_async (messages = chat .history , options = OPTIONS )
215+
216+ activated_rails = result .log .activated_rails if result .log else None
217+ assert activated_rails is not None , "Expected activated_rails to be present"
218+
219+ rails_to_check = [
220+ "self check output" ,
221+ "check blocked output terms $duration=1.0" ,
222+ ]
223+ rails_set = set (rails_to_check )
224+
225+ stopping_rails = [rail for rail in activated_rails if rail .stop ]
226+
227+ assert len (stopping_rails ) > 0 , "Expected at least one stopping rail"
228+
229+ blocked_rails = []
230+ for rail in stopping_rails :
231+ if rail .name in rails_set :
232+ blocked_rails .append (rail .name )
233+
234+ assert (
235+ len (blocked_rails ) == 1
236+ ), f"Expected exactly one blocked rail from our check list, got { len (blocked_rails )} : { blocked_rails } "
237+ assert (
238+ "check blocked output terms $duration=1.0" in blocked_rails
239+ ), f"Expected 'check blocked output terms $duration=1.0' to be blocked, got { blocked_rails } "
240+
241+ for rail in activated_rails :
242+ if (
243+ rail .name in rails_set
244+ and rail .name != "check blocked output terms $duration=1.0"
245+ ):
246+ assert (
247+ not rail .stop
248+ ), f"Non-blocked rail { rail .name } should not have stop=True"
249+
250+
251+ @pytest .mark .asyncio
252+ async def test_parallel_rails_multiple_activated_rails ():
253+ config = RailsConfig .from_path (os .path .join (CONFIGS_FOLDER , "parallel_rails" ))
254+ chat = TestChat (
255+ config ,
256+ llm_completions = [
257+ "No" ,
258+ "Hi there! This is a blocked term!" ,
259+ "No" ,
260+ ],
261+ )
262+
263+ chat >> "hi!"
264+ result = await chat .app .generate_async (messages = chat .history , options = OPTIONS )
265+
266+ activated_rails = result .log .activated_rails if result .log else None
267+ assert activated_rails is not None , "Expected activated_rails to be present"
268+ assert len (activated_rails ) > 1 , (
269+ f"Expected multiple activated_rails, got { len (activated_rails )} : "
270+ f"{ [rail .name for rail in activated_rails ]} "
271+ )
272+
273+ stopped_rails = [rail for rail in activated_rails if rail .stop ]
274+ assert len (stopped_rails ) == 1 , (
275+ f"Expected exactly one stopped rail, got { len (stopped_rails )} : "
276+ f"{ [rail .name for rail in stopped_rails ]} "
277+ )
278+
279+ rails_with_stop_true = [rail for rail in activated_rails if rail .stop is True ]
280+ assert len (rails_with_stop_true ) == 1 , (
281+ f"Expected exactly one rail with stop=True, got { len (rails_with_stop_true )} : "
282+ f"{ [rail .name for rail in rails_with_stop_true ]} "
283+ )
0 commit comments