Skip to content

Commit

Permalink
added component scan feature
Browse files Browse the repository at this point in the history
  • Loading branch information
npgrosser committed Nov 25, 2023
1 parent bca6159 commit 5faa8ae
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 119 deletions.
166 changes: 47 additions & 119 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ class PluginB(Plugin):
print("Plugin B")


class PluginController:
class PluginManager:
def __init__(self, plugins: list[Plugin]):
self.plugins = plugins

Expand All @@ -639,7 +639,7 @@ class PluginController:


class ApplicationContext(Context):
plugin_controller: PluginController = autowired()
plugin_manager: PluginManager = autowired()


# usage
Expand All @@ -649,150 +649,78 @@ ctx = ApplicationContext()
ctx.container.add(PluginA())
ctx.container.add(PluginB())

ctx.plugin_controller.run_all()
ctx.plugin_manager.run_all()

```

---

## Example Application — FastAPI
### Component Scan

Although FastAPI already provides a powerful dependency injection mechanism, you might want to reuse your
autowired-based context classes.
The following example shows how to use autowired in a FastAPI application.
It does not aim to fully replace FastAPI's dependency injection, but rather demonstrates
how to seamlessly combine both approaches.
In many applications, you might want to automatically discover all components in a specific package.
Like list injection, this can be useful for implementing a plugin system.
Another common use case is to automatically discover all controllers in a web application to easily set up routing.
You can use the `@component` decorator to mark a class as a component.
When you call `component_scan()` on a container, it will automatically discover all components in the given package
and add them to the container.

```python
from dataclasses import dataclass
from autowired import Context, autowired, provided
# my_module/controllers/controller.py

from abc import ABC, abstractmethod

# Components

@dataclass
class DatabaseService:
conn_str: str

def load_allowed_tokens(self):
return ["123", "456", ""]

def get_user_name_by_id(self, user_id: int) -> str | None:
print(f"Loading user {user_id} from database {self.conn_str}")
d = {1: "John", 2: "Jane"}
return d.get(user_id)
class Controller(ABC):
@abstractmethod
def run(self):
...


@dataclass
class UserService:
db_service: DatabaseService
# my_module/controllers/controller1.py
from autowired import component
from .controller import Controller

def get_user_name_by_id(self, user_id: int) -> str | None:
if user_id == 0:
return "admin"
return self.db_service.get_user_name_by_id(user_id)

@component
class Controller1(Controller):
def run(self):
print("Starting Controller 1")

@dataclass
class UserController:
user_service: UserService

def get_user(self, user_id: int) -> str:
user_name = self.user_service.get_user_name_by_id(user_id)
if user_name is None:
raise HTTPException(status_code=404, detail="User not found")
# my_module/controllers/controller2.py
from autowired import component
from .controller import Controller

return user_name

@component
class Controller2(Controller):
def run(self):
print("Starting Controller 2")

# Application Settings and Context

# my_module/main.py

@dataclass
class ApplicationSettings:
database_connection_string: str = "db://localhost"
from autowired import Context, autowired
import my_module.controllers


# Application Context
class ControllerManager:
def __init__(self, controllers: list[Controller]):
self.controllers = controllers

def start(self):
for controller in self.controllers:
controller.run()

class ApplicationContext(Context):
user_controller: UserController = autowired()
database_service: DatabaseService = autowired(
lambda self: dict(conn_str=self.settings.database_connection_string)
)

def __init__(self, settings: ApplicationSettings = ApplicationSettings()):
self.settings = settings


from fastapi import FastAPI, Request, Depends, HTTPException


# Request Scoped Service for the FastAPI Application


@dataclass
class RequestAuthService:
db_service: DatabaseService
request: Request

def is_authorised(self):
token = self.request.headers.get("Authorization") or ""
token = token.replace("Bearer ", "")
if token in self.db_service.load_allowed_tokens():
return True
return False


# Request Context


class RequestContext(Context):
request_auth_service: RequestAuthService = autowired()
request: Request = provided()

def __init__(self, parent_context: Context, request: Request):
self.derive_from(parent_context)
self.request = request


# Setting up the FastAPI Application

app = FastAPI()
application_context = ApplicationContext()


def request_context(request: Request):
return RequestContext(application_context, request)


# We can seamlessly combine autowired's and FastAPIs dependency injection mechanisms
def request_auth_service(request_context: RequestContext = Depends(request_context)):
return request_context.request_auth_service


def user_controller():
return application_context.user_controller


@app.get("/users/{user_id}")
def get_user(
user_id: int,
request_auth_service: RequestAuthService = Depends(request_auth_service),
user_controller=Depends(user_controller),
):
if request_auth_service.is_authorised():
return user_controller.get_user(user_id=int(user_id))
else:
return {"detail": "Not authorised"}
controller_manager: ControllerManager = autowired()

