Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions aws_lambda_powertools/event_handler/openapi/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,42 @@ def model_rebuild(model: type[BaseModel]) -> None:
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
# Create a shallow copy of the field_info to preserve its type and all attributes
new_field = copy(field_info)
# Update only the annotation to the new one
new_field.annotation = annotation

# Recursively extract all metadata from nested Annotated types
def extract_metadata(ann: Any) -> tuple[Any, list[Any]]:
"""Extract base type and all non-FieldInfo metadata from potentially nested Annotated types."""
if get_origin(ann) is not Annotated:
return ann, []

args = get_args(ann)
base_type = args[0]
metadata = list(args[1:])

# If base type is also Annotated, recursively extract its metadata
if get_origin(base_type) is Annotated:
inner_base, inner_metadata = extract_metadata(base_type)
# Combine metadata from both levels, filtering out FieldInfo instances
from pydantic.fields import FieldInfo as PydanticFieldInfo

all_metadata = [m for m in inner_metadata + metadata if not isinstance(m, PydanticFieldInfo)]
return inner_base, all_metadata
else:
# Filter out FieldInfo instances from metadata
from pydantic.fields import FieldInfo as PydanticFieldInfo

constraint_metadata = [m for m in metadata if not isinstance(m, PydanticFieldInfo)]
return base_type, constraint_metadata

# Extract base type and constraints
base_type, constraints = extract_metadata(annotation)

# Set the annotation with base type and all constraint metadata
# Use tuple unpacking for Python 3.9+ compatibility
if constraints:
new_field.annotation = Annotated[(base_type, *constraints)]
else:
new_field.annotation = base_type

return new_field


Expand Down
11 changes: 11 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,10 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
type_annotation = annotated_args[0]
powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)]

# Preserve non-FieldInfo metadata (like annotated_types constraints)
# This is important for constraints like Interval, Gt, Lt, etc.
other_metadata = [arg for arg in annotated_args[1:] if not isinstance(arg, FieldInfo)]

# Determine which annotation to use
powertools_annotation: FieldInfo | None = None
has_discriminator_with_param = False
Expand All @@ -1124,6 +1128,13 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
else:
powertools_annotation = next(iter(powertools_annotations), None)

# Reconstruct type_annotation with non-FieldInfo metadata if present
# This ensures constraints like Interval are preserved
if other_metadata and not has_discriminator_with_param:
from typing_extensions import Annotated

type_annotation = Annotated[(type_annotation, *other_metadata)]

# Process the annotation if it exists
field_info: FieldInfo | None = None
if isinstance(powertools_annotation, FieldInfo): # pragma: no cover
Expand Down
184 changes: 184 additions & 0 deletions tests/functional/event_handler/_pydantic/test_openapi_params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional, Tuple

import pytest
from pydantic import BaseModel, Field
from typing_extensions import Annotated

Expand Down Expand Up @@ -1044,3 +1046,185 @@ def complex_handler(params: Annotated[QueryParams, Query()]):
assert type_mapping["int_field"] == "integer"
assert type_mapping["float_field"] == "number"
assert type_mapping["bool_field"] == "boolean"


@pytest.mark.parametrize(
"body_value,expected_value",
[
("50", 50), # Valid: within range
("0", 0), # Valid: at lower bound
("100", 100), # Valid: at upper bound
],
)
def test_annotated_types_interval_constraints_in_body_params(body_value, expected_value):
"""
Test for issue #7600: Validate that annotated_types.Interval constraints
are properly enforced in Body parameters with valid values.
"""
from annotated_types import Interval

# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# AND a constrained type using annotated_types.Interval
ConstrainedInt = Annotated[int, Interval(ge=0, le=100)]

@app.post("/items")
def create_item(value: Annotated[ConstrainedInt, Body()]):
return {"value": value}

# WHEN sending a request with a valid value
event = {
"resource": "/items",
"path": "/items",
"httpMethod": "POST",
"body": body_value,
"isBase64Encoded": False,
}

# THEN the request should succeed
result = app(event, {})
assert result["statusCode"] == 200
body = json.loads(result["body"])
assert body["value"] == expected_value


@pytest.mark.parametrize(
"body_value",
[
"-1", # Invalid: below range
"101", # Invalid: above range
],
)
def test_annotated_types_interval_constraints_in_body_params_invalid(body_value):
"""
Test for issue #7600: Validate that annotated_types.Interval constraints
reject invalid values in Body parameters.
"""
from annotated_types import Interval

# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# AND a constrained type using annotated_types.Interval
constrained_int = Annotated[int, Interval(ge=0, le=100)]

@app.post("/items")
def create_item(value: Annotated[constrained_int, Body()]):
return {"value": value}

# WHEN sending a request with an invalid value
event = {
"resource": "/items",
"path": "/items",
"httpMethod": "POST",
"body": body_value,
"isBase64Encoded": False,
}

# THEN validation should fail
result = app(event, {})
assert result["statusCode"] == 422


@pytest.mark.parametrize(
"query_value,expected_value",
[
("50", 50), # Valid: within range
("0", 0), # Valid: at lower bound
("100", 100), # Valid: at upper bound
],
)
def test_annotated_types_interval_constraints_in_query_params(query_value, expected_value):
"""
Test for issue #7600: Validate that annotated_types.Interval constraints
are properly enforced in Query parameters with valid values.
"""
from annotated_types import Interval

# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# AND a constrained type using annotated_types.Interval
constrained_int = Annotated[int, Interval(ge=0, le=100)]

@app.get("/items")
def list_items(limit: Annotated[constrained_int, Query()]):
return {"limit": limit}

# WHEN sending a request with a valid value
event = {
"resource": "/items",
"path": "/items",
"httpMethod": "GET",
"queryStringParameters": {"limit": query_value},
"isBase64Encoded": False,
}

# THEN the request should succeed
result = app(event, {})
assert result["statusCode"] == 200
body = json.loads(result["body"])
assert body["limit"] == expected_value


@pytest.mark.parametrize(
"query_value",
[
"-1", # Invalid: below range
"101", # Invalid: above range
],
)
def test_annotated_types_interval_constraints_in_query_params_invalid(query_value):
"""
Test for issue #7600: Validate that annotated_types.Interval constraints
reject invalid values in Query parameters.
"""
from annotated_types import Interval

# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# AND a constrained type using annotated_types.Interval
constrained_int = Annotated[int, Interval(ge=0, le=100)]

@app.get("/items")
def list_items(limit: Annotated[constrained_int, Query()]):
return {"limit": limit}

# WHEN sending a request with an invalid value
event = {
"resource": "/items",
"path": "/items",
"httpMethod": "GET",
"queryStringParameters": {"limit": query_value},
"isBase64Encoded": False,
}

# THEN validation should fail
result = app(event, {})
assert result["statusCode"] == 422


def test_annotated_types_interval_in_openapi_schema():
"""
Test that annotated_types.Interval constraints are reflected in the OpenAPI schema.
"""
from annotated_types import Interval

app = APIGatewayRestResolver()
constrained_int = Annotated[int, Interval(ge=0, le=100)]

@app.get("/items")
def list_items(limit: Annotated[constrained_int, Query()] = 10):
return {"limit": limit}

schema = app.get_openapi_schema()

# Verify the Query parameter schema includes constraints
get_operation = schema.paths["/items"].get
limit_param = next(p for p in get_operation.parameters if p.name == "limit")

assert limit_param.schema_.type == "integer"
assert limit_param.schema_.default == 10
assert limit_param.required is False