Skip to content

Commit ee34fee

Browse files
committed
Feat: Support Output Rails Streaming (#966)
This commit introduces streaming support for output rails in the nemoguardrails configuration. It includes the following changes: - Added `OutputRailsStreamingConfig` class to define streaming configuration for output rails. - Updated `OutputRails` class to include streaming configuration. - Modified `LLMGenerationActions` to handle streaming for output rails. - Enhanced `LLMRails` to run output rails in streaming mode. - Implemented `BufferStrategy` for buffering and processing streamed tokens - Implemented RollingBuffer - Updated `StreamingHandler` to support asynchronous iteration. - Replaced direct string replacement with a partial function `get_action_name_from_flow_id` to get action names from flow ids - Added `_get_last_context_message` to retrieve the last context message. - Modified `_get_latest_user_message` to return an empty dict instead of None. - Updated `_prepare_params` to include context message and action params. - Enhanced `_prepare_params` to handle placeholders in action params. - Replaced `get_action_name_from_flow_id` with `get_action_details_from_flow_id`. - feat: handle ABORT SSE in streaming output
1 parent 9fa2740 commit ee34fee

File tree

13 files changed

+905
-24
lines changed

13 files changed

+905
-24
lines changed

nemoguardrails/actions/actions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ def action(
3737
name (Optional[str]): The name to associate with the action.
3838
execute_async: Whether the function should be executed in async mode.
3939
output_mapping (Optional[Callable[[Any], bool]]): A function to interpret the action's result.
40-
It should accept the return value (e.g. the first element of a tuple) and return True if the output
41-
should be considered blocked.
40+
It accepts the return value (e.g. the first element of a tuple) and return True if the output
41+
is not safe.
42+
4243
Returns:
4344
callable: The decorated function or class.
4445
"""

nemoguardrails/actions/llm/generation.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,16 @@ async def generate_bot_message(
762762

763763
streaming_handler = streaming_handler_var.get()
764764

765+
# when we have 'output rails streaming' enabled
766+
# we must disable (skip) the output rails which gets executed on $bot_message
767+
# as it is executed separately in llmrails.py
768+
# of course, it does not work when passed as context in `run_output_rails_in_streaming`
769+
# streaming_handler is set when stream_async method is used
770+
771+
if streaming_handler and len(self.config.rails.output.flows) > 0:
772+
# if streaming_handler and self.config.rails.output.streaming.enabled:
773+
context_updates["skip_output_rails"] = True
774+
765775
if bot_intent in self.config.bot_messages:
766776
# Choose a message randomly from self.config.bot_messages[bot_message]
767777
# However, in test mode, we always choose the first one, to keep it predictable.
@@ -779,7 +789,7 @@ async def generate_bot_message(
779789
context_updates["skip_output_rails"] = True
780790

781791
# Check if the output is supposed to be the content of a context variable
782-
elif bot_intent[0] == "$" and bot_intent[1:] in context:
792+
elif bot_intent and bot_intent[0] == "$" and bot_intent[1:] in context:
783793
bot_utterance = context[bot_intent[1:]]
784794

785795
else:
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Any, Tuple
17+
18+
19+
def default_output_mapping(result: Any) -> bool:
20+
"""A fallback output mapping if an action does not provide one.
21+
22+
- For a boolean result: assume True means allowed (so block if False).
23+
- For a numeric result: use 0.5 as a threshold (block if the value is less).
24+
- Otherwise, assume the result is allowed.
25+
"""
26+
if isinstance(result, bool):
27+
return not result
28+
elif isinstance(result, (int, float)):
29+
return result < 0.5
30+
else:
31+
return False
32+
33+
34+
def is_output_blocked(result: Any, action_func: Any) -> bool:
35+
"""Determines if an action result is not allowed using its attached mapping.
36+
37+
Args:
38+
result: The value returned by the action.
39+
action_func: The action function (whose metadata contains the mapping).
40+
41+
Returns:
42+
True if the mapping indicates that the output should be blocked, False otherwise.
43+
"""
44+
mapping = getattr(action_func, "action_meta", {}).get("output_mapping")
45+
if mapping is None:
46+
mapping = default_output_mapping
47+
48+
if not isinstance(result, Tuple):
49+
result = (result,)
50+
51+
return mapping(result[0])

nemoguardrails/cli/chat.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import asyncio
16+
import json
1617
import os
1718
from dataclasses import dataclass, field
1819
from typing import Dict, List, Optional, cast
@@ -82,17 +83,29 @@ async def _run_chat_v1_0(
8283
if not server_url:
8384
# If we have streaming from a locally loaded config, we initialize the handler.
8485
if streaming and not server_url and rails_app.main_llm_supports_streaming:
85-
streaming_handler = StreamingHandler(enable_print=True)
86-
else:
87-
streaming_handler = None
86+
bot_message_list = []
87+
async for chunk in rails_app.stream_async(messages=history):
88+
if '{"event": "ABORT"' in chunk:
89+
dict_chunk = json.loads(chunk)
90+
console.print(
91+
"\n\n[red]"
92+
+ f"ABORT streaming. {dict_chunk['data']}"
93+
+ "[/]"
94+
)
95+
break
8896

89-
bot_message = await rails_app.generate_async(
90-
messages=history, streaming_handler=streaming_handler
91-
)
97+
console.print("[green]" + f"{chunk}" + "[/]", end="")
98+
bot_message_list.append(chunk)
9299

93-
if not streaming or not rails_app.main_llm_supports_streaming:
94-
# We print bot messages in green.
95-
console.print("[green]" + f"{bot_message['content']}" + "[/]")
100+
bot_message_text = "".join(bot_message_list)
101+
bot_message = {"role": "assistant", "content": bot_message_text}
102+
103+
else:
104+
bot_message = await rails_app.generate_async(messages=history)
105+
106+
if not streaming or not rails_app.main_llm_supports_streaming:
107+
# We print bot messages in green.
108+
console.print("[green]" + f"{bot_message['content']}" + "[/]")
96109
else:
97110
data = {
98111
"config_id": config_id,

nemoguardrails/logging/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ async def on_llm_end(
246246
llm_call_info.completion_tokens = token_usage.get("completion_tokens", 0)
247247

248248
if not token_stats_found:
249-
log.warning(
249+
log.info(
250250
"Token stats in LLM call info cannot be computed for current model!"
251251
)
252252

nemoguardrails/rails/llm/buffer.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from abc import ABC, abstractmethod
17+
from typing import AsyncGenerator, List, Tuple
18+
19+
from nemoguardrails.rails.llm.config import OutputRailsStreamingConfig
20+
21+
22+
class BufferStrategy(ABC):
23+
@classmethod
24+
@abstractmethod
25+
def from_config(cls, config: OutputRailsStreamingConfig) -> "BufferStrategy":
26+
pass
27+
28+
# The abstract method is not async to ensure the return type
29+
# matches the async generator in the concrete implementation.
30+
@abstractmethod
31+
def __call__(
32+
self, streaming_handler
33+
) -> AsyncGenerator[Tuple[List[str], str], None]:
34+
pass
35+
36+
@abstractmethod
37+
def generate_chunk_str(self, *args, **kwargs) -> str:
38+
pass
39+
40+
41+
class RollingBuffer(BufferStrategy):
42+
"""A minimal buffer strategy that buffers chunks and yields them when the buffer is full.
43+
44+
Args:
45+
buffer_context_size (int): The number of tokens carried over from the previous chunk to provide context for continuity in processing.
46+
buffer_chunk_size (int): The number of tokens in each processing chunk. This is the size of the token block on which output rails are applied.
47+
"""
48+
49+
def __init__(self, buffer_context_size: int = 5, buffer_chunk_size: int = 10):
50+
self.buffer_context_size = buffer_context_size
51+
self.buffer_chunk_size = buffer_chunk_size
52+
self.last_index = 0
53+
54+
@classmethod
55+
def from_config(cls, config: OutputRailsStreamingConfig):
56+
return cls(
57+
buffer_context_size=config.context_size, buffer_chunk_size=config.chunk_size
58+
)
59+
60+
async def __call__(
61+
self, streaming_handler
62+
) -> AsyncGenerator[Tuple[List[str], str], None]:
63+
buffer = []
64+
index = 0
65+
66+
async for chunk in streaming_handler:
67+
buffer.append(chunk)
68+
index += 1
69+
70+
if len(buffer) >= self.buffer_chunk_size:
71+
yield (
72+
# we apply output rails on the buffer
73+
buffer[-self.buffer_chunk_size - self.buffer_context_size :],
74+
# generate_chunk_str is what gets printed in the console or yield to user
75+
# to avoid repeating the already streamed/printed chunk
76+
self.generate_chunk_str(
77+
buffer[-self.buffer_chunk_size - self.buffer_context_size :],
78+
index,
79+
),
80+
)
81+
buffer = buffer[-self.buffer_context_size :]
82+
83+
# Yield any remaining buffer if it's not empty
84+
if buffer:
85+
yield (
86+
buffer,
87+
self.generate_chunk_str(
88+
buffer[-self.buffer_chunk_size - self.buffer_context_size :], index
89+
),
90+
)
91+
92+
def generate_chunk_str(self, buffer, current_index) -> str:
93+
if current_index <= self.last_index:
94+
return ""
95+
96+
new_chunks = buffer[self.last_index - current_index :]
97+
self.last_index = current_index
98+
# TODO: something causes duplicate whitespaces between tokens, figure out why,
99+
# If using `return "".join(new_chunks)` works, then the issue might be elsewhere in the code where the chunks are being generated or processed.
100+
# Ensure that the chunks themselves do not contain extra spaces.
101+
# WAR: return "".join(new_chunks)
102+
return "".join(new_chunks)
103+
104+
105+
def get_buffer_strategy(config: OutputRailsStreamingConfig) -> BufferStrategy:
106+
# TODO: use a factory function or class
107+
# currently we only have RollingBuffer, in future we use a registry
108+
return RollingBuffer.from_config(config)

nemoguardrails/rails/llm/config.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,27 @@ class InputRails(BaseModel):
304304
)
305305

306306

307+
class OutputRailsStreamingConfig(BaseModel):
308+
"""Configuration for managing streaming output of LLM tokens."""
309+
310+
enabled: bool = Field(
311+
default=False, description="Enables streaming mode when True."
312+
)
313+
chunk_size: int = Field(
314+
default=200,
315+
description="The number of tokens in each processing chunk. This is the size of the token block on which output rails are applied.",
316+
)
317+
context_size: int = Field(
318+
default=50,
319+
description="The number of tokens carried over from the previous chunk to provide context for continuity in processing.",
320+
)
321+
stream_first: bool = Field(
322+
default=True,
323+
description="If True, token chunks are streamed immediately before output rails are applied.",
324+
)
325+
model_config = ConfigDict(extra="allow")
326+
327+
307328
class OutputRails(BaseModel):
308329
"""Configuration of output rails."""
309330

@@ -312,6 +333,11 @@ class OutputRails(BaseModel):
312333
description="The names of all the flows that implement output rails.",
313334
)
314335

336+
streaming: Optional[OutputRailsStreamingConfig] = Field(
337+
default_factory=OutputRailsStreamingConfig,
338+
description="Configuration for streaming output rails.",
339+
)
340+
315341

316342
class RetrievalRails(BaseModel):
317343
"""Configuration of retrieval rails."""
@@ -1201,12 +1227,15 @@ def parse_object(cls, obj):
12011227

12021228
@property
12031229
def streaming_supported(self):
1204-
"""Whether the current config supports streaming or not.
1205-
1206-
Currently, we don't support streaming if there are output rails.
1207-
"""
1208-
if len(self.rails.output.flows) > 0:
1209-
return False
1230+
"""Whether the current config supports streaming or not."""
1231+
1232+
# if len(self.rails.output.flows) > 0:
1233+
# # if we have output rails streaming enabled
1234+
# # we keep it in case it was needed when we have
1235+
# # support per rails
1236+
# if self.rails.output.streaming.enabled:
1237+
# return True
1238+
# return False
12101239

12111240
return True
12121241

0 commit comments

Comments
 (0)