3131class StreamingHandler (AsyncCallbackHandler , AsyncIterator ):
3232 """Streaming async handler.
3333
34- Implements the LangChain AsyncCallbackHandler, so it can be notified of new tokens.
35- It also implements the AsyncIterator interface, so it can be used directly to stream
34+ Implements the LangChain AsyncCallbackHandler so it can be notified of new tokens.
35+ It also implements the AsyncIterator interface so it can be used directly to stream
3636 back the response.
3737 """
3838
39- def __init__ (self , enable_print : bool = False , enable_buffer : bool = False ):
39+ def __init__ (
40+ self ,
41+ enable_print : bool = False ,
42+ enable_buffer : bool = False ,
43+ include_generation_metadata : Optional [bool ] = False ,
44+ ):
4045 # A unique id for the stream handler
4146 self .uid = new_uuid ()
4247
@@ -50,34 +55,37 @@ def __init__(self, enable_print: bool = False, enable_buffer: bool = False):
5055 # When buffering is enabled, the chunks will gather in a buffer.
5156 self .enable_buffer = enable_buffer
5257
53- # The prefix/suffix that should be removed
58+ # The prefix/suffix that should be removed (text-only processing)
5459 self .prefix = None
5560 self .suffix = None
5661
5762 # The current chunk which needs to be checked for prefix/suffix matching
5863 self .current_chunk = ""
5964
60- # The current buffer, until we start the processing.
65+ # The current buffer until we start processing.
6166 self .buffer = ""
6267
6368 # The full completion
6469 self .completion = ""
6570
66- # Weather we're interested in the top k non-empty lines
71+ # Whether we're interested in the top k non-empty lines
6772 self .k = 0
6873 self .top_k_nonempty_lines_event = asyncio .Event ()
6974
70- # If set, the chunk will be piped to the specified handler rather than added to
71- # the queue or printed
75+ # If set, the chunk will be piped to the specified handler rather than added to the queue or printed
7276 self .pipe_to = None
7377
7478 self .first_token = True
7579
7680 # The stop chunks
7781 self .stop = []
7882
83+ # Generation metadata handling
84+ self .include_generation_metadata = include_generation_metadata
85+ self .current_generation_info = {}
86+
7987 def set_pattern (self , prefix : Optional [str ] = None , suffix : Optional [str ] = None ):
80- """Sets the patter that is expected.
88+ """Sets the pattern that is expected.
8189
8290 If a prefix or a suffix are specified, they will be removed from the output.
8391 """
@@ -96,7 +104,7 @@ async def wait_top_k_nonempty_lines(self, k: int):
96104 """Waits for top k non-empty lines from the LLM.
97105
98106 When k lines have been received (and k+1 has been started) it will return
99- and remove them from the buffer
107+ and remove them from the buffer.
100108 """
101109 self .k = k
102110 await self .top_k_nonempty_lines_event .wait ()
@@ -121,7 +129,6 @@ async def enable_buffering(self):
121129 async def disable_buffering (self ):
122130 """When we disable the buffer, we process the buffer as a chunk."""
123131 self .enable_buffer = False
124-
125132 await self .push_chunk (self .buffer )
126133 self .buffer = ""
127134
@@ -136,6 +143,13 @@ async def generator():
136143 raise ex
137144 if element is None or element == "" :
138145 break
146+
147+ if isinstance (element , dict ):
148+ if element is not None and (
149+ element .get ("text" ) is None or element .get ("text" ) == ""
150+ ):
151+ yield element
152+ break
139153 yield element
140154
141155 return generator ()
@@ -147,31 +161,43 @@ async def __anext__(self):
147161 except RuntimeError as ex :
148162 if "Event loop is closed" not in str (ex ):
149163 raise ex
164+ # following test is because of TestChat and FakeLLM implementation
165+ #
150166 if element is None or element == "" :
151167 raise StopAsyncIteration
168+
169+ if isinstance (element , dict ):
170+ if element is not None and (
171+ element .get ("text" ) is None or element .get ("text" ) == ""
172+ ):
173+ raise StopAsyncIteration
152174 else :
153175 return element
154176
155- async def _process (self , chunk : str ):
177+ async def _process (
178+ self , chunk : str , generation_info : Optional [Dict [str , Any ]] = None
179+ ):
156180 """Process a chunk of text.
157181
158- If we're in buffering mode, we just record it .
159- If we need to pipe it to another streaming handler, we do that .
182+ If we're in buffering mode, record the text .
183+ Otherwise, update the full completion, check for stop tokens, and enqueue the chunk .
160184 """
185+
186+ if self .include_generation_metadata and generation_info :
187+ self .current_generation_info = generation_info
188+
161189 if self .enable_buffer :
162190 self .buffer += chunk
163-
164191 lines = [line .strip () for line in self .buffer .split ("\n " )]
165192 lines = [line for line in lines if len (line ) > 0 and line [0 ] != "#" ]
166- # We wait until we got to k+1 lines, to make sure the k-th line is finished
167193 if len (lines ) > self .k > 0 :
168194 self .top_k_nonempty_lines_event .set ()
195+
169196 else :
170- # Temporarily save the content of the completion before this new chunk.
171197 prev_completion = self .completion
198+
172199 if chunk is not None :
173200 self .completion += chunk
174-
175201 # Check if the completion contains one of the stop chunks
176202 for stop_chunk in self .stop :
177203 if stop_chunk in self .completion :
@@ -183,7 +209,6 @@ async def _process(self, chunk: str):
183209 if len (self .completion ) > len (prev_completion ):
184210 self .current_chunk = self .completion [len (prev_completion ) :]
185211 await self .push_chunk (None )
186-
187212 # And we stop the streaming
188213 self .streaming_finished_event .set ()
189214 self .top_k_nonempty_lines_event .set ()
@@ -197,14 +222,25 @@ async def _process(self, chunk: str):
197222 else :
198223 if self .enable_print and chunk is not None :
199224 print (f"\033 [92m{ chunk } \033 [0m" , end = "" , flush = True )
200- await self .queue .put (chunk )
201-
225+ # await self.queue.put(chunk)
226+ if self .include_generation_metadata :
227+ await self .queue .put (
228+ {
229+ "text" : chunk ,
230+ "generation_info" : self .current_generation_info .copy (),
231+ }
232+ )
233+ else :
234+ await self .queue .put (chunk )
235+ # If the chunk is empty (used as termination), mark the stream as finished.
202236 if chunk is None or chunk == "" :
203237 self .streaming_finished_event .set ()
204238 self .top_k_nonempty_lines_event .set ()
205239
206240 async def push_chunk (
207- self , chunk : Union [str , GenerationChunk , AIMessageChunk , None ]
241+ self ,
242+ chunk : Union [str , GenerationChunk , AIMessageChunk , None ],
243+ generation_info : Optional [Dict [str , Any ]] = None ,
208244 ):
209245 """Push a new chunk to the stream."""
210246 if isinstance (chunk , GenerationChunk ):
@@ -222,44 +258,38 @@ async def push_chunk(
222258 log .info (f"{ self .uid [0 :3 ]} - CHUNK after finish: { chunk } " )
223259 return
224260
225- # Only after we get the expected prefix we remove it and start streaming
261+ if self .include_generation_metadata and generation_info :
262+ self .current_generation_info = generation_info
263+
264+ # Process prefix: accumulate until the expected prefix is received, then remove it.
226265 if self .prefix :
227266 if chunk is not None :
228267 self .current_chunk += chunk
229-
230268 if self .current_chunk .startswith (self .prefix ):
231269 self .current_chunk = self .current_chunk [len (self .prefix ) :]
232270 self .prefix = None
233-
234271 # If we're left with something, we "forward it".
235272 if self .current_chunk :
236273 await self ._process (self .current_chunk )
237274 self .current_chunk = ""
275+ # Process suffix/stop tokens: accumulate and check whether the current chunk ends with one.
238276 elif self .suffix or self .stop :
239- # If we have a suffix, we always check that the total current chunk does not end
240- # with the suffix.
241-
242277 if chunk is not None :
243278 self .current_chunk += chunk
244-
245279 _chunks = []
246280 if self .suffix :
247281 _chunks .append (self .suffix )
248282 if self .stop :
249283 _chunks .extend (self .stop )
250-
251284 skip_processing = False
252285 for _chunk in _chunks :
253286 if skip_processing :
254287 break
255-
256288 for _len in range (len (_chunk )):
257289 if self .current_chunk .endswith (_chunk [0 : _len + 1 ]):
258290 skip_processing = True
259291 break
260292
261- # TODO: improve this logic to work for multi-token suffixes.
262- # if self.current_chunk.endswith(self.suffix):
263293 if skip_processing and chunk != "" and chunk is not None :
264294 # We do nothing in this case. The suffix/stop chunks will be removed when
265295 # the generation ends and if there's something left, will be processed then.
@@ -274,13 +304,10 @@ async def push_chunk(
274304 self .current_chunk = self .current_chunk [
275305 0 : - 1 * len (self .suffix )
276306 ]
277-
278- await self ._process (self .current_chunk )
307+ await self ._process (self .current_chunk , generation_info )
279308 self .current_chunk = ""
280309 else :
281- await self ._process (chunk )
282-
283- # Methods from the LangChain AsyncCallbackHandler
310+ await self ._process (chunk , generation_info )
284311
285312 async def on_chat_model_start (
286313 self ,
@@ -306,13 +333,15 @@ async def on_llm_new_token(
306333 ** kwargs : Any ,
307334 ) -> None :
308335 """Run on new LLM token. Only available when streaming is enabled."""
309- # If the first token is an empty one, we ignore.
310336 if self .first_token :
311337 self .first_token = False
312338 if token == "" :
313339 return
314-
315- await self .push_chunk (chunk )
340+ # Pass token as generation metadata.
341+ generation_info = (
342+ chunk .generation_info if chunk and hasattr (chunk , "generation_info" ) else {}
343+ )
344+ await self .push_chunk (chunk , generation_info = generation_info )
316345
317346 async def on_llm_end (
318347 self ,
@@ -330,13 +359,11 @@ async def on_llm_end(
330359
331360 await self ._process (self .current_chunk )
332361 self .current_chunk = ""
333-
334362 await self ._process ("" )
335-
336363 # We explicitly print a new line here
337364 if self .enable_print :
338365 print ("" )
339366
340- # We also reset the prefix/suffix
367+ # Reset prefix/suffix for the next generation.
341368 self .prefix = None
342369 self .suffix = None
0 commit comments