def __init__(self):
# register all components from the my_module.controllers package
self.container.component_scan(my_module.controllers)


if __name__ == "__main__":
import uvicorn
# usage

uvicorn.run(app)
ctx = ApplicationContext()
ctx.controller_manager.start()

# http://127.0.0.1:8000/users/0 should now return "admin"
```

3 changes: 3 additions & 0 deletions autowired/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
"Provider",
"autowired",
"provided",
"component",
"Module",
]

from functools import cached_property
from ._component_scan import component, Module

from ._container import Container, Dependency, Provider
from ._context import Context, autowired, provided
Expand Down
72 changes: 72 additions & 0 deletions autowired/_component_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import importlib
import inspect
import pkgutil
import sys
from dataclasses import dataclass
from typing import Type, Iterable, Set, Optional

Module = type(sys)


@dataclass
class ClassComponentInfo:
cls: Type
transient: bool = False


def component(cls=None, *, transient: bool = False):
def wrap_class(cls_to_wrap):
cls_to_wrap._component_info = ClassComponentInfo(cls_to_wrap, transient)
return cls_to_wrap

if cls is None:
# The decorator has been called like @component(transient=True)
# Return the wrapper function to apply later with the class
return wrap_class
else:
# The decorator has been called like @component without parentheses
# Apply the wrapper function directly to the class
return wrap_class(cls)


def get_component_info(cls) -> Optional[ClassComponentInfo]:
return getattr(cls, "_component_info", None)


class ClassScanner:
def __init__(self, root_module: Module):
if not inspect.ismodule(root_module):
raise TypeError(f"Expected a module, got {type(root_module)}")
self.root_module = root_module

def _get_classes(self) -> Iterable[Type]:
for name, cls in inspect.getmembers(self.root_module, inspect.isclass):
if cls.__module__ == self.root_module.__name__:
yield cls

path = self.root_module.__path__
prefix = self.root_module.__name__ + "."

for importer, modname, is_pkg in pkgutil.walk_packages(path, prefix):
sub_module = importlib.import_module(modname)
for name, cls in inspect.getmembers(sub_module, inspect.isclass):
if cls.__module__ == sub_module.__name__:
yield cls

def get_classes(self) -> Iterable[Type]:
seen: Set[str] = set()

def cls_to_key(cls: Type) -> str:
return f"{cls.__module__}.{cls.__name__}"

for cls in self._get_classes():
key = cls_to_key(cls)
if key not in seen:
seen.add(key)
yield cls


def component_scan(root_module: Module) -> Iterable[ClassComponentInfo]:
scanner = ClassScanner(root_module)
component_infos = (get_component_info(cls) for cls in scanner.get_classes())
return (c for c in component_infos if c is not None)
41 changes: 41 additions & 0 deletions autowired/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass, field
from types import FunctionType
from typing import Type, Callable, Any, List, Optional, Union, Generic, Dict, TypeVar
from ._component_scan import component_scan, Module

from ._exceptions import (
MissingTypeAnnotation,
Expand Down Expand Up @@ -111,6 +112,31 @@ def from_supplier(

return _SimpleProvider(name, type, supplier)

@staticmethod
def from_class(cls, container: "Container", transient: bool) -> "Provider[_T]":
def supplier():
return container.autowire(cls)

if not transient:
supplier = _cached(supplier)

return _SimpleProvider(_camel_to_snake(cls.__name__), cls, supplier)


def _cached(supplier: Callable[[], _T]) -> Callable[[], _T]:
cached = False
result = None

def wrapper():
nonlocal cached
nonlocal result
if not cached:
result = supplier()
cached = True
return result

return wrapper


@dataclass(frozen=True)
class _SimpleProvider(Provider[_T]):
Expand Down Expand Up @@ -315,6 +341,21 @@ def autowire(
except Exception as e:
raise InstantiationError(f"Failed to initialize {t.__name__}") from e

# use component scan to add providers by type

def component_scan(self, root_module: Module) -> None:
"""
Scans the given module and all submodules for classes annotated with `@component`.
Adds a singleton provider for each class found.
:param root_module: The root module to scan
"""
for component_info in component_scan(root_module):
provider = Provider.from_class(
component_info.cls, self, component_info.transient
)
self.add(provider)


# region utils

Expand Down
11 changes: 11 additions & 0 deletions tests/component_scan_test_module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from autowired import component
from tests.component_scan_test_module.file_module import FileComponentInitExposed


@component
class RootModuleComponent:
pass


class RootModuleNotComponent:
pass
15 changes: 15 additions & 0 deletions tests/component_scan_test_module/file_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from autowired import component


@component
class FileComponent:
pass


@component
class FileComponentInitExposed:
pass


class FileNotComponent:
pass
Loading

0 comments on commit 5faa8ae

Please sign in to comment.