From 7628c79e8b10eb8987421eec1aa2ad564532f5b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolai=20Gro=C3=9Fer?= <nicolai.grosser@googlemail.com> Date: Sat, 25 Nov 2023 14:30:31 +0100 Subject: [PATCH] added support for list and tuple injection --- autowired/_container.py | 15 +++++++++- autowired/_typing_utils.py | 32 ++++++++++++++++++++- tests/test_autowired.py | 58 +++++++++++++++++++++++++++++++++++++- tests/test_typing_utils.py | 9 +++++- 4 files changed, 110 insertions(+), 4 deletions(-) diff --git a/autowired/_container.py b/autowired/_container.py index e1632de..6c97ed7 100644 --- a/autowired/_container.py +++ b/autowired/_container.py @@ -14,7 +14,7 @@ AutowiredException, ) from ._logging import logger -from ._typing_utils import is_subtype +from ._typing_utils import is_subtype, get_sequence_type _T = TypeVar("_T") @@ -245,6 +245,19 @@ def resolve(self, dependency: Union[Dependency, Type[_T]]) -> _T: logger.trace(f"Existing not found, auto-wiring {dependency}") + # region list injection special case + # check if the dependency type is a list + sequence_type, element_type = get_sequence_type(dependency.type) + if element_type is not None: + element_type: Any + element_dependency = Dependency(dependency.name, element_type, True) + elements = [] + for provider in self.get_providers(element_dependency): + elements.append(provider.get_instance(element_dependency, self)) + return sequence_type(elements) + + # endregion + result = self.autowire(dependency.type) self.add( diff --git a/autowired/_typing_utils.py b/autowired/_typing_utils.py index 8a989e8..4b53b1e 100644 --- a/autowired/_typing_utils.py +++ b/autowired/_typing_utils.py @@ -1,5 +1,5 @@ from types import UnionType -from typing import Type, get_args, get_origin, Union, Any, Optional +from typing import Type, get_args, get_origin, Union, Any, Optional, List, Tuple def is_subtype(t1: Type, t2: Type) -> bool: @@ -77,3 +77,33 @@ def _as_union_types(t: Type | UnionType) -> tuple[Type, ...]: if isinstance(t, UnionType) or get_origin(t) is Union: return get_args(t) return (t,) + + +def get_list_element_type(t: Type) -> Optional[Type]: + """ + Returns the type of the elements of a list type, or None if t is not a list type. + """ + origin = get_origin(t) + if origin is list or origin is List: + args = get_args(t) + if args: + return args[0] + return None + + +def get_sequence_type(t: Type) -> Union[Tuple[Type, Type], Tuple[None, None]]: + """ + Returns the type of the elements of a list type, or None if t is not a list type. + """ + origin = get_origin(t) + if origin is list or origin is List: + args = get_args(t) + if args: + return list, args[0] + + if origin is tuple or origin is Tuple: + args = get_args(t) + if len(args) == 2 and args[1] is Ellipsis: + return tuple, args[0] + + return None, None diff --git a/tests/test_autowired.py b/tests/test_autowired.py index 92cd48f..e3672e0 100644 --- a/tests/test_autowired.py +++ b/tests/test_autowired.py @@ -1,4 +1,5 @@ -from typing import List +from abc import ABC +from typing import List, Tuple import pytest @@ -879,3 +880,58 @@ def data(self) -> List[SomeData]: assert isinstance(ctx.service, Service) for data in ctx.service.data: assert isinstance(data, SomeData) + + +def test_list_injection(): + class Plugin(ABC): + pass + + @dataclass + class PluginService: + plugins: List[Plugin] + + class PluginA(Plugin): + pass + + class PluginB(Plugin): + pass + + def plugin_container(): + container = Container() + container.add(PluginA()) + container.add(PluginB()) + return container + + def assert_plugin_service(plugin_service, sequence_type): + assert isinstance(plugin_service, PluginService) + assert len(plugin_service.plugins) == 2 + assert isinstance(plugin_service.plugins, sequence_type) + assert isinstance(plugin_service.plugins[0], PluginA) + assert isinstance(plugin_service.plugins[1], PluginB) + + container = plugin_container() + + plugin_service = container.resolve(PluginService) + + assert_plugin_service(plugin_service, list) + + # test with tuple + + @dataclass + class PluginService: + plugins: Tuple[Plugin, ...] + + container = plugin_container() + + plugin_service = container.resolve(PluginService) + assert_plugin_service(plugin_service, tuple) + + # test illegal tuple type + @dataclass + class PluginService: + plugins: Tuple[Plugin] + + container = plugin_container() + + with pytest.raises(UnresolvableDependencyException): + container.resolve(PluginService) diff --git a/tests/test_typing_utils.py b/tests/test_typing_utils.py index 51c0b48..e0d9b71 100644 --- a/tests/test_typing_utils.py +++ b/tests/test_typing_utils.py @@ -4,7 +4,7 @@ from typing import List, Dict, Tuple, Union, Any, Set # noinspection PyProtectedMember -from autowired._typing_utils import is_subtype +from autowired._typing_utils import is_subtype, get_list_element_type def test_non_generic_types(): @@ -139,3 +139,10 @@ def union(*args): # case 4: one is Any assert is_subtype(Any, union(int, str)) is True assert is_subtype(union(int, str), Any) is True + + +def test_get_list_type(): + assert get_list_element_type(List[int]) == int + assert get_list_element_type(List) is None + assert get_list_element_type(List[object]) is object + assert get_list_element_type(int) is None