Skip to content

Commit

Permalink
add form processing into starlette, falcon
Browse files Browse the repository at this point in the history
  • Loading branch information
yedpodtrzitko committed Jun 26, 2022
1 parent 92e095d commit b6e257d
Show file tree
Hide file tree
Showing 17 changed files with 230 additions and 165 deletions.
Empty file added examples/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions examples/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pydantic import BaseModel, Field

from spectree import BaseFile


class File(BaseModel):
uid: str = None
file: BaseFile


class FileResp(BaseModel):
filename: str
type: str


class Query(BaseModel):
text: str = Field(
...,
max_length=100,
)
23 changes: 3 additions & 20 deletions examples/falcon_asgi_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import uvicorn
from pydantic import BaseModel, Field

from spectree import BaseFile, Response, SpecTree, Tag
from examples.common import File, FileResp, Query
from spectree import Response, SpecTree, Tag

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()


api = SpecTree(
"falcon-asgi",
title="Demo Service",
Expand All @@ -21,13 +21,6 @@
demo = Tag(name="demo", description="😊", externalDocs={"url": "https://github.com"})


class Query(BaseModel):
text: str = Field(
...,
max_length=100,
)


class Resp(BaseModel):
label: int = Field(
...,
Expand All @@ -53,16 +46,6 @@ class Data(BaseModel):
vip: bool


class File(BaseModel):
uid: str = None
file: BaseFile


class FileResp(BaseModel):
filename: str
type: str


class Ping:
def check(self):
pass
Expand Down Expand Up @@ -130,7 +113,7 @@ async def on_post(self, req, resp):
app = falcon.asgi.App()
app.add_route("/ping", Ping())
app.add_route("/api/{source}/{target}", Classification())
app.add_route("/api/upload-file", FileUpload())
app.add_route("/api/file_upload", FileUpload())
api.register(app)

uvicorn.run(app, log_level="info")
55 changes: 3 additions & 52 deletions examples/falcon_demo.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import json
import logging
from random import random
from wsgiref import simple_server

import falcon # type: ignore
from pydantic import BaseModel, Field

from spectree import BaseFile, Response, SpecTree, Tag
from examples.common import File, FileResp, Query
from spectree import Response, SpecTree, Tag

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()


api = SpecTree(
"falcon",
title="Demo Service",
Expand All @@ -25,13 +24,6 @@
demo = Tag(name="demo", description="😊", externalDocs={"url": "https://github.com"})


class Query(BaseModel):
text: str = Field(
...,
max_length=100,
)


class Resp(BaseModel):
label: int = Field(
...,
Expand All @@ -57,16 +49,6 @@ class Data(BaseModel):
vip: bool


class File(BaseModel):
uid: str = None
file: BaseFile


class FileResp(BaseModel):
filename: str
type: str


class Ping:
def check(self):
pass
Expand Down Expand Up @@ -130,37 +112,6 @@ def on_post(self, req, resp):
resp.media = {"filename": file.filename, "type": file.type}


class JSONFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
lr = logging.LogRecord(None, None, "", 0, "", (), None, None)
self.default_keys = [key for key in lr.__dict__]

def extra_data(self, record):
return {
key: getattr(record, key)
for key in record.__dict__
if key not in self.default_keys
}

def format(self, record):
log_data = {
"severity": record.levelname,
"path_name": record.pathname,
"function_name": record.funcName,
"message": record.msg,
**self.extra_data(record),
}
return json.dumps(log_data)


logger = logging.getLogger()
handler = logging.StreamHandler()
handler.setFormatter(JSONFormatter())
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)


if __name__ == "__main__":
"""
cmd:
Expand All @@ -170,7 +121,7 @@ def format(self, record):
app = falcon.API()
app.add_route("/ping", Ping())
app.add_route("/api/{source}/{target}", Classification())
app.add_route("/api/upload-file", FileUpload())
app.add_route("/api/file_upload", FileUpload())
api.register(app)

httpd = simple_server.make_server("localhost", 8000, app)
Expand Down
22 changes: 4 additions & 18 deletions examples/flask_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@
from flask.views import MethodView
from pydantic import BaseModel, Field

from spectree import BaseFile, Response, SpecTree
from examples.common import File, FileResp, Query
from spectree import Response, SpecTree

app = Flask(__name__)
api = SpecTree("flask")


class Query(BaseModel):
text: str = "default query strings"


class Resp(BaseModel):
label: int
score: float = Field(
Expand All @@ -39,16 +36,6 @@ class Config:
}


class File(BaseModel):
uid: str = None
file: BaseFile


class FileResp(BaseModel):
filename: str
type: str


class Language(str, Enum):
en = "en-US"
zh = "zh-CN"
Expand Down Expand Up @@ -99,7 +86,7 @@ def with_code_header():
return jsonify(language=request.context.headers.Lang), 203, {"X": 233}


@app.route("/api/upload-file", methods=["POST"])
@app.route("/api/file_upload", methods=["POST"])
@api.validate(form=File, resp=Response(HTTP_200=FileResp), tags=["file-upload"])
def with_file():
"""
Expand All @@ -108,14 +95,13 @@ def with_file():
demo for 'form'
"""
file = request.context.form.file
return jsonify(filename=file.filename, type=file.content_type)
return {"filename": file.filename, "type": file.content_type}


class UserAPI(MethodView):
@api.validate(json=Data, resp=Response(HTTP_200=Resp), tags=["test"])
def post(self):
return jsonify(label=int(10 * random()), score=random())
# return Resp(label=int(10 * random()), score=random())


if __name__ == "__main__":
Expand Down
19 changes: 2 additions & 17 deletions examples/starlette_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route

from spectree import BaseFile, Response, SpecTree

# from spectree.plugins.starlette_plugin import PydanticResponse
from examples.common import File, FileResp, Query
from spectree import Response, SpecTree

api = SpecTree("starlette")


class Query(BaseModel):
text: str


class Resp(BaseModel):
label: int = Field(
...,
Expand All @@ -35,16 +30,6 @@ class Data(BaseModel):
vip: bool


class File(BaseModel):
uid: str = None
file: BaseFile


class FileResp(BaseModel):
filename: str
type: str


@api.validate(query=Query, json=Data, resp=Response(HTTP_200=Resp), tags=["api"])
async def predict(request):
"""
Expand Down
48 changes: 14 additions & 34 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import inspect
import re
from functools import partial
from io import BytesIO
from typing import Any, Callable, Dict, List, Mapping, Optional

from pydantic import ValidationError
Expand Down Expand Up @@ -176,22 +175,6 @@ def parse_path(self, route, path_parameter_descriptions):

return f'/{"/".join(subs)}', parameters

def parse_field(self, field):
if isinstance(field, list):
return [self.parse_field(subfield) for subfield in field]

# When file name isn't ascii FieldStorage will not consider it.
encoded = field.disposition_options.get("filename*")
if encoded:
encoding, filename = encoded.split("''")
field.filename = filename
field.file = BytesIO(field.file.read().encode(encoding))
if getattr(field, "filename", False):
return field

# This is not a file, thus get flat value (not FieldStorage instance).
return field.value

def request_validation(self, req, query, json, form, headers, cookies):
if query:
req.context.query = query.parse_obj(req.params)
Expand All @@ -208,11 +191,7 @@ def request_validation(self, req, query, json, form, headers, cookies):
media = None
req.context.json = json.parse_obj(media)
if form:
_form, files = multipart.parse_form_data(req.env)
req_form = {}
for key in _form:
req_form[key] = self.parse_field(_form[key])

req_form = {x.name: x for x in req.get_media()}
req.context.form = form.parse_obj(req_form)

def validate(
Expand Down Expand Up @@ -296,18 +275,19 @@ async def request_validation(self, req, query, json, form, headers, cookies):
raise
media = None
req.context.json = json.parse_obj(media)
elif form:
# For WSGI compatibility
# https://asgi.readthedocs.io/en/latest/specs/www.html#wsgi-compatibility
req.scope.setdefault("REQUEST_METHOD", req.method)

_form, files = multipart.parse_form_data(req.env)

req_form = {}
for key in _form:
req_form[key] = self.parse_field(_form[key])

req.context.form = form.parse_obj(req_form)
if form:
try:
form_data = await req.get_media()
except self.FALCON_HTTP_ERROR as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
req.context.form = None
else:
res_data = {}
async for x in form_data:
res_data[x.name] = x
await x.data # TODO - how to avoid this?
req.context.form = form.parse_obj(res_data)

async def validate(
self,
Expand Down
2 changes: 1 addition & 1 deletion spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class FlaskPlugin(BasePlugin):
blueprint_state = None

def find_routes(self):
from flask import current_app # type: ignore
from flask import current_app

for rule in current_app.url_map.iter_rules():
if any(
Expand Down
13 changes: 8 additions & 5 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,15 @@ def register_route(self, app):

async def request_validation(self, request, query, json, form, headers, cookies):
has_data = request.method not in ("GET", "DELETE")
use_json = json and has_data and request.mimetype == "application/json"
use_form = form and has_data and request.mimetype in self.FORM_MIMETYPE
content_type = request.headers.get("content-type", "").lower()
use_json = json and has_data and content_type == "application/json"
use_form = (
form and has_data and any([x in content_type for x in self.FORM_MIMETYPE])
)
request.context = Context(
query.parse_obj(request.query_params) if query else None,
json.parse_raw(await request.body() or "{}") if use_json else None,
form.parse_obj(await request.form() or "{}") if use_form else None,
json.parse_obj(await request.json() or {}) if use_json else None,
form.parse_obj((await request.form()) or {}) if use_form else None,
headers.parse_obj(request.headers) if headers else None,
cookies.parse_obj(request.cookies) if cookies else None,
)
Expand Down Expand Up @@ -97,7 +100,7 @@ async def validate(
try:
await self.request_validation(request, query, json, form, headers, cookies)
if self.config.annotations:
for name in ("query", "json", "headers", "cookies"):
for name in ("query", "json", "form", "headers", "cookies"):
if func.__annotations__.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
Expand Down
2 changes: 1 addition & 1 deletion tests/flask_imports/dry_plugin_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_flask_upload_file(client):
file_content = "abcdef"
data = {"file": (io.BytesIO(file_content.encode("utf-8")), "test.txt")}
resp = client.post(
"/api/file_upload/",
"/api/file_upload",
data=data,
content_type="multipart/form-data",
)
Expand Down
Loading

0 comments on commit b6e257d

Please sign in to comment.