Skip to content

Commit 21b7235

Browse files
feat(event_handler): add support for Pydantic models in Query and Header types (#7253)
* Adding supoort for Pydantic models in Query and Header * Improve error message and fix header serialization * Improving field validation method + tests * Improving field validation method + tests * support validate_by_name for pydantic BaseModels * support multi value headers and queries * remove unused get flat params for pydantic * remove unused pydantic openapi function * fix typing * fix sonarqube code smell findings * remove unnecessary checks and increase coverage * generalize body check with Param * rename for merge * consolidate normalize field value functions * Adding docs + small fix in tests --------- Co-authored-by: Leandro Damascena <lcdama@amazon.pt>
1 parent c1d53c2 commit 21b7235

File tree

9 files changed

+1116
-100
lines changed

9 files changed

+1116
-100
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 103 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
Server,
9393
Tag,
9494
)
95-
from aws_lambda_powertools.event_handler.openapi.params import Dependant
95+
from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param
9696
from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import (
9797
OAuth2Config,
9898
)
@@ -818,46 +818,123 @@ def _openapi_operation_parameters(
818818
"""
819819
Returns the OpenAPI operation parameters.
820820
"""
821-
from aws_lambda_powertools.event_handler.openapi.compat import (
822-
get_schema_from_model_field,
823-
)
824821
from aws_lambda_powertools.event_handler.openapi.params import Param
825822

826-
parameters = []
827-
parameter: dict[str, Any] = {}
823+
parameters: list[dict[str, Any]] = []
828824

829825
for param in all_route_params:
830-
field_info = param.field_info
831-
field_info = cast(Param, field_info)
826+
field_info = cast(Param, param.field_info)
832827
if not field_info.include_in_schema:
833828
continue
834829

835-
param_schema = get_schema_from_model_field(
836-
field=param,
837-
model_name_map=model_name_map,
838-
field_mapping=field_mapping,
839-
)
830+
# Check if this is a Pydantic model that should be expanded
831+
if Route._is_pydantic_model_param(field_info):
832+
parameters.extend(Route._expand_pydantic_model_parameters(field_info))
833+
else:
834+
parameters.append(Route._create_regular_parameter(param, model_name_map, field_mapping))
840835

841-
parameter = {
842-
"name": param.alias,
843-
"in": field_info.in_.value,
844-
"required": param.required,
845-
"schema": param_schema,
846-
}
836+
return parameters
847837

848-
if field_info.description:
849-
parameter["description"] = field_info.description
838+
@staticmethod
839+
def _is_pydantic_model_param(field_info: Param) -> bool:
840+
"""Check if the field info represents a Pydantic model parameter."""
841+
from pydantic import BaseModel
850842

851-
if field_info.openapi_examples:
852-
parameter["examples"] = field_info.openapi_examples
843+
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
853844

854-
if field_info.deprecated:
855-
parameter["deprecated"] = field_info.deprecated
845+
return lenient_issubclass(field_info.annotation, BaseModel)
856846

857-
parameters.append(parameter)
847+
@staticmethod
848+
def _expand_pydantic_model_parameters(field_info: Param) -> list[dict[str, Any]]:
849+
"""Expand a Pydantic model into individual OpenAPI parameters."""
850+
from pydantic import BaseModel
851+
852+
model_class = cast(type[BaseModel], field_info.annotation)
853+
parameters: list[dict[str, Any]] = []
854+
855+
for field_name, field_def in model_class.model_fields.items():
856+
param_name = field_def.alias or field_name
857+
individual_param = Route._create_pydantic_field_parameter(
858+
param_name=param_name,
859+
field_def=field_def,
860+
param_location=field_info.in_.value,
861+
)
862+
parameters.append(individual_param)
858863

859864
return parameters
860865

866+
@staticmethod
867+
def _create_pydantic_field_parameter(
868+
param_name: str,
869+
field_def: Any,
870+
param_location: str,
871+
) -> dict[str, Any]:
872+
"""Create an OpenAPI parameter from a Pydantic field definition."""
873+
individual_param: dict[str, Any] = {
874+
"name": param_name,
875+
"in": param_location,
876+
"required": field_def.is_required() if hasattr(field_def, "is_required") else field_def.default is ...,
877+
"schema": Route._get_basic_type_schema(field_def.annotation or type(None)),
878+
}
879+
880+
if field_def.description:
881+
individual_param["description"] = field_def.description
882+
883+
return individual_param
884+
885+
@staticmethod
886+
def _create_regular_parameter(
887+
param: ModelField,
888+
model_name_map: dict[TypeModelOrEnum, str],
889+
field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
890+
) -> dict[str, Any]:
891+
"""Create an OpenAPI parameter from a regular ModelField."""
892+
from aws_lambda_powertools.event_handler.openapi.compat import get_schema_from_model_field
893+
from aws_lambda_powertools.event_handler.openapi.params import Param
894+
895+
field_info = cast(Param, param.field_info)
896+
param_schema = get_schema_from_model_field(
897+
field=param,
898+
model_name_map=model_name_map,
899+
field_mapping=field_mapping,
900+
)
901+
902+
parameter: dict[str, Any] = {
903+
"name": param.alias,
904+
"in": field_info.in_.value,
905+
"required": param.required,
906+
"schema": param_schema,
907+
}
908+
909+
# Add optional attributes if present
910+
if field_info.description:
911+
parameter["description"] = field_info.description
912+
if field_info.openapi_examples:
913+
parameter["examples"] = field_info.openapi_examples
914+
if field_info.deprecated:
915+
parameter["deprecated"] = field_info.deprecated
916+
917+
return parameter
918+
919+
@staticmethod
920+
def _get_basic_type_schema(param_type: type) -> dict[str, str]:
921+
"""
922+
Get basic OpenAPI schema for simple types
923+
"""
924+
try:
925+
# Check bool before int, since bool is a subclass of int in Python
926+
if issubclass(param_type, bool):
927+
return {"type": "boolean"}
928+
elif issubclass(param_type, int):
929+
return {"type": "integer"}
930+
elif issubclass(param_type, float):
931+
return {"type": "number"}
932+
else:
933+
return {"type": "string"}
934+
except TypeError:
935+
# param_type may not be a type (e.g., typing.Optional[int]), fallback to string
936+
return {"type": "string"}
937+
861938
@staticmethod
862939
def _openapi_operation_return(
863940
*,

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 65 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import dataclasses
44
import json
55
import logging
6-
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
6+
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast
77
from urllib.parse import parse_qs
88

99
from pydantic import BaseModel
@@ -13,15 +13,18 @@
1313
_model_dump,
1414
_normalize_errors,
1515
_regenerate_error_with_loc,
16+
field_annotation_is_sequence,
1617
get_missing_field_error,
17-
is_sequence_field,
18+
lenient_issubclass,
1819
)
1920
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
2021
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
2122
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
2223
from aws_lambda_powertools.event_handler.openapi.params import Param
2324

2425
if TYPE_CHECKING:
26+
from pydantic.fields import FieldInfo
27+
2528
from aws_lambda_powertools.event_handler import Response
2629
from aws_lambda_powertools.event_handler.api_gateway import Route
2730
from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
@@ -64,7 +67,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
6467
)
6568

