11import logging
22from itertools import groupby
3- from typing import Any , Callable , List , Optional , Type , TypeVar , Union
3+ from typing import Any , Callable , List , Optional , Type , Union
44
55from aws_lambda_powertools .utilities .data_classes import AppSyncResolverEvent
66from aws_lambda_powertools .utilities .typing import LambdaContext
77
88logger = logging .getLogger (__name__ )
99
10- AppSyncResolverEventT = TypeVar ("AppSyncResolverEventT" , bound = AppSyncResolverEvent )
1110
11+ class RouterContext :
12+ def __init__ (self ):
13+ super ().__init__ ()
14+ self .context = {}
1215
13- class BaseRouter :
14- current_event : Union [AppSyncResolverEventT , List [AppSyncResolverEventT ]] # type: ignore[valid-type]
15- lambda_context : LambdaContext
16- context : dict
16+ def append_context (self , ** additional_context ):
17+ """Append key=value data as routing context"""
18+ self .context .update (** additional_context )
1719
20+ def clear_context (self ):
21+ """Resets routing context"""
22+ self .context .clear ()
23+
24+
25+ class ResolverRegistry :
1826 def __init__ (self ):
27+ super ().__init__ ()
1928 self ._resolvers : dict = {}
29+ self ._batch_resolvers : dict = {}
2030
2131 def resolver (self , type_name : str = "*" , field_name : Optional [str ] = None ):
2232 """Registers the resolver for field_name
@@ -29,23 +39,33 @@ def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
2939 Field name
3040 """
3141
32- def register_resolver (func ):
42+ def register (func ):
3343 logger .debug (f"Adding resolver `{ func .__name__ } ` for field `{ type_name } .{ field_name } `" )
3444 self ._resolvers [f"{ type_name } .{ field_name } " ] = {"func" : func }
3545 return func
3646
37- return register_resolver
47+ return register
3848
39- def append_context (self , ** additional_context ):
40- """Append key=value data as routing context"""
41- self .context .update (** additional_context )
49+ def batch_resolver (self , type_name : str = "*" , field_name : Optional [str ] = None ):
50+ """Registers the resolver for field_name
4251
43- def clear_context (self ):
44- """Resets routing context"""
45- self .context .clear ()
52+ Parameters
53+ ----------
54+ type_name : str
55+ Type name
56+ field_name : str
57+ Field name
58+ """
4659
60+ def register (func ):
61+ logger .debug (f"Adding batch resolver `{ func .__name__ } ` for field `{ type_name } .{ field_name } `" )
62+ self ._batch_resolvers [f"{ type_name } .{ field_name } " ] = {"func" : func }
63+ return func
4764
48- class AppSyncResolver (BaseRouter ):
65+ return register
66+
67+
68+ class AppSyncResolver (ResolverRegistry , RouterContext ):
4969 """
5070 AppSync resolver decorator
5171
@@ -78,16 +98,20 @@ def common_field() -> str:
7898
7999 def __init__ (self ):
80100 super ().__init__ ()
81- self .context = {} # early init as customers might add context before event resolution
101+ self .current_batch_event : List [AppSyncResolverEvent ] = []
102+ self .current_event : Optional [AppSyncResolverEvent ] = None
82103
83104 def resolve (
84- self , event : dict , context : LambdaContext , data_model : Type [AppSyncResolverEvent ] = AppSyncResolverEvent
105+ self ,
106+ event : Union [dict , List [dict ]],
107+ context : LambdaContext ,
108+ data_model : Type [AppSyncResolverEvent ] = AppSyncResolverEvent ,
85109 ) -> Any :
86110 """Resolve field_name
87111
88112 Parameters
89113 ----------
90- event : dict
114+ event : dict | List[dict]
91115 Lambda event
92116 context : LambdaContext
93117 Lambda context
@@ -152,33 +176,38 @@ def lambda_handler(event, context):
152176 ValueError
153177 If we could not find a field resolver
154178 """
155- # Maintenance: revisit generics/overload to fix [attr-defined] in mypy usage
156-
157- BaseRouter .lambda_context = context
158-
159- # If event is a list it means that AppSync sent batch request
160- if isinstance (event , list ):
161- event_groups = [
162- {"field_name" : field_name , "events" : list (events )}
163- for field_name , events in groupby (event , key = lambda x : x ["info" ]["fieldName" ])
164- ]
165- if len (event_groups ) > 1 :
166- ValueError ("batch with different field names. It shouldn't happen!" )
167-
168- appconfig_events = [data_model (event ) for event in event_groups [0 ]["events" ]]
169- BaseRouter .current_event = appconfig_events
170- resolver = self ._get_resolver (appconfig_events [0 ].type_name , event_groups [0 ]["field_name" ])
171- response = resolver ()
172- else :
173- appconfig_event = data_model (event )
174- BaseRouter .current_event = appconfig_event
175- resolver = self ._get_resolver (appconfig_event .type_name , appconfig_event .field_name )
176- response = resolver (** appconfig_event .arguments )
177179
180+ self .lambda_context = context
181+
182+ response = (
183+ self ._call_batch_resolver (event , data_model )
184+ if isinstance (event , list )
185+ else self ._call_resolver (event , data_model )
186+ )
178187 self .clear_context ()
179188
180189 return response
181190
191+ def _call_resolver (self , event : dict , data_model : Type [AppSyncResolverEvent ]) -> Any :
192+ self .current_event = data_model (event )
193+ resolver = self ._get_resolver (self .current_event .type_name , self .current_event .field_name )
194+ return resolver (** self .current_event .arguments )
195+
196+ def _call_batch_resolver (self , event : List [dict ], data_model : Type [AppSyncResolverEvent ]) -> List [Any ]:
197+ event_groups = [
198+ {"field_name" : field_name , "events" : list (events )}
199+ for field_name , events in groupby (event , key = lambda x : x ["info" ]["fieldName" ])
200+ ]
201+ if len (event_groups ) > 1 :
202+ ValueError ("batch with different field names. It shouldn't happen!" )
203+
204+ self .current_batch_event = [data_model (event ) for event in event_groups [0 ]["events" ]]
205+ resolver = self ._get_batch_resolver (
206+ self .current_batch_event [0 ].type_name , self .current_batch_event [0 ].field_name
207+ )
208+
209+ return [resolver (event = appconfig_event ) for appconfig_event in self .current_batch_event ]
210+
182211 def _get_resolver (self , type_name : str , field_name : str ) -> Callable :
183212 """Get resolver for field_name
184213
@@ -200,8 +229,32 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable:
200229 raise ValueError (f"No resolver found for '{ full_name } '" )
201230 return resolver ["func" ]
202231
232+ def _get_batch_resolver (self , type_name : str , field_name : str ) -> Callable :
233+ """Get resolver for field_name
234+
235+ Parameters
236+ ----------
237+ type_name : str
238+ Type name
239+ field_name : str
240+ Field name
241+
242+ Returns
243+ -------
244+ Callable
245+ callable function and configuration
246+ """
247+ full_name = f"{ type_name } .{ field_name } "
248+ resolver = self ._batch_resolvers .get (full_name , self ._batch_resolvers .get (f"*.{ field_name } " ))
249+ if not resolver :
250+ raise ValueError (f"No batch resolver found for '{ full_name } '" )
251+ return resolver ["func" ]
252+
203253 def __call__ (
204- self , event : dict , context : LambdaContext , data_model : Type [AppSyncResolverEvent ] = AppSyncResolverEvent
254+ self ,
255+ event : Union [dict , List [dict ]],
256+ context : LambdaContext ,
257+ data_model : Type [AppSyncResolverEvent ] = AppSyncResolverEvent ,
205258 ) -> Any :
206259 """Implicit lambda handler which internally calls `resolve`"""
207260 return self .resolve (event , context , data_model )
@@ -222,7 +275,6 @@ def include_router(self, router: "Router") -> None:
222275 self ._resolvers .update (router ._resolvers )
223276
224277
225- class Router (BaseRouter ):
278+ class Router (RouterContext , ResolverRegistry ):
226279 def __init__ (self ):
227280 super ().__init__ ()
228- self .context = {} # early init as customers might add context before event resolution
0 commit comments