Skip to content

Commit dcd0d4d

Browse files
rubenfonsecaleandrodamascenaCavalcante Damascena
authored
feat(event_handler): generate OpenAPI specifications and validate input/output (aws-powertools#3109)
* feat: generate OpenAPI spec from event handler * fix: resolver circular dependencies * fix: rebase * fix: document the new methods * fix: linter * fix: remove unneeded code * fix: reduce duplication * fix: types and sonarcube * chore: refactor complex function * fix: typing extensions * fix: tests * fix: mypy * fix: security baseline * feat: add simultaneous support for Pydantic v2 * fix: disable mypy and ruff on openapi compat * chore: add explanation to imports * chore: add first test * fix: test * fix: test * fix: don't require pydantic to run normal things * chore: added first tests * fix: refactored tests to remove code smell * fix: customize the handler methods * fix: tests * feat: add a validation middleware * fix: uniontype * fix: types * fix: ignore unused-ignore * fix: moved things around * fix: compatibility with pydantic v2 * chore: add tests on the body request * chore: add tests for validation middleware * fix: assorted fixes * fix: make tests pass in both pydantic versions * fix: remove assert * fix: complexity * fix: move Response class back * fix: more fix * fix: more fix * fix: one more fix * fix: refactor OpenAPI validation middleware * fix: refactor dependant.py * fix: beautify encoders * fix: move things around * fix: costmetic changes * fix: add more comments * fix: format * fix: cyclomatic * fix: change method of generating operation id * fix: allow validation in all resolvers * fix: use proper resolver in tests * fix: move from flake8 to ruff * fix: customizing responses * fix: add documentation to a method * fix: more explicit comments * fix: typo * fix: add extra comment * fix: comment * fix: add comments * fix: comments * fix: typo * fix: remove leftover comment * fix: addressing comments * fix: pydantic2 models * fix: typing extension problems * Adding more tests and fixing small things * Adding more tests and fixing small things * Adding more tests and fixing small things * Removing flaky tests * fix: improve coverage of encoders * fix: mark test as pydantic v1 only * fix: make sonarcube happy * fix: improve coverage of params.py * fix: add codecov.yml file to ignore compat.py * Increasing coverage --------- Signed-off-by: Leandro Damascena <lcdama@amazon.pt> Co-authored-by: Leandro Damascena <lcdama@amazon.pt> Co-authored-by: Cavalcante Damascena <lcdama@b0be8355743f.ant.amazon.com>
1 parent 14cb407 commit dcd0d4d

22 files changed

+4733
-31
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+769-25
Large diffs are not rendered by default.

Diff for: aws_lambda_powertools/event_handler/lambda_function_url.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,13 @@ def __init__(
5252
debug: Optional[bool] = None,
5353
serializer: Optional[Callable[[Dict], str]] = None,
5454
strip_prefixes: Optional[List[Union[str, Pattern]]] = None,
55+
enable_validation: bool = False,
5556
):
56-
super().__init__(ProxyEventType.LambdaFunctionUrlEvent, cors, debug, serializer, strip_prefixes)
57+
super().__init__(
58+
ProxyEventType.LambdaFunctionUrlEvent,
59+
cors,
60+
debug,
61+
serializer,
62+
strip_prefixes,
63+
enable_validation,
64+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
import dataclasses
2+
import json
3+
import logging
4+
from copy import deepcopy
5+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
6+
7+
from pydantic import BaseModel
8+
9+
from aws_lambda_powertools.event_handler import Response
10+
from aws_lambda_powertools.event_handler.api_gateway import Route
11+
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
12+
from aws_lambda_powertools.event_handler.openapi.compat import (
13+
ModelField,
14+
_model_dump,
15+
_normalize_errors,
16+
_regenerate_error_with_loc,
17+
get_missing_field_error,
18+
)
19+
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
20+
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
21+
from aws_lambda_powertools.event_handler.openapi.params import Param
22+
from aws_lambda_powertools.event_handler.openapi.types import IncEx
23+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
29+
"""
30+
OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the
31+
Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It
32+
should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`.
33+
34+
Examples
35+
--------
36+
37+
```python
38+
from typing import List
39+
40+
from pydantic import BaseModel
41+
42+
from aws_lambda_powertools.event_handler.api_gateway import (
43+
APIGatewayRestResolver,
44+
)
45+
46+
class Todo(BaseModel):
47+
name: str
48+
49+
app = APIGatewayRestResolver(enable_validation=True)
50+
51+
@app.get("/todos")
52+
def get_todos(): List[Todo]:
53+
return [Todo(name="hello world")]
54+
```
55+
"""
56+
57+
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
58+
logger.debug("OpenAPIValidationMiddleware handler")
59+
60+
route: Route = app.context["_route"]
61+
62+
values: Dict[str, Any] = {}
63+
errors: List[Any] = []
64+
65+
try:
66+
# Process path values, which can be found on the route_args
67+
path_values, path_errors = _request_params_to_args(
68+
route.dependant.path_params,
69+
app.context["_route_args"],
70+
)
71+
72+
# Process query values
73+
query_values, query_errors = _request_params_to_args(
74+
route.dependant.query_params,
75+
app.current_event.query_string_parameters or {},
76+
)
77+
78+
values.update(path_values)
79+
values.update(query_values)
80+
errors += path_errors + query_errors
81+
82+
# Process the request body, if it exists
83+
if route.dependant.body_params:
84+
(body_values, body_errors) = _request_body_to_args(
85+
required_params=route.dependant.body_params,
86+
received_body=self._get_body(app),
87+
)
88+
values.update(body_values)
89+
errors.extend(body_errors)
90+
91+
if errors:
92+
# Raise the validation errors
93+
raise RequestValidationError(_normalize_errors(errors))
94+
else:
95+
# Re-write the route_args with the validated values, and call the next middleware
96+
app.context["_route_args"] = values
97+
response = next_middleware(app)
98+
99+
# Process the response body if it exists
100+
raw_response = jsonable_encoder(response.body)
101+
102+
# Validate and serialize the response
103+
return self._serialize_response(field=route.dependant.return_param, response_content=raw_response)
104+
except RequestValidationError as e:
105+
return Response(
106+
status_code=422,
107+
content_type="application/json",
108+
body=json.dumps({"detail": e.errors()}),
109+
)
110+
111+
def _serialize_response(
112+
self,
113+
*,
114+
field: Optional[ModelField] = None,
115+
response_content: Any,
116+
include: Optional[IncEx] = None,
117+
exclude: Optional[IncEx] = None,
118+
by_alias: bool = True,
119+
exclude_unset: bool = False,
120+
exclude_defaults: bool = False,
121+
exclude_none: bool = False,
122+
) -> Any:
123+
"""
124+
Serialize the response content according to the field type.
125+
"""
126+
if field:
127+
errors: List[Dict[str, Any]] = []
128+
# MAINTENANCE: remove this when we drop pydantic v1
129+
if not hasattr(field, "serializable"):
130+
response_content = self._prepare_response_content(
131+
response_content,
132+
exclude_unset=exclude_unset,
133+
exclude_defaults=exclude_defaults,
134+
exclude_none=exclude_none,
135+
)
136+
137+
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
138+
if errors:
139+
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
140+
141+
if hasattr(field, "serialize"):
142+
return field.serialize(
143+
value,
144+
include=include,
145+
exclude=exclude,
146+
by_alias=by_alias,
147+
exclude_unset=exclude_unset,
148+
exclude_defaults=exclude_defaults,
149+
exclude_none=exclude_none,
150+
)
151+
152+
return jsonable_encoder(
153+
value,
154+
include=include,
155+
exclude=exclude,
156+
by_alias=by_alias,
157+
exclude_unset=exclude_unset,
158+
exclude_defaults=exclude_defaults,
159+
exclude_none=exclude_none,
160+
)
161+
else:
162+
# Just serialize the response content returned from the handler
163+
return jsonable_encoder(response_content)
164+
165+
def _prepare_response_content(
166+
self,
167+
res: Any,
168+
*,
169+
exclude_unset: bool,
170+
exclude_defaults: bool = False,
171+
exclude_none: bool = False,
172+
) -> Any:
173+
"""
174+
Prepares the response content for serialization.
175+
"""
176+
if isinstance(res, BaseModel):
177+
return _model_dump(
178+
res,
179+
by_alias=True,
180+
exclude_unset=exclude_unset,
181+
exclude_defaults=exclude_defaults,
182+
exclude_none=exclude_none,
183+
)
184+
elif isinstance(res, list):
185+
return [
186+
self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
187+
for item in res
188+
]
189+
elif isinstance(res, dict):
190+
return {
191+
k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
192+
for k, v in res.items()
193+
}
194+
elif dataclasses.is_dataclass(res):
195+
return dataclasses.asdict(res)
196+
return res
197+
198+
def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
199+
"""
200+
Get the request body from the event, and parse it as JSON.
201+
"""
202+
203+
content_type_value = app.current_event.get_header_value("content-type")
204+
if not content_type_value or content_type_value.startswith("application/json"):
205+
try:
206+
return app.current_event.json_body
207+
except json.JSONDecodeError as e:
208+
raise RequestValidationError(
209+
[
210+
{
211+
"type": "json_invalid",
212+
"loc": ("body", e.pos),
213+
"msg": "JSON decode error",
214+
"input": {},
215+
"ctx": {"error": e.msg},
216+
},
217+
],
218+
body=e.doc,
219+
) from e
220+
else:
221+
raise NotImplementedError("Only JSON body is supported")
222+
223+
224+
def _request_params_to_args(
225+
required_params: Sequence[ModelField],
226+
received_params: Mapping[str, Any],
227+
) -> Tuple[Dict[str, Any], List[Any]]:
228+
"""
229+
Convert the request params to a dictionary of values using validation, and returns a list of errors.
230+
"""
231+
values = {}
232+
errors = []
233+
234+
for field in required_params:
235+
value = received_params.get(field.alias)
236+
237+
field_info = field.field_info
238+
if not isinstance(field_info, Param):
239+
raise AssertionError(f"Expected Param field_info, got {field_info}")
240+
241+
loc = (field_info.in_.value, field.alias)
242+
243+
# If we don't have a value, see if it's required or has a default
244+
if value is None:
245+
if field.required:
246+
errors.append(get_missing_field_error(loc=loc))
247+
else:
248+
values[field.name] = deepcopy(field.default)
249+
continue
250+
251+
# Finally, validate the value
252+
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
253+
254+
return values, errors
255+
256+
257+
def _request_body_to_args(
258+
required_params: List[ModelField],
259+
received_body: Optional[Dict[str, Any]],
260+
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
261+
"""
262+
Convert the request body to a dictionary of values using validation, and returns a list of errors.
263+
"""
264+
values: Dict[str, Any] = {}
265+
errors: List[Dict[str, Any]] = []
266+
267+
received_body, field_alias_omitted = _get_embed_body(
268+
field=required_params[0],
269+
required_params=required_params,
270+
received_body=received_body,
271+
)
272+
273+
for field in required_params:
274+
# This sets the location to:
275+
# { "user": { object } } if field.alias == user
276+
# { { object } if field_alias is omitted
277+
loc: Tuple[str, ...] = ("body", field.alias)
278+
if field_alias_omitted:
279+
loc = ("body",)
280+
281+
value: Optional[Any] = None
282+
283+
# Now that we know what to look for, try to get the value from the received body
284+
if received_body is not None:
285+
try:
286+
value = received_body.get(field.alias)
287+
except AttributeError:
288+
errors.append(get_missing_field_error(loc))
289+
continue
290+
291+
# Determine if the field is required
292+
if value is None:
293+
if field.required:
294+
errors.append(get_missing_field_error(loc))
295+
else:
296+
values[field.name] = deepcopy(field.default)
297+
continue
298+
299+
# MAINTENANCE: Handle byte and file fields
300+
301+
# Finally, validate the value
302+
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
303+
304+
return values, errors
305+
306+
307+
def _validate_field(
308+
*,
309+
field: ModelField,
310+
value: Any,
311+
loc: Tuple[str, ...],
312+
existing_errors: List[Dict[str, Any]],
313+
):
314+
"""
315+
Validate a field, and append any errors to the existing_errors list.
316+
"""
317+
validated_value, errors = field.validate(value, value, loc=loc)
318+
319+
if isinstance(errors, list):
320+
processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=())
321+
existing_errors.extend(processed_errors)
322+
elif errors:
323+
existing_errors.append(errors)
324+
325+
return validated_value
326+
327+
328+
def _get_embed_body(
329+
*,
330+
field: ModelField,
331+
required_params: List[ModelField],
332+
received_body: Optional[Dict[str, Any]],
333+
) -> Tuple[Optional[Dict[str, Any]], bool]:
334+
field_info = field.field_info
335+
embed = getattr(field_info, "embed", None)
336+
337+
# If the field is an embed, and the field alias is omitted, we need to wrap the received body in the field alias.
338+
field_alias_omitted = len(required_params) == 1 and not embed
339+
if field_alias_omitted:
340+
received_body = {field.alias: received_body}
341+
342+
return received_body, field_alias_omitted

Diff for: aws_lambda_powertools/event_handler/openapi/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)