From 47bc00e6037f6ff0952ae5021c0273e9e30f3813 Mon Sep 17 00:00:00 2001 From: Nikhil Narayen Date: Fri, 7 Feb 2025 18:28:39 +0000 Subject: [PATCH] Fix integration test for span cleanup --- truss/templates/server/model_wrapper.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 49b1d256e..8b78c00ba 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -641,7 +641,7 @@ async def _stream_with_background_task( generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]], span: trace.Span, trace_ctx: trace.Context, - release_and_end: Callable[[], None], + cleanup_fn: Callable[[], None], ) -> AsyncGenerator[bytes, None]: # The streaming read timeout is the amount of time in between streamed chunk # before a timeout is triggered. @@ -661,7 +661,7 @@ async def _stream_with_background_task( self._write_response_to_queue(response_queue, async_generator, span) ) # Defer the release of the semaphore until the write_response_to_queue task. - gen_task.add_done_callback(lambda _: release_and_end()) + gen_task.add_done_callback(lambda _: cleanup_fn()) # The gap between responses in a stream must be < streaming_read_timeout # TODO: this whole buffering might be superfluous and sufficiently done by @@ -717,7 +717,7 @@ async def _process_model_fn( if inspect.isgenerator(result) or inspect.isasyncgen(result): return await self._handle_generator_response( - request, result, fn_span, detached_ctx, release_and_end=lambda: None + request, result, fn_span, detached_ctx ) return result @@ -738,13 +738,13 @@ async def _handle_generator_response( generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]], span: trace.Span, trace_ctx: trace.Context, - release_and_end: Callable[[], None], + get_cleanup_fn: Callable[[], Callable[[], None]] = lambda: lambda: None, ): if self._should_gather_generator(request): return await _gather_generator(generator) else: return await self._stream_with_background_task( - generator, span, trace_ctx, release_and_end + generator, span, trace_ctx, cleanup_fn=get_cleanup_fn() ) async def completions( @@ -824,7 +824,7 @@ async def __call__( predict_result, span_predict, detached_ctx, - release_and_end=get_defer_fn(), + get_cleanup_fn=get_defer_fn, ) if isinstance(predict_result, starlette.responses.Response):