Skip to content

Commit

Permalink
added support for list and tuple injection
Browse files Browse the repository at this point in the history
  • Loading branch information
npgrosser committed Nov 25, 2023
1 parent f31201c commit 7628c79
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 4 deletions.
15 changes: 14 additions & 1 deletion autowired/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AutowiredException,
)
from ._logging import logger
from ._typing_utils import is_subtype
from ._typing_utils import is_subtype, get_sequence_type

_T = TypeVar("_T")

Expand Down Expand Up @@ -245,6 +245,19 @@ def resolve(self, dependency: Union[Dependency, Type[_T]]) -> _T:

logger.trace(f"Existing not found, auto-wiring {dependency}")

# region list injection special case
# check if the dependency type is a list
sequence_type, element_type = get_sequence_type(dependency.type)
if element_type is not None:
element_type: Any
element_dependency = Dependency(dependency.name, element_type, True)
elements = []
for provider in self.get_providers(element_dependency):
elements.append(provider.get_instance(element_dependency, self))
return sequence_type(elements)

# endregion

result = self.autowire(dependency.type)

self.add(
Expand Down
32 changes: 31 additions & 1 deletion autowired/_typing_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import UnionType
from typing import Type, get_args, get_origin, Union, Any, Optional
from typing import Type, get_args, get_origin, Union, Any, Optional, List, Tuple


def is_subtype(t1: Type, t2: Type) -> bool:
Expand Down Expand Up @@ -77,3 +77,33 @@ def _as_union_types(t: Type | UnionType) -> tuple[Type, ...]:
if isinstance(t, UnionType) or get_origin(t) is Union:
return get_args(t)
return (t,)


def get_list_element_type(t: Type) -> Optional[Type]:
"""
Returns the type of the elements of a list type, or None if t is not a list type.
"""
origin = get_origin(t)
if origin is list or origin is List:
args = get_args(t)
if args:
return args[0]
return None


def get_sequence_type(t: Type) -> Union[Tuple[Type, Type], Tuple[None, None]]:
"""
Returns the type of the elements of a list type, or None if t is not a list type.
"""
origin = get_origin(t)
if origin is list or origin is List:
args = get_args(t)
if args:
return list, args[0]

if origin is tuple or origin is Tuple:
args = get_args(t)
if len(args) == 2 and args[1] is Ellipsis:
return tuple, args[0]

return None, None
58 changes: 57 additions & 1 deletion tests/test_autowired.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
from abc import ABC
from typing import List, Tuple

import pytest

Expand Down Expand Up @@ -879,3 +880,58 @@ def data(self) -> List[SomeData]:
assert isinstance(ctx.service, Service)
for data in ctx.service.data:
assert isinstance(data, SomeData)


def test_list_injection():
class Plugin(ABC):
pass

@dataclass
class PluginService:
plugins: List[Plugin]

class PluginA(Plugin):
pass

class PluginB(Plugin):
pass

def plugin_container():
container = Container()
container.add(PluginA())
container.add(PluginB())
return container

def assert_plugin_service(plugin_service, sequence_type):
assert isinstance(plugin_service, PluginService)
assert len(plugin_service.plugins) == 2
assert isinstance(plugin_service.plugins, sequence_type)
assert isinstance(plugin_service.plugins[0], PluginA)
assert isinstance(plugin_service.plugins[1], PluginB)

container = plugin_container()

plugin_service = container.resolve(PluginService)

assert_plugin_service(plugin_service, list)

# test with tuple

@dataclass
class PluginService:
plugins: Tuple[Plugin, ...]

container = plugin_container()

plugin_service = container.resolve(PluginService)
assert_plugin_service(plugin_service, tuple)

# test illegal tuple type
@dataclass
class PluginService:
plugins: Tuple[Plugin]

container = plugin_container()

with pytest.raises(UnresolvableDependencyException):
container.resolve(PluginService)
9 changes: 8 additions & 1 deletion tests/test_typing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Dict, Tuple, Union, Any, Set

# noinspection PyProtectedMember
from autowired._typing_utils import is_subtype
from autowired._typing_utils import is_subtype, get_list_element_type


def test_non_generic_types():
Expand Down Expand Up @@ -139,3 +139,10 @@ def union(*args):
# case 4: one is Any
assert is_subtype(Any, union(int, str)) is True
assert is_subtype(union(int, str), Any) is True


def test_get_list_type():
assert get_list_element_type(List[int]) == int
assert get_list_element_type(List) is None
assert get_list_element_type(List[object]) is object
assert get_list_element_type(int) is None

0 comments on commit 7628c79

Please sign in to comment.