Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable view function annotation based type detection #100

Merged
merged 2 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ Check the [examples](/examples) folder.

If the request doesn't pass the validation, it will return a 422 with JSON error message(ctx, loc, msg, type).

### Opt-in type annotation feature
This library also supports injection of validated fields into view function arguments along with parameter annotation based type declaration. This works well with linters that can take advantage of typing features like mypy. See examples section below.

## How To

> How to add summary and description to endpoints?
Expand All @@ -65,7 +68,7 @@ Check the [pydantic](https://pydantic-docs.helpmanual.io/usage/schema/) document

Of course. Check the [config](https://spectree.readthedocs.io/en/latest/config.html) document.

You can update the config when init the spectree like:
You can update the config when init the spectree like:

```py
SpecTree('flask', title='Demo API', version='v1.0', path='doc')
Expand Down Expand Up @@ -153,6 +156,25 @@ if __name__ == "__main__":

```

#### Flask example with type annotation

```python
# opt in into annotations feature
api = SpecTree("flask", annotations=True)


@app.route('/api/user', methods=['POST'])
@api.validate(resp=Response(HTTP_200=Message, HTTP_403=None), tags=['api'])
def user_profile(json: Profile):
"""
verify user profile (summary of this endpoint)

user's name, user's age, ... (long description)
"""
print(json) # or `request.json`
return jsonify(text='it works')
```

### Falcon

```py
Expand Down Expand Up @@ -201,6 +223,25 @@ if __name__ == "__main__":

```

#### Falcon with type annotations

```python
# opt in into annotations feature
api = SpecTree("flask", annotations=True)


class UserProfile:
@api.validate(resp=Response(HTTP_200=Message, HTTP_403=None), tags=['api'])
def on_post(self, req, resp, json: Profile):
"""
verify user profile (summary of this endpoint)

user's name, user's age, ... (long description)
"""
print(req.context.json) # or `req.media`
resp.media = {'text': 'it works'}
```

### Starlette

```py
Expand Down Expand Up @@ -252,6 +293,25 @@ if __name__ == "__main__":

```

#### Starlette example with type annotations

```python
# opt in into annotations feature
api = SpecTree("flask", annotations=True)


@api.validate(resp=Response(HTTP_200=Message, HTTP_403=None), tags=['api'])
async def user_profile(request, json=Profile):
"""
verify user profile (summary of this endpoint)

user's name, user's age, ... (long description)
"""
print(request.context.json) # or await request.json()
return JSONResponse({'text': 'it works'})
```


## FAQ

> ValidationError: missing field for headers
Expand Down
1 change: 1 addition & 0 deletions spectree/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, **kwargs):
self._SUPPORT_UI = {"redoc", "swagger"}
self.MODE = "normal"
self._SUPPORT_MODE = {"normal", "strict", "greedy"}
self.ANNOTATIONS = False

self.TITLE = "Service API Document"
self.VERSION = "0.1"
Expand Down
4 changes: 4 additions & 0 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def validate(
req_validation_error, resp_validation_error = None, None
try:
self.request_validation(_req, query, json, headers, cookies)
if self.config.ANNOTATIONS:
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
kwargs[name] = getattr(_req.context, name)

except ValidationError as err:
req_validation_error = err
Expand Down
4 changes: 4 additions & 0 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def validate(
response, req_validation_error, resp_validation_error = None, None, None
try:
self.request_validation(request, query, json, headers, cookies)
if self.config.ANNOTATIONS:
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
req_validation_error = err
response = make_response(jsonify(err.errors()), 422)
Expand Down
4 changes: 4 additions & 0 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ async def validate(

try:
await self.request_validation(request, query, json, headers, cookies)
if self.config.ANNOTATIONS:
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
req_validation_error = err
response = JSONResponse(err.errors(), 422)
Expand Down
10 changes: 10 additions & 0 deletions spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ async def async_validate(*args, **kwargs):

validation = async_validate if self.backend.ASYNC else sync_validate

if self.config.ANNOTATIONS:
nonlocal query
query = func.__annotations__.get("query", query)
nonlocal json
json = func.__annotations__.get("json", json)
nonlocal headers
headers = func.__annotations__.get("headers", headers)
nonlocal cookies
cookies = func.__annotations__.get("cookies", cookies)

# register
for name, model in zip(
("query", "json", "headers", "cookies"), (query, json, headers, cookies)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def test_plugin_spec(api):

assert api.spec["tags"] == [{"name": tag} for tag in ("test", "health", "api")]

assert get_paths(api.spec) == ["/api/user/{name}", "/ping"]
assert get_paths(api.spec) == [
"/api/user/{name}",
"/api/user_annotated/{name}",
"/ping",
]

ping = api.spec["paths"]["/ping"]["get"]
assert ping["tags"] == ["test", "health"]
Expand Down
26 changes: 25 additions & 1 deletion tests/test_plugin_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def after_handler(req, resp, err, instance):
print(resp.get_header("X-Name"))


api = SpecTree("falcon", before=before_handler, after=after_handler)
api = SpecTree("falcon", before=before_handler, after=after_handler, annotations=True)


class Ping:
Expand Down Expand Up @@ -60,9 +60,33 @@ def on_post(self, req, resp, name):
resp.media = {"name": req.context.json.name, "score": score}


class UserScoreAnnotated:
name = "sorted random score"

def extra_method(self):
pass

@api.validate(resp=Response(HTTP_200=StrDict))
def on_get(self, req, resp, name):
self.extra_method()
resp.media = {"name": name}

@api.validate(
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=["api", "test"],
)
def on_post(self, req, resp, name, query: Query, json: JSON, cookies: Cookies):
score = [randint(0, req.context.json.limit) for _ in range(5)]
score.sort(reverse=req.context.query.order)
assert req.context.cookies.pub == "abcdefg"
assert req.cookies["pub"] == "abcdefg"
resp.media = {"name": req.context.json.name, "score": score}


app = falcon.API()
app.add_route("/ping", Ping())
app.add_route("/api/user/{name}", UserScore())
app.add_route("/api/user_annotated/{name}", UserScoreAnnotated())
api.register(app)


Expand Down
65 changes: 40 additions & 25 deletions tests/test_plugin_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def api_after_handler(req, resp, err, _):
resp.headers["X-API"] = "OK"


api = SpecTree("flask", before=before_handler, after=after_handler)
api = SpecTree("flask", before=before_handler, after=after_handler, annotations=True)
app = Flask(__name__)
app.config["TESTING"] = True

Expand Down Expand Up @@ -52,6 +52,20 @@ def user_score(name):
return jsonify(name=request.context.json.name, score=score)


@app.route("/api/user_annotated/<name>", methods=["POST"])
@api.validate(
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=["api", "test"],
after=api_after_handler,
)
def user_score_annotated(name, query: Query, json: JSON, cookies: Cookies):
score = [randint(0, json.limit) for _ in range(5)]
score.sort(reverse=query.order)
assert cookies.pub == "abcdefg"
assert request.cookies["pub"] == "abcdefg"
return jsonify(name=json.name, score=score)


# INFO: ensures that spec is calculated and cached _after_ registering
# view functions for validations. This enables tests to access `api.spec`
# without app_context.
Expand Down Expand Up @@ -83,30 +97,31 @@ def test_flask_validate(client):
assert resp.headers.get("X-Error") == "Validation Error"

client.set_cookie("flask", "pub", "abcdefg")
resp = client.post(
"/api/user/flask?order=1",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.status_code == 200, resp.json
assert resp.headers.get("X-Validation") is None
assert resp.headers.get("X-API") == "OK"
assert resp.json["name"] == "flask"
assert resp.json["score"] == sorted(resp.json["score"], reverse=True)

resp = client.post(
"/api/user/flask?order=0",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)

resp = client.post(
"/api/user/flask?order=0",
data="name=flask&limit=10",
content_type="application/x-www-form-urlencoded",
)
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)
for fragment in ("user", "user_annotated"):
resp = client.post(
f"/api/{fragment}/flask?order=1",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.status_code == 200, resp.json
assert resp.headers.get("X-Validation") is None
assert resp.headers.get("X-API") == "OK"
assert resp.json["name"] == "flask"
assert resp.json["score"] == sorted(resp.json["score"], reverse=True)

resp = client.post(
f"/api/{fragment}/flask?order=0",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)

resp = client.post(
f"/api/{fragment}/flask?order=0",
data="name=flask&limit=10",
content_type="application/x-www-form-urlencoded",
)
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)


def test_flask_doc(client):
Expand Down
71 changes: 46 additions & 25 deletions tests/test_plugin_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def method_handler(req, resp, err, instance):
resp.headers["X-Name"] = instance.name


api = SpecTree("starlette", before=before_handler, after=after_handler)
api = SpecTree(
"starlette", before=before_handler, after=after_handler, annotations=True
)


class Ping(HTTPEndpoint):
Expand Down Expand Up @@ -59,6 +61,18 @@ async def user_score(request):
return JSONResponse({"name": request.context.json.name, "score": score})


@api.validate(
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=["api", "test"],
)
async def user_score_annotated(request, query: Query, json: JSON, cookies: Cookies):
score = [randint(0, json.limit) for _ in range(5)]
score.sort(reverse=query.order)
assert cookies.pub == "abcdefg"
assert request.cookies["pub"] == "abcdefg"
return JSONResponse({"name": json.name, "score": score})


app = Starlette(
routes=[
Route("/ping", Ping),
Expand All @@ -70,7 +84,13 @@ async def user_score(request):
routes=[
Route("/{name}", user_score, methods=["POST"]),
],
)
),
Mount(
"/user_annotated",
routes=[
Route("/{name}", user_score_annotated, methods=["POST"]),
],
),
],
),
Mount("/static", app=StaticFiles(directory="docs"), name="static"),
Expand All @@ -96,29 +116,30 @@ def test_starlette_validate(client):
assert resp.headers.get("X-Name") == "Ping"
assert resp.headers.get("X-Validation") is None

resp = client.post("/api/user/starlette")
assert resp.status_code == 422
assert resp.headers.get("X-Error") == "Validation Error"

resp = client.post(
"/api/user/starlette?order=1",
json=dict(name="starlette", limit=10),
cookies=dict(pub="abcdefg"),
)
resp_body = resp.json()
assert resp_body["name"] == "starlette"
assert resp_body["score"] == sorted(resp_body["score"], reverse=True)
assert resp.headers.get("X-Validation") == "Pass"

resp = client.post(
"/api/user/starlette?order=0",
json=dict(name="starlette", limit=10),
cookies=dict(pub="abcdefg"),
)
resp_body = resp.json()
assert resp_body["name"] == "starlette"
assert resp_body["score"] == sorted(resp_body["score"], reverse=False)
assert resp.headers.get("X-Validation") == "Pass"
for fragment in ("user", "user_annotated"):
resp = client.post(f"/api/{fragment}/starlette")
assert resp.status_code == 422
assert resp.headers.get("X-Error") == "Validation Error"

resp = client.post(
f"/api/{fragment}/starlette?order=1",
json=dict(name="starlette", limit=10),
cookies=dict(pub="abcdefg"),
)
resp_body = resp.json()
assert resp_body["name"] == "starlette"
assert resp_body["score"] == sorted(resp_body["score"], reverse=True)
assert resp.headers.get("X-Validation") == "Pass"

resp = client.post(
f"/api/{fragment}/starlette?order=0",
json=dict(name="starlette", limit=10),
cookies=dict(pub="abcdefg"),
)
resp_body = resp.json()
assert resp_body["name"] == "starlette"
assert resp_body["score"] == sorted(resp_body["score"], reverse=False)
assert resp.headers.get("X-Validation") == "Pass"


def test_starlette_doc(client):
Expand Down