Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(event_handler): support richer top level Tags #3543

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
License,
OpenAPI,
Server,
Tag,
)
from aws_lambda_powertools.event_handler.openapi.params import Dependant
from aws_lambda_powertools.event_handler.openapi.types import (
Expand Down Expand Up @@ -1360,7 +1361,7 @@ def get_openapi_schema(
openapi_version: str = DEFAULT_OPENAPI_VERSION,
summary: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
tags: Optional[List[Union["Tag", str]]] = None,
servers: Optional[List["Server"]] = None,
terms_of_service: Optional[str] = None,
contact: Optional["Contact"] = None,
Expand All @@ -1381,7 +1382,7 @@ def get_openapi_schema(
A short summary of what the application does.
description: str, optional
A verbose explanation of the application behavior.
tags: List[str], optional
tags: List[Tag | str], optional
A list of tags used by the specification with additional metadata.
servers: List[Server], optional
An array of Server Objects, which provide connectivity information to a target server.
Expand All @@ -1403,7 +1404,7 @@ def get_openapi_schema(
get_compat_model_name_map,
get_definitions,
)
from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server
from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server, Tag
from aws_lambda_powertools.event_handler.openapi.types import (
COMPONENT_REF_TEMPLATE,
)
Expand Down Expand Up @@ -1468,7 +1469,7 @@ def get_openapi_schema(
if components:
output["components"] = components
if tags:
output["tags"] = [{"name": tag} for tag in tags]
output["tags"] = [Tag(name=tag) if isinstance(tag, str) else tag for tag in tags]

output["paths"] = {k: PathItem(**v) for k, v in paths.items()}

Expand All @@ -1482,7 +1483,7 @@ def get_openapi_json_schema(
openapi_version: str = DEFAULT_OPENAPI_VERSION,
summary: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
tags: Optional[List[Union["Tag", str]]] = None,
servers: Optional[List["Server"]] = None,
terms_of_service: Optional[str] = None,
contact: Optional["Contact"] = None,
Expand All @@ -1503,7 +1504,7 @@ def get_openapi_json_schema(
A short summary of what the application does.
description: str, optional
A verbose explanation of the application behavior.
tags: List[str], optional
tags: List[Tag, str], optional
A list of tags used by the specification with additional metadata.
servers: List[Server], optional
An array of Server Objects, which provide connectivity information to a target server.
Expand Down Expand Up @@ -1548,7 +1549,7 @@ def enable_swagger(
openapi_version: str = DEFAULT_OPENAPI_VERSION,
summary: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
tags: Optional[List[Union["Tag", str]]] = None,
servers: Optional[List["Server"]] = None,
terms_of_service: Optional[str] = None,
contact: Optional["Contact"] = None,
Expand All @@ -1573,7 +1574,7 @@ def enable_swagger(
A short summary of what the application does.
description: str, optional
A verbose explanation of the application behavior.
tags: List[str], optional
tags: List[Tag, str], optional
A list of tags used by the specification with additional metadata.
servers: List[Server], optional
An array of Server Objects, which provide connectivity information to a target server.
Expand Down
31 changes: 0 additions & 31 deletions tests/functional/event_handler/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,37 +349,6 @@ def handler(user: Annotated[User, Body(embed=True)]):
assert body_post_handler_schema.properties["user"].ref == "#/components/schemas/User"


def test_openapi_with_tags():
app = APIGatewayRestResolver()

@app.get("/users")
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema(tags=["Orders"])
assert len(schema.tags) == 1

tag = schema.tags[0]
assert tag.name == "Orders"


def test_openapi_operation_with_tags():
app = APIGatewayRestResolver()

@app.get("/users", tags=["Users"])
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/users"].get
assert len(get.tags) == 1

tag = get.tags[0]
assert tag == "Users"


def test_openapi_with_excluded_operations():
app = APIGatewayRestResolver()

Expand Down
53 changes: 53 additions & 0 deletions tests/functional/event_handler/test_openapi_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.models import Tag


def test_openapi_with_tags():
app = APIGatewayRestResolver()

@app.get("/users")
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema(tags=["Orders"])
assert schema.tags is not None
assert len(schema.tags) == 1

tag = schema.tags[0]
assert tag.name == "Orders"


def test_openapi_with_object_tags():
app = APIGatewayRestResolver()

@app.get("/users")
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema(tags=[Tag(name="Orders", description="Order description tag")])
assert schema.tags is not None
assert len(schema.tags) == 1

tag = schema.tags[0]
assert tag.name == "Orders"
assert tag.description == "Order description tag"


def test_openapi_operation_with_tags():
app = APIGatewayRestResolver()

@app.get("/users", tags=["Users"])
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema()
assert schema.paths is not None
assert len(schema.paths.keys()) == 1

get = schema.paths["/users"].get
assert get is not None
assert get.tags is not None
assert len(get.tags) == 1

tag = get.tags[0]
assert tag == "Users"