Skip to content

Commit

Permalink
Refactor code to use composition instead of inheritence
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Ploski committed Mar 31, 2023
1 parent bc45703 commit 5668cba
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 91 deletions.
166 changes: 81 additions & 85 deletions aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
import logging
from itertools import groupby
from typing import Any, Callable, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
from aws_lambda_powertools.utilities.typing import LambdaContext
Expand All @@ -10,23 +11,44 @@

class RouterContext:
def __init__(self):
super().__init__()
self.context = {}
self._context = {}

def append_context(self, **additional_context):
@property
def context(self) -> Dict[str, Any]:
return self._context

@context.setter
def context(self, additional_context: Dict[str, Any]) -> None:
"""Append key=value data as routing context"""
self.context.update(**additional_context)
self._context.update(**additional_context)

def clear_context(self):
@context.deleter
def context(self):
"""Resets routing context"""
self.context.clear()
self._context.clear()


class IResolverRegistry(ABC):
@abstractmethod
def resolver(self, type_name: str = "*", field_name: Optional[str] = None) -> Callable:
...

@abstractmethod
def find_resolver(self, type_name: str, field_name: str) -> Callable:
...

class ResolverRegistry:

class ResolverRegistry(IResolverRegistry):
def __init__(self):
super().__init__()
self._resolvers: dict = {}
self._batch_resolvers: dict = {}
self._resolvers: Dict[str, Dict[str, Any]] = {}

@property
def resolvers(self) -> Dict[str, Dict[str, Any]]:
return self._resolvers

@resolvers.setter
def resolvers(self, resolvers: dict) -> None:
self._resolvers.update(resolvers)

def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
"""Registers the resolver for field_name
Expand All @@ -46,26 +68,15 @@ def register(func):

return register

def batch_resolver(self, type_name: str = "*", field_name: Optional[str] = None):
"""Registers the resolver for field_name
Parameters
----------
type_name : str
Type name
field_name : str
Field name
"""

def register(func):
logger.debug(f"Adding batch resolver `{func.__name__}` for field `{type_name}.{field_name}`")
self._batch_resolvers[f"{type_name}.{field_name}"] = {"func": func}
return func

return register
def find_resolver(self, type_name: str, field_name: str) -> Callable:
full_name = f"{type_name}.{field_name}"
resolver = self._resolvers.get(full_name, self._resolvers.get(f"*.{field_name}"))
if not resolver:
raise ValueError(f"No resolver found for '{full_name}'")
return resolver["func"]


class AppSyncResolver(ResolverRegistry, RouterContext):
class AppSyncResolver:
"""
AppSync resolver decorator
Expand Down Expand Up @@ -97,17 +108,20 @@ def common_field() -> str:
"""

def __init__(self):
super().__init__()
self._resolver_registry: IResolverRegistry = ResolverRegistry()
self._batch_resolver_registry: IResolverRegistry = ResolverRegistry()
self._router_context: RouterContext = RouterContext()
self.current_batch_event: List[AppSyncResolverEvent] = []
self.current_event: Optional[AppSyncResolverEvent] = None
self.lambda_context: Optional[LambdaContext] = None

def resolve(
self,
event: Union[dict, List[dict]],
event: Union[Dict[str, Any], List[Dict[str, Any]]],
context: LambdaContext,
data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent,
) -> Any:
"""Resolve field_name
"""Resolve field_name in single event or in a batch event
Parameters
----------
Expand Down Expand Up @@ -180,17 +194,17 @@ def lambda_handler(event, context):
self.lambda_context = context

response = (
self._call_batch_resolver(event, data_model)
self._call_batch_resolver(event=event, data_model=data_model)
if isinstance(event, list)
else self._call_resolver(event, data_model)
else self._call_single_resolver(event=event, data_model=data_model)
)
self.clear_context()
del self._router_context.context

return response

def _call_resolver(self, event: dict, data_model: Type[AppSyncResolverEvent]) -> Any:
def _call_single_resolver(self, event: dict, data_model: Type[AppSyncResolverEvent]) -> Any:
self.current_event = data_model(event)
resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name)
resolver = self._resolver_registry.find_resolver(self.current_event.type_name, self.current_event.field_name)
return resolver(**self.current_event.arguments)

def _call_batch_resolver(self, event: List[dict], data_model: Type[AppSyncResolverEvent]) -> List[Any]:
Expand All @@ -202,54 +216,12 @@ def _call_batch_resolver(self, event: List[dict], data_model: Type[AppSyncResolv
ValueError("batch with different field names. It shouldn't happen!")

self.current_batch_event = [data_model(event) for event in event_groups[0]["events"]]
resolver = self._get_batch_resolver(
resolver = self._batch_resolver_registry.find_resolver(
self.current_batch_event[0].type_name, self.current_batch_event[0].field_name
)

return [resolver(event=appconfig_event) for appconfig_event in self.current_batch_event]

def _get_resolver(self, type_name: str, field_name: str) -> Callable:
"""Get resolver for field_name
Parameters
----------
type_name : str
Type name
field_name : str
Field name
Returns
-------
Callable
callable function and configuration
"""
full_name = f"{type_name}.{field_name}"
resolver = self._resolvers.get(full_name, self._resolvers.get(f"*.{field_name}"))
if not resolver:
raise ValueError(f"No resolver found for '{full_name}'")
return resolver["func"]

def _get_batch_resolver(self, type_name: str, field_name: str) -> Callable:
"""Get resolver for field_name
Parameters
----------
type_name : str
Type name
field_name : str
Field name
Returns
-------
Callable
callable function and configuration
"""
full_name = f"{type_name}.{field_name}"
resolver = self._batch_resolvers.get(full_name, self._batch_resolvers.get(f"*.{field_name}"))
if not resolver:
raise ValueError(f"No batch resolver found for '{full_name}'")
return resolver["func"]

def __call__(
self,
event: Union[dict, List[dict]],
Expand All @@ -267,14 +239,38 @@ def include_router(self, router: "Router") -> None:
router : Router
A router containing a dict of field resolvers
"""

