33import dataclasses
44import json
55import logging
6- from typing import TYPE_CHECKING , Any , Callable , Mapping , MutableMapping , Sequence
6+ from typing import TYPE_CHECKING , Any , Callable , Mapping , MutableMapping , Sequence , cast
77from urllib .parse import parse_qs
88
99from pydantic import BaseModel
1313 _model_dump ,
1414 _normalize_errors ,
1515 _regenerate_error_with_loc ,
16+ field_annotation_is_sequence ,
1617 get_missing_field_error ,
17- is_sequence_field ,
18+ lenient_issubclass ,
1819)
1920from aws_lambda_powertools .event_handler .openapi .dependant import is_scalar_field
2021from aws_lambda_powertools .event_handler .openapi .encoders import jsonable_encoder
2122from aws_lambda_powertools .event_handler .openapi .exceptions import RequestValidationError , ResponseValidationError
2223from aws_lambda_powertools .event_handler .openapi .params import Param
2324
2425if TYPE_CHECKING :
26+ from pydantic .fields import FieldInfo
27+
2528 from aws_lambda_powertools .event_handler import Response
2629 from aws_lambda_powertools .event_handler .api_gateway import Route
2730 from aws_lambda_powertools .event_handler .middlewares import NextMiddleware
@@ -64,7 +67,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
6467 )
6568
6669 # Normalize query values before validate this
67- query_string = _normalize_multi_query_string_with_param (
70+ query_string = _normalize_multi_params (
6871 app .current_event .resolved_query_string_parameters ,
6972 route .dependant .query_params ,
7073 )
@@ -76,7 +79,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
7679 )
7780
7881 # Normalize header values before validate this
79- headers = _normalize_multi_header_values_with_param (
82+ headers = _normalize_multi_params (
8083 app .current_event .resolved_headers_field ,
8184 route .dependant .header_params ,
8285 )
@@ -366,7 +369,7 @@ def _request_body_to_args(
366369 _handle_missing_field_value (field , values , errors , loc )
367370 continue
368371
369- value = _normalize_field_value (field , value )
372+ value = _normalize_field_value (value = value , field_info = field . field_info )
370373 values [field .name ] = _validate_field (field = field , value = value , loc = loc , existing_errors = errors )
371374
372375 return values , errors
@@ -409,10 +412,13 @@ def _handle_missing_field_value(
409412 values [field .name ] = field .get_default ()
410413
411414
412- def _normalize_field_value (field : ModelField , value : Any ) -> Any :
415+ def _normalize_field_value (value : Any , field_info : FieldInfo ) -> Any :
413416 """Normalize field value, converting lists to single values for non-sequence fields."""
414- if isinstance (value , list ) and not is_sequence_field (field ):
417+ if field_annotation_is_sequence (field_info .annotation ):
418+ return value
419+ elif isinstance (value , list ) and value :
415420 return value [0 ]
421+
416422 return value
417423
418424
@@ -454,57 +460,70 @@ def _get_embed_body(
454460 return received_body , field_alias_omitted
455461
456462
457- def _normalize_multi_query_string_with_param (
458- query_string : dict [str , list [ str ] ],
463+ def _normalize_multi_params (
464+ input_dict : MutableMapping [str , Any ],
459465 params : Sequence [ModelField ],
460- ) -> dict [str , Any ]:
466+ ) -> MutableMapping [str , Any ]:
461467 """
462- Extract and normalize resolved_query_string_parameters
468+ Extract and normalize query string or header parameters with Pydantic model support.
463469
464470 Parameters
465471 ----------
466- query_string: dict
467- A dictionary containing the initial query string parameters.
472+ input_dict: MutableMapping[str, Any]
473+ A dictionary containing the initial query string or header parameters.
468474 params: Sequence[ModelField]
469475 A sequence of ModelField objects representing parameters.
470476
471477 Returns
472478 -------
473- A dictionary containing the processed multi_query_string_parameters.
479+ MutableMapping[str, Any]
480+ A dictionary containing the processed parameters with normalized values.
474481 """
475- resolved_query_string : dict [str , Any ] = query_string
476- for param in filter (is_scalar_field , params ):
477- try :
478- # if the target parameter is a scalar, we keep the first value of the query string
479- # regardless if there are more in the payload
480- resolved_query_string [param .alias ] = query_string [param .alias ][0 ]
481- except KeyError :
482- pass
483- return resolved_query_string
482+ for param in params :
483+ if is_scalar_field (param ):
484+ _process_scalar_param (input_dict , param )
485+ elif lenient_issubclass (param .field_info .annotation , BaseModel ):
486+ _process_model_param (input_dict , param )
487+ return input_dict
484488
485489
486- def _normalize_multi_header_values_with_param (headers : MutableMapping [str , Any ], params : Sequence [ModelField ]):
487- """
488- Extract and normalize resolved_headers_field
490+ def _process_scalar_param (input_dict : MutableMapping [str , Any ], param : ModelField ) -> None :
491+ """Process a scalar parameter by normalizing single-item lists."""
492+ try :
493+ value = input_dict [param .alias ]
494+ if isinstance (value , list ) and len (value ) == 1 :
495+ input_dict [param .alias ] = value [0 ]
496+ except KeyError :
497+ pass
489498
490- Parameters
491- ----------
492- headers: MutableMapping[str, Any]
493- A dictionary containing the initial header parameters.
494- params: Sequence[ModelField]
495- A sequence of ModelField objects representing parameters.
496499
497- Returns
498- -------
499- A dictionary containing the processed headers.
500- """
501- if headers :
502- for param in filter (is_scalar_field , params ):
503- try :
504- if len (headers [param .alias ]) == 1 :
505- # if the target parameter is a scalar and the list contains only 1 element
506- # we keep the first value of the headers regardless if there are more in the payload
507- headers [param .alias ] = headers [param .alias ][0 ]
508- except KeyError :
509- pass
510- return headers
500+ def _process_model_param (input_dict : MutableMapping [str , Any ], param : ModelField ) -> None :
501+ """Process a Pydantic model parameter by extracting model fields."""
502+ model_class = cast (type [BaseModel ], param .field_info .annotation )
503+
504+ model_data = {}
505+ for field_name , field_info in model_class .model_fields .items ():
506+ field_alias = field_info .alias or field_name
507+ value = _get_param_value (input_dict , field_alias , field_name , model_class )
508+
509+ if value is not None :
510+ model_data [field_alias ] = _normalize_field_value (value = value , field_info = field_info )
511+
512+ input_dict [param .alias ] = model_data
513+
514+
515+ def _get_param_value (
516+ input_dict : MutableMapping [str , Any ],
517+ field_alias : str ,
518+ field_name : str ,
519+ model_class : type [BaseModel ],
520+ ) -> Any :
521+ """Get parameter value, checking both alias and field name if needed."""
522+ value = input_dict .get (field_alias )
523+ if value is not None :
524+ return value
525+
526+ if model_class .model_config .get ("validate_by_name" ) or model_class .model_config .get ("populate_by_name" ):
527+ value = input_dict .get (field_name )
528+
529+ return value
0 commit comments