Skip to content

Commit 9acfee4

Browse files
feat(event_handler): Ensure Bedrock Agents resolver works with Pydantic v2 (#5156)
Make sure Bedrock Agent works with Pydantic v2
1 parent dbfb0db commit 9acfee4

File tree

2 files changed

+127
-1
lines changed

2 files changed

+127
-1
lines changed

aws_lambda_powertools/event_handler/bedrock_agent.py

+107
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import json
34
from typing import TYPE_CHECKING, Any, Callable
45

56
from typing_extensions import override
@@ -10,10 +11,12 @@
1011
ProxyEventType,
1112
ResponseBuilder,
1213
)
14+
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION
1315

1416
if TYPE_CHECKING:
1517
from re import Match
1618

19+
from aws_lambda_powertools.event_handler.openapi.models import Contact, License, SecurityScheme, Server, Tag
1720
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse
1821
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
1922

@@ -273,3 +276,107 @@ def _convert_matches_into_route_keys(self, match: Match) -> dict[str, str]:
273276
if match.groupdict() and self.current_event.parameters:
274277
parameters = {parameter["name"]: parameter["value"] for parameter in self.current_event.parameters}
275278
return parameters
279+
280+
@override
281+
def get_openapi_json_schema(
282+
self,
283+
*,
284+
title: str = "Powertools API",
285+
version: str = DEFAULT_API_VERSION,
286+
openapi_version: str = DEFAULT_OPENAPI_VERSION,
287+
summary: str | None = None,
288+
description: str | None = None,
289+
tags: list[Tag | str] | None = None,
290+
servers: list[Server] | None = None,
291+
terms_of_service: str | None = None,
292+
contact: Contact | None = None,
293+
license_info: License | None = None,
294+
security_schemes: dict[str, SecurityScheme] | None = None,
295+
security: list[dict[str, list[str]]] | None = None,
296+
) -> str:
297+
"""
298+
Returns the OpenAPI schema as a JSON serializable dict.
299+
Since Bedrock Agents only support OpenAPI 3.0.0, we convert OpenAPI 3.1.0 schemas
300+
and enforce 3.0.0 compatibility for seamless integration.
301+
302+
Parameters
303+
----------
304+
title: str
305+
The title of the application.
306+
version: str
307+
The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API
308+
openapi_version: str, default = "3.0.0"
309+
The version of the OpenAPI Specification (which the document uses).
310+
summary: str, optional
311+
A short summary of what the application does.
312+
description: str, optional
313+
A verbose explanation of the application behavior.
314+
tags: list[Tag, str], optional
315+
A list of tags used by the specification with additional metadata.
316+
servers: list[Server], optional
317+
An array of Server Objects, which provide connectivity information to a target server.
318+
terms_of_service: str, optional
319+
A URL to the Terms of Service for the API. MUST be in the format of a URL.
320+
contact: Contact, optional
321+
The contact information for the exposed API.
322+
license_info: License, optional
323+
The license information for the exposed API.
324+
security_schemes: dict[str, SecurityScheme]], optional
325+
A declaration of the security schemes available to be used in the specification.
326+
security: list[dict[str, list[str]]], optional
327+
A declaration of which security mechanisms are applied globally across the API.
328+
329+
Returns
330+
-------
331+
str
332+
The OpenAPI schema as a JSON serializable dict.
333+
"""
334+
from aws_lambda_powertools.event_handler.openapi.compat import model_json
335+
336+
schema = super().get_openapi_schema(
337+
title=title,
338+
version=version,
339+
openapi_version=openapi_version,
340+
summary=summary,
341+
description=description,
342+
tags=tags,
343+
servers=servers,
344+
terms_of_service=terms_of_service,
345+
contact=contact,
346+
license_info=license_info,
347+
security_schemes=security_schemes,
348+
security=security,
349+
)
350+
schema.openapi = "3.0.3"
351+
352+
# Transform OpenAPI 3.1 into 3.0
353+
def inner(yaml_dict):
354+
if isinstance(yaml_dict, dict):
355+
if "anyOf" in yaml_dict and isinstance((anyOf := yaml_dict["anyOf"]), list):
356+
for i, item in enumerate(anyOf):
357+
if isinstance(item, dict) and item.get("type") == "null":
358+
anyOf.pop(i)
359+
yaml_dict["nullable"] = True
360+
if "examples" in yaml_dict:
361+
examples = yaml_dict["examples"]
362+
del yaml_dict["examples"]
363+
if isinstance(examples, list) and len(examples):
364+
yaml_dict["example"] = examples[0]
365+
for value in yaml_dict.values():
366+
inner(value)
367+
elif isinstance(yaml_dict, list):
368+
for item in yaml_dict:
369+
inner(item)
370+
371+
model = json.loads(
372+
model_json(
373+
schema,
374+
by_alias=True,
375+
exclude_none=True,
376+
indent=2,
377+
),
378+
)
379+
380+
inner(model)
381+
382+
return json.dumps(model)

tests/functional/event_handler/test_bedrock_agent.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
2-
from typing import Any, Dict
2+
from typing import Any, Dict, Optional
33

4+
import pytest
45
from typing_extensions import Annotated
56

67
from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types
@@ -181,3 +182,21 @@ def send_reminders(
181182
# THEN return the correct result
182183
body = result["response"]["responseBody"]["application/json"]["body"]
183184
assert json.loads(body) is True
185+
186+
187+
@pytest.mark.usefixtures("pydanticv2_only")
188+
def test_openapi_schema_for_pydanticv2(openapi30_schema):
189+
# GIVEN BedrockAgentResolver is initialized with enable_validation=True
190+
app = BedrockAgentResolver(enable_validation=True)
191+
192+
# WHEN we have a simple handler
193+
@app.get("/", description="Testing")
194+
def handler() -> Optional[Dict]:
195+
pass
196+
197+
# WHEN we get the schema
198+
schema = json.loads(app.get_openapi_json_schema())
199+
200+
# THEN the schema must be a valid 3.0.3 version
201+
assert openapi30_schema(schema)
202+
assert schema.get("openapi") == "3.0.3"

0 commit comments

Comments
 (0)