Skip to content

Commit

Permalink
enable function annotation based type detection
Browse files Browse the repository at this point in the history
  • Loading branch information
yoursvivek committed Jan 7, 2021
1 parent 5af4730 commit aef28f6
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 52 deletions.
1 change: 1 addition & 0 deletions spectree/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, **kwargs):
self._SUPPORT_UI = {"redoc", "swagger"}
self.MODE = "normal"
self._SUPPORT_MODE = {"normal", "strict", "greedy"}
self.ANNOTATIONS = False

self.TITLE = "Service API Document"
self.VERSION = "0.1"
Expand Down
4 changes: 4 additions & 0 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def validate(
req_validation_error, resp_validation_error = None, None
try:
self.request_validation(_req, query, json, headers, cookies)
if self.config.ANNOTATIONS:
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
kwargs[name] = getattr(_req.context, name)

except ValidationError as err:
req_validation_error = err
Expand Down
4 changes: 4 additions & 0 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def validate(
response, req_validation_error, resp_validation_error = None, None, None
try:
self.request_validation(request, query, json, headers, cookies)
if self.config.ANNOTATIONS:
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
req_validation_error = err
response = make_response(jsonify(err.errors()), 422)
Expand Down
4 changes: 4 additions & 0 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ async def validate(

try:
await self.request_validation(request, query, json, headers, cookies)
if self.config.ANNOTATIONS:
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
req_validation_error = err
response = JSONResponse(err.errors(), 422)
Expand Down
10 changes: 10 additions & 0 deletions spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ async def async_validate(*args, **kwargs):

validation = async_validate if self.backend.ASYNC else sync_validate

if self.config.ANNOTATIONS:
nonlocal query
query = func.__annotations__.get("query", query)
nonlocal json
json = func.__annotations__.get("json", json)
nonlocal headers
headers = func.__annotations__.get("headers", headers)
nonlocal cookies
cookies = func.__annotations__.get("cookies", cookies)

# register
for name, model in zip(
("query", "json", "headers", "cookies"), (query, json, headers, cookies)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def test_plugin_spec(api):

assert api.spec["tags"] == [{"name": tag} for tag in ("test", "health", "api")]

assert get_paths(api.spec) == ["/api/user/{name}", "/ping"]
assert get_paths(api.spec) == [
"/api/user/{name}", "/api/user_annotated/{name}", "/ping"
]

ping = api.spec["paths"]["/ping"]["get"]
assert ping["tags"] == ["test", "health"]
Expand Down
36 changes: 35 additions & 1 deletion tests/test_plugin_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def after_handler(req, resp, err, instance):
print(resp.get_header("X-Name"))


api = SpecTree("falcon", before=before_handler, after=after_handler)
api = SpecTree(
"falcon", before=before_handler, after=after_handler, annotations=True
)


class Ping:
Expand Down Expand Up @@ -60,9 +62,41 @@ def on_post(self, req, resp, name):
resp.media = {"name": req.context.json.name, "score": score}


class UserScoreAnnotated:
name = "sorted random score"

def extra_method(self):
pass

@api.validate(resp=Response(HTTP_200=StrDict))
def on_get(self, req, resp, name):
self.extra_method()
resp.media = {"name": name}

@api.validate(
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=["api", "test"],
)
def on_post(
self,
req,
resp,
name,
query: Query,
json: JSON,
cookies: Cookies
):
score = [randint(0, req.context.json.limit) for _ in range(5)]
score.sort(reverse=req.context.query.order)
assert req.context.cookies.pub == "abcdefg"
assert req.cookies["pub"] == "abcdefg"
resp.media = {"name": req.context.json.name, "score": score}


app = falcon.API()
app.add_route("/ping", Ping())
app.add_route("/api/user/{name}", UserScore())
app.add_route("/api/user_annotated/{name}", UserScoreAnnotated())
api.register(app)


Expand Down
72 changes: 47 additions & 25 deletions tests/test_plugin_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def api_after_handler(req, resp, err, _):
resp.headers["X-API"] = "OK"


api = SpecTree("flask", before=before_handler, after=after_handler)
api = SpecTree(
"flask", before=before_handler, after=after_handler, annotations=True
)
app = Flask(__name__)


Expand Down Expand Up @@ -51,6 +53,25 @@ def user_score(name):
return jsonify(name=request.context.json.name, score=score)


@app.route("/api/user_annotated/<name>", methods=["POST"])
@api.validate(
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=["api", "test"],
after=api_after_handler,
)
def user_score_annotated(
name,
query: Query,
json: JSON,
cookies: Cookies
):
score = [randint(0, json.limit) for _ in range(5)]
score.sort(reverse=query.order)
assert cookies.pub == "abcdefg"
assert request.cookies["pub"] == "abcdefg"
return jsonify(name=json.name, score=score)


api.register(app)


Expand All @@ -75,30 +96,31 @@ def test_flask_validate(client):
assert resp.headers.get("X-Error") == "Validation Error"

client.set_cookie("flask", "pub", "abcdefg")
resp = client.post(
"/api/user/flask?order=1",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.status_code == 200, resp.json
assert resp.headers.get("X-Validation") is None
assert resp.headers.get("X-API") == "OK"
assert resp.json["name"] == "flask"
assert resp.json["score"] == sorted(resp.json["score"], reverse=True)

resp = client.post(
"/api/user/flask?order=0",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)

resp = client.post(
"/api/user/flask?order=0",
data="name=flask&limit=10",
content_type="application/x-www-form-urlencoded",
)
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)
for fragment in ("user", "user_annotated"):
resp = client.post(
f"/api/{fragment}/flask?order=1",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.status_code == 200, resp.json
assert resp.headers.get("X-Validation") is None
assert resp.headers.get("X-API") == "OK"
assert resp.json["name"] == "flask"
assert resp.json["score"] == sorted(resp.json["score"], reverse=True)

resp = client.post(
f"/api/{fragment}/flask?order=0",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)

resp = client.post(
f"/api/{fragment}/flask?order=0",
data="name=flask&limit=10",
content_type="application/x-www-form-urlencoded",
)
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)


def test_flask_doc(client):
Expand Down
75 changes: 50 additions & 25 deletions tests/test_plugin_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def method_handler(req, resp, err, instance):
resp.headers["X-Name"] = instance.name


api = SpecTree("starlette", before=before_handler, after=after_handler)
api = SpecTree(
"starlette", before=before_handler, after=after_handler, annotations=True
)


class Ping(HTTPEndpoint):
Expand Down Expand Up @@ -59,6 +61,22 @@ async def user_score(request):
return JSONResponse({"name": request.context.json.name, "score": score})


@api.validate(
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=["api", "test"],
)
async def user_score_annotated(
request,
query: Query,
json: JSON,
cookies: Cookies
):
score = [randint(0, json.limit) for _ in range(5)]
score.sort(reverse=query.order)
assert cookies.pub == "abcdefg"
assert request.cookies["pub"] == "abcdefg"
return JSONResponse({"name": json.name, "score": score})

app = Starlette(
routes=[
Route("/ping", Ping),
Expand All @@ -70,7 +88,13 @@ async def user_score(request):
routes=[
Route("/{name}", user_score, methods=["POST"]),
],
)
),
Mount(
"/user_annotated",
routes=[
Route("/{name}", user_score_annotated, methods=["POST"]),
],
),
],
),
Mount("/static", app=StaticFiles(directory="docs"), name="static"),
Expand All @@ -96,29 +120,30 @@ def test_starlette_validate(client):
assert resp.headers.get("X-Name") == "Ping"
assert resp.headers.get("X-Validation") is None

resp = client.post("/api/user/starlette")
assert resp.status_code == 422
assert resp.headers.get("X-Error") == "Validation Error"

resp = client.post(
"/api/user/starlette?order=1",
json=dict(name="starlette", limit=10),
cookies=dict(pub="abcdefg"),
)
resp_body = resp.json()
assert resp_body["name"] == "starlette"
assert resp_body["score"] == sorted(resp_body["score"], reverse=True)
assert resp.headers.get("X-Validation") == "Pass"

resp = client.post(
"/api/user/starlette?order=0",
json=dict(name="starlette", limit=10),
cookies=dict(pub="abcdefg"),
)
resp_body = resp.json()
assert resp_body["name"] == "starlette"
assert resp_body["score"] == sorted(resp_body["score"], reverse=False)
assert resp.headers.get("X-Validation") == "Pass"
for fragment in ("user", "user_annotated"):
resp = client.post(f"/api/{fragment}/starlette")
assert resp.status_code == 422
assert resp.headers.get("X-Error") == "Validation Error"

resp = client.post(
f"/api/{fragment}/starlette?order=1",
json=dict(name="starlette", limit=10),
cookies=dict(pub="abcdefg"),
)
resp_body = resp.json()
assert resp_body["name"] == "starlette"
assert resp_body["score"] == sorted(resp_body["score"], reverse=True)
assert resp.headers.get("X-Validation") == "Pass"

resp = client.post(
f"/api/{fragment}/starlette?order=0",
json=dict(name="starlette", limit=10),
cookies=dict(pub="abcdefg"),
)
resp_body = resp.json()
assert resp_body["name"] == "starlette"
assert resp_body["score"] == sorted(resp_body["score"], reverse=False)
assert resp.headers.get("X-Validation") == "Pass"


def test_starlette_doc(client):
Expand Down

0 comments on commit aef28f6

Please sign in to comment.