6669
# Normalize query values before validate this
67-
query_string = _normalize_multi_query_string_with_param(
70+
query_string = _normalize_multi_params(
6871
app.current_event.resolved_query_string_parameters,
6972
route.dependant.query_params,
7073
)
@@ -76,7 +79,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
7679
)
7780

7881
# Normalize header values before validate this
79-
headers = _normalize_multi_header_values_with_param(
82+
headers = _normalize_multi_params(
8083
app.current_event.resolved_headers_field,
8184
route.dependant.header_params,
8285
)
@@ -366,7 +369,7 @@ def _request_body_to_args(
366369
_handle_missing_field_value(field, values, errors, loc)
367370
continue
368371

369-
value = _normalize_field_value(field, value)
372+
value = _normalize_field_value(value=value, field_info=field.field_info)
370373
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
371374

372375
return values, errors
@@ -409,10 +412,13 @@ def _handle_missing_field_value(
409412
values[field.name] = field.get_default()
410413

411414

412-
def _normalize_field_value(field: ModelField, value: Any) -> Any:
415+
def _normalize_field_value(value: Any, field_info: FieldInfo) -> Any:
413416
"""Normalize field value, converting lists to single values for non-sequence fields."""
414-
if isinstance(value, list) and not is_sequence_field(field):
417+
if field_annotation_is_sequence(field_info.annotation):
418+
return value
419+
elif isinstance(value, list) and value:
415420
return value[0]
421+
416422
return value
417423

418424

@@ -454,57 +460,70 @@ def _get_embed_body(
454460
return received_body, field_alias_omitted
455461

456462

457-
def _normalize_multi_query_string_with_param(
458-
query_string: dict[str, list[str]],
463+
def _normalize_multi_params(
464+
input_dict: MutableMapping[str, Any],
459465
params: Sequence[ModelField],
460-
) -> dict[str, Any]:
466+
) -> MutableMapping[str, Any]:
461467
"""
462-
Extract and normalize resolved_query_string_parameters
468+
Extract and normalize query string or header parameters with Pydantic model support.
463469
464470
Parameters
465471
----------
466-
query_string: dict
467-
A dictionary containing the initial query string parameters.
472+
input_dict: MutableMapping[str, Any]
473+
A dictionary containing the initial query string or header parameters.
468474
params: Sequence[ModelField]
469475
A sequence of ModelField objects representing parameters.
470476
471477
Returns
472478
-------
473-
A dictionary containing the processed multi_query_string_parameters.
479+
MutableMapping[str, Any]
480+
A dictionary containing the processed parameters with normalized values.
474481
"""
475-
resolved_query_string: dict[str, Any] = query_string
476-
for param in filter(is_scalar_field, params):
477-
try:
478-
# if the target parameter is a scalar, we keep the first value of the query string
479-
# regardless if there are more in the payload
480-
resolved_query_string[param.alias] = query_string[param.alias][0]
481-
except KeyError:
482-
pass
483-
return resolved_query_string
482+
for param in params:
483+
if is_scalar_field(param):
484+
_process_scalar_param(input_dict, param)
485+
elif lenient_issubclass(param.field_info.annotation, BaseModel):
486+
_process_model_param(input_dict, param)
487+
return input_dict
484488

485489

486-
def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]):
487-
"""
488-
Extract and normalize resolved_headers_field
490+
def _process_scalar_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
491+
"""Process a scalar parameter by normalizing single-item lists."""
492+
try:
493+
value = input_dict[param.alias]
494+
if isinstance(value, list) and len(value) == 1:
495+
input_dict[param.alias] = value[0]
496+
except KeyError:
497+
pass
489498

490-
Parameters
491-
----------
492-
headers: MutableMapping[str, Any]
493-
A dictionary containing the initial header parameters.
494-
params: Sequence[ModelField]
495-
A sequence of ModelField objects representing parameters.
496499

497-
Returns
498-
-------
499-
A dictionary containing the processed headers.
500-
"""
501-
if headers:
502-
for param in filter(is_scalar_field, params):
503-
try:
504-
if len(headers[param.alias]) == 1:
505-
# if the target parameter is a scalar and the list contains only 1 element
506-
# we keep the first value of the headers regardless if there are more in the payload
507-
headers[param.alias] = headers[param.alias][0]
508-
except KeyError:
509-
pass
510-
return headers
500+
def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
501+
"""Process a Pydantic model parameter by extracting model fields."""
502+
model_class = cast(type[BaseModel], param.field_info.annotation)
503+
504+
model_data = {}
505+
for field_name, field_info in model_class.model_fields.items():
506+
field_alias = field_info.alias or field_name
507+
value = _get_param_value(input_dict, field_alias, field_name, model_class)
508+
509+
if value is not None:
510+
model_data[field_alias] = _normalize_field_value(value=value, field_info=field_info)
511+
512+
input_dict[param.alias] = model_data
513+
514+
515+
def _get_param_value(
516+
input_dict: MutableMapping[str, Any],
517+
field_alias: str,
518+
field_name: str,
519+
model_class: type[BaseModel],
520+
) -> Any:
521+
"""Get parameter value, checking both alias and field name if needed."""
522+
value = input_dict.get(field_alias)
523+
if value is not None:
524+
return value
525+
526+
if model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name"):
527+
value = input_dict.get(field_name)
528+
529+
return value

aws_lambda_powertools/event_handler/openapi/dependant.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@
99
create_body_model,
1010
evaluate_forwardref,
1111
is_scalar_field,
12-
is_scalar_sequence_field,
1312
)
1413
from aws_lambda_powertools.event_handler.openapi.params import (
1514
Body,
1615
Dependant,
1716
Form,
18-
Header,
1917
Param,
2018
ParamTypes,
21-
Query,
2219
_File,
2320
analyze_param,
2421
create_response_field,
@@ -275,7 +272,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
275272
return False
276273
elif is_scalar_field(field=param_field):
277274
return False
278-
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
275+
elif isinstance(param_field.field_info, Param):
279276
return False
280277
else:
281278
if not isinstance(param_field.field_info, Body):

0 commit comments

Comments
 (0)