Skip to content

Commit

Permalink
feat(event_handler): support richer Tags on OpenAPI object
Browse files Browse the repository at this point in the history
  • Loading branch information
rubenfonseca committed Dec 20, 2023
1 parent cedb5c9 commit 26ca413
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 39 deletions.
17 changes: 9 additions & 8 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
License,
OpenAPI,
Server,
Tag,
)
from aws_lambda_powertools.event_handler.openapi.params import Dependant
from aws_lambda_powertools.event_handler.openapi.types import (
Expand Down Expand Up @@ -1360,7 +1361,7 @@ def get_openapi_schema(
openapi_version: str = DEFAULT_OPENAPI_VERSION,
summary: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
tags: Optional[List[Union["Tag", str]]] = None,
servers: Optional[List["Server"]] = None,
terms_of_service: Optional[str] = None,
contact: Optional["Contact"] = None,
Expand All @@ -1381,7 +1382,7 @@ def get_openapi_schema(
A short summary of what the application does.
description: str, optional
A verbose explanation of the application behavior.
tags: List[str], optional
tags: List[Tag | str], optional
A list of tags used by the specification with additional metadata.
servers: List[Server], optional
An array of Server Objects, which provide connectivity information to a target server.
Expand All @@ -1403,7 +1404,7 @@ def get_openapi_schema(
get_compat_model_name_map,
get_definitions,
)
from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server
from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server, Tag
from aws_lambda_powertools.event_handler.openapi.types import (
COMPONENT_REF_TEMPLATE,
)
Expand Down Expand Up @@ -1468,7 +1469,7 @@ def get_openapi_schema(
if components:
output["components"] = components
if tags:
output["tags"] = [{"name": tag} for tag in tags]
output["tags"] = [Tag(name=tag) if isinstance(tag, str) else tag for tag in tags]

output["paths"] = {k: PathItem(**v) for k, v in paths.items()}

Expand All @@ -1482,7 +1483,7 @@ def get_openapi_json_schema(
openapi_version: str = DEFAULT_OPENAPI_VERSION,
summary: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
tags: Optional[List[Union["Tag", str]]] = None,
servers: Optional[List["Server"]] = None,
terms_of_service: Optional[str] = None,
contact: Optional["Contact"] = None,
Expand All @@ -1503,7 +1504,7 @@ def get_openapi_json_schema(
A short summary of what the application does.
description: str, optional
A verbose explanation of the application behavior.
tags: List[str], optional
tags: List[Tag, str], optional
A list of tags used by the specification with additional metadata.
servers: List[Server], optional
An array of Server Objects, which provide connectivity information to a target server.
Expand Down Expand Up @@ -1548,7 +1549,7 @@ def enable_swagger(
openapi_version: str = DEFAULT_OPENAPI_VERSION,
summary: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
tags: Optional[List[Union["Tag", str]]] = None,
servers: Optional[List["Server"]] = None,
terms_of_service: Optional[str] = None,
contact: Optional["Contact"] = None,
Expand All @@ -1573,7 +1574,7 @@ def enable_swagger(
A short summary of what the application does.
description: str, optional
A verbose explanation of the application behavior.
tags: List[str], optional
tags: List[Tag, str], optional
A list of tags used by the specification with additional metadata.
servers: List[Server], optional
An array of Server Objects, which provide connectivity information to a target server.
Expand Down
31 changes: 0 additions & 31 deletions tests/functional/event_handler/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,37 +349,6 @@ def handler(user: Annotated[User, Body(embed=True)]):
assert body_post_handler_schema.properties["user"].ref == "#/components/schemas/User"


def test_openapi_with_tags():
app = APIGatewayRestResolver()

@app.get("/users")
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema(tags=["Orders"])
assert len(schema.tags) == 1

tag = schema.tags[0]
assert tag.name == "Orders"


def test_openapi_operation_with_tags():
app = APIGatewayRestResolver()

@app.get("/users", tags=["Users"])
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/users"].get
assert len(get.tags) == 1

tag = get.tags[0]
assert tag == "Users"


def test_openapi_with_excluded_operations():
app = APIGatewayRestResolver()

Expand Down
53 changes: 53 additions & 0 deletions tests/functional/event_handler/test_openapi_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.models import Tag


def test_openapi_with_tags():
app = APIGatewayRestResolver()

@app.get("/users")
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema(tags=["Orders"])
assert schema.tags is not None
assert len(schema.tags) == 1

tag = schema.tags[0]
assert tag.name == "Orders"


def test_openapi_with_object_tags():
app = APIGatewayRestResolver()

@app.get("/users")
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema(tags=[Tag(name="Orders", description="Order description tag")])
assert schema.tags is not None
assert len(schema.tags) == 1

tag = schema.tags[0]
assert tag.name == "Orders"
assert tag.description == "Order description tag"


def test_openapi_operation_with_tags():
app = APIGatewayRestResolver()

@app.get("/users", tags=["Users"])
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema()
assert schema.paths is not None
assert len(schema.paths.keys()) == 1

get = schema.paths["/users"].get
assert get is not None
assert get.tags is not None
assert len(get.tags) == 1

tag = get.tags[0]
assert tag == "Users"

0 comments on commit 26ca413

Please sign in to comment.