|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import json |
3 | 4 | from typing import TYPE_CHECKING, Any, Callable
|
4 | 5 |
|
5 | 6 | from typing_extensions import override
|
|
10 | 11 | ProxyEventType,
|
11 | 12 | ResponseBuilder,
|
12 | 13 | )
|
| 14 | +from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION |
13 | 15 |
|
14 | 16 | if TYPE_CHECKING:
|
15 | 17 | from re import Match
|
16 | 18 |
|
| 19 | + from aws_lambda_powertools.event_handler.openapi.models import Contact, License, SecurityScheme, Server, Tag |
17 | 20 | from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse
|
18 | 21 | from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
|
19 | 22 |
|
@@ -273,3 +276,107 @@ def _convert_matches_into_route_keys(self, match: Match) -> dict[str, str]:
|
273 | 276 | if match.groupdict() and self.current_event.parameters:
|
274 | 277 | parameters = {parameter["name"]: parameter["value"] for parameter in self.current_event.parameters}
|
275 | 278 | return parameters
|
| 279 | + |
| 280 | + @override |
| 281 | + def get_openapi_json_schema( |
| 282 | + self, |
| 283 | + *, |
| 284 | + title: str = "Powertools API", |
| 285 | + version: str = DEFAULT_API_VERSION, |
| 286 | + openapi_version: str = DEFAULT_OPENAPI_VERSION, |
| 287 | + summary: str | None = None, |
| 288 | + description: str | None = None, |
| 289 | + tags: list[Tag | str] | None = None, |
| 290 | + servers: list[Server] | None = None, |
| 291 | + terms_of_service: str | None = None, |
| 292 | + contact: Contact | None = None, |
| 293 | + license_info: License | None = None, |
| 294 | + security_schemes: dict[str, SecurityScheme] | None = None, |
| 295 | + security: list[dict[str, list[str]]] | None = None, |
| 296 | + ) -> str: |
| 297 | + """ |
| 298 | + Returns the OpenAPI schema as a JSON serializable dict. |
| 299 | + Since Bedrock Agents only support OpenAPI 3.0.0, we convert OpenAPI 3.1.0 schemas |
| 300 | + and enforce 3.0.0 compatibility for seamless integration. |
| 301 | +
|
| 302 | + Parameters |
| 303 | + ---------- |
| 304 | + title: str |
| 305 | + The title of the application. |
| 306 | + version: str |
| 307 | + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API |
| 308 | + openapi_version: str, default = "3.0.0" |
| 309 | + The version of the OpenAPI Specification (which the document uses). |
| 310 | + summary: str, optional |
| 311 | + A short summary of what the application does. |
| 312 | + description: str, optional |
| 313 | + A verbose explanation of the application behavior. |
| 314 | + tags: list[Tag, str], optional |
| 315 | + A list of tags used by the specification with additional metadata. |
| 316 | + servers: list[Server], optional |
| 317 | + An array of Server Objects, which provide connectivity information to a target server. |
| 318 | + terms_of_service: str, optional |
| 319 | + A URL to the Terms of Service for the API. MUST be in the format of a URL. |
| 320 | + contact: Contact, optional |
| 321 | + The contact information for the exposed API. |
| 322 | + license_info: License, optional |
| 323 | + The license information for the exposed API. |
| 324 | + security_schemes: dict[str, SecurityScheme]], optional |
| 325 | + A declaration of the security schemes available to be used in the specification. |
| 326 | + security: list[dict[str, list[str]]], optional |
| 327 | + A declaration of which security mechanisms are applied globally across the API. |
| 328 | +
|
| 329 | + Returns |
| 330 | + ------- |
| 331 | + str |
| 332 | + The OpenAPI schema as a JSON serializable dict. |
| 333 | + """ |
| 334 | + from aws_lambda_powertools.event_handler.openapi.compat import model_json |
| 335 | + |
| 336 | + schema = super().get_openapi_schema( |
| 337 | + title=title, |
| 338 | + version=version, |
| 339 | + openapi_version=openapi_version, |
| 340 | + summary=summary, |
| 341 | + description=description, |
| 342 | + tags=tags, |
| 343 | + servers=servers, |
| 344 | + terms_of_service=terms_of_service, |
| 345 | + contact=contact, |
| 346 | + license_info=license_info, |
| 347 | + security_schemes=security_schemes, |
| 348 | + security=security, |
| 349 | + ) |
| 350 | + schema.openapi = "3.0.3" |
| 351 | + |
| 352 | + # Transform OpenAPI 3.1 into 3.0 |
| 353 | + def inner(yaml_dict): |
| 354 | + if isinstance(yaml_dict, dict): |
| 355 | + if "anyOf" in yaml_dict and isinstance((anyOf := yaml_dict["anyOf"]), list): |
| 356 | + for i, item in enumerate(anyOf): |
| 357 | + if isinstance(item, dict) and item.get("type") == "null": |
| 358 | + anyOf.pop(i) |
| 359 | + yaml_dict["nullable"] = True |
| 360 | + if "examples" in yaml_dict: |
| 361 | + examples = yaml_dict["examples"] |
| 362 | + del yaml_dict["examples"] |
| 363 | + if isinstance(examples, list) and len(examples): |
| 364 | + yaml_dict["example"] = examples[0] |
| 365 | + for value in yaml_dict.values(): |
| 366 | + inner(value) |
| 367 | + elif isinstance(yaml_dict, list): |
| 368 | + for item in yaml_dict: |
| 369 | + inner(item) |
| 370 | + |
| 371 | + model = json.loads( |
| 372 | + model_json( |
| 373 | + schema, |
| 374 | + by_alias=True, |
| 375 | + exclude_none=True, |
| 376 | + indent=2, |
| 377 | + ), |
| 378 | + ) |
| 379 | + |
| 380 | + inner(model) |
| 381 | + |
| 382 | + return json.dumps(model) |
0 commit comments