Skip to content

Commit 85400a5

Browse files
authored
refactor(streaming): introduce END_OF_STREAM sentinel and update handling (#1185)
* refactor(streaming): introduce END_OF_STREAM sentinel and update handling - Replaced inconsistent use of `None` and `""` for stream termination in `StreamingHandler` with a dedicated `END_OF_STREAM` sentinel object. - Modified `push_chunk` to convert `None` to `END_OF_STREAM`. - Updated `__anext__` to raise `StopAsyncIteration` only for `END_OF_STREAM` and to return empty strings or dicts with empty/None text as data. - Adjusted `_process` to correctly handle `END_OF_STREAM` for buffering and queueing logic. - Updated `on_llm_end` to use `END_OF_STREAM`. - Revised tests in `test_streaming_handler.py` to reflect these changes, including how empty first tokens are handled and how `__anext__` behaves with various inputs. * coverage to the moon: fix missing generation_info and add more tests
1 parent ca02cb6 commit 85400a5

File tree

3 files changed

+243
-89
lines changed

3 files changed

+243
-89
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
GenerationResponse,
7575
)
7676
from nemoguardrails.rails.llm.utils import get_history_cache_key
77-
from nemoguardrails.streaming import StreamingHandler
77+
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler
7878
from nemoguardrails.utils import (
7979
extract_error_json,
8080
get_or_create_event_loop,
@@ -712,7 +712,7 @@ async def generate_async(
712712
error_payload = json.dumps(error_dict)
713713
await streaming_handler.push_chunk(error_payload)
714714
# push a termination signal
715-
await streaming_handler.push_chunk(None)
715+
await streaming_handler.push_chunk(END_OF_STREAM)
716716
# Re-raise the exact exception
717717
raise
718718
else:
@@ -827,7 +827,7 @@ async def generate_async(
827827
streaming_handler = streaming_handler_var.get()
828828
if streaming_handler:
829829
# print("Closing the stream handler explicitly")
830-
await streaming_handler.push_chunk(None)
830+
await streaming_handler.push_chunk(END_OF_STREAM)
831831

832832
# IF tracing is enabled we need to set GenerationLog attrs
833833
if self.config.tracing.enabled:

nemoguardrails/streaming.py

Lines changed: 106 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
log = logging.getLogger(__name__)
2929

30+
# sentinel object to indicate end of stream
31+
END_OF_STREAM = object()
32+
3033

3134
class StreamingHandler(AsyncCallbackHandler, AsyncIterator):
3235
"""Streaming async handler.
@@ -141,13 +144,11 @@ async def generator():
141144
except RuntimeError as ex:
142145
if "Event loop is closed" not in str(ex):
143146
raise ex
144-
if element is None or element == "":
147+
if element is END_OF_STREAM:
145148
break
146149

147150
if isinstance(element, dict):
148-
if element is not None and (
149-
element.get("text") is None or element.get("text") == ""
150-
):
151+
if element is not None and (element.get("text") is END_OF_STREAM):
151152
yield element
152153
break
153154
yield element
@@ -161,21 +162,20 @@ async def __anext__(self):
161162
except RuntimeError as ex:
162163
if "Event loop is closed" not in str(ex):
163164
raise ex
164-
# following test is because of TestChat and FakeLLM implementation
165-
#
166-
if element is None or element == "":
165+
if element is END_OF_STREAM:
167166
raise StopAsyncIteration
168167

169168
if isinstance(element, dict):
170-
if element is not None and (
171-
element.get("text") is None or element.get("text") == ""
172-
):
169+
if element is not None and (element.get("text") is END_OF_STREAM):
173170
raise StopAsyncIteration
171+
return element
174172
else:
175173
return element
176174

177175
async def _process(
178-
self, chunk: str, generation_info: Optional[Dict[str, Any]] = None
176+
self,
177+
chunk: Union[str, object],
178+
generation_info: Optional[Dict[str, Any]] = None,
179179
):
180180
"""Process a chunk of text.
181181
@@ -187,16 +187,17 @@ async def _process(
187187
self.current_generation_info = generation_info
188188

189189
if self.enable_buffer:
190-
self.buffer += chunk
191-
lines = [line.strip() for line in self.buffer.split("\n")]
192-
lines = [line for line in lines if len(line) > 0 and line[0] != "#"]
193-
if len(lines) > self.k > 0:
194-
self.top_k_nonempty_lines_event.set()
190+
if chunk is not END_OF_STREAM:
191+
self.buffer += chunk if chunk is not None else ""
192+
lines = [line.strip() for line in self.buffer.split("\n")]
193+
lines = [line for line in lines if len(line) > 0 and line[0] != "#"]
194+
if len(lines) > self.k > 0:
195+
self.top_k_nonempty_lines_event.set()
195196

196197
else:
197198
prev_completion = self.completion
198199

199-
if chunk is not None:
200+
if chunk is not None and chunk is not END_OF_STREAM:
200201
self.completion += chunk
201202
# Check if the completion contains one of the stop chunks
202203
for stop_chunk in self.stop:
@@ -208,48 +209,84 @@ async def _process(
208209
# We push that as well.
209210
if len(self.completion) > len(prev_completion):
210211
self.current_chunk = self.completion[len(prev_completion) :]
211-
await self.push_chunk(None)
212+
await self.push_chunk(END_OF_STREAM)
212213
# And we stop the streaming
213214
self.streaming_finished_event.set()
214215
self.top_k_nonempty_lines_event.set()
215216
return
216217

217218
if self.pipe_to:
218-
asyncio.create_task(self.pipe_to.push_chunk(chunk))
219-
if chunk is None or chunk == "":
219+
# only add explicit empty strings, not ones created during processing
220+
if chunk is END_OF_STREAM or chunk is not None:
221+
asyncio.create_task(self.pipe_to.push_chunk(chunk))
222+
if chunk is END_OF_STREAM:
220223
self.streaming_finished_event.set()
221224
self.top_k_nonempty_lines_event.set()
222225
else:
223-
if self.enable_print and chunk is not None:
226+
if (
227+
self.enable_print
228+
and chunk is not None
229+
and chunk is not END_OF_STREAM
230+
):
224231
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
225-
# await self.queue.put(chunk)
226-
if self.include_generation_metadata:
227-
await self.queue.put(
228-
{
229-
"text": chunk,
230-
"generation_info": self.current_generation_info.copy(),
231-
}
232-
)
233-
else:
234-
await self.queue.put(chunk)
235-
# If the chunk is empty (used as termination), mark the stream as finished.
236-
if chunk is None or chunk == "":
237-
self.streaming_finished_event.set()
238-
self.top_k_nonempty_lines_event.set()
232+
233+
# we only want to filter out empty strings that are created during suffix processing,
234+
# not ones directly pushed by the user
235+
if chunk is not None:
236+
# process all valid chunks, including empty strings directly from the user
237+
if self.include_generation_metadata:
238+
if chunk is not END_OF_STREAM:
239+
await self.queue.put(
240+
{
241+
"text": chunk,
242+
"generation_info": self.current_generation_info.copy(),
243+
}
244+
)
245+
else:
246+
await self.queue.put(
247+
{
248+
"text": END_OF_STREAM,
249+
"generation_info": self.current_generation_info.copy(),
250+
}
251+
)
252+
else:
253+
await self.queue.put(chunk)
254+
255+
# If the chunk is the special end of stream marker, mark the stream as finished.
256+
if chunk is END_OF_STREAM:
257+
self.streaming_finished_event.set()
258+
self.top_k_nonempty_lines_event.set()
239259

240260
async def push_chunk(
241261
self,
242-
chunk: Union[str, GenerationChunk, AIMessageChunk, None],
262+
chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None],
243263
generation_info: Optional[Dict[str, Any]] = None,
244264
):
245265
"""Push a new chunk to the stream."""
266+
267+
# if generation_info is not explicitly passed,
268+
# try to get it from the chunk itself if it's a GenerationChunk or ChatGenerationChunk
269+
if generation_info is None:
270+
if isinstance(chunk, (GenerationChunk, ChatGenerationChunk)) and hasattr(
271+
chunk, "generation_info"
272+
):
273+
if chunk.generation_info is not None:
274+
generation_info = chunk.generation_info.copy()
275+
246276
if isinstance(chunk, GenerationChunk):
247277
chunk = chunk.text
248278
elif isinstance(chunk, AIMessageChunk):
249279
chunk = chunk.content
250280
elif isinstance(chunk, ChatGenerationChunk):
251281
chunk = chunk.text
252-
elif isinstance(chunk, str) or chunk is None:
282+
elif chunk is None:
283+
# replace None with the END_OF_STREAM marker
284+
chunk = END_OF_STREAM
285+
elif chunk is END_OF_STREAM:
286+
# already the correct marker, no conversion needed
287+
pass
288+
elif isinstance(chunk, str):
289+
# empty string is a valid chunk and should be processed normally
253290
pass
254291
else:
255292
raise Exception(f"Unsupported chunk type: {chunk.__class__.__name__}")
@@ -263,7 +300,7 @@ async def push_chunk(
263300

264301
# Process prefix: accumulate until the expected prefix is received, then remove it.
265302
if self.prefix:
266-
if chunk is not None:
303+
if chunk is not None and chunk is not END_OF_STREAM:
267304
self.current_chunk += chunk
268305
if self.current_chunk.startswith(self.prefix):
269306
self.current_chunk = self.current_chunk[len(self.prefix) :]
@@ -274,7 +311,7 @@ async def push_chunk(
274311
self.current_chunk = ""
275312
# Process suffix/stop tokens: accumulate and check whether the current chunk ends with one.
276313
elif self.suffix or self.stop:
277-
if chunk is not None:
314+
if chunk is not None and chunk is not END_OF_STREAM:
278315
self.current_chunk += chunk
279316
_chunks = []
280317
if self.suffix:
@@ -290,12 +327,12 @@ async def push_chunk(
290327
skip_processing = True
291328
break
292329

293-
if skip_processing and chunk != "" and chunk is not None:
330+
if skip_processing and chunk is not END_OF_STREAM and chunk != "":
294331
# We do nothing in this case. The suffix/stop chunks will be removed when
295332
# the generation ends and if there's something left, will be processed then.
296333
return
297334
else:
298-
if chunk == "" or chunk is None:
335+
if chunk is END_OF_STREAM:
299336
if (
300337
self.current_chunk
301338
and self.suffix
@@ -304,8 +341,15 @@ async def push_chunk(
304341
self.current_chunk = self.current_chunk[
305342
0 : -1 * len(self.suffix)
306343
]
307-
await self._process(self.current_chunk, generation_info)
308-
self.current_chunk = ""
344+
345+
# only process the current_chunk if it's not empty
346+
if self.current_chunk:
347+
await self._process(self.current_chunk, generation_info)
348+
self.current_chunk = ""
349+
350+
# if this is the end of stream, pass it through after processing the current chunk
351+
if chunk is END_OF_STREAM:
352+
await self._process(END_OF_STREAM, generation_info)
309353
else:
310354
await self._process(chunk, generation_info)
311355

@@ -333,15 +377,27 @@ async def on_llm_new_token(
333377
**kwargs: Any,
334378
) -> None:
335379
"""Run on new LLM token. Only available when streaming is enabled."""
380+
# Log the first token if it's empty to help with debugging
381+
if self.first_token and token == "":
382+
log.debug(f"{self.uid[0:3]} - Received empty first token from LLM")
383+
384+
# set first_token to False regardless of token content
385+
# we always process tokens, even empty ones
336386
if self.first_token:
337387
self.first_token = False
338-
if token == "":
339-
return
340-
# Pass token as generation metadata.
341-
generation_info = (
342-
chunk.generation_info if chunk and hasattr(chunk, "generation_info") else {}
388+
389+
generation_info = None
390+
if chunk and hasattr(chunk, "generation_info"):
391+
if chunk.generation_info is not None:
392+
generation_info = chunk.generation_info.copy()
393+
else:
394+
generation_info = {}
395+
else:
396+
generation_info = {}
397+
398+
await self.push_chunk(
399+
token if chunk is None else chunk, generation_info=generation_info
343400
)
344-
await self.push_chunk(chunk, generation_info=generation_info)
345401

346402
async def on_llm_end(
347403
self,
@@ -359,7 +415,7 @@ async def on_llm_end(
359415

360416
await self._process(self.current_chunk)
361417
self.current_chunk = ""
362-
await self._process("")
418+
await self._process(END_OF_STREAM)
363419
# We explicitly print a new line here
364420
if self.enable_print:
365421
print("")

0 commit comments

Comments
 (0)