From 5faa8ae24cadd3e3ced6803862b42875a089f6c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolai=20Gro=C3=9Fer?= Date: Sat, 25 Nov 2023 19:17:02 +0100 Subject: [PATCH] added component scan feature --- README.md | 166 +++++------------- autowired/__init__.py | 3 + autowired/_component_scan.py | 72 ++++++++ autowired/_container.py | 41 +++++ tests/component_scan_test_module/__init__.py | 11 ++ .../component_scan_test_module/file_module.py | 15 ++ .../test_package/__init__.py | 13 ++ .../test_package/file_module.py | 20 +++ tests/test_component_scan.py | 46 +++++ 9 files changed, 268 insertions(+), 119 deletions(-) create mode 100644 autowired/_component_scan.py create mode 100644 tests/component_scan_test_module/__init__.py create mode 100644 tests/component_scan_test_module/file_module.py create mode 100644 tests/component_scan_test_module/test_package/__init__.py create mode 100644 tests/component_scan_test_module/test_package/file_module.py create mode 100644 tests/test_component_scan.py diff --git a/README.md b/README.md index 55e9ca6..85ca95a 100644 --- a/README.md +++ b/README.md @@ -629,7 +629,7 @@ class PluginB(Plugin): print("Plugin B") -class PluginController: +class PluginManager: def __init__(self, plugins: list[Plugin]): self.plugins = plugins @@ -639,7 +639,7 @@ class PluginController: class ApplicationContext(Context): - plugin_controller: PluginController = autowired() + plugin_manager: PluginManager = autowired() # usage @@ -649,150 +649,78 @@ ctx = ApplicationContext() ctx.container.add(PluginA()) ctx.container.add(PluginB()) -ctx.plugin_controller.run_all() +ctx.plugin_manager.run_all() ``` ---- - -## Example Application — FastAPI +### Component Scan -Although FastAPI already provides a powerful dependency injection mechanism, you might want to reuse your -autowired-based context classes. -The following example shows how to use autowired in a FastAPI application. -It does not aim to fully replace FastAPI's dependency injection, but rather demonstrates -how to seamlessly combine both approaches. +In many applications, you might want to automatically discover all components in a specific package. +Like list injection, this can be useful for implementing a plugin system. +Another common use case is to automatically discover all controllers in a web application to easily set up routing. +You can use the `@component` decorator to mark a class as a component. +When you call `component_scan()` on a container, it will automatically discover all components in the given package +and add them to the container. ```python -from dataclasses import dataclass -from autowired import Context, autowired, provided +# my_module/controllers/controller.py +from abc import ABC, abstractmethod -# Components - -@dataclass -class DatabaseService: - conn_str: str - - def load_allowed_tokens(self): - return ["123", "456", ""] - def get_user_name_by_id(self, user_id: int) -> str | None: - print(f"Loading user {user_id} from database {self.conn_str}") - d = {1: "John", 2: "Jane"} - return d.get(user_id) +class Controller(ABC): + @abstractmethod + def run(self): + ... -@dataclass -class UserService: - db_service: DatabaseService +# my_module/controllers/controller1.py +from autowired import component +from .controller import Controller - def get_user_name_by_id(self, user_id: int) -> str | None: - if user_id == 0: - return "admin" - return self.db_service.get_user_name_by_id(user_id) +@component +class Controller1(Controller): + def run(self): + print("Starting Controller 1") -@dataclass -class UserController: - user_service: UserService - def get_user(self, user_id: int) -> str: - user_name = self.user_service.get_user_name_by_id(user_id) - if user_name is None: - raise HTTPException(status_code=404, detail="User not found") +# my_module/controllers/controller2.py +from autowired import component +from .controller import Controller - return user_name +@component +class Controller2(Controller): + def run(self): + print("Starting Controller 2") -# Application Settings and Context +# my_module/main.py -@dataclass -class ApplicationSettings: - database_connection_string: str = "db://localhost" +from autowired import Context, autowired +import my_module.controllers -# Application Context +class ControllerManager: + def __init__(self, controllers: list[Controller]): + self.controllers = controllers + def start(self): + for controller in self.controllers: + controller.run() class ApplicationContext(Context): - user_controller: UserController = autowired() - database_service: DatabaseService = autowired( - lambda self: dict(conn_str=self.settings.database_connection_string) - ) - - def __init__(self, settings: ApplicationSettings = ApplicationSettings()): - self.settings = settings - - -from fastapi import FastAPI, Request, Depends, HTTPException - - -# Request Scoped Service for the FastAPI Application - - -@dataclass -class RequestAuthService: - db_service: DatabaseService - request: Request - - def is_authorised(self): - token = self.request.headers.get("Authorization") or "" - token = token.replace("Bearer ", "") - if token in self.db_service.load_allowed_tokens(): - return True - return False - - -# Request Context - - -class RequestContext(Context): - request_auth_service: RequestAuthService = autowired() - request: Request = provided() - - def __init__(self, parent_context: Context, request: Request): - self.derive_from(parent_context) - self.request = request - - -# Setting up the FastAPI Application - -app = FastAPI() -application_context = ApplicationContext() - - -def request_context(request: Request): - return RequestContext(application_context, request) - - -# We can seamlessly combine autowired's and FastAPIs dependency injection mechanisms -def request_auth_service(request_context: RequestContext = Depends(request_context)): - return request_context.request_auth_service - - -def user_controller(): - return application_context.user_controller - - -@app.get("/users/{user_id}") -def get_user( - user_id: int, - request_auth_service: RequestAuthService = Depends(request_auth_service), - user_controller=Depends(user_controller), -): - if request_auth_service.is_authorised(): - return user_controller.get_user(user_id=int(user_id)) - else: - return {"detail": "Not authorised"} + controller_manager: ControllerManager = autowired() + + def __init__(self): + # register all components from the my_module.controllers package + self.container.component_scan(my_module.controllers) -if __name__ == "__main__": - import uvicorn +# usage - uvicorn.run(app) +ctx = ApplicationContext() +ctx.controller_manager.start() - # http://127.0.0.1:8000/users/0 should now return "admin" ``` - diff --git a/autowired/__init__.py b/autowired/__init__.py index 7fc2e11..845e97e 100644 --- a/autowired/__init__.py +++ b/autowired/__init__.py @@ -16,9 +16,12 @@ "Provider", "autowired", "provided", + "component", + "Module", ] from functools import cached_property +from ._component_scan import component, Module from ._container import Container, Dependency, Provider from ._context import Context, autowired, provided diff --git a/autowired/_component_scan.py b/autowired/_component_scan.py new file mode 100644 index 0000000..832db51 --- /dev/null +++ b/autowired/_component_scan.py @@ -0,0 +1,72 @@ +import importlib +import inspect +import pkgutil +import sys +from dataclasses import dataclass +from typing import Type, Iterable, Set, Optional + +Module = type(sys) + + +@dataclass +class ClassComponentInfo: + cls: Type + transient: bool = False + + +def component(cls=None, *, transient: bool = False): + def wrap_class(cls_to_wrap): + cls_to_wrap._component_info = ClassComponentInfo(cls_to_wrap, transient) + return cls_to_wrap + + if cls is None: + # The decorator has been called like @component(transient=True) + # Return the wrapper function to apply later with the class + return wrap_class + else: + # The decorator has been called like @component without parentheses + # Apply the wrapper function directly to the class + return wrap_class(cls) + + +def get_component_info(cls) -> Optional[ClassComponentInfo]: + return getattr(cls, "_component_info", None) + + +class ClassScanner: + def __init__(self, root_module: Module): + if not inspect.ismodule(root_module): + raise TypeError(f"Expected a module, got {type(root_module)}") + self.root_module = root_module + + def _get_classes(self) -> Iterable[Type]: + for name, cls in inspect.getmembers(self.root_module, inspect.isclass): + if cls.__module__ == self.root_module.__name__: + yield cls + + path = self.root_module.__path__ + prefix = self.root_module.__name__ + "." + + for importer, modname, is_pkg in pkgutil.walk_packages(path, prefix): + sub_module = importlib.import_module(modname) + for name, cls in inspect.getmembers(sub_module, inspect.isclass): + if cls.__module__ == sub_module.__name__: + yield cls + + def get_classes(self) -> Iterable[Type]: + seen: Set[str] = set() + + def cls_to_key(cls: Type) -> str: + return f"{cls.__module__}.{cls.__name__}" + + for cls in self._get_classes(): + key = cls_to_key(cls) + if key not in seen: + seen.add(key) + yield cls + + +def component_scan(root_module: Module) -> Iterable[ClassComponentInfo]: + scanner = ClassScanner(root_module) + component_infos = (get_component_info(cls) for cls in scanner.get_classes()) + return (c for c in component_infos if c is not None) diff --git a/autowired/_container.py b/autowired/_container.py index 4df24a6..15565ce 100644 --- a/autowired/_container.py +++ b/autowired/_container.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from types import FunctionType from typing import Type, Callable, Any, List, Optional, Union, Generic, Dict, TypeVar +from ._component_scan import component_scan, Module from ._exceptions import ( MissingTypeAnnotation, @@ -111,6 +112,31 @@ def from_supplier( return _SimpleProvider(name, type, supplier) + @staticmethod + def from_class(cls, container: "Container", transient: bool) -> "Provider[_T]": + def supplier(): + return container.autowire(cls) + + if not transient: + supplier = _cached(supplier) + + return _SimpleProvider(_camel_to_snake(cls.__name__), cls, supplier) + + +def _cached(supplier: Callable[[], _T]) -> Callable[[], _T]: + cached = False + result = None + + def wrapper(): + nonlocal cached + nonlocal result + if not cached: + result = supplier() + cached = True + return result + + return wrapper + @dataclass(frozen=True) class _SimpleProvider(Provider[_T]): @@ -315,6 +341,21 @@ def autowire( except Exception as e: raise InstantiationError(f"Failed to initialize {t.__name__}") from e + # use component scan to add providers by type + + def component_scan(self, root_module: Module) -> None: + """ + Scans the given module and all submodules for classes annotated with `@component`. + Adds a singleton provider for each class found. + + :param root_module: The root module to scan + """ + for component_info in component_scan(root_module): + provider = Provider.from_class( + component_info.cls, self, component_info.transient + ) + self.add(provider) + # region utils diff --git a/tests/component_scan_test_module/__init__.py b/tests/component_scan_test_module/__init__.py new file mode 100644 index 0000000..53b00c1 --- /dev/null +++ b/tests/component_scan_test_module/__init__.py @@ -0,0 +1,11 @@ +from autowired import component +from tests.component_scan_test_module.file_module import FileComponentInitExposed + + +@component +class RootModuleComponent: + pass + + +class RootModuleNotComponent: + pass diff --git a/tests/component_scan_test_module/file_module.py b/tests/component_scan_test_module/file_module.py new file mode 100644 index 0000000..9acfe97 --- /dev/null +++ b/tests/component_scan_test_module/file_module.py @@ -0,0 +1,15 @@ +from autowired import component + + +@component +class FileComponent: + pass + + +@component +class FileComponentInitExposed: + pass + + +class FileNotComponent: + pass diff --git a/tests/component_scan_test_module/test_package/__init__.py b/tests/component_scan_test_module/test_package/__init__.py new file mode 100644 index 0000000..73ddbed --- /dev/null +++ b/tests/component_scan_test_module/test_package/__init__.py @@ -0,0 +1,13 @@ +from autowired import component +from tests.component_scan_test_module.test_package.file_module import ( + TestPackageFileComponentInitExposed, +) + + +@component +class TestPackageRootComponent: + pass + + +class TestPackageRootNotComponent: + pass diff --git a/tests/component_scan_test_module/test_package/file_module.py b/tests/component_scan_test_module/test_package/file_module.py new file mode 100644 index 0000000..2d5c60b --- /dev/null +++ b/tests/component_scan_test_module/test_package/file_module.py @@ -0,0 +1,20 @@ +from autowired import component + + +@component +class TestPackageFileComponent: + pass + + +@component +class TestPackageFileComponentInitExposed: + pass + + +@component(transient=True) +class TransientComponent: + pass + + +class TestPackageFileNotComponent: + pass diff --git a/tests/test_component_scan.py b/tests/test_component_scan.py new file mode 100644 index 0000000..53db3ce --- /dev/null +++ b/tests/test_component_scan.py @@ -0,0 +1,46 @@ +from typing import List + +import pytest + +import tests.component_scan_test_module +from autowired import Container, Dependency + + +def test_component_scan(): + container = Container() + + container.component_scan(tests.component_scan_test_module) + + providers = container.get_providers() + assert len(providers) == 7 + + components = container.resolve(Dependency("components", List[object])) + + assert len(components) == 7 + + component_class_names = set(type(c).__name__ for c in components) + + assert component_class_names == { + "FileComponent", + "FileComponentInitExposed", + "RootModuleComponent", + "TestPackageRootComponent", + "TestPackageFileComponent", + "TestPackageFileComponentInitExposed", + "TransientComponent", + } + + for provider in providers: + expect_singleton = provider.get_name() != "transient_component" + + instance1 = provider.get_instance(Dependency("instance1", object), container) + instance2 = provider.get_instance(Dependency("instance2", object), container) + + assert (instance1 is instance2) == expect_singleton + + +def test_component_scan_invalid_module(): + container = Container() + + with pytest.raises(TypeError): + container.component_scan("tests.component_scan_test_module")