# Merge app and router context
self.context.update(**router.context)
self._router_context.context = router._router_context.context
# use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx)
router.context = self.context
router._router_context._context = self._router_context.context

self._resolvers.update(router._resolvers)
self._resolver_registry.resolvers = router._resolver_registry.resolvers
self._batch_resolver_registry.resolvers = router._batch_resolver_registry.resolvers

# Interfaces
def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
return self._resolver_registry.resolver(field_name=field_name, type_name=type_name)

class Router(RouterContext, ResolverRegistry):
def batch_resolver(self, type_name: str = "*", field_name: Optional[str] = None):
return self._batch_resolver_registry.resolver(field_name=field_name, type_name=type_name)

def append_context(self, **additional_context) -> None:
self._router_context.context = additional_context


class Router:
def __init__(self):
super().__init__()
self._resolver_registry = ResolverRegistry()
self._batch_resolver_registry = ResolverRegistry()
self._router_context = RouterContext()

# Interfaces
def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
return self._resolver_registry.resolver(field_name=field_name, type_name=type_name)

def batch_resolver(self, type_name: str = "*", field_name: Optional[str] = None):
return self._batch_resolver_registry.resolver(field_name=field_name, type_name=type_name)

def append_context(self, **additional_context) -> None:
self._router_context.context = additional_context
115 changes: 109 additions & 6 deletions tests/functional/event_handler/test_appsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,119 @@ def get_locations2(name: str):
assert result2 == "get_locations2#value"


def test_resolver_include_batch_resolver():
# GIVEN
app = AppSyncResolver()
router = Router()

@router.batch_resolver(type_name="Query", field_name="listLocations")
def get_locations(event: AppSyncResolverEvent) -> str:
return "get_locations#" + event.arguments["name"]

@app.batch_resolver(field_name="listLocations2")
def get_locations2(event: AppSyncResolverEvent) -> str:
return "get_locations2#" + event.arguments["name"]

app.include_router(router)

# WHEN
mock_event1 = [
{
"typeName": "Query",
"info": {
"fieldName": "listLocations",
"parentTypeName": "Query",
},
"fieldName": "listLocations",
"arguments": {"name": "value"},
"source": {
"id": "1",
},
}
]
mock_event2 = [
{
"typeName": "Query",
"info": {
"fieldName": "listLocations2",
"parentTypeName": "Post",
},
"fieldName": "listLocations2",
"arguments": {"name": "value"},
"source": {
"id": "2",
},
}
]
result1 = app.resolve(mock_event1, LambdaContext())
result2 = app.resolve(mock_event2, LambdaContext())

# THEN
assert result1 == ["get_locations#value"]
assert result2 == ["get_locations2#value"]


def test_resolver_include_mixed_resolver():
# GIVEN
app = AppSyncResolver()
router = Router()

@router.batch_resolver(type_name="Query", field_name="listLocations")
def get_locations(event: AppSyncResolverEvent) -> str:
return "get_locations#" + event.arguments["name"]

@app.resolver(field_name="listLocations2")
def get_locations2(name: str) -> str:
return "get_locations2#" + name

app.include_router(router)

# WHEN
mock_event1 = [
{
"typeName": "Query",
"info": {
"fieldName": "listLocations",
"parentTypeName": "Query",
},
"fieldName": "listLocations",
"arguments": {"name": "value"},
"source": {
"id": "1",
},
}
]
mock_event2 = {
"typeName": "Query",
"info": {
"fieldName": "listLocations2",
"parentTypeName": "Post",
},
"fieldName": "listLocations2",
"arguments": {"name": "value"},
"source": {
"id": "2",
},
}

result1 = app.resolve(mock_event1, LambdaContext())
result2 = app.resolve(mock_event2, LambdaContext())

# THEN
assert result1 == ["get_locations#value"]
assert result2 == "get_locations2#value"


def test_append_context():
app = AppSyncResolver()
app.append_context(is_admin=True)
assert app.context.get("is_admin") is True
assert app._router_context.context.get("is_admin") is True


def test_router_append_context():
router = Router()
router.append_context(is_admin=True)
assert router.context.get("is_admin") is True
assert router._router_context.context.get("is_admin") is True


def test_route_context_is_cleared_after_resolve():
Expand All @@ -271,7 +374,7 @@ def get_locations(name: str):
app.resolve(event, {})

# THEN context should be empty
assert app.context == {}
assert app._router_context.context == {}


def test_router_has_access_to_app_context():
Expand All @@ -282,7 +385,7 @@ def test_router_has_access_to_app_context():

@router.resolver(type_name="Query", field_name="listLocations")
def get_locations(name: str):
if router.context["is_admin"]:
if router._router_context.context.get("is_admin"):
return f"get_locations#{name}"

app.include_router(router)
Expand All @@ -293,7 +396,7 @@ def get_locations(name: str):

# THEN
assert ret == "get_locations#value"
assert router.context == {}
assert router._router_context.context == {}


def test_include_router_merges_context():
Expand All @@ -307,4 +410,4 @@ def test_include_router_merges_context():

app.include_router(router)

assert app.context == router.context
assert app._router_context.context == router._router_context.context

0 comments on commit 5668cba

Please sign in to comment.