@@ -73,6 +73,8 @@ def my_config(binder):
7373 inject.configure(my_config)
7474
7575"""
76+ import contextlib
77+
7678from inject ._version import __version__
7779
7880import inspect
@@ -156,7 +158,10 @@ def bind_to_constructor(self, cls: Binding, constructor: Constructor) -> 'Binder
156158 return self
157159
158160 def bind_to_provider (self , cls : Binding , provider : Provider ) -> 'Binder' :
159- """Bind a class to a callable instance provider executed for each injection."""
161+ """
162+ Bind a class to a callable instance provider executed for each injection.
163+ A provider can be a normal function or a context manager. Both sync and async are supported.
164+ """
160165 self ._check_class (cls )
161166 if provider is None :
162167 raise InjectorException ('Provider cannot be None, key=%s' % cls )
@@ -323,6 +328,35 @@ class _ParametersInjection(Generic[T]):
323328 def __init__ (self , ** kwargs : Any ) -> None :
324329 self ._params = kwargs
325330
331+ @staticmethod
332+ def _aggregate_sync_stack (
333+ sync_stack : contextlib .ExitStack ,
334+ provided_params : frozenset [str ],
335+ kwargs : dict [str , Any ]
336+ ) -> None :
337+ """Extracts context managers, aggregate them in an ExitStack and swap out the param value with results of
338+ running __enter__(). The result is equivalent to using `with` multiple times """
339+ executed_kwargs = {
340+ param : sync_stack .enter_context (inst )
341+ for param , inst in kwargs .items ()
342+ if param not in provided_params and isinstance (inst , contextlib ._GeneratorContextManager )
343+ }
344+ kwargs .update (executed_kwargs )
345+
346+ @staticmethod
347+ async def _aggregate_async_stack (
348+ async_stack : contextlib .AsyncExitStack ,
349+ provided_params : frozenset [str ],
350+ kwargs : dict [str , Any ]
351+ ) -> None :
352+ """Similar to _aggregate_sync_stack, but for async context managers"""
353+ executed_kwargs = {
354+ param : await async_stack .enter_async_context (inst )
355+ for param , inst in kwargs .items ()
356+ if param not in provided_params and isinstance (inst , contextlib ._AsyncGeneratorContextManager )
357+ }
358+ kwargs .update (executed_kwargs )
359+
326360 def __call__ (self , func : Callable [..., Union [Awaitable [T ], T ]]) -> Callable [..., Union [Awaitable [T ], T ]]:
327361 if sys .version_info .major == 2 :
328362 arg_names = inspect .getargspec (func ).args
@@ -340,7 +374,11 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
340374 kwargs [param ] = instance (cls )
341375 async_func = cast (Callable [..., Awaitable [T ]], func )
342376 try :
343- return await async_func (* args , ** kwargs )
377+ with contextlib .ExitStack () as sync_stack :
378+ async with contextlib .AsyncExitStack () as async_stack :
379+ self ._aggregate_sync_stack (sync_stack , provided_params , kwargs )
380+ await self ._aggregate_async_stack (async_stack , provided_params , kwargs )
381+ return await async_func (* args , ** kwargs )
344382 except TypeError as previous_error :
345383 raise ConstructorTypeError (func , previous_error )
346384
@@ -355,7 +393,9 @@ def injection_wrapper(*args: Any, **kwargs: Any) -> T:
355393 kwargs [param ] = instance (cls )
356394 sync_func = cast (Callable [..., T ], func )
357395 try :
358- return sync_func (* args , ** kwargs )
396+ with contextlib .ExitStack () as sync_stack :
397+ self ._aggregate_sync_stack (sync_stack , provided_params , kwargs )
398+ return sync_func (* args , ** kwargs )
359399 except TypeError as previous_error :
360400 raise ConstructorTypeError (func , previous_error )
361401 return injection_wrapper
0 commit comments