From 7628c79e8b10eb8987421eec1aa2ad564532f5b6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nicolai=20Gro=C3=9Fer?= <nicolai.grosser@googlemail.com>
Date: Sat, 25 Nov 2023 14:30:31 +0100
Subject: [PATCH] added support for list and tuple injection

---
 autowired/_container.py    | 15 +++++++++-
 autowired/_typing_utils.py | 32 ++++++++++++++++++++-
 tests/test_autowired.py    | 58 +++++++++++++++++++++++++++++++++++++-
 tests/test_typing_utils.py |  9 +++++-
 4 files changed, 110 insertions(+), 4 deletions(-)

diff --git a/autowired/_container.py b/autowired/_container.py
index e1632de..6c97ed7 100644
--- a/autowired/_container.py
+++ b/autowired/_container.py
@@ -14,7 +14,7 @@
     AutowiredException,
 )
 from ._logging import logger
-from ._typing_utils import is_subtype
+from ._typing_utils import is_subtype, get_sequence_type
 
 _T = TypeVar("_T")
 
@@ -245,6 +245,19 @@ def resolve(self, dependency: Union[Dependency, Type[_T]]) -> _T:
 
         logger.trace(f"Existing not found, auto-wiring {dependency}")
 
+        # region list injection special case
+        # check if the dependency type is a list
+        sequence_type, element_type = get_sequence_type(dependency.type)
+        if element_type is not None:
+            element_type: Any
+            element_dependency = Dependency(dependency.name, element_type, True)
+            elements = []
+            for provider in self.get_providers(element_dependency):
+                elements.append(provider.get_instance(element_dependency, self))
+            return sequence_type(elements)
+
+        # endregion
+
         result = self.autowire(dependency.type)
 
         self.add(
diff --git a/autowired/_typing_utils.py b/autowired/_typing_utils.py
index 8a989e8..4b53b1e 100644
--- a/autowired/_typing_utils.py
+++ b/autowired/_typing_utils.py
@@ -1,5 +1,5 @@
 from types import UnionType
-from typing import Type, get_args, get_origin, Union, Any, Optional
+from typing import Type, get_args, get_origin, Union, Any, Optional, List, Tuple
 
 
 def is_subtype(t1: Type, t2: Type) -> bool:
@@ -77,3 +77,33 @@ def _as_union_types(t: Type | UnionType) -> tuple[Type, ...]:
     if isinstance(t, UnionType) or get_origin(t) is Union:
         return get_args(t)
     return (t,)
+
+
+def get_list_element_type(t: Type) -> Optional[Type]:
+    """
+    Returns the type of the elements of a list type, or None if t is not a list type.
+    """
+    origin = get_origin(t)
+    if origin is list or origin is List:
+        args = get_args(t)
+        if args:
+            return args[0]
+    return None
+
+
+def get_sequence_type(t: Type) -> Union[Tuple[Type, Type], Tuple[None, None]]:
+    """
+    Returns the type of the elements of a list type, or None if t is not a list type.
+    """
+    origin = get_origin(t)
+    if origin is list or origin is List:
+        args = get_args(t)
+        if args:
+            return list, args[0]
+
+    if origin is tuple or origin is Tuple:
+        args = get_args(t)
+        if len(args) == 2 and args[1] is Ellipsis:
+            return tuple, args[0]
+
+    return None, None
diff --git a/tests/test_autowired.py b/tests/test_autowired.py
index 92cd48f..e3672e0 100644
--- a/tests/test_autowired.py
+++ b/tests/test_autowired.py
@@ -1,4 +1,5 @@
-from typing import List
+from abc import ABC
+from typing import List, Tuple
 
 import pytest
 
@@ -879,3 +880,58 @@ def data(self) -> List[SomeData]:
     assert isinstance(ctx.service, Service)
     for data in ctx.service.data:
         assert isinstance(data, SomeData)
+
+
+def test_list_injection():
+    class Plugin(ABC):
+        pass
+
+    @dataclass
+    class PluginService:
+        plugins: List[Plugin]
+
+    class PluginA(Plugin):
+        pass
+
+    class PluginB(Plugin):
+        pass
+
+    def plugin_container():
+        container = Container()
+        container.add(PluginA())
+        container.add(PluginB())
+        return container
+
+    def assert_plugin_service(plugin_service, sequence_type):
+        assert isinstance(plugin_service, PluginService)
+        assert len(plugin_service.plugins) == 2
+        assert isinstance(plugin_service.plugins, sequence_type)
+        assert isinstance(plugin_service.plugins[0], PluginA)
+        assert isinstance(plugin_service.plugins[1], PluginB)
+
+    container = plugin_container()
+
+    plugin_service = container.resolve(PluginService)
+
+    assert_plugin_service(plugin_service, list)
+
+    # test with tuple
+
+    @dataclass
+    class PluginService:
+        plugins: Tuple[Plugin, ...]
+
+    container = plugin_container()
+
+    plugin_service = container.resolve(PluginService)
+    assert_plugin_service(plugin_service, tuple)
+
+    # test illegal tuple type
+    @dataclass
+    class PluginService:
+        plugins: Tuple[Plugin]
+
+    container = plugin_container()
+
+    with pytest.raises(UnresolvableDependencyException):
+        container.resolve(PluginService)
diff --git a/tests/test_typing_utils.py b/tests/test_typing_utils.py
index 51c0b48..e0d9b71 100644
--- a/tests/test_typing_utils.py
+++ b/tests/test_typing_utils.py
@@ -4,7 +4,7 @@
 from typing import List, Dict, Tuple, Union, Any, Set
 
 # noinspection PyProtectedMember
-from autowired._typing_utils import is_subtype
+from autowired._typing_utils import is_subtype, get_list_element_type
 
 
 def test_non_generic_types():
@@ -139,3 +139,10 @@ def union(*args):
     # case 4: one is Any
     assert is_subtype(Any, union(int, str)) is True
     assert is_subtype(union(int, str), Any) is True
+
+
+def test_get_list_type():
+    assert get_list_element_type(List[int]) == int
+    assert get_list_element_type(List) is None
+    assert get_list_element_type(List[object]) is object
+    assert get_list_element_type(int) is None