Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,18 @@ async def generate_async(
new_message["tool_calls"] = tool_calls
return new_message

def _validate_streaming_with_output_rails(self) -> None:
if len(self.config.rails.output.flows) > 0 and (
not self.config.rails.output.streaming
or not self.config.rails.output.streaming.enabled
):
raise ValueError(
"stream_async() cannot be used when output rails are configured but "
"rails.output.streaming.enabled is False. Either set "
"rails.output.streaming.enabled to True in your configuration, or use "
"generate_async() instead of stream_async()."
)

def stream_async(
self,
prompt: Optional[str] = None,
Expand All @@ -1259,6 +1271,7 @@ def stream_async(
) -> AsyncIterator[str]:
"""Simplified interface for getting directly the streamed tokens from the LLM."""

self._validate_streaming_with_output_rails()
# if an external generator is provided, use it directly
if generator:
if (
Expand Down
27 changes: 15 additions & 12 deletions tests/test_parallel_streaming_output_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,21 +605,24 @@ async def test_parallel_streaming_output_rails_performance_benefits():
async def test_parallel_streaming_output_rails_default_config_behavior(
parallel_output_rails_default_config,
):
"""Tests parallel output rails with default streaming configuration"""
"""Tests that stream_async raises an error with default config (no explicit streaming config)"""

llm_completions = [
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
' "This is a test message with default streaming config."',
]
from nemoguardrails import LLMRails

chunks = await run_parallel_self_check_test(
parallel_output_rails_default_config, llm_completions
)
llmrails = LLMRails(parallel_output_rails_default_config)

response = "".join(chunks)
assert len(response) > 0
assert len(chunks) > 0
assert "test message" in response
with pytest.raises(ValueError) as exc_info:
async for chunk in llmrails.stream_async(
messages=[{"role": "user", "content": "Hi!"}]
):
pass

assert str(exc_info.value) == (
"stream_async() cannot be used when output rails are configured but "
"rails.output.streaming.enabled is False. Either set "
"rails.output.streaming.enabled to True in your configuration, or use "
"generate_async() instead of stream_async()."
)

await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})

Expand Down
89 changes: 89 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,95 @@ def _calculate_number_of_actions(input_length, chunk_size, context_size):
return math.ceil((input_length - context_size) / (chunk_size - context_size))


@pytest.mark.asyncio
async def test_streaming_with_output_rails_disabled_raises_error():
config = RailsConfig.from_content(
config={
"models": [],
"rails": {
"output": {
"flows": {"self check output"},
"streaming": {
"enabled": False,
},
}
},
"streaming": True,
"prompts": [{"task": "self_check_output", "content": "a test template"}],
},
colang_content="""
define user express greeting
"hi"

define flow
user express greeting
bot tell joke
""",
)

chat = TestChat(
config,
llm_completions=[],
streaming=True,
)

with pytest.raises(ValueError) as exc_info:
async for chunk in chat.app.stream_async(
messages=[{"role": "user", "content": "Hi!"}],
):
pass

assert str(exc_info.value) == (
"stream_async() cannot be used when output rails are configured but "
"rails.output.streaming.enabled is False. Either set "
"rails.output.streaming.enabled to True in your configuration, or use "
"generate_async() instead of stream_async()."
)


@pytest.mark.asyncio
async def test_streaming_with_output_rails_no_streaming_config_raises_error():
config = RailsConfig.from_content(
config={
"models": [],
"rails": {
"output": {
"flows": {"self check output"},
}
},
"streaming": True,
"prompts": [{"task": "self_check_output", "content": "a test template"}],
},
colang_content="""
define user express greeting
"hi"

define flow
user express greeting
bot tell joke
""",
)

chat = TestChat(
config,
llm_completions=[],
streaming=True,
)

with pytest.raises(ValueError) as exc_info:
async for chunk in chat.app.stream_async(
messages=[{"role": "user", "content": "Hi!"}],
):
pass

assert str(exc_info.value) == (
"stream_async() cannot be used when output rails are configured but "
"rails.output.streaming.enabled is False. Either set "
"rails.output.streaming.enabled to True in your configuration, or use "
"generate_async() instead of stream_async()."
)


@pytest.mark.asyncio
async def test_streaming_error_handling():
"""Test that errors during streaming are properly formatted and returned."""
Expand Down
69 changes: 13 additions & 56 deletions tests/test_streaming_output_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,6 @@ def output_rails_streaming_config_default():
)


@pytest.mark.asyncio
async def test_stream_async_streaming_disabled(output_rails_streaming_config_default):
"""Tests if stream_async returns a StreamingHandler instance when streaming is disabled"""

llmrails = LLMRails(output_rails_streaming_config_default)

result = llmrails.stream_async(prompt="test")
assert isinstance(
result, StreamingHandler
), "Expected StreamingHandler instance when streaming is disabled"


@pytest.mark.asyncio
async def test_stream_async_streaming_enabled(output_rails_streaming_config):
"""Tests if stream_async returns does not return StreamingHandler instance when streaming is enabled"""
Expand Down Expand Up @@ -175,33 +163,23 @@ async def test_streaming_output_rails_blocked_explicit(output_rails_streaming_co
async def test_streaming_output_rails_blocked_default_config(
output_rails_streaming_config_default,
):
"""Tests if output rails streaming default config do not block content with BLOCK keyword"""
"""Tests that stream_async raises an error with default config (output rails without explicit streaming config)"""

# text with a BLOCK keyword
llm_completions = [
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
' "This is a [BLOCK] joke that should be blocked."',
]
llmrails = LLMRails(output_rails_streaming_config_default)

chunks = await run_self_check_test(
output_rails_streaming_config_default, llm_completions
with pytest.raises(ValueError) as exc_info:
async for chunk in llmrails.stream_async(
messages=[{"role": "user", "content": "Hi!"}]
):
pass

assert str(exc_info.value) == (
"stream_async() cannot be used when output rails are configured but "
"rails.output.streaming.enabled is False. Either set "
"rails.output.streaming.enabled to True in your configuration, or use "
"generate_async() instead of stream_async()."
)

expected_error = {
"error": {
"message": "Blocked by self check output rails.",
"type": "guardrails_violation",
"param": "self check output",
"code": "content_blocked",
}
}

error_chunks = [
json.loads(chunk) for chunk in chunks if chunk.startswith('{"error":')
]
assert len(error_chunks) == 0
assert expected_error not in error_chunks

await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})


Expand Down Expand Up @@ -231,27 +209,6 @@ async def test_streaming_output_rails_blocked_at_start(output_rails_streaming_co
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})


@pytest.mark.asyncio
async def test_streaming_output_rails_default_config_not_blocked_at_start(
output_rails_streaming_config_default,
):
"""Tests blocking with BLOCK at the very beginning of the response does not return abort sse"""

llm_completions = [
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
' "[BLOCK] This should be blocked immediately at the start."',
]

chunks = await run_self_check_test(
output_rails_streaming_config_default, llm_completions
)

with pytest.raises(JSONDecodeError):
json.loads(chunks[0])

await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})


async def simple_token_generator() -> AsyncIterator[str]:
"""Simple generator that yields tokens."""
tokens = ["Hello", " ", "world", "!"]
Expand Down