Skip to content

Commit 62975e9

Browse files
authored
feat: add generation metadata to streaming chunks (#1011)
1 parent 2d9a41d commit 62975e9

File tree

2 files changed

+75
-45
lines changed

2 files changed

+75
-45
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,9 +925,12 @@ def stream_async(
925925
messages: Optional[List[dict]] = None,
926926
options: Optional[Union[dict, GenerationOptions]] = None,
927927
state: Optional[Union[dict, State]] = None,
928+
include_generation_metadata: Optional[bool] = False,
928929
) -> AsyncIterator[str]:
929930
"""Simplified interface for getting directly the streamed tokens from the LLM."""
930-
streaming_handler = StreamingHandler()
931+
streaming_handler = StreamingHandler(
932+
include_generation_metadata=include_generation_metadata
933+
)
931934

932935
# todo use a context var for buffer strategy and return it here?
933936
# then iterating over buffer strategy is nested loop?

nemoguardrails/streaming.py

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,17 @@
3131
class StreamingHandler(AsyncCallbackHandler, AsyncIterator):
3232
"""Streaming async handler.
3333
34-
Implements the LangChain AsyncCallbackHandler, so it can be notified of new tokens.
35-
It also implements the AsyncIterator interface, so it can be used directly to stream
34+
Implements the LangChain AsyncCallbackHandler so it can be notified of new tokens.
35+
It also implements the AsyncIterator interface so it can be used directly to stream
3636
back the response.
3737
"""
3838

39-
def __init__(self, enable_print: bool = False, enable_buffer: bool = False):
39+
def __init__(
40+
self,
41+
enable_print: bool = False,
42+
enable_buffer: bool = False,
43+
include_generation_metadata: Optional[bool] = False,
44+
):
4045
# A unique id for the stream handler
4146
self.uid = new_uuid()
4247

@@ -50,34 +55,37 @@ def __init__(self, enable_print: bool = False, enable_buffer: bool = False):
5055
# When buffering is enabled, the chunks will gather in a buffer.
5156
self.enable_buffer = enable_buffer
5257

53-
# The prefix/suffix that should be removed
58+
# The prefix/suffix that should be removed (text-only processing)
5459
self.prefix = None
5560
self.suffix = None
5661

5762
# The current chunk which needs to be checked for prefix/suffix matching
5863
self.current_chunk = ""
5964

60-
# The current buffer, until we start the processing.
65+
# The current buffer until we start processing.
6166
self.buffer = ""
6267

6368
# The full completion
6469
self.completion = ""
6570

66-
# Weather we're interested in the top k non-empty lines
71+
# Whether we're interested in the top k non-empty lines
6772
self.k = 0
6873
self.top_k_nonempty_lines_event = asyncio.Event()
6974

70-
# If set, the chunk will be piped to the specified handler rather than added to
71-
# the queue or printed
75+
# If set, the chunk will be piped to the specified handler rather than added to the queue or printed
7276
self.pipe_to = None
7377

7478
self.first_token = True
7579

7680
# The stop chunks
7781
self.stop = []
7882

83+
# Generation metadata handling
84+
self.include_generation_metadata = include_generation_metadata
85+
self.current_generation_info = {}
86+
7987
def set_pattern(self, prefix: Optional[str] = None, suffix: Optional[str] = None):
80-
"""Sets the patter that is expected.
88+
"""Sets the pattern that is expected.
8189
8290
If a prefix or a suffix are specified, they will be removed from the output.
8391
"""
@@ -96,7 +104,7 @@ async def wait_top_k_nonempty_lines(self, k: int):
96104
"""Waits for top k non-empty lines from the LLM.
97105
98106
When k lines have been received (and k+1 has been started) it will return
99-
and remove them from the buffer
107+
and remove them from the buffer.
100108
"""
101109
self.k = k
102110
await self.top_k_nonempty_lines_event.wait()
@@ -121,7 +129,6 @@ async def enable_buffering(self):
121129
async def disable_buffering(self):
122130
"""When we disable the buffer, we process the buffer as a chunk."""
123131
self.enable_buffer = False
124-
125132
await self.push_chunk(self.buffer)
126133
self.buffer = ""
127134

@@ -136,6 +143,13 @@ async def generator():
136143
raise ex
137144
if element is None or element == "":
138145
break
146+
147+
if isinstance(element, dict):
148+
if element is not None and (
149+
element.get("text") is None or element.get("text") == ""
150+
):
151+
yield element
152+
break
139153
yield element
140154

141155
return generator()
@@ -147,31 +161,43 @@ async def __anext__(self):
147161
except RuntimeError as ex:
148162
if "Event loop is closed" not in str(ex):
149163
raise ex
164+
# following test is because of TestChat and FakeLLM implementation
165+
#
150166
if element is None or element == "":
151167
raise StopAsyncIteration
168+
169+
if isinstance(element, dict):
170+
if element is not None and (
171+
element.get("text") is None or element.get("text") == ""
172+
):
173+
raise StopAsyncIteration
152174
else:
153175
return element
154176

155-
async def _process(self, chunk: str):
177+
async def _process(
178+
self, chunk: str, generation_info: Optional[Dict[str, Any]] = None
179+
):
156180
"""Process a chunk of text.
157181
158-
If we're in buffering mode, we just record it.
159-
If we need to pipe it to another streaming handler, we do that.
182+
If we're in buffering mode, record the text.
183+
Otherwise, update the full completion, check for stop tokens, and enqueue the chunk.
160184
"""
185+
186+
if self.include_generation_metadata and generation_info:
187+
self.current_generation_info = generation_info
188+
161189
if self.enable_buffer:
162190
self.buffer += chunk
163-
164191
lines = [line.strip() for line in self.buffer.split("\n")]
165192
lines = [line for line in lines if len(line) > 0 and line[0] != "#"]
166-
# We wait until we got to k+1 lines, to make sure the k-th line is finished
167193
if len(lines) > self.k > 0:
168194
self.top_k_nonempty_lines_event.set()
195+
169196
else:
170-
# Temporarily save the content of the completion before this new chunk.
171197
prev_completion = self.completion
198+
172199
if chunk is not None:
173200
self.completion += chunk
174-
175201
# Check if the completion contains one of the stop chunks
176202
for stop_chunk in self.stop:
177203
if stop_chunk in self.completion:
@@ -183,7 +209,6 @@ async def _process(self, chunk: str):
183209
if len(self.completion) > len(prev_completion):
184210
self.current_chunk = self.completion[len(prev_completion) :]
185211
await self.push_chunk(None)
186-
187212
# And we stop the streaming
188213
self.streaming_finished_event.set()
189214
self.top_k_nonempty_lines_event.set()
@@ -197,14 +222,25 @@ async def _process(self, chunk: str):
197222
else:
198223
if self.enable_print and chunk is not None:
199224
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
200-
await self.queue.put(chunk)
201-
225+
# await self.queue.put(chunk)
226+
if self.include_generation_metadata:
227+
await self.queue.put(
228+
{
229+
"text": chunk,
230+
"generation_info": self.current_generation_info.copy(),
231+
}
232+
)
233+
else:
234+
await self.queue.put(chunk)
235+
# If the chunk is empty (used as termination), mark the stream as finished.
202236
if chunk is None or chunk == "":
203237
self.streaming_finished_event.set()
204238
self.top_k_nonempty_lines_event.set()
205239

206240
async def push_chunk(
207-
self, chunk: Union[str, GenerationChunk, AIMessageChunk, None]
241+
self,
242+
chunk: Union[str, GenerationChunk, AIMessageChunk, None],
243+
generation_info: Optional[Dict[str, Any]] = None,
208244
):
209245
"""Push a new chunk to the stream."""
210246
if isinstance(chunk, GenerationChunk):
@@ -222,44 +258,38 @@ async def push_chunk(
222258
log.info(f"{self.uid[0:3]} - CHUNK after finish: {chunk}")
223259
return
224260

225-
# Only after we get the expected prefix we remove it and start streaming
261+
if self.include_generation_metadata and generation_info:
262+
self.current_generation_info = generation_info
263+
264+
# Process prefix: accumulate until the expected prefix is received, then remove it.
226265
if self.prefix:
227266
if chunk is not None:
228267
self.current_chunk += chunk
229-
230268
if self.current_chunk.startswith(self.prefix):
231269
self.current_chunk = self.current_chunk[len(self.prefix) :]
232270
self.prefix = None
233-
234271
# If we're left with something, we "forward it".
235272
if self.current_chunk:
236273
await self._process(self.current_chunk)
237274
self.current_chunk = ""
275+
# Process suffix/stop tokens: accumulate and check whether the current chunk ends with one.
238276
elif self.suffix or self.stop:
239-
# If we have a suffix, we always check that the total current chunk does not end
240-
# with the suffix.
241-
242277
if chunk is not None:
243278
self.current_chunk += chunk
244-
245279
_chunks = []
246280
if self.suffix:
247281
_chunks.append(self.suffix)
248282
if self.stop:
249283
_chunks.extend(self.stop)
250-
251284
skip_processing = False
252285
for _chunk in _chunks:
253286
if skip_processing:
254287
break
255-
256288
for _len in range(len(_chunk)):
257289
if self.current_chunk.endswith(_chunk[0 : _len + 1]):
258290
skip_processing = True
259291
break
260292

261-
# TODO: improve this logic to work for multi-token suffixes.
262-
# if self.current_chunk.endswith(self.suffix):
263293
if skip_processing and chunk != "" and chunk is not None:
264294
# We do nothing in this case. The suffix/stop chunks will be removed when
265295
# the generation ends and if there's something left, will be processed then.
@@ -274,13 +304,10 @@ async def push_chunk(
274304
self.current_chunk = self.current_chunk[
275305
0 : -1 * len(self.suffix)
276306
]
277-
278-
await self._process(self.current_chunk)
307+
await self._process(self.current_chunk, generation_info)
279308
self.current_chunk = ""
280309
else:
281-
await self._process(chunk)
282-
283-
# Methods from the LangChain AsyncCallbackHandler
310+
await self._process(chunk, generation_info)
284311

285312
async def on_chat_model_start(
286313
self,
@@ -306,13 +333,15 @@ async def on_llm_new_token(
306333
**kwargs: Any,
307334
) -> None:
308335
"""Run on new LLM token. Only available when streaming is enabled."""
309-
# If the first token is an empty one, we ignore.
310336
if self.first_token:
311337
self.first_token = False
312338
if token == "":
313339
return
314-
315-
await self.push_chunk(chunk)
340+
# Pass token as generation metadata.
341+
generation_info = (
342+
chunk.generation_info if chunk and hasattr(chunk, "generation_info") else {}
343+
)
344+
await self.push_chunk(chunk, generation_info=generation_info)
316345

317346
async def on_llm_end(
318347
self,
@@ -330,13 +359,11 @@ async def on_llm_end(
330359

331360
await self._process(self.current_chunk)
332361
self.current_chunk = ""
333-
334362
await self._process("")
335-
336363
# We explicitly print a new line here
337364
if self.enable_print:
338365
print("")
339366

340-
# We also reset the prefix/suffix
367+
# Reset prefix/suffix for the next generation.
341368
self.prefix = None
342369
self.suffix = None

0 commit comments

Comments
 (0)