@@ -32,42 +32,63 @@ def __init__(self, streaming_handler: StreamingHandler):
3232
3333 self .chunks = []
3434 self .finished = False
35+ self ._task = None
3536 self ._start ()
3637
3738 async def process_tokens (self ):
38- async for chunk in self .streaming_handler :
39- self .chunks .append (chunk )
40-
41- self .finished = True
39+ try :
40+ async for chunk in self .streaming_handler :
41+ self .chunks .append (chunk )
42+ except asyncio .CancelledError :
43+ # task was cancelled. this is expected during cleanup
44+ pass
45+ finally :
46+ self .finished = True
4247
4348 def _start (self ):
44- asyncio .create_task (self .process_tokens ())
49+ self . _task = asyncio .create_task (self .process_tokens ())
4550
4651 async def get_chunks (self ):
4752 """Helper to get the chunks."""
4853 # We wait a bit to allow all asyncio callbacks to get called.
4954 await asyncio .sleep (0.01 )
5055 return self .chunks
5156
57+ async def cancel (self ):
58+ """Cancel the background task and wait for it to finish."""
59+ if self ._task and not self ._task .done ():
60+ self ._task .cancel ()
61+ try :
62+ await self ._task
63+ except asyncio .CancelledError :
64+ # this is expected when cancelling the task
65+ pass
66+
5267
5368@pytest .mark .asyncio
5469async def test_single_chunk ():
5570 streaming_handler = StreamingHandler ()
5671 streaming_consumer = StreamingConsumer (streaming_handler )
5772
58- await streaming_handler .push_chunk ("a" )
59- assert await streaming_consumer .get_chunks () == ["a" ]
73+ try :
74+ await streaming_handler .push_chunk ("a" )
75+ assert await streaming_consumer .get_chunks () == ["a" ]
76+ finally :
77+ await streaming_consumer .cancel ()
6078
6179
6280@pytest .mark .asyncio
6381async def test_sequence_of_chunks ():
6482 streaming_handler = StreamingHandler ()
6583 streaming_consumer = StreamingConsumer (streaming_handler )
6684
67- for chunk in ["1" , "2" , "3" , "4" , "5" ]:
68- await streaming_handler .push_chunk (chunk )
85+ try :
86+ for chunk in ["1" , "2" , "3" , "4" , "5" ]:
87+ await streaming_handler .push_chunk (chunk )
6988
70- assert await streaming_consumer .get_chunks () == ["1" , "2" , "3" , "4" , "5" ]
89+ assert await streaming_consumer .get_chunks () == ["1" , "2" , "3" , "4" , "5" ]
90+ finally :
91+ await streaming_consumer .cancel ()
7192
7293
7394async def _test_pattern_case (
@@ -93,16 +114,19 @@ async def _test_pattern_case(
93114 else :
94115 streaming_consumer = StreamingConsumer (streaming_handler )
95116
96- for chunk in chunks :
97- if chunk is None :
98- assert await streaming_consumer .get_chunks () == []
99- else :
100- await streaming_handler .push_chunk (chunk )
117+ try :
118+ for chunk in chunks :
119+ if chunk is None :
120+ assert await streaming_consumer .get_chunks () == []
121+ else :
122+ await streaming_handler .push_chunk (chunk )
101123
102- # Push an empty chunk to signal the ending.
103- await streaming_handler .push_chunk ("" )
124+ # Push an empty chunk to signal the ending.
125+ await streaming_handler .push_chunk ("" )
104126
105- assert await streaming_consumer .get_chunks () == final_chunks
127+ assert await streaming_consumer .get_chunks () == final_chunks
128+ finally :
129+ await streaming_consumer .cancel ()
106130
107131
108132@pytest .mark .asyncio
@@ -218,7 +242,7 @@ async def test_suffix_with_stop_and_pipe_3():
218242 " message: " ,
219243 '"' ,
220244 "This is a message" ,
221- '."' " \n User" ,
245+ '."\n User' ,
222246 " intent: " ,
223247 " xxx" ,
224248 ],
@@ -238,7 +262,7 @@ async def test_suffix_with_stop_and_pipe_4():
238262 " message: " ,
239263 '"' ,
240264 "This is a message" ,
241- '."' " \n User" ,
265+ '."\n User' ,
242266 " intent: " ,
243267 " xxx" ,
244268 ],
0 commit comments