@@ -73,6 +73,8 @@ def my_config(binder):
7373 inject.configure(my_config)
7474
7575"""
76+ import contextlib
77+
7678from inject ._version import __version__
7779
7880import inspect
@@ -335,12 +337,28 @@ def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[...,
335337 async def async_injection_wrapper (* args : Any , ** kwargs : Any ) -> T :
336338 provided_params = frozenset (
337339 arg_names [:len (args )]) | frozenset (kwargs .keys ())
340+ ctx_managers = {}
341+ async_ctx_managers = {}
338342 for param , cls in params_to_provide .items ():
339343 if param not in provided_params :
340- kwargs [param ] = instance (cls )
344+ inst = instance (cls )
345+ if isinstance (inst , contextlib .AbstractContextManager ):
346+ ctx_managers [param ] = inst
347+ elif isinstance (inst , contextlib .AbstractAsyncContextManager ):
348+ async_ctx_managers [param ] = inst
349+ else :
350+ kwargs [param ] = inst
341351 async_func = cast (Callable [..., Awaitable [T ]], func )
342352 try :
343- return await async_func (* args , ** kwargs )
353+ with contextlib .ExitStack () as sync_stack :
354+ ctx_kwargs = {param : sync_stack .enter_context (ctx_manager ) for param , ctx_manager in
355+ ctx_managers .items ()}
356+ kwargs .update (ctx_kwargs )
357+ async with contextlib .AsyncExitStack () as async_stack :
358+ asynx_ctx_kwargs = {param : await async_stack .enter_async_context (ctx_manager ) for param , ctx_manager in
359+ async_ctx_managers .items ()}
360+ kwargs .update (asynx_ctx_kwargs )
361+ return await async_func (* args , ** kwargs )
344362 except TypeError as previous_error :
345363 raise ConstructorTypeError (func , previous_error )
346364
@@ -350,12 +368,20 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
350368 def injection_wrapper (* args : Any , ** kwargs : Any ) -> T :
351369 provided_params = frozenset (
352370 arg_names [:len (args )]) | frozenset (kwargs .keys ())
371+ ctx_managers = {}
353372 for param , cls in params_to_provide .items ():
354373 if param not in provided_params :
355- kwargs [param ] = instance (cls )
374+ inst = instance (cls )
375+ if isinstance (inst , contextlib .AbstractContextManager ):
376+ ctx_managers [param ] = inst
377+ else :
378+ kwargs [param ] = inst
356379 sync_func = cast (Callable [..., T ], func )
357380 try :
358- return sync_func (* args , ** kwargs )
381+ with contextlib .ExitStack () as stack :
382+ ctx_kwargs = {param : stack .enter_context (ctx_manager ) for param , ctx_manager in ctx_managers .items ()}
383+ kwargs .update (ctx_kwargs )
384+ return sync_func (* args , ** kwargs )
359385 except TypeError as previous_error :
360386 raise ConstructorTypeError (func , previous_error )
361387 return injection_wrapper
0 commit comments