diff --git a/connexion/__init__.py b/connexion/__init__.py index 6cd8d8cf9..6174c4bbc 100755 --- a/connexion/__init__.py +++ b/connexion/__init__.py @@ -1,12 +1,13 @@ -from flask import (abort, request, send_file, send_from_directory, # NOQA - render_template, render_template_string, url_for) import werkzeug.exceptions as exceptions # NOQA -from .app import App # NOQA -from .api import Api # NOQA +from .apps import AbstractApp, FlaskApp # NOQA +from .apis import AbstractAPI, FlaskApi # NOQA from .exceptions import ProblemException # NOQA from .problem import problem # NOQA from .decorators.produces import NoContent # NOQA from .resolver import Resolution, Resolver, RestyResolver # NOQA +App = FlaskApp +Api = FlaskApi + # This version is replaced during release process. __version__ = '2016.0.dev1' diff --git a/connexion/apis/__init__.py b/connexion/apis/__init__.py new file mode 100644 index 000000000..defe04391 --- /dev/null +++ b/connexion/apis/__init__.py @@ -0,0 +1,4 @@ +from .abstract import AbstractAPI +from .flask_api import FlaskApi + +__all__ = ['AbstractAPI', 'FlaskApi'] diff --git a/connexion/api.py b/connexion/apis/abstract.py similarity index 78% rename from connexion/api.py rename to connexion/apis/abstract.py index b051a1ab1..f84d7a930 100644 --- a/connexion/api.py +++ b/connexion/apis/abstract.py @@ -1,29 +1,33 @@ +import abc import copy import logging import pathlib import sys -import flask -import jinja2 import six -import werkzeug.exceptions import yaml from swagger_spec_validator.validator20 import validate_spec -from . import utils -from .exceptions import ResolverError -from .handlers import AuthErrorHandler -from .operation import Operation -from .resolver import Resolver +import jinja2 +from ..exceptions import ResolverError +from ..operation import Operation +from ..resolver import Resolver -MODULE_PATH = pathlib.Path(__file__).absolute().parent +MODULE_PATH = pathlib.Path(__file__).absolute().parent.parent SWAGGER_UI_PATH = MODULE_PATH / 'vendor' / 'swagger-ui' SWAGGER_UI_URL = 'ui' RESOLVER_ERROR_ENDPOINT_RANDOM_DIGITS = 6 -logger = logging.getLogger('connexion.api') +logger = logging.getLogger('connexion.apis') + + +def canonical_base_url(base_path): + """ + Make given "basePath" a canonical base URL which can be prepended to paths starting with "/". + """ + return base_path.rstrip('/') def compatibility_layer(spec): @@ -47,19 +51,13 @@ def compatibility_layer(spec): return spec -def canonical_base_url(base_path): +@six.add_metaclass(abc.ABCMeta) +class AbstractAPI(object): """ - Make given "basePath" a canonical base URL which can be prepended to paths starting with "/". - """ - return base_path.rstrip('/') - - -class Api(object): - """ - Single API that corresponds to a flask blueprint + Defines an abstract interface for a Swagger API """ - def __init__(self, specification, base_url=None, arguments=None, + def __init__(self, specification, jsonifier, base_url=None, arguments=None, swagger_json=None, swagger_ui=None, swagger_path=None, swagger_url=None, validate_responses=False, strict_validation=False, resolver=None, auth_all_paths=False, debug=False, resolver_error_handler=None, @@ -111,13 +109,12 @@ def __init__(self, specification, base_url=None, arguments=None, spec = copy.deepcopy(self.specification) validate_spec(spec) + self.swagger_path = swagger_path or SWAGGER_UI_PATH + self.swagger_url = swagger_url or SWAGGER_UI_URL + # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#fixed-fields # If base_url is not on provided then we try to read it from the swagger.yaml or use / by default - if base_url is None: - self.base_url = canonical_base_url(self.specification.get('basePath', '')) - else: - self.base_url = canonical_base_url(base_url) - self.specification['basePath'] = base_url + self._set_base_url(base_url) # A list of MIME types the APIs can produce. This is global to all APIs but can be overridden on specific # API calls. @@ -135,9 +132,6 @@ def __init__(self, specification, base_url=None, arguments=None, self.parameter_definitions = self.specification.get('parameters', {}) self.response_definitions = self.specification.get('responses', {}) - self.swagger_path = swagger_path or SWAGGER_UI_PATH - self.swagger_url = swagger_url or SWAGGER_UI_URL - self.resolver = resolver or Resolver() logger.debug('Validate Responses: %s', str(validate_responses)) @@ -149,8 +143,7 @@ def __init__(self, specification, base_url=None, arguments=None, logger.debug('Pythonic params: %s', str(pythonic_params)) self.pythonic_params = pythonic_params - # Create blueprint and endpoints - self.blueprint = self.create_blueprint() + self.jsonifier = jsonifier if swagger_json: self.add_swagger_json() @@ -160,7 +153,32 @@ def __init__(self, specification, base_url=None, arguments=None, self.add_paths() if auth_all_paths: - self.add_auth_on_not_found() + self.add_auth_on_not_found(self.security, self.security_definitions) + + def _set_base_url(self, base_url): + if base_url is None: + self.base_url = canonical_base_url(self.specification.get('basePath', '')) + else: + self.base_url = canonical_base_url(base_url) + self.specification['basePath'] = base_url + + @abc.abstractmethod + def add_swagger_json(self): + """ + Adds swagger json to {base_url}/swagger.json + """ + + @abc.abstractmethod + def add_swagger_ui(self): + """ + Adds swagger ui to {base_url}/ui/ + """ + + @abc.abstractmethod + def add_auth_on_not_found(self, security, security_definitions): + """ + Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass. + """ def add_operation(self, method, path, swagger_operation, path_parameters): """ @@ -179,7 +197,8 @@ def add_operation(self, method, path, swagger_operation, path_parameters): :type path: str :type swagger_operation: dict """ - operation = Operation(method=method, + operation = Operation(self, + method=method, path=path, path_parameters=path_parameters, operation=swagger_operation, @@ -197,6 +216,13 @@ def add_operation(self, method, path, swagger_operation, path_parameters): pythonic_params=self.pythonic_params) self._add_operation_internal(method, path, operation) + @abc.abstractmethod + def _add_operation_internal(self, method, path, operation): + """ + Adds the operation according to the user framework in use. + It will be used to register the operation on the user framework router. + """ + def _add_resolver_error_handler(self, method, path, err): """ Adds a handler for ResolverError for the given method and path. @@ -216,14 +242,6 @@ def _add_resolver_error_handler(self, method, path, err): randomize_endpoint=RESOLVER_ERROR_ENDPOINT_RANDOM_DIGITS) self._add_operation_internal(method, path, operation) - def _add_operation_internal(self, method, path, operation): - operation_id = operation.operation_id - logger.debug('... Adding %s -> %s', method.upper(), operation_id, - extra=vars(operation)) - - flask_path = utils.flaskify_path(path, operation.get_path_parameter_types()) - self.blueprint.add_url_rule(flask_path, operation.endpoint_name, operation.function, methods=[method]) - def add_paths(self, paths=None): """ Adds the paths defined in the specification as endpoints @@ -244,8 +262,7 @@ def add_paths(self, paths=None): try: self.add_operation(method, path, endpoint, path_parameters) except ResolverError as err: - # If we have an error handler for resolver errors, add it - # as an operation (but randomize the flask endpoint name). + # If we have an error handler for resolver errors, add it as an operation. # Otherwise treat it as any other error. if self.resolver_error_handler is not None: self._add_resolver_error_handler(method, path, err) @@ -269,59 +286,6 @@ def _handle_add_operation_error(self, path, method, exc_info): logger.error(error_msg) six.reraise(*exc_info) - def add_auth_on_not_found(self): - """ - Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass. - """ - logger.debug('Adding path not found authentication') - not_found_error = AuthErrorHandler(werkzeug.exceptions.NotFound(), security=self.security, - security_definitions=self.security_definitions) - endpoint_name = "{name}_not_found".format(name=self.blueprint.name) - self.blueprint.add_url_rule('/', endpoint_name, not_found_error.function) - - def add_swagger_json(self): - """ - Adds swagger json to {base_url}/swagger.json - """ - logger.debug('Adding swagger.json: %s/swagger.json', self.base_url) - endpoint_name = "{name}_swagger_json".format(name=self.blueprint.name) - self.blueprint.add_url_rule('/swagger.json', - endpoint_name, - lambda: flask.jsonify(self.specification)) - - def add_swagger_ui(self): - """ - Adds swagger ui to {base_url}/ui/ - """ - logger.debug('Adding swagger-ui: %s/%s/', self.base_url, self.swagger_url) - static_endpoint_name = "{name}_swagger_ui_static".format(name=self.blueprint.name) - self.blueprint.add_url_rule('/{swagger_url}/'.format(swagger_url=self.swagger_url), - static_endpoint_name, self.swagger_ui_static) - index_endpoint_name = "{name}_swagger_ui_index".format(name=self.blueprint.name) - self.blueprint.add_url_rule('/{swagger_url}/'.format(swagger_url=self.swagger_url), - index_endpoint_name, self.swagger_ui_index) - - def create_blueprint(self, base_url=None): - """ - :type base_url: str | None - :rtype: flask.Blueprint - """ - base_url = base_url or self.base_url - logger.debug('Creating API blueprint: %s', base_url) - endpoint = utils.flaskify_endpoint(base_url) - blueprint = flask.Blueprint(endpoint, __name__, url_prefix=base_url, - template_folder=str(self.swagger_path)) - return blueprint - - def swagger_ui_index(self): - return flask.render_template('index.html', api_url=self.base_url) - - def swagger_ui_static(self, filename): - """ - :type filename: str - """ - return flask.send_from_directory(str(self.swagger_path), filename) - def load_spec_from_file(self, arguments, specification): arguments = arguments or {} @@ -334,3 +298,23 @@ def load_spec_from_file(self, arguments, specification): swagger_string = jinja2.Template(swagger_template).render(**arguments) return yaml.safe_load(swagger_string) # type: dict + + @classmethod + @abc.abstractmethod + def get_request(self, *args, **kwargs): + """ + This method converts the user framework request to a ConnexionRequest. + """ + + @classmethod + @abc.abstractmethod + def get_response(self, response, mimetype=None, request=None): + """ + This method converts the ConnexionResponse to a user framework response. + :param response: A response to cast. + :param mimetype: The response mimetype. + :param request: The request associated with this response (the user framework request). + + :type response: ConnexionResponse + :type mimetype: str + """ diff --git a/connexion/apis/flask_api.py b/connexion/apis/flask_api.py new file mode 100644 index 000000000..ad970a2b2 --- /dev/null +++ b/connexion/apis/flask_api.py @@ -0,0 +1,254 @@ +import logging + +import six + +import flask +import werkzeug.exceptions +from connexion import flask_utils +from connexion.apis.abstract import AbstractAPI +from connexion.decorators.produces import BaseSerializer, NoContent +from connexion.handlers import AuthErrorHandler +from connexion.request import ConnexionRequest +from connexion.response import ConnexionResponse +from connexion.utils import is_json_mimetype + +logger = logging.getLogger('connexion.apis.flask_api') + + +class Jsonifier(BaseSerializer): + @staticmethod + def dumps(data): + """ Central point where JSON serialization happens inside + Connexion. + """ + return "{}\n".format(flask.json.dumps(data, indent=2)) + + @staticmethod + def loads(data): + """ Central point where JSON serialization happens inside + Connexion. + """ + if isinstance(data, six.binary_type): + data = data.decode() + + try: + return flask.json.loads(data) + except Exception as error: + if isinstance(data, six.string_types): + return data + + def __repr__(self): + """ + :rtype: str + """ + return ''.format(self.mimetype) + + +class FlaskApi(AbstractAPI): + jsonifier = Jsonifier + + def __init__(self, specification, base_url=None, arguments=None, + swagger_json=None, swagger_ui=None, swagger_path=None, swagger_url=None, + validate_responses=False, strict_validation=False, resolver=None, + auth_all_paths=False, debug=False, resolver_error_handler=None, + validator_map=None, pythonic_params=False): + super(FlaskApi, self).__init__( + specification, FlaskApi.jsonifier, base_url=base_url, arguments=arguments, + swagger_json=swagger_json, swagger_ui=swagger_ui, + swagger_path=swagger_path, swagger_url=swagger_url, + validate_responses=validate_responses, strict_validation=strict_validation, + resolver=resolver, auth_all_paths=auth_all_paths, debug=debug, + resolver_error_handler=resolver_error_handler, validator_map=validator_map, + pythonic_params=pythonic_params + ) + + def _set_base_url(self, base_url): + super(FlaskApi, self)._set_base_url(base_url) + self._set_blueprint() + + def _set_blueprint(self): + logger.debug('Creating API blueprint: %s', self.base_url) + endpoint = flask_utils.flaskify_endpoint(self.base_url) + self.blueprint = flask.Blueprint(endpoint, __name__, url_prefix=self.base_url, + template_folder=str(self.swagger_path)) + + def add_swagger_json(self): + """ + Adds swagger json to {base_url}/swagger.json + """ + logger.debug('Adding swagger.json: %s/swagger.json', self.base_url) + endpoint_name = "{name}_swagger_json".format(name=self.blueprint.name) + self.blueprint.add_url_rule('/swagger.json', + endpoint_name, + lambda: flask.jsonify(self.specification)) + + def add_swagger_ui(self): + """ + Adds swagger ui to {base_url}/ui/ + """ + logger.debug('Adding swagger-ui: %s/%s/', self.base_url, self.swagger_url) + static_endpoint_name = "{name}_swagger_ui_static".format(name=self.blueprint.name) + self.blueprint.add_url_rule('/{swagger_url}/'.format(swagger_url=self.swagger_url), + static_endpoint_name, self.swagger_ui_static) + index_endpoint_name = "{name}_swagger_ui_index".format(name=self.blueprint.name) + self.blueprint.add_url_rule('/{swagger_url}/'.format(swagger_url=self.swagger_url), + index_endpoint_name, self.swagger_ui_index) + + def swagger_ui_index(self): + return flask.render_template('index.html', api_url=self.base_url) + + def swagger_ui_static(self, filename): + """ + :type filename: str + """ + return flask.send_from_directory(str(self.swagger_path), filename) + + def add_auth_on_not_found(self, security, security_definitions): + """ + Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass. + """ + logger.debug('Adding path not found authentication') + not_found_error = AuthErrorHandler(self, werkzeug.exceptions.NotFound(), security=security, + security_definitions=security_definitions) + endpoint_name = "{name}_not_found".format(name=self.blueprint.name) + self.blueprint.add_url_rule('/', endpoint_name, not_found_error.function) + + def _add_operation_internal(self, method, path, operation): + operation_id = operation.operation_id + logger.debug('... Adding %s -> %s', method.upper(), operation_id, + extra=vars(operation)) + + flask_path = flask_utils.flaskify_path(path, operation.get_path_parameter_types()) + endpoint_name = flask_utils.flaskify_endpoint(operation.operation_id, + operation.randomize_endpoint) + function = operation.function + self.blueprint.add_url_rule(flask_path, endpoint_name, function, methods=[method]) + + @classmethod + def get_response(cls, response, mimetype=None, request=None): + """Gets ConnexionResponse instance for the operation handler + result. Status Code and Headers for response. If only body + data is returned by the endpoint function, then the status + code will be set to 200 and no headers will be added. + + If the returned object is a flask.Response then it will just + pass the information needed to recreate it. + + :type operation_handler_result: flask.Response | (flask.Response, int) | (flask.Response, int, dict) + :rtype: ConnexionRequest + """ + logger.debug('Getting data and status code', + extra={ + 'data': response, + 'data_type': type(response), + 'url': flask.request.url + }) + + if isinstance(response, ConnexionResponse): + flask_response = cls._get_flask_response_from_connexion(response, mimetype) + else: + flask_response = cls._get_flask_response(response, mimetype) + + logger.debug('Got data and status code (%d)', + flask_response.status_code, + extra={ + 'data': response, + 'datatype': type(response), + 'url': flask.request.url + }) + + return flask_response + + @classmethod + def _get_flask_response_from_connexion(cls, response, mimetype): + data = response.body + status_code = response.status_code + mimetype = response.mimetype or mimetype + content_type = response.content_type or mimetype + headers = response.headers + + flask_response = cls._build_flask_response(mimetype, content_type, + headers, status_code, data) + + return flask_response + + @classmethod + def _build_flask_response(cls, mimetype=None, content_type=None, + headers=None, status_code=None, data=None): + kwargs = { + 'mimetype': mimetype, + 'content_type': content_type, + 'headers': headers + } + kwargs = {k: v for k, v in six.iteritems(kwargs) if v is not None} + flask_response = flask.current_app.response_class(**kwargs) # type: flask.Response + + if status_code is not None: + flask_response.status = str(status_code) + + if data is not None and data is not NoContent: + data = cls._jsonify_data(data, mimetype) + flask_response.set_data(data) + + elif data is NoContent: + flask_response.set_data('') + + return flask_response + + @classmethod + def _jsonify_data(cls, data, mimetype): + if (isinstance(mimetype, six.string_types) and is_json_mimetype(mimetype)) \ + or not (isinstance(data, six.binary_type) or isinstance(data, six.text_type)): + return cls.jsonifier.dumps(data) + + return data + + @classmethod + def _get_flask_response(cls, response, mimetype): + if flask_utils.is_flask_response(response): + return response + + elif isinstance(response, tuple) and len(response) == 3: + data, status_code, headers = response + return cls._build_flask_response(mimetype, None, + headers, status_code, data) + + elif isinstance(response, tuple) and len(response) == 2: + data, status_code = response + return cls._build_flask_response(mimetype, None, None, + status_code, data) + + else: + return cls._build_flask_response(mimetype=mimetype, data=response) + + @classmethod + def get_request(cls, **params): + """Gets ConnexionRequest instance for the operation handler + result. Status Code and Headers for response. If only body + data is returned by the endpoint function, then the status + code will be set to 200 and no headers will be added. + + If the returned object is a flask.Response then it will just + pass the information needed to recreate it. + + :type operation_handler_result: flask.Response | (flask.Response, int) | (flask.Response, int, dict) + :rtype: ConnexionRequest + """ + request = flask.request + request = ConnexionRequest( + request.url, request.method, + headers=request.headers, + form=request.form, + query=request.args, + body=request.get_data(), + json=request.get_json(silent=True), + files=request.files, + path_params=params + ) + logger.debug('Getting data and status code', + extra={ + 'data': request.body, + 'data_type': type(request.body), + 'url': request.url + }) + return request diff --git a/connexion/apps/__init__.py b/connexion/apps/__init__.py new file mode 100644 index 000000000..83067db54 --- /dev/null +++ b/connexion/apps/__init__.py @@ -0,0 +1,4 @@ +from .abstract import AbstractApp +from .flask_app import FlaskApp + +__all__ = ['AbstractApp', 'FlaskApp'] diff --git a/connexion/app.py b/connexion/apps/abstract.py similarity index 69% rename from connexion/app.py rename to connexion/apps/abstract.py index cf836fa78..c2e3e47d8 100644 --- a/connexion/app.py +++ b/connexion/apps/abstract.py @@ -1,20 +1,16 @@ +import abc import logging import pathlib +import six -import flask -import werkzeug.exceptions - -from .api import Api -from .decorators.produces import JSONEncoder as ConnexionJSONEncoder -from .exceptions import ProblemException -from .problem import problem -from .resolver import Resolver +from ..resolver import Resolver logger = logging.getLogger('connexion.app') -class App(object): - def __init__(self, import_name, port=None, specification_dir='', +@six.add_metaclass(abc.ABCMeta) +class AbstractApp(object): + def __init__(self, import_name, api_cls, port=None, specification_dir='', server=None, arguments=None, auth_all_paths=False, debug=False, swagger_json=True, swagger_ui=True, swagger_path=None, swagger_url=None, host=None, validator_map=None): @@ -46,12 +42,25 @@ def __init__(self, import_name, port=None, specification_dir='', :param validator_map: map of validators :type validator_map: dict """ - self.app = flask.Flask(import_name) + self.port = port + self.host = host + self.debug = debug + self.import_name = import_name + self.arguments = arguments or {} + self.swagger_json = swagger_json + self.swagger_ui = swagger_ui + self.swagger_path = swagger_path + self.swagger_url = swagger_url + self.auth_all_paths = auth_all_paths + self.resolver_error = None + self.validator_map = validator_map + self.api_cls = api_cls - self.app.json_encoder = ConnexionJSONEncoder + self.app = self.create_app() + self.server = server - # we get our application root path from flask to avoid duplicating logic - self.root_path = pathlib.Path(self.app.root_path) + # we get our application root path to avoid duplicating logic + self.root_path = self.get_root_path() logger.debug('Root Path: %s', self.root_path) specification_dir = pathlib.Path(specification_dir) # Ensure specification dir is a Path @@ -63,44 +72,31 @@ def __init__(self, import_name, port=None, specification_dir='', logger.debug('Specification directory: %s', self.specification_dir) logger.debug('Setting error handlers') - for error_code in werkzeug.exceptions.default_exceptions: - self.add_error_handler(error_code, self.common_error_handler) + self.set_errors_handlers() - self.add_error_handler(ProblemException, self.common_error_handler) - - self.port = port - self.host = host - self.server = server or 'flask' - self.debug = debug - self.import_name = import_name - self.arguments = arguments or {} - self.swagger_json = swagger_json - self.swagger_ui = swagger_ui - self.swagger_path = swagger_path - self.swagger_url = swagger_url - self.auth_all_paths = auth_all_paths - self.resolver_error = None - self.validator_map = validator_map - - @staticmethod - def common_error_handler(exception): + @abc.abstractmethod + def create_app(self): """ - :type exception: Exception + Creates the user framework application """ - if isinstance(exception, ProblemException): - response_container = exception.to_problem() - else: - if not isinstance(exception, werkzeug.exceptions.HTTPException): - exception = werkzeug.exceptions.InternalServerError() - response_container = problem(title=exception.name, detail=exception.description, - status=exception.code) + @abc.abstractmethod + def get_root_path(self): + """ + Gets the root path of the user framework application + """ - return response_container.flask_response_object() + @abc.abstractmethod + def set_errors_handlers(self): + """ + Sets all errors handlers of the user framework application + """ - def add_api(self, specification, base_path=None, arguments=None, auth_all_paths=None, swagger_json=None, - swagger_ui=None, swagger_path=None, swagger_url=None, validate_responses=False, - strict_validation=False, resolver=Resolver(), resolver_error=None, pythonic_params=False): + def add_api(self, specification, base_path=None, arguments=None, + auth_all_paths=None, swagger_json=None, swagger_ui=None, + swagger_path=None, swagger_url=None, validate_responses=False, + strict_validation=False, resolver=Resolver(), resolver_error=None, + pythonic_params=False): """ Adds an API to the application based on a swagger file or API dict @@ -131,7 +127,7 @@ def add_api(self, specification, base_path=None, arguments=None, auth_all_paths= :type resolver_error: int | None :param pythonic_params: When True CamelCase parameters are converted to snake_case :type pythonic_params: bool - :rtype: Api + :rtype: AbstractAPI """ # Turn the resolver_error code into a handler object self.resolver_error = resolver_error @@ -155,21 +151,20 @@ def add_api(self, specification, base_path=None, arguments=None, auth_all_paths= else: specification = self.specification_dir / specification - api = Api(specification=specification, - base_url=base_path, arguments=arguments, - swagger_json=swagger_json, - swagger_ui=swagger_ui, - swagger_path=swagger_path, - swagger_url=swagger_url, - resolver=resolver, - resolver_error_handler=resolver_error_handler, - validate_responses=validate_responses, - strict_validation=strict_validation, - auth_all_paths=auth_all_paths, - debug=self.debug, - validator_map=self.validator_map, - pythonic_params=pythonic_params) - self.app.register_blueprint(api.blueprint) + api = self.api_cls(specification=specification, + base_url=base_path, arguments=arguments, + swagger_json=swagger_json, + swagger_ui=swagger_ui, + swagger_path=swagger_path, + swagger_url=swagger_url, + resolver=resolver, + resolver_error_handler=resolver_error_handler, + validate_responses=validate_responses, + strict_validation=strict_validation, + auth_all_paths=auth_all_paths, + debug=self.debug, + validator_map=self.validator_map, + pythonic_params=pythonic_params) return api def _resolver_error_handler(self, *args, **kwargs): @@ -178,15 +173,7 @@ def _resolver_error_handler(self, *args, **kwargs): 'operationId': 'connexion.handlers.ResolverErrorHandler', } kwargs.setdefault('app_consumes', ['application/json']) - return ResolverErrorHandler(self.resolver_error, *args, **kwargs) - - def add_error_handler(self, error_code, function): - """ - - :type error_code: int - :type function: types.FunctionType - """ - self.app.register_error_handler(error_code, function) + return ResolverErrorHandler(self.api_cls, self.resolver_error, *args, **kwargs) def add_url_rule(self, rule, endpoint=None, view_func=None, **options): """ @@ -252,6 +239,7 @@ def index(): logger.debug('Adding %s with decorator', rule, extra=options) return self.app.route(rule, **options) + @abc.abstractmethod def run(self, port=None, server=None, debug=None, host=None, **options): # pragma: no cover """ Runs the application on a local development server. @@ -266,47 +254,6 @@ def run(self, port=None, server=None, debug=None, host=None, **options): # prag :param options: options to be forwarded to the underlying server :type options: dict """ - # this functions is not covered in unit tests because we would effectively testing the mocks - - # overwrite constructor parameter - if port is not None: - self.port = port - elif self.port is None: - self.port = 5000 - - self.host = host or self.host or '0.0.0.0' - - if server is not None: - self.server = server - - if debug is not None: - self.debug = debug - - logger.debug('Starting %s HTTP server..', self.server, extra=vars(self)) - if self.server == 'flask': - self.app.run(self.host, port=self.port, debug=self.debug, **options) - elif self.server == 'tornado': - try: - import tornado.wsgi - import tornado.httpserver - import tornado.ioloop - except: - raise Exception('tornado library not installed') - wsgi_container = tornado.wsgi.WSGIContainer(self.app) - http_server = tornado.httpserver.HTTPServer(wsgi_container, **options) - http_server.listen(self.port, address=self.host) - logger.info('Listening on %s:%s..', self.host, self.port) - tornado.ioloop.IOLoop.instance().start() - elif self.server == 'gevent': - try: - import gevent.wsgi - except: - raise Exception('gevent library not installed') - http_server = gevent.wsgi.WSGIServer((self.host, self.port), self.app, **options) - logger.info('Listening on %s:%s..', self.host, self.port) - http_server.serve_forever() - else: - raise Exception('Server %s not recognized', self.server) def __call__(self, environ, start_response): # pragma: no cover """ diff --git a/connexion/apps/flask_app.py b/connexion/apps/flask_app.py new file mode 100644 index 000000000..feedbdfc4 --- /dev/null +++ b/connexion/apps/flask_app.py @@ -0,0 +1,164 @@ +import datetime +import logging +import pathlib +from decimal import Decimal + +import flask +import werkzeug.exceptions +from flask import json + +from ..apis.flask_api import FlaskApi +from ..exceptions import ProblemException +from ..problem import problem +from ..resolver import Resolver +from .abstract import AbstractApp + + +logger = logging.getLogger('connexion.app') + + +class FlaskApp(AbstractApp): + def __init__(self, import_name, port=None, specification_dir='', + server=None, arguments=None, auth_all_paths=False, + debug=False, swagger_json=True, swagger_ui=True, swagger_path=None, + swagger_url=None, host=None, validator_map=None): + server = server or 'flask' + super(FlaskApp, self).__init__( + import_name, port=port, specification_dir=specification_dir, + server=server, arguments=arguments, auth_all_paths=auth_all_paths, + debug=debug, swagger_json=swagger_json, swagger_ui=swagger_ui, + swagger_path=swagger_path, swagger_url=swagger_url, + host=host, validator_map=validator_map, api_cls=FlaskApi + ) + + def create_app(self): + app = flask.Flask(self.import_name) + app.json_encoder = FlaskJSONEncoder + return app + + def get_root_path(self): + return pathlib.Path(self.app.root_path) + + def set_errors_handlers(self): + for error_code in werkzeug.exceptions.default_exceptions: + self.add_error_handler(error_code, self.common_error_handler) + + self.add_error_handler(ProblemException, self.common_error_handler) + + @staticmethod + def common_error_handler(exception): + """ + :type exception: Exception + """ + if isinstance(exception, ProblemException): + response = exception.to_problem() + else: + if not isinstance(exception, werkzeug.exceptions.HTTPException): + exception = werkzeug.exceptions.InternalServerError() + + response = problem(title=exception.name, detail=exception.description, + status=exception.code) + kwargs = {attr_name: getattr(response, attr_name) for attr_name in response._fields} + response = type(response)(**kwargs) + + response = FlaskApi.get_response(response) + return response + + def add_api(self, specification, base_path=None, arguments=None, + auth_all_paths=None, swagger_json=None, swagger_ui=None, + swagger_path=None, swagger_url=None, validate_responses=False, + strict_validation=False, resolver=Resolver(), resolver_error=None, + pythonic_params=False): + api = super(FlaskApp, self).add_api( + specification, base_path=base_path, + arguments=arguments, auth_all_paths=auth_all_paths, swagger_json=swagger_json, + swagger_ui=swagger_ui, swagger_path=swagger_path, swagger_url=swagger_url, + validate_responses=validate_responses, strict_validation=strict_validation, + resolver=resolver, resolver_error=resolver_error, pythonic_params=pythonic_params + ) + self.app.register_blueprint(api.blueprint) + return api + + def add_error_handler(self, error_code, function): + """ + + :type error_code: int + :type function: types.FunctionType + """ + self.app.register_error_handler(error_code, function) + + def run(self, port=None, server=None, debug=None, host=None, **options): # pragma: no cover + """ + Runs the application on a local development server. + :param host: the host interface to bind on. + :type host: str + :param port: port to listen to + :type port: int + :param server: which wsgi server to use + :type server: str | None + :param debug: include debugging information + :type debug: bool + :param options: options to be forwarded to the underlying server + :type options: dict + """ + # this functions is not covered in unit tests because we would effectively testing the mocks + + # overwrite constructor parameter + if port is not None: + self.port = port + elif self.port is None: + self.port = 5000 + + self.host = host or self.host or '0.0.0.0' + + if server is not None: + self.server = server + + if debug is not None: + self.debug = debug + + logger.debug('Starting %s HTTP server..', self.server, extra=vars(self)) + if self.server == 'flask': + self.app.run(self.host, port=self.port, debug=self.debug, **options) + elif self.server == 'tornado': + try: + import tornado.wsgi + import tornado.httpserver + import tornado.ioloop + except: + raise Exception('tornado library not installed') + wsgi_container = tornado.wsgi.WSGIContainer(self.app) + http_server = tornado.httpserver.HTTPServer(wsgi_container, **options) + http_server.listen(self.port, address=self.host) + logger.info('Listening on %s:%s..', self.host, self.port) + tornado.ioloop.IOLoop.instance().start() + elif self.server == 'gevent': + try: + import gevent.wsgi + except: + raise Exception('gevent library not installed') + http_server = gevent.wsgi.WSGIServer((self.host, self.port), self.app, **options) + logger.info('Listening on %s:%s..', self.host, self.port) + http_server.serve_forever() + else: + raise Exception('Server %s not recognized', self.server) + + +class FlaskJSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, datetime.datetime): + if o.tzinfo: + # eg: '2015-09-25T23:14:42.588601+00:00' + return o.isoformat('T') + else: + # No timezone present - assume UTC. + # eg: '2015-09-25T23:14:42.588601Z' + return o.isoformat('T') + 'Z' + + if isinstance(o, datetime.date): + return o.isoformat() + + if isinstance(o, Decimal): + return float(o) + + return json.JSONEncoder.default(self, o) diff --git a/connexion/cli.py b/connexion/cli.py index 3e5891db9..cb7323211 100644 --- a/connexion/cli.py +++ b/connexion/cli.py @@ -128,13 +128,13 @@ def run(spec_file, resolver = MockResolver(mock_all=mock == 'all') api_extra_args['resolver'] = resolver - app = connexion.App(__name__, - swagger_json=not hide_spec, - swagger_ui=not hide_console_ui, - swagger_path=console_ui_from or None, - swagger_url=console_ui_url or None, - auth_all_paths=auth_all_paths, - debug=debug) + app = connexion.FlaskApp(__name__, + swagger_json=not hide_spec, + swagger_ui=not hide_console_ui, + swagger_path=console_ui_from or None, + swagger_url=console_ui_url or None, + auth_all_paths=auth_all_paths, + debug=debug) app.add_api(spec_file_full_path, base_path=base_path, diff --git a/connexion/decorators/decorator.py b/connexion/decorators/decorator.py index 94dc1b73e..471e3b2a2 100644 --- a/connexion/decorators/decorator.py +++ b/connexion/decorators/decorator.py @@ -1,10 +1,6 @@ import functools import logging -import flask - -from ..utils import is_flask_response - logger = logging.getLogger('connexion.decorators.decorator') @@ -27,87 +23,38 @@ def __repr__(self): # pragma: no cover class BeginOfRequestLifecycleDecorator(BaseDecorator): """Manages the lifecycle of the request internally in Connexion. - Transforms the operation handler response into a `ResponseContainer` + Transforms the operation handler response into a `ConnexionRequest` that can be manipulated by the series of decorators during the lifecycle of the request. """ - def __init__(self, mimetype): + def __init__(self, api, mimetype): + self.api = api self.mimetype = mimetype - def get_response_container(self, operation_handler_result): - """Gets ResponseContainer instance for the operation handler - result. Status Code and Headers for response. If only body - data is returned by the endpoint function, then the status - code will be set to 200 and no headers will be added. - - If the returned object is a flask.Response then it will just - pass the information needed to recreate it. - - :type operation_handler_result: flask.Response | (flask.Response, int) | (flask.Response, int, dict) - :rtype: ResponseContainer - """ - url = flask.request.url - logger.debug('Getting data and status code', - extra={ - 'data': operation_handler_result, - 'data_type': type(operation_handler_result), - 'url': url - }) - - response_container = None - if is_flask_response(operation_handler_result): - response_container = ResponseContainer( - self.mimetype, response=operation_handler_result) - - elif isinstance(operation_handler_result, ResponseContainer): - response_container = operation_handler_result - - elif isinstance(operation_handler_result, tuple) and len( - operation_handler_result) == 3: - data, status_code, headers = operation_handler_result - response_container = ResponseContainer( - self.mimetype, data=data, status_code=status_code, headers=headers) - - elif isinstance(operation_handler_result, tuple) and len( - operation_handler_result) == 2: - data, status_code = operation_handler_result - response_container = ResponseContainer( - self.mimetype, data=data, status_code=status_code) - else: - response_container = ResponseContainer( - self.mimetype, data=operation_handler_result) - - logger.debug('Got data and status code (%d)', - response_container.status_code, - extra={ - 'data': operation_handler_result, - 'datatype': type(operation_handler_result), - 'url': url - }) - - return response_container - def __call__(self, function): """ :type function: types.FunctionType :rtype: types.FunctionType """ @functools.wraps(function) - def wrapper(*args, **kwargs): - operation_handler_result = function(*args, **kwargs) - return self.get_response_container(operation_handler_result) + def wrapper(request): + response = function(request) + return self.api.get_response(response, self.mimetype, request) return wrapper class EndOfRequestLifecycleDecorator(BaseDecorator): """Manages the lifecycle of the request internally in Connexion. - - Filter the ResponseContainer instance to return the corresponding + Filter the ConnexionRequest instance to return the corresponding flask.Response object. """ + def __init__(self, api, mimetype): + self.api = api + self.mimetype = mimetype + def __call__(self, function): """ :type function: types.FunctionType @@ -115,67 +62,8 @@ def __call__(self, function): """ @functools.wraps(function) def wrapper(*args, **kwargs): - response_container = function(*args, **kwargs) - return response_container.flask_response_object() + request = self.api.get_request(*args, **kwargs) + response = function(request) + return self.api.get_response(response, self.mimetype, request) return wrapper - - -class ResponseContainer(object): - """ Internal response object to be passed among Connexion - decorators that want to manipulate response data. - - Partially matches the flask.Response interface for easy access and - manipulation. - - The methods ResponseContainer#get_data and ResponseContainer#set_data - are added here following the recommendation found in: - - http://flask.pocoo.org/docs/0.11/api/#flask.Response.data - """ - - def __init__(self, mimetype, data=None, status_code=200, headers=None, response=None): - """ - :type data: dict | None - :type status_code: int | None - :type headers: dict | None - :type response: flask.Response | None - """ - self.mimetype = mimetype - self.data = data - self.status_code = status_code - self.headers = headers or {} - - self._response = response - self.is_handler_response_object = bool(response) - - if self._response: - self.data = self._response.get_data() - self.status_code = self._response.status_code - self.headers = self._response.headers - - def get_data(self): - """ Get the current data to be used when creating the - flask.Response instance. - """ - return self.data - - def set_data(self, data): - """ Gets the data that is going to be used to create the - flask.Response instance. - """ - self.data = data - - def flask_response_object(self): - """ - Builds an Flask response using the contained data, - status_code, and headers. - - :rtype: flask.Response - """ - self._response = flask.current_app.response_class( - self.data, mimetype=self.mimetype, content_type=self.headers.get('content-type'), - headers=self.headers) # type: flask.Response - self._response.status_code = self.status_code - - return self._response diff --git a/connexion/decorators/metrics.py b/connexion/decorators/metrics.py index 446504af3..b87c46228 100644 --- a/connexion/decorators/metrics.py +++ b/connexion/decorators/metrics.py @@ -29,11 +29,11 @@ def __call__(self, function): """ @functools.wraps(function) - def wrapper(*args, **kwargs): + def wrapper(request): status_code = 500 start_time_s = time.time() try: - response = function(*args, **kwargs) + response = function(request) status_code = response.status_code finally: end_time_s = time.time() diff --git a/connexion/decorators/parameter.py b/connexion/decorators/parameter.py index b7b6bd2cb..4d905af85 100644 --- a/connexion/decorators/parameter.py +++ b/connexion/decorators/parameter.py @@ -4,10 +4,8 @@ import logging import re -import flask import inflection import six -import werkzeug.exceptions as exceptions from ..utils import all_json, boolean, is_null, is_nullable @@ -116,24 +114,23 @@ def sanitize_param(name): for param in parameters if param['in'] == 'formData' and 'default' in param} @functools.wraps(function) - def wrapper(*args, **kwargs): + def wrapper(request): logger.debug('Function Arguments: %s', arguments) + kwargs = {} if all_json(consumes): - try: - request_body = flask.request.get_json() - except exceptions.BadRequest: - request_body = None + request_body = request.json else: - request_body = flask.request.data + request_body = request.body if default_body and not request_body: request_body = default_body # Parse path parameters + path_params = request.path_params for key, path_param_definitions in path_types.items(): - if key in kwargs: - kwargs[key] = get_val_from_param(kwargs[key], + if key in path_params: + kwargs[key] = get_val_from_param(path_params[key], path_param_definitions) # Add body parameters @@ -145,7 +142,7 @@ def wrapper(*args, **kwargs): # Add query parameters query_arguments = copy.deepcopy(default_query_params) - query_arguments.update({sanitize_param(k): v for k, v in flask.request.args.items()}) + query_arguments.update({sanitize_param(k): v for k, v in request.query.items()}) for key, value in query_arguments.items(): if not has_kwargs and key not in arguments: logger.debug("Query Parameter '%s' not in function arguments", key) @@ -161,7 +158,7 @@ def wrapper(*args, **kwargs): # Add formData parameters form_arguments = copy.deepcopy(default_form_params) - form_arguments.update({sanitize_param(k): v for k, v in flask.request.form.items()}) + form_arguments.update({sanitize_param(k): v for k, v in request.form.items()}) for key, value in form_arguments.items(): if not has_kwargs and key not in arguments: logger.debug("FormData parameter '%s' not in function arguments", key) @@ -175,7 +172,7 @@ def wrapper(*args, **kwargs): kwargs[key] = get_val_from_param(value, form_param) # Add file parameters - file_arguments = flask.request.files + file_arguments = request.files for key, value in file_arguments.items(): if not has_kwargs and key not in arguments: logger.debug("File parameter (formData) '%s' not in function arguments", key) @@ -186,6 +183,8 @@ def wrapper(*args, **kwargs): # optionally convert parameter variable names to un-shadowed, snake_case form if pythonic_params: kwargs = {snake_and_shadow(k): v for k, v in kwargs.items()} - return function(*args, **kwargs) + + kwargs.update(request.context) + return function(**kwargs) return wrapper diff --git a/connexion/decorators/produces.py b/connexion/decorators/produces.py index 8b4c46169..7cbc6748a 100644 --- a/connexion/decorators/produces.py +++ b/connexion/decorators/produces.py @@ -1,14 +1,7 @@ # Decorators to change the return type of endpoints -import datetime import functools import logging -from decimal import Decimal - -import flask -import six -from flask import json - from .decorator import BaseDecorator logger = logging.getLogger('connexion.decorators.produces') @@ -18,26 +11,6 @@ NoContent = object() -class JSONEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, datetime.datetime): - if o.tzinfo: - # eg: '2015-09-25T23:14:42.588601+00:00' - return o.isoformat('T') - else: - # No timezone present - assume UTC. - # eg: '2015-09-25T23:14:42.588601Z' - return o.isoformat('T') + 'Z' - - if isinstance(o, datetime.date): - return o.isoformat() - - if isinstance(o, Decimal): - return float(o) - - return json.JSONEncoder.default(self, o) - - class BaseSerializer(BaseDecorator): def __init__(self, mimetype='text/plain'): """ @@ -60,9 +33,9 @@ def __call__(self, function): """ @functools.wraps(function) - def wrapper(*args, **kwargs): - url = flask.request.url - response = function(*args, **kwargs) + def wrapper(request): + url = request.url + response = function(request) logger.debug('Returning %s', url, extra={'url': url, 'mimetype': self.mimetype}) return response @@ -74,63 +47,3 @@ def __repr__(self): :rtype: str """ return ''.format(self.mimetype) - - -class Jsonifier(BaseSerializer): - @staticmethod - def dumps(data): - """ Central point where JSON serialization happens inside - Connexion. - """ - if six.PY2: - json_content = json.dumps(data, indent=2, encoding="utf-8") - else: - json_content = json.dumps(data, indent=2) - - return "{}\n".format(json_content) - - def __call__(self, function): - """ - :type function: types.FunctionType - :rtype: types.FunctionType - """ - - @functools.wraps(function) - def wrapper(*args, **kwargs): - url = flask.request.url - - logger.debug('Jsonifing %s', url, - extra={'url': url, 'mimetype': self.mimetype}) - - response = function(*args, **kwargs) - - if response.is_handler_response_object: - logger.debug('Endpoint returned a Flask Response', - extra={'url': url, 'mimetype': self.mimetype}) - return response - - elif response.data is NoContent: - response.set_data('') - return response - - elif response.status_code == 204: - logger.debug('Endpoint returned an empty response (204)', - extra={'url': url, 'mimetype': self.mimetype}) - response.set_data('') - return response - - elif response.mimetype == 'application/problem+json' and isinstance(response.data, str): - # connexion.problem() already adds data as a serialized JSON - return response - - json_content = Jsonifier.dumps(response.get_data()) - response.set_data(json_content) - return response - - return wrapper - - def __repr__(self): - """ - :rtype: str - """ - return ''.format(self.mimetype) diff --git a/connexion/decorators/response.py b/connexion/decorators/response.py index c04f71d7d..8bc0ec251 100644 --- a/connexion/decorators/response.py +++ b/connexion/decorators/response.py @@ -2,7 +2,6 @@ import functools import logging -from flask import json from jsonschema import ValidationError from ..exceptions import (NonConformingResponseBody, @@ -24,7 +23,7 @@ def __init__(self, operation, mimetype): self.operation = operation self.mimetype = mimetype - def validate_response(self, data, status_code, headers): + def validate_response(self, data, status_code, headers, url): """ Validates the Response object based on what has been declared in the specification. Ensures the response body matches the declated schema. @@ -42,12 +41,8 @@ def validate_response(self, data, status_code, headers): schema = response_definition.get("schema") v = ResponseBodyValidator(schema) try: - # For cases of custom encoders, we need to encode and decode to - # transform to the actual types that are going to be returned. - data = json.dumps(data) - data = json.loads(data) - - v.validate_schema(data) + data = self.operation.json_loads(data) + v.validate_schema(data, url) except ValidationError as e: raise NonConformingResponseBody(message=str(e)) @@ -86,14 +81,17 @@ def __call__(self, function): :rtype: types.FunctionType """ @functools.wraps(function) - def wrapper(*args, **kwargs): - response = function(*args, **kwargs) + def wrapper(request): + response = function(request) try: - self.validate_response(response.get_data(), response.status_code, response.headers) - except NonConformingResponseBody as e: - return problem(500, e.reason, e.message) - except NonConformingResponseHeaders as e: - return problem(500, e.reason, e.message) + self.validate_response( + response.get_data(), response.status_code, + response.headers, request.url) + + except (NonConformingResponseBody, NonConformingResponseHeaders) as e: + response = problem(500, e.reason, e.message) + return self.operation.api.get_response(response) + return response return wrapper diff --git a/connexion/decorators/security.py b/connexion/decorators/security.py index 8bcda7d95..2a044f279 100644 --- a/connexion/decorators/security.py +++ b/connexion/decorators/security.py @@ -5,7 +5,6 @@ import textwrap import requests -from flask import request from ..exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem @@ -52,7 +51,7 @@ def verify_oauth(token_info_url, allowed_scopes, function): """ @functools.wraps(function) - def wrapper(*args, **kwargs): + def wrapper(request): logger.debug("%s Oauth verification...", request.url) authorization = request.headers.get('Authorization') # type: str if not authorization: @@ -86,8 +85,8 @@ def wrapper(*args, **kwargs): token_scopes=user_scopes ) logger.info("... Token authenticated.") - request.user = token_info.get('uid') - request.token_info = token_info - return function(*args, **kwargs) + request.context['user'] = token_info.get('uid') + request.context['token_info'] = token_info + return function(request) return wrapper diff --git a/connexion/decorators/validation.py b/connexion/decorators/validation.py index 51c5518d9..ced1216f4 100644 --- a/connexion/decorators/validation.py +++ b/connexion/decorators/validation.py @@ -4,7 +4,6 @@ import logging import sys -import flask import six from jsonschema import Draft4Validator, ValidationError, draft4_format_checker from werkzeug import FileStorage @@ -80,7 +79,7 @@ def validate_parameter_list(request_params, spec_params): class RequestBodyValidator(object): - def __init__(self, schema, consumes, is_null_value_valid=False, validator=None): + def __init__(self, schema, consumes, api, is_null_value_valid=False, validator=None): """ :param schema: The schema of the request body :param consumes: The list of content types the operation consumes @@ -92,8 +91,9 @@ def __init__(self, schema, consumes, is_null_value_valid=False, validator=None): self.consumes = consumes self.has_default = schema.get('default', False) self.is_null_value_valid = is_null_value_valid - ValidatorClass = validator or Draft4Validator - self.validator = ValidatorClass(schema, format_checker=draft4_format_checker) + validatorClass = validator or Draft4Validator + self.validator = validatorClass(schema, format_checker=draft4_format_checker) + self.api = api def __call__(self, function): """ @@ -102,29 +102,29 @@ def __call__(self, function): """ @functools.wraps(function) - def wrapper(*args, **kwargs): + def wrapper(request): if all_json(self.consumes): - data = flask.request.get_json() + data = request.json # flask does not process json if the Content-Type header is not equal to "application/json" - if data is None and len(flask.request.data) > 0 and not self.is_null_value_valid: + if data is None and len(request.body) > 0 and not self.is_null_value_valid: return problem(415, "Unsupported Media Type", "Invalid Content-type ({content_type}), expected JSON data".format( - content_type=flask.request.headers["Content-Type"] + content_type=request.headers["Content-Type"] )) - logger.debug("%s validating schema...", flask.request.url) - error = self.validate_schema(data) + logger.debug("%s validating schema...", request.url) + error = self.validate_schema(data, request.url) if error and not self.has_default: return error - response = function(*args, **kwargs) + response = function(request) return response return wrapper - def validate_schema(self, data): + def validate_schema(self, data, url): """ :type data: dict :rtype: flask.Response | None @@ -135,7 +135,7 @@ def validate_schema(self, data): try: self.validator.validate(data) except ValidationError as exception: - logger.error("{url} validation error: {error}".format(url=flask.request.url, + logger.error("{url} validation error: {error}".format(url=url, error=exception.message)) return problem(400, 'Bad Request', str(exception.message)) @@ -153,7 +153,7 @@ def __init__(self, schema, validator=None): ValidatorClass = validator or Draft4Validator self.validator = ValidatorClass(schema, format_checker=draft4_format_checker) - def validate_schema(self, data): + def validate_schema(self, data, url): """ :type data: dict :rtype: flask.Response | None @@ -161,7 +161,7 @@ def validate_schema(self, data): try: self.validator.validate(data) except ValidationError as exception: - logger.error("{url} validation error: {error}".format(url=flask.request.url, + logger.error("{url} validation error: {error}".format(url=url, error=exception)) six.reraise(*sys.exc_info()) @@ -169,7 +169,7 @@ def validate_schema(self, data): class ParameterValidator(object): - def __init__(self, parameters, strict_validation=False): + def __init__(self, parameters, api, strict_validation=False): """ :param parameters: List of request parameter dictionaries :param strict_validation: Flag indicating if parameters not in spec are allowed @@ -178,6 +178,7 @@ def __init__(self, parameters, strict_validation=False): for p in parameters: self.parameters[p['in']].append(p) + self.api = api self.strict_validation = strict_validation @staticmethod @@ -218,39 +219,39 @@ def validate_parameter(parameter_type, value, param): elif param.get('required'): return "Missing {parameter_type} parameter '{param[name]}'".format(**locals()) - def validate_query_parameter_list(self): - request_params = flask.request.args.keys() + def validate_query_parameter_list(self, request): + request_params = request.query.keys() spec_params = [x['name'] for x in self.parameters.get('query', [])] return validate_parameter_list(request_params, spec_params) - def validate_formdata_parameter_list(self): - request_params = flask.request.form.keys() + def validate_formdata_parameter_list(self, request): + request_params = request.form.keys() spec_params = [x['name'] for x in self.parameters.get('formData', [])] return validate_parameter_list(request_params, spec_params) - def validate_query_parameter(self, param): + def validate_query_parameter(self, param, request): """ Validate a single query parameter (request.args in Flask) :type param: dict :rtype: str """ - val = flask.request.args.get(param['name']) + val = request.query.get(param['name']) return self.validate_parameter('query', val, param) - def validate_path_parameter(self, args, param): - val = args.get(param['name'].replace('-', '_')) + def validate_path_parameter(self, param, request): + val = request.path_params.get(param['name'].replace('-', '_')) return self.validate_parameter('path', val, param) - def validate_header_parameter(self, param): - val = flask.request.headers.get(param['name']) + def validate_header_parameter(self, param, request): + val = request.headers.get(param['name']) return self.validate_parameter('header', val, param) - def validate_formdata_parameter(self, param): + def validate_formdata_parameter(self, param, request): if param.get('type') == 'file': - val = flask.request.files.get(param['name']) + val = request.files.get(param['name']) else: - val = flask.request.form.get(param['name']) + val = request.form.get(param['name']) return self.validate_parameter('formdata', val, param) @@ -261,36 +262,40 @@ def __call__(self, function): """ @functools.wraps(function) - def wrapper(*args, **kwargs): - logger.debug("%s validating parameters...", flask.request.url) + def wrapper(request): + logger.debug("%s validating parameters...", request.url) if self.strict_validation: - query_errors = self.validate_query_parameter_list() - formdata_errors = self.validate_formdata_parameter_list() + query_errors = self.validate_query_parameter_list(request) + formdata_errors = self.validate_formdata_parameter_list(request) if formdata_errors or query_errors: raise ExtraParameterProblem(formdata_errors, query_errors) for param in self.parameters.get('query', []): - error = self.validate_query_parameter(param) + error = self.validate_query_parameter(param, request) if error: - return problem(400, 'Bad Request', error) + response = problem(400, 'Bad Request', error) + return self.api.get_response(response) for param in self.parameters.get('path', []): - error = self.validate_path_parameter(kwargs, param) + error = self.validate_path_parameter(param, request) if error: - return problem(400, 'Bad Request', error) + response = problem(400, 'Bad Request', error) + return self.api.get_response(response) for param in self.parameters.get('header', []): - error = self.validate_header_parameter(param) + error = self.validate_header_parameter(param, request) if error: - return problem(400, 'Bad Request', error) + response = problem(400, 'Bad Request', error) + return self.api.get_response(response) for param in self.parameters.get('formData', []): - error = self.validate_formdata_parameter(param) + error = self.validate_formdata_parameter(param, request) if error: - return problem(400, 'Bad Request', error) + response = problem(400, 'Bad Request', error) + return self.api.get_response(response) - return function(*args, **kwargs) + return function(request) return wrapper diff --git a/connexion/flask_utils.py b/connexion/flask_utils.py new file mode 100644 index 000000000..f6e86e393 --- /dev/null +++ b/connexion/flask_utils.py @@ -0,0 +1,82 @@ +import functools +import random +import re +import string + + +import flask +import werkzeug.wrappers + +PATH_PARAMETER = re.compile(r'\{([^}]*)\}') + +# map Swagger type to flask path converter +# see http://flask.pocoo.org/docs/0.10/api/#url-route-registrations +PATH_PARAMETER_CONVERTERS = { + 'integer': 'int', + 'number': 'float', + 'path': 'path' +} + + +def flaskify_endpoint(identifier, randomize=None): + """ + Converts the provided identifier in a valid flask endpoint name + + :type identifier: str + :param randomize: If specified, add this many random characters (upper case + and digits) to the endpoint name, separated by a pipe character. + :type randomize: int | None + :rtype: str + """ + result = identifier.replace('.', '_') + if randomize is None: + return result + + chars = string.ascii_uppercase + string.digits + return "{result}|{random_string}".format( + result=result, + random_string=''.join(random.SystemRandom().choice(chars) for _ in range(randomize))) + + +def convert_path_parameter(match, types): + name = match.group(1) + swagger_type = types.get(name) + converter = PATH_PARAMETER_CONVERTERS.get(swagger_type) + return '<{0}{1}{2}>'.format(converter or '', + ':' if converter else '', + name.replace('-', '_')) + + +def flaskify_path(swagger_path, types=None): + """ + Convert swagger path templates to flask path templates + + :type swagger_path: str + :type types: dict + :rtype: str + + >>> flaskify_path('/foo-bar/{my-param}') + '/foo-bar/' + + >>> flaskify_path('/foo/{someint}', {'someint': 'int'}) + '/foo/' + """ + if types is None: + types = {} + convert_match = functools.partial(convert_path_parameter, types=types) + return PATH_PARAMETER.sub(convert_match, swagger_path) + + +def is_flask_response(obj): + """ + Verifies if obj is a default Flask response instance. + + :type obj: object + :rtype bool + + >>> is_flask_response(redirect('http://example.com/')) + True + >>> is_flask_response(flask.Response()) + True + """ + return isinstance(obj, flask.Response) or isinstance(obj, werkzeug.wrappers.Response) diff --git a/connexion/handlers.py b/connexion/handlers.py index e4838863f..87093b48f 100644 --- a/connexion/handlers.py +++ b/connexion/handlers.py @@ -11,7 +11,7 @@ class AuthErrorHandler(SecureOperation): Wraps an error with authentication. """ - def __init__(self, exception, security, security_definitions): + def __init__(self, api, exception, security, security_definitions): """ This class uses the exception instance to produce the proper response problem in case the request is authenticated. @@ -25,7 +25,7 @@ def __init__(self, exception, security, security_definitions): :type security_definitions: dict """ self.exception = exception - SecureOperation.__init__(self, security, security_definitions) + SecureOperation.__init__(self, api, security, security_definitions) @property def function(self): @@ -44,12 +44,12 @@ def handle(self, *args, **kwargs): """ Actual handler for the execution after authentication. """ - response_container = problem( + response = problem( title=self.exception.name, detail=self.exception.description, status=self.exception.code ) - return response_container + return self.api.get_response(response) class ResolverErrorHandler(Operation): @@ -57,19 +57,19 @@ class ResolverErrorHandler(Operation): Handler for responding to ResolverError. """ - def __init__(self, status_code, exception, *args, **kwargs): + def __init__(self, api, status_code, exception, *args, **kwargs): self.status_code = status_code self.exception = exception - Operation.__init__(self, *args, **kwargs) + Operation.__init__(self, api, *args, **kwargs) @property def function(self): return self.handle def handle(self, *args, **kwargs): - response_container = problem( + response = problem( title='Not Implemented', detail=self.exception.reason, status=self.status_code ) - return response_container.flask_response_object() + return self.api.get_response(response) diff --git a/connexion/operation.py b/connexion/operation.py index c68280f1f..3f2656934 100644 --- a/connexion/operation.py +++ b/connexion/operation.py @@ -9,14 +9,14 @@ EndOfRequestLifecycleDecorator) from .decorators.metrics import UWSGIMetricsCollector from .decorators.parameter import parameter_to_arg -from .decorators.produces import BaseSerializer, Jsonifier, Produces +from .decorators.produces import BaseSerializer, Produces from .decorators.response import ResponseValidator from .decorators.security import (get_tokeninfo_url, security_passthrough, verify_oauth) from .decorators.validation import (ParameterValidator, RequestBodyValidator, TypeValidationError) from .exceptions import InvalidSpecification -from .utils import all_json, flaskify_endpoint, is_nullable +from .utils import all_json, is_nullable logger = logging.getLogger('connexion.operation') @@ -32,7 +32,7 @@ class SecureOperation(object): - def __init__(self, security, security_definitions): + def __init__(self, api, security, security_definitions): """ :param security: list of security rules the application uses by default :type security: list @@ -40,6 +40,7 @@ def __init__(self, security, security_definitions): `_ :type security_definitions: dict """ + self.api = api self.security = security self.security_definitions = security_definitions @@ -103,24 +104,23 @@ def get_mimetype(self): def _request_begin_lifecycle_decorator(self): """ Transforms the result of the operation handler in a internal - representation (connexion.decorators.ResponseContainer) to be + representation (connexion.decorators.ConnexionRequest) to be used by internal Connexion decorators. :rtype: types.FunctionType """ - return BeginOfRequestLifecycleDecorator(self.get_mimetype()) + return BeginOfRequestLifecycleDecorator(self.api, self.get_mimetype()) @property def _request_end_lifecycle_decorator(self): """ Guarantees that instead of the internal representation of the operation handler response - (connexion.decorators.ResponseContainer) a flask.Response + (connexion.decorators.ConnexionRequest) a flask.Response object is returned. - :rtype: types.FunctionType """ - return EndOfRequestLifecycleDecorator() + return EndOfRequestLifecycleDecorator(self.api, self.get_mimetype()) class Operation(SecureOperation): @@ -129,7 +129,7 @@ class Operation(SecureOperation): A single API operation on a path. """ - def __init__(self, method, path, operation, resolver, app_produces, app_consumes, + def __init__(self, api, method, path, operation, resolver, app_produces, app_consumes, path_parameters=None, app_security=None, security_definitions=None, definitions=None, parameter_definitions=None, response_definitions=None, validate_responses=False, strict_validation=False, randomize_endpoint=None, @@ -182,6 +182,7 @@ def __init__(self, method, path, operation, resolver, app_produces, app_consumes :type pythonic_params: bool """ + self.api = api self.method = method self.path = path self.validator_map = dict(VALIDATOR_MAP) @@ -213,7 +214,6 @@ def __init__(self, method, path, operation, resolver, app_produces, app_consumes resolution = resolver.resolve(self) self.operation_id = resolution.operation_id - self.endpoint_name = flaskify_endpoint(self.operation_id, self.randomize_endpoint) self.__undecorated_function = resolution.function self.validate_defaults() @@ -391,6 +391,7 @@ def function(self): function = decorator(function) function = self._request_end_lifecycle_decorator(function) + return function @property @@ -415,12 +416,14 @@ def __content_type_decorator(self): mimetype = self.get_mimetype() if all_json(self.produces): # endpoint will return json logger.debug('... Produces json', extra=vars(self)) - jsonify = Jsonifier(mimetype) + jsonify = self.api.jsonifier(mimetype) return jsonify + elif len(self.produces) == 1: logger.debug('... Produces %s', mimetype, extra=vars(self)) decorator = Produces(mimetype) return decorator + else: return BaseSerializer() @@ -432,9 +435,11 @@ def __validation_decorators(self): ParameterValidator = self.validator_map['parameter'] RequestBodyValidator = self.validator_map['body'] if self.parameters: - yield ParameterValidator(self.parameters, strict_validation=self.strict_validation) + yield ParameterValidator(self.parameters, + self.api, + strict_validation=self.strict_validation) if self.body_schema: - yield RequestBodyValidator(self.body_schema, self.consumes, + yield RequestBodyValidator(self.body_schema, self.consumes, self.api, is_nullable(self.body_definition)) @property @@ -445,3 +450,11 @@ def __response_validation_decorator(self): """ ResponseValidator = self.validator_map['response'] return ResponseValidator(self, self.get_mimetype()) + + def json_loads(self, data): + """ + A Wrapper for calling the jsonifier. + :param data: The json to loads + :type data: bytes + """ + return self.api.jsonifier.loads(data) diff --git a/connexion/problem.py b/connexion/problem.py index 83f2f8c45..f5c83dd26 100644 --- a/connexion/problem.py +++ b/connexion/problem.py @@ -1,5 +1,4 @@ -from .decorators.decorator import ResponseContainer -from .decorators.produces import Jsonifier +from .response import ConnexionResponse def problem(status, title, detail, type=None, instance=None, headers=None, ext=None): @@ -31,7 +30,7 @@ def problem(status, title, detail, type=None, instance=None, headers=None, ext=N if not type: type = 'about:blank' - problem_response = {'type': type, 'title': title, 'detail': detail, 'status': status, } + problem_response = {'type': type, 'title': title, 'detail': detail, 'status': status} if instance: problem_response['instance'] = instance if ext: @@ -42,9 +41,7 @@ def problem(status, title, detail, type=None, instance=None, headers=None, ext=N # `decorators.produces.Jsonifier` will not be added to the request # life-cycle (so we cannot rely on that serialization), we will # return a problem payload in JSON format. - return ResponseContainer( - mimetype='application/problem+json', - data=Jsonifier.dumps(problem_response), - status_code=status, - headers=headers - ) + mimetype = content_type = 'application/problem+json' + return ConnexionResponse(status, mimetype, content_type, + body=problem_response, + headers=headers) diff --git a/connexion/request.py b/connexion/request.py new file mode 100644 index 000000000..fb4bd5bf1 --- /dev/null +++ b/connexion/request.py @@ -0,0 +1,24 @@ +from collections import namedtuple + + +_ConnexionRequest = namedtuple('SwaggerRequest', [ + 'url', 'method', 'path_params', 'query', 'headers', + 'form', 'body', 'json', 'files', 'context' +]) + + +class ConnexionRequest(_ConnexionRequest): + + def __new__(cls, url, method, path_params=None, query=None, headers=None, + form=None, body=None, json=None, files=None, context=None): + return _ConnexionRequest.__new__( + cls, url, method, + path_params=path_params or {}, + query=query or {}, + headers=headers or {}, + form=form or {}, + body=body, + json=json, + files=files, + context=context or {} + ) diff --git a/connexion/response.py b/connexion/response.py new file mode 100644 index 000000000..a3f947756 --- /dev/null +++ b/connexion/response.py @@ -0,0 +1,18 @@ +from collections import namedtuple + + +_ConnexionResponse = namedtuple('SwaggerRequest', [ + 'mimetype', 'content_type', 'status_code', 'body', 'headers' +]) + + +class ConnexionResponse(_ConnexionResponse): + + def __new__(cls, status_code=200, mimetype=None, content_type=None, body=None, headers=None): + return _ConnexionResponse.__new__( + cls, mimetype, + content_type, + status_code, + body=body, + headers=headers or {} + ) diff --git a/connexion/utils.py b/connexion/utils.py index f682b1b0e..6795c9c6d 100644 --- a/connexion/utils.py +++ b/connexion/utils.py @@ -1,85 +1,5 @@ import functools import importlib -import random -import re -import string - -import flask -import werkzeug.wrappers - -PATH_PARAMETER = re.compile(r'\{([^}]*)\}') - -# map Swagger type to flask path converter -# see http://flask.pocoo.org/docs/0.10/api/#url-route-registrations -PATH_PARAMETER_CONVERTERS = { - 'integer': 'int', - 'number': 'float', - 'path': 'path' -} - - -def flaskify_endpoint(identifier, randomize=None): - """ - Converts the provided identifier in a valid flask endpoint name - - :type identifier: str - :param randomize: If specified, add this many random characters (upper case - and digits) to the endpoint name, separated by a pipe character. - :type randomize: int | None - :rtype: str - """ - result = identifier.replace('.', '_') - if randomize is None: - return result - - chars = string.ascii_uppercase + string.digits - return "{result}|{random_string}".format( - result=result, - random_string=''.join(random.SystemRandom().choice(chars) for _ in range(randomize))) - - -def convert_path_parameter(match, types): - name = match.group(1) - swagger_type = types.get(name) - converter = PATH_PARAMETER_CONVERTERS.get(swagger_type) - return '<{0}{1}{2}>'.format(converter or '', - ':' if converter else '', - name.replace('-', '_')) - - -def flaskify_path(swagger_path, types=None): - """ - Convert swagger path templates to flask path templates - - :type swagger_path: str - :type types: dict - :rtype: str - - >>> flaskify_path('/foo-bar/{my-param}') - '/foo-bar/' - - >>> flaskify_path('/foo/{someint}', {'someint': 'int'}) - '/foo/' - """ - if types is None: - types = {} - convert_match = functools.partial(convert_path_parameter, types=types) - return PATH_PARAMETER.sub(convert_match, swagger_path) - - -def is_flask_response(obj): - """ - Verifies if obj is a default Flask response instance. - - :type obj: object - :rtype bool - - >>> is_flask_response(redirect('http://example.com/')) - True - >>> is_flask_response(flask.Response()) - True - """ - return isinstance(obj, flask.Response) or isinstance(obj, werkzeug.wrappers.Response) def deep_getattr(obj, attr): diff --git a/docs/cookbook.rst b/docs/cookbook.rst index 4b13027a8..8faadf1e9 100644 --- a/docs/cookbook.rst +++ b/docs/cookbook.rst @@ -80,7 +80,7 @@ to set CORS headers: import connexion from flask.ext.cors import CORS - app = connexion.App(__name__) + app = connexion.FlaskApp(__name__) app.add_api('swagger.yaml') # add CORS support diff --git a/docs/exceptions.rst b/docs/exceptions.rst index d8a883b7d..26ce9839f 100644 --- a/docs/exceptions.rst +++ b/docs/exceptions.rst @@ -42,7 +42,7 @@ exception and render it in some sort of custom format. For example def render_unauthorized(exception): return Response(response=json.dumps({'error': 'There is an error in the oAuth token supplied'}), status=401, mimetype="application/json") - app = connexion.App(__name__, specification_dir='./../swagger/', debug=False, swagger_ui=False) + app = connexion.FlaskApp(__name__, specification_dir='./../swagger/', debug=False, swagger_ui=False) app = app.add_error_handler(OAuthResponseProblem, render_unauthorized) Custom Exceptions diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 1f419067d..6893bcfe3 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -26,7 +26,7 @@ Put your API YAML inside a folder in the root path of your application (e.g ``sw import connexion - app = connexion.App(__name__, specification_dir='swagger/') + app = connexion.FlaskApp(__name__, specification_dir='swagger/') app.add_api('my_api.yaml') app.run(port=8080) @@ -41,7 +41,7 @@ for each specific API in the `connexion.App#add_api` method: .. code-block:: python - app = connexion.App(__name__, specification_dir='swagger/', + app = connexion.FlaskApp(__name__, specification_dir='swagger/', arguments={'global': 'global_value'}) app.add_api('my_api.yaml', arguments={'api_local': 'local_value'}) app.run(port=8080) @@ -58,7 +58,7 @@ You can disable the Swagger UI at the application level: .. code-block:: python - app = connexion.App(__name__, specification_dir='swagger/', + app = connexion.FlaskApp(__name__, specification_dir='swagger/', swagger_ui=False) app.add_api('my_api.yaml') @@ -67,7 +67,7 @@ You can also disable it at the API level: .. code-block:: python - app = connexion.App(__name__, specification_dir='swagger/') + app = connexion.FlaskApp(__name__, specification_dir='swagger/') app.add_api('my_api.yaml', swagger_ui=False) Server Backend @@ -79,7 +79,7 @@ to ``tornado`` or ``gevent``: import connexion - app = connexion.App(__name__, port = 8080, specification_dir='swagger/', server='tornado') + app = connexion.FlaskApp(__name__, port = 8080, specification_dir='swagger/', server='tornado') .. _Jinja2: http://jinja.pocoo.org/ diff --git a/docs/request.rst b/docs/request.rst index dc5d9b2f7..87bd45df6 100644 --- a/docs/request.rst +++ b/docs/request.rst @@ -164,6 +164,6 @@ change the validation, you can override the defaults with: 'body': CustomRequestBodyValidator, 'parameter': CustomParameterValidator } - app = connexion.App(__name__, ..., validator_map=validator_map) + app = connexion.FlaskApp(__name__, ..., validator_map=validator_map) See custom validator example in ``examples/enforcedefaults``. diff --git a/docs/response.rst b/docs/response.rst index 2d5550e0f..52d769ea6 100644 --- a/docs/response.rst +++ b/docs/response.rst @@ -59,7 +59,7 @@ do so by opting in when adding the API: import connexion - app = connexion.App(__name__, specification_dir='swagger/') + app = connexion.FlaskApp(__name__, specification_dir='swagger/') app.add_api('my_api.yaml', validate_responses=True) app.run(port=8080) @@ -79,7 +79,7 @@ the validation, you can override the default class with: validator_map = { 'response': CustomResponseValidator } - app = connexion.App(__name__, ..., validator_map=validator_map) + app = connexion.FlaskApp(__name__, ..., validator_map=validator_map) Error Handling diff --git a/docs/routing.rst b/docs/routing.rst index 752443194..4cf7cf2a1 100644 --- a/docs/routing.rst +++ b/docs/routing.rst @@ -42,7 +42,7 @@ the endpoints in your specification: from connexion.resolver import RestyResolver - app = connexion.App(__name__) + app = connexion.FlaskApp(__name__) app.add_api('swagger.yaml', resolver=RestyResolver('api')) .. code-block:: yaml @@ -146,7 +146,7 @@ You can disable the Swagger JSON at the application level: .. code-block:: python - app = connexion.App(__name__, specification_dir='swagger/', + app = connexion.FlaskApp(__name__, specification_dir='swagger/', swagger_json=False) app.add_api('my_api.yaml') @@ -154,5 +154,5 @@ You can also disable it at the API level: .. code-block:: python - app = connexion.App(__name__, specification_dir='swagger/') + app = connexion.FlaskApp(__name__, specification_dir='swagger/') app.add_api('my_api.yaml', swagger_json=False) diff --git a/examples/basicauth/app.py b/examples/basicauth/app.py index 351c03828..34ffb8ac0 100755 --- a/examples/basicauth/app.py +++ b/examples/basicauth/app.py @@ -46,6 +46,6 @@ def get_secret() -> str: return 'This is a very secret string requiring authentication!' if __name__ == '__main__': - app = connexion.App(__name__) + app = connexion.FlaskApp(__name__) app.add_api('swagger.yaml') app.run(port=8080) diff --git a/examples/enforcedefaults/enforcedefaults.py b/examples/enforcedefaults/enforcedefaults.py index b4a6d102e..dd2af66f4 100755 --- a/examples/enforcedefaults/enforcedefaults.py +++ b/examples/enforcedefaults/enforcedefaults.py @@ -42,7 +42,7 @@ def __init__(self, *args, **kwargs): if __name__ == '__main__': - app = connexion.App( + app = connexion.FlaskApp( __name__, 8080, specification_dir='.', validator_map=validator_map) app.add_api('enforcedefaults-api.yaml') app.run() diff --git a/examples/helloworld/hello.py b/examples/helloworld/hello.py index ca08aab5f..54f275e14 100755 --- a/examples/helloworld/hello.py +++ b/examples/helloworld/hello.py @@ -7,6 +7,6 @@ def post_greeting(name: str) -> str: return 'Hello {name}'.format(name=name) if __name__ == '__main__': - app = connexion.App(__name__, 9090, specification_dir='swagger/') + app = connexion.FlaskApp(__name__, 9090, specification_dir='swagger/') app.add_api('helloworld-api.yaml', arguments={'title': 'Hello World Example'}) app.run() diff --git a/examples/oauth2/app.py b/examples/oauth2/app.py index d515053a4..f53b5f560 100755 --- a/examples/oauth2/app.py +++ b/examples/oauth2/app.py @@ -12,6 +12,6 @@ def get_secret() -> str: return 'You are: {uid}'.format(uid=flask.request.user) if __name__ == '__main__': - app = connexion.App(__name__) + app = connexion.FlaskApp(__name__) app.add_api('app.yaml') app.run(port=8080) diff --git a/examples/oauth2/mock_tokeninfo.py b/examples/oauth2/mock_tokeninfo.py index aac216d01..f45384fc6 100755 --- a/examples/oauth2/mock_tokeninfo.py +++ b/examples/oauth2/mock_tokeninfo.py @@ -17,6 +17,6 @@ def get_tokeninfo(access_token: str) -> dict: return {'uid': uid, 'scope': ['uid']} if __name__ == '__main__': - app = connexion.App(__name__) + app = connexion.FlaskApp(__name__) app.add_api('mock_tokeninfo.yaml') app.run(port=7979) diff --git a/examples/restyresolver/resty.py b/examples/restyresolver/resty.py index 969e423af..ae0237b05 100755 --- a/examples/restyresolver/resty.py +++ b/examples/restyresolver/resty.py @@ -7,6 +7,8 @@ logging.basicConfig(level=logging.INFO) if __name__ == '__main__': - app = connexion.App(__name__) - app.add_api('resty-api.yaml', arguments={'title': 'RestyResolver Example'}, resolver=RestyResolver('api')) + app = connexion.FlaskApp(__name__) + app.add_api('resty-api.yaml', + arguments={'title': 'RestyResolver Example'}, + resolver=RestyResolver('api')) app.run(port=9090) diff --git a/examples/sqlalchemy/app.py b/examples/sqlalchemy/app.py index bedf9eb2d..fb6daca5a 100755 --- a/examples/sqlalchemy/app.py +++ b/examples/sqlalchemy/app.py @@ -48,7 +48,7 @@ def delete_pet(pet_id): logging.basicConfig(level=logging.INFO) db_session = orm.init_db('sqlite:///:memory:') -app = connexion.App(__name__) +app = connexion.FlaskApp(__name__) app.add_api('swagger.yaml') application = app.app diff --git a/setup.py b/setup.py index 3e5af3939..e99c83a38 100755 --- a/setup.py +++ b/setup.py @@ -26,7 +26,6 @@ def read_version(package): install_requires = [ 'clickclick>=1.2', - 'flask>=0.10.1', 'jsonschema>=2.5.1', 'PyYAML>=3.11', 'requests>=2.9.1', @@ -35,6 +34,8 @@ def read_version(package): 'inflection>=0.3.1' ] +flask_require = 'flask>=0.10.1' + if py_major_minor_version < (3, 4): install_requires.append('pathlib>=1.0.1') @@ -42,7 +43,8 @@ def read_version(package): 'decorator', 'mock', 'pytest', - 'pytest-cov' + 'pytest-cov', + flask_require ] @@ -89,7 +91,7 @@ def readme(): setup_requires=['flake8'], install_requires=install_requires, tests_require=tests_require, - extras_require={'tests': tests_require}, + extras_require={'tests': tests_require, 'flask': flask_require}, cmdclass={'test': PyTest}, test_suite='tests', classifiers=[ diff --git a/tests/api/test_bootstrap.py b/tests/api/test_bootstrap.py index 11ca4aae6..0aaa1da44 100644 --- a/tests/api/test_bootstrap.py +++ b/tests/api/test_bootstrap.py @@ -1,15 +1,16 @@ -import jinja2 import yaml +import jinja2 import pytest from conftest import TEST_FOLDER, build_app_from_fixture -from connexion.app import App +from connexion.apis import FlaskApi +from connexion.apps import FlaskApp from connexion.exceptions import InvalidSpecification def test_app_with_relative_path(simple_api_spec_dir): # Create the app with a realative path and run the test_app testcase below. - app = App(__name__, 5001, '..' / simple_api_spec_dir.relative_to(TEST_FOLDER), + app = FlaskApp(__name__, 5001, '..' / simple_api_spec_dir.relative_to(TEST_FOLDER), debug=True) app.add_api('swagger.yaml') @@ -20,14 +21,15 @@ def test_app_with_relative_path(simple_api_spec_dir): def test_no_swagger_ui(simple_api_spec_dir): - app = App(__name__, 5001, simple_api_spec_dir, swagger_ui=False, debug=True) + app = FlaskApp(__name__, 5001, simple_api_spec_dir, swagger_ui=False, debug=True) + # app = FlaskApp(__name__, 5001, simple_api_spec_dir, debug=True) app.add_api('swagger.yaml') app_client = app.app.test_client() swagger_ui = app_client.get('/v1.0/ui/') # type: flask.Response assert swagger_ui.status_code == 404 - app2 = App(__name__, 5001, simple_api_spec_dir, debug=True) + app2 = FlaskApp(__name__, 5001, simple_api_spec_dir, debug=True) app2.add_api('swagger.yaml', swagger_ui=False) app2_client = app2.app.test_client() swagger_ui2 = app2_client.get('/v1.0/ui/') # type: flask.Response @@ -36,7 +38,7 @@ def test_no_swagger_ui(simple_api_spec_dir): def test_swagger_json_app(simple_api_spec_dir): """ Verify the swagger.json file is returned for default setting passed to app. """ - app = App(__name__, 5001, simple_api_spec_dir, debug=True) + app = FlaskApp(__name__, 5001, simple_api_spec_dir, debug=True) app.add_api('swagger.yaml') app_client = app.app.test_client() @@ -46,7 +48,7 @@ def test_swagger_json_app(simple_api_spec_dir): def test_no_swagger_json_app(simple_api_spec_dir): """ Verify the swagger.json file is not returned when set to False when creating app. """ - app = App(__name__, 5001, simple_api_spec_dir, swagger_json=False, debug=True) + app = FlaskApp(__name__, 5001, simple_api_spec_dir, swagger_json=False, debug=True) app.add_api('swagger.yaml') app_client = app.app.test_client() @@ -68,7 +70,7 @@ def test_dict_as_yaml_path(simple_api_spec_dir): swagger_string = jinja2.Template(swagger_template).render({}) specification = yaml.safe_load(swagger_string) # type: dict - app = App(__name__, 5001, simple_api_spec_dir, debug=True) + app = FlaskApp(__name__, 5001, simple_api_spec_dir, debug=True) app.add_api(specification) app_client = app.app.test_client() @@ -78,7 +80,7 @@ def test_dict_as_yaml_path(simple_api_spec_dir): def test_swagger_json_api(simple_api_spec_dir): """ Verify the swagger.json file is returned for default setting passed to api. """ - app = App(__name__, 5001, simple_api_spec_dir, debug=True) + app = FlaskApp(__name__, 5001, simple_api_spec_dir, debug=True) app.add_api('swagger.yaml') app_client = app.app.test_client() @@ -88,7 +90,7 @@ def test_swagger_json_api(simple_api_spec_dir): def test_no_swagger_json_api(simple_api_spec_dir): """ Verify the swagger.json file is not returned when set to False when adding api. """ - app = App(__name__, 5001, simple_api_spec_dir, debug=True) + app = FlaskApp(__name__, 5001, simple_api_spec_dir, debug=True) app.add_api('swagger.yaml', swagger_json=False) app_client = app.app.test_client() @@ -133,7 +135,7 @@ def route2(): def test_resolve_method(simple_app): app_client = simple_app.app.test_client() resp = app_client.get('/v1.0/resolver-test/method') # type: flask.Response - assert resp.data.decode('utf-8', 'replace') == '"DummyClass"\n' + assert resp.data == b'"DummyClass"\n' def test_resolve_classmethod(simple_app): @@ -143,7 +145,7 @@ def test_resolve_classmethod(simple_app): def test_add_api_with_function_resolver_function_is_wrapped(simple_api_spec_dir): - app = App(__name__, specification_dir=simple_api_spec_dir) + app = FlaskApp(__name__, specification_dir=simple_api_spec_dir) api = app.add_api('swagger.yaml', resolver=lambda oid: (lambda foo: 'bar')) assert api.resolver.resolve_function_from_operation_id('faux')('bah') == 'bar' diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py index 5065ffaaa..1e871adab 100644 --- a/tests/api/test_errors.py +++ b/tests/api/test_errors.py @@ -1,4 +1,9 @@ import json +import flask + + +def fix_data(data): + return data.replace(b'\\"', b'"') def test_errors(problem_app): @@ -7,7 +12,7 @@ def test_errors(problem_app): greeting404 = app_client.get('/v1.0/greeting') # type: flask.Response assert greeting404.content_type == 'application/problem+json' assert greeting404.status_code == 404 - error404 = json.loads(greeting404.data.decode('utf-8', 'replace')) + error404 = flask.json.loads(fix_data(greeting404.data)) assert error404['type'] == 'about:blank' assert error404['title'] == 'Not Found' assert error404['detail'] == 'The requested URL was not found on the server. ' \ diff --git a/tests/api/test_responses.py b/tests/api/test_responses.py index 8cdbc7698..97c221bae 100644 --- a/tests/api/test_responses.py +++ b/tests/api/test_responses.py @@ -1,6 +1,7 @@ import json from struct import unpack -from connexion.decorators.produces import JSONEncoder + +from connexion.apps.flask_app import FlaskJSONEncoder def test_app(simple_app): @@ -113,11 +114,11 @@ def test_default_object_body(simple_app): def test_custom_encoder(simple_app): - class CustomEncoder(JSONEncoder): + class CustomEncoder(FlaskJSONEncoder): def default(self, o): if o.__class__.__name__ == 'DummyClass': return "cool result" - return JSONEncoder.default(self, o) + return FlaskJSONEncoder.default(self, o) flask_app = simple_app.app flask_app.json_encoder = CustomEncoder diff --git a/tests/api/test_secure_api.py b/tests/api/test_secure_api.py index 41dfc7663..1a8179595 100644 --- a/tests/api/test_secure_api.py +++ b/tests/api/test_secure_api.py @@ -1,10 +1,11 @@ import json -from connexion.app import App +from connexion.apis import FlaskApi +from connexion.apps import FlaskApp def test_security_over_inexistent_endpoints(oauth_requests, secure_api_spec_dir): - app1 = App(__name__, 5001, secure_api_spec_dir, swagger_ui=False, + app1 = FlaskApp(__name__, 5001, secure_api_spec_dir, swagger_ui=False, debug=True, auth_all_paths=True) app1.add_api('swagger.yaml') assert app1.port == 5001 diff --git a/tests/conftest.py b/tests/conftest.py index d1bbc7091..2641329a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,8 @@ import pathlib import pytest -from connexion.app import App +from connexion.apis import FlaskApi +from connexion.apps import FlaskApp logging.basicConfig(level=logging.DEBUG) @@ -55,7 +56,7 @@ def fake_get(url, params=None, timeout=None): @pytest.fixture def app(): - app = App(__name__, 5001, SPEC_FOLDER, debug=True) + app = FlaskApp(__name__, 5001, SPEC_FOLDER, debug=True) app.add_api('api.yaml', validate_responses=True) return app @@ -85,7 +86,7 @@ def build_app_from_fixture(api_spec_folder, **kwargs): if 'debug' in kwargs: debug = kwargs['debug'] del(kwargs['debug']) - app = App(__name__, 5001, FIXTURES_FOLDER / api_spec_folder, debug=debug) + app = FlaskApp(__name__, 5001, FIXTURES_FOLDER / api_spec_folder, debug=debug) app.add_api('swagger.yaml', **kwargs) return app diff --git a/tests/decorators/test_decorators.py b/tests/decorators/test_decorators.py deleted file mode 100644 index cdea5e9e3..000000000 --- a/tests/decorators/test_decorators.py +++ /dev/null @@ -1,15 +0,0 @@ -import flask -from connexion.decorators.decorator import ResponseContainer - - -def test_response_container_content_type(): - app = flask.Flask(__name__) - response = flask.Response(response='test response', content_type='text/plain') - container = ResponseContainer(mimetype='application/json', response=response) - - with app.app_context(): - headers = container.flask_response_object().headers - - content_types = [value for key, value in headers.items() if key == 'Content-Type'] - assert len(content_types) == 1 - assert content_types[0] == 'text/plain' diff --git a/tests/decorators/test_security.py b/tests/decorators/test_security.py index 02c1a2435..eaa58f579 100644 --- a/tests/decorators/test_security.py +++ b/tests/decorators/test_security.py @@ -29,8 +29,7 @@ def func(): request = MagicMock() app = MagicMock() - monkeypatch.setattr('connexion.decorators.security.request', request) monkeypatch.setattr('flask.current_app', app) - with pytest.raises(OAuthProblem): - wrapped_func() + with pytest.raises(OAuthProblem) as exc_info: + wrapped_func(MagicMock()) diff --git a/tests/fakeapi/hello.py b/tests/fakeapi/hello.py index faef0b9d5..ddc3d4c1f 100755 --- a/tests/fakeapi/hello.py +++ b/tests/fakeapi/hello.py @@ -1,8 +1,10 @@ #!/usr/bin/env python3 -from flask import redirect +import json -from connexion import NoContent, ProblemException, problem, request +from connexion import NoContent, ProblemException, problem +from connexion.apis import FlaskApi +from flask import redirect, request class DummyClass(object): @@ -32,7 +34,7 @@ def post(): return '' -def post_greeting(name): +def post_greeting(name, **kwargs): data = {'greeting': 'Hello {name}'.format(name=name)} return data @@ -59,11 +61,11 @@ def get_list(name): def get_bye(name): - return 'Goodbye {name}'.format(name=name), 200 + return 'Goodbye {name}'.format(name=name) -def get_bye_secure(name): - return 'Goodbye {name} (Secure: {user})'.format(name=name, user=request.user) +def get_bye_secure(name, user, token_info): + return 'Goodbye {name} (Secure: {user})'.format(name=name, user=user) def with_problem(): @@ -360,7 +362,7 @@ def unordered_params_response(first, path_param, second): return dict(first=int(first), path_param=str(path_param), second=int(second)) -def more_than_one_scope_defined(): +def more_than_one_scope_defined(**kwargs): return "OK" diff --git a/tests/test_api.py b/tests/test_api.py index 225d919fd..622574a4a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -7,7 +7,9 @@ from yaml import YAMLError import pytest -from connexion.api import Api, canonical_base_url +from connexion.apis import FlaskApi +from connexion.apis.abstract import canonical_base_url +from connexion.apis.flask_api import FlaskApi from connexion.exceptions import InvalidSpecification, ResolverError TEST_FOLDER = pathlib.Path(__file__).parent @@ -21,70 +23,70 @@ def test_canonical_base_url(): def test_api(): - api = Api(TEST_FOLDER / "fixtures/simple/swagger.yaml", "/api/v1.0", {}) + api = FlaskApi(TEST_FOLDER / "fixtures/simple/swagger.yaml", "/api/v1.0", {}) assert api.blueprint.name == '/api/v1_0' assert api.blueprint.url_prefix == '/api/v1.0' # TODO test base_url in spec - api2 = Api(TEST_FOLDER / "fixtures/simple/swagger.yaml") + api2 = FlaskApi(TEST_FOLDER / "fixtures/simple/swagger.yaml") assert api2.blueprint.name == '/v1_0' assert api2.blueprint.url_prefix == '/v1.0' def test_api_basepath_slash(): - api = Api(TEST_FOLDER / "fixtures/simple/basepath-slash.yaml", None, {}) + api = FlaskApi(TEST_FOLDER / "fixtures/simple/basepath-slash.yaml", None, {}) assert api.blueprint.name == '' assert api.blueprint.url_prefix == '' def test_template(): - api1 = Api(TEST_FOLDER / "fixtures/simple/swagger.yaml", "/api/v1.0", {'title': 'test'}) + api1 = FlaskApi(TEST_FOLDER / "fixtures/simple/swagger.yaml", "/api/v1.0", {'title': 'test'}) assert api1.specification['info']['title'] == 'test' - api2 = Api(TEST_FOLDER / "fixtures/simple/swagger.yaml", "/api/v1.0", {'title': 'other test'}) + api2 = FlaskApi(TEST_FOLDER / "fixtures/simple/swagger.yaml", "/api/v1.0", {'title': 'other test'}) assert api2.specification['info']['title'] == 'other test' def test_invalid_operation_does_stop_application_to_setup(): with pytest.raises(ImportError): - Api(TEST_FOLDER / "fixtures/op_error_api/swagger.yaml", "/api/v1.0", + FlaskApi(TEST_FOLDER / "fixtures/op_error_api/swagger.yaml", "/api/v1.0", {'title': 'OK'}) with pytest.raises(ResolverError): - Api(TEST_FOLDER / "fixtures/missing_op_id/swagger.yaml", "/api/v1.0", + FlaskApi(TEST_FOLDER / "fixtures/missing_op_id/swagger.yaml", "/api/v1.0", {'title': 'OK'}) with pytest.raises(ImportError): - Api(TEST_FOLDER / "fixtures/module_not_implemented/swagger.yaml", "/api/v1.0", + FlaskApi(TEST_FOLDER / "fixtures/module_not_implemented/swagger.yaml", "/api/v1.0", {'title': 'OK'}) with pytest.raises(ValueError): - Api(TEST_FOLDER / "fixtures/user_module_loading_error/swagger.yaml", "/api/v1.0", + FlaskApi(TEST_FOLDER / "fixtures/user_module_loading_error/swagger.yaml", "/api/v1.0", {'title': 'OK'}) with pytest.raises(ResolverError): - Api(TEST_FOLDER / "fixtures/missing_op_id/swagger.yaml", "/api/v1.0", + FlaskApi(TEST_FOLDER / "fixtures/missing_op_id/swagger.yaml", "/api/v1.0", {'title': 'OK'}) def test_invalid_operation_does_not_stop_application_in_debug_mode(): - api = Api(TEST_FOLDER / "fixtures/op_error_api/swagger.yaml", "/api/v1.0", + api = FlaskApi(TEST_FOLDER / "fixtures/op_error_api/swagger.yaml", "/api/v1.0", {'title': 'OK'}, debug=True) assert api.specification['info']['title'] == 'OK' - api = Api(TEST_FOLDER / "fixtures/missing_op_id/swagger.yaml", "/api/v1.0", + api = FlaskApi(TEST_FOLDER / "fixtures/missing_op_id/swagger.yaml", "/api/v1.0", {'title': 'OK'}, debug=True) assert api.specification['info']['title'] == 'OK' - api = Api(TEST_FOLDER / "fixtures/module_not_implemented/swagger.yaml", "/api/v1.0", + api = FlaskApi(TEST_FOLDER / "fixtures/module_not_implemented/swagger.yaml", "/api/v1.0", {'title': 'OK'}, debug=True) assert api.specification['info']['title'] == 'OK' - api = Api(TEST_FOLDER / "fixtures/user_module_loading_error/swagger.yaml", "/api/v1.0", + api = FlaskApi(TEST_FOLDER / "fixtures/user_module_loading_error/swagger.yaml", "/api/v1.0", {'title': 'OK'}, debug=True) assert api.specification['info']['title'] == 'OK' - api = Api(TEST_FOLDER / "fixtures/missing_op_id/swagger.yaml", "/api/v1.0", + api = FlaskApi(TEST_FOLDER / "fixtures/missing_op_id/swagger.yaml", "/api/v1.0", {'title': 'OK'}, debug=True) assert api.specification['info']['title'] == 'OK' @@ -93,18 +95,18 @@ def test_other_errors_stop_application_to_setup(): # The previous tests were just about operationId not being resolvable. # Other errors should still result exceptions! with pytest.raises(InvalidSpecification): - Api(TEST_FOLDER / "fixtures/bad_specs/swagger.yaml", "/api/v1.0", + FlaskApi(TEST_FOLDER / "fixtures/bad_specs/swagger.yaml", "/api/v1.0", {'title': 'OK'}) # Debug mode should ignore the error - api = Api(TEST_FOLDER / "fixtures/bad_specs/swagger.yaml", "/api/v1.0", + api = FlaskApi(TEST_FOLDER / "fixtures/bad_specs/swagger.yaml", "/api/v1.0", {'title': 'OK'}, debug=True) assert api.specification['info']['title'] == 'OK' def test_invalid_schema_file_structure(): with pytest.raises(SwaggerValidationError): - Api(TEST_FOLDER / "fixtures/invalid_schema/swagger.yaml", "/api/v1.0", + FlaskApi(TEST_FOLDER / "fixtures/invalid_schema/swagger.yaml", "/api/v1.0", {'title': 'OK'}, debug=True) @@ -112,7 +114,7 @@ def test_invalid_encoding(): with tempfile.NamedTemporaryFile(mode='wb') as f: f.write(u"swagger: '2.0'\ninfo:\n title: Foo 整\n version: v1\npaths: {}".encode('gbk')) f.flush() - Api(pathlib.Path(f.name), "/api/v1.0") + FlaskApi(pathlib.Path(f.name), "/api/v1.0") def test_use_of_safe_load_for_yaml_swagger_specs(): @@ -121,7 +123,7 @@ def test_use_of_safe_load_for_yaml_swagger_specs(): f.write('!!python/object:object {}\n'.encode()) f.flush() try: - Api(pathlib.Path(f.name), "/api/v1.0") + FlaskApi(pathlib.Path(f.name), "/api/v1.0") except SwaggerValidationError: pytest.fail("Could load invalid YAML file, use yaml.safe_load!") @@ -131,4 +133,4 @@ def test_validation_error_on_completely_invalid_swagger_spec(): with tempfile.NamedTemporaryFile() as f: f.write('[1]\n'.encode()) f.flush() - Api(pathlib.Path(f.name), "/api/v1.0") + FlaskApi(pathlib.Path(f.name), "/api/v1.0") diff --git a/tests/test_cli.py b/tests/test_cli.py index b5b743d1b..f7a52dabe 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,10 +12,10 @@ @pytest.fixture() def mock_app_run(monkeypatch): - test_server = MagicMock(wraps=connexion.App(__name__)) + test_server = MagicMock(wraps=connexion.FlaskApp(__name__)) test_server.run = MagicMock(return_value=True) test_app = MagicMock(return_value=test_server) - monkeypatch.setattr('connexion.cli.connexion.App', test_app) + monkeypatch.setattr('connexion.cli.connexion.FlaskApp', test_app) return test_app diff --git a/tests/test_produces.py b/tests/test_flask_encoder.py similarity index 54% rename from tests/test_produces.py rename to tests/test_flask_encoder.py index a90cd312a..fb143e3ea 100644 --- a/tests/test_produces.py +++ b/tests/test_flask_encoder.py @@ -4,23 +4,23 @@ from decimal import Decimal -from connexion.decorators.produces import JSONEncoder +from connexion.apps.flask_app import FlaskJSONEncoder def test_json_encoder(): - s = json.dumps({1: 2}, cls=JSONEncoder) + s = json.dumps({1: 2}, cls=FlaskJSONEncoder) assert '{"1": 2}' == s - s = json.dumps(datetime.date.today(), cls=JSONEncoder) + s = json.dumps(datetime.date.today(), cls=FlaskJSONEncoder) assert len(s) == 12 - s = json.dumps(datetime.datetime.utcnow(), cls=JSONEncoder) + s = json.dumps(datetime.datetime.utcnow(), cls=FlaskJSONEncoder) assert s.endswith('Z"') - s = json.dumps(Decimal(1.01), cls=JSONEncoder) + s = json.dumps(Decimal(1.01), cls=FlaskJSONEncoder) assert s == '1.01' - s = json.dumps(math.expm1(1e-10), cls=JSONEncoder) + s = json.dumps(math.expm1(1e-10), cls=FlaskJSONEncoder) assert s == '1.00000000005e-10' @@ -34,5 +34,5 @@ def utcoffset(self, dt): def dst(self, dt): return datetime.timedelta(0) - s = json.dumps(datetime.datetime.now(DummyTimezone()), cls=JSONEncoder) + s = json.dumps(datetime.datetime.now(DummyTimezone()), cls=FlaskJSONEncoder) assert s.endswith('+00:00"') diff --git a/tests/test_flask_utils.py b/tests/test_flask_utils.py new file mode 100644 index 000000000..7adcbfb4c --- /dev/null +++ b/tests/test_flask_utils.py @@ -0,0 +1,28 @@ +import math + +import connexion.apps +import connexion.flask_utils as flask_utils +import connexion.utils as utils +import pytest +from mock import MagicMock + + +def test_flaskify_path(): + assert flask_utils.flaskify_path("{test-path}") == "" + assert flask_utils.flaskify_path("api/{test-path}") == "api/" + assert flask_utils.flaskify_path("my-api/{test-path}") == "my-api/" + assert flask_utils.flaskify_path("foo_bar/{a-b}/{c_d}") == "foo_bar//" + assert flask_utils.flaskify_path("foo/{a}/{b}", {'a': 'integer'}) == "foo//" + assert flask_utils.flaskify_path("foo/{a}/{b}", {'a': 'number'}) == "foo//" + assert flask_utils.flaskify_path("foo/{a}/{b}", {'a': 'path'}) == "foo//" + + +def test_flaskify_endpoint(): + assert flask_utils.flaskify_endpoint("module.function") == "module_function" + assert flask_utils.flaskify_endpoint("function") == "function" + + name = 'module.function' + randlen = 6 + res = flask_utils.flaskify_endpoint(name, randlen) + assert res.startswith('module_function') + assert len(res) == len(name) + 1 + randlen diff --git a/tests/test_metrics.py b/tests/test_metrics.py index d2785fd5a..0c51ae271 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,4 +1,8 @@ +import json + import connexion +import flask +from connexion.apis import FlaskApi from connexion.decorators.metrics import UWSGIMetricsCollector from mock import MagicMock @@ -6,13 +10,14 @@ def test_timer(monkeypatch): wrapper = UWSGIMetricsCollector('/foo/bar/', 'get') - def operation(): + def operation(req): return connexion.problem(418, '', '') op = wrapper(operation) metrics = MagicMock() monkeypatch.setattr('flask.request', MagicMock()) + monkeypatch.setattr('flask.current_app', MagicMock(response_class=flask.Response)) monkeypatch.setattr('connexion.decorators.metrics.uwsgi_metrics', metrics) - op() + op(MagicMock()) assert metrics.timer.call_args[0][:2] == ('connexion.response', '418.GET.foo.bar.{param}') diff --git a/tests/test_mock.py b/tests/test_mock.py index 2d7e4cbee..e8f26189c 100644 --- a/tests/test_mock.py +++ b/tests/test_mock.py @@ -24,7 +24,8 @@ def test_mock_resolver(): } } - operation = Operation(method='GET', + operation = Operation(api=None, + method='GET', path='endpoint', path_parameters=[], operation={ @@ -51,7 +52,8 @@ def test_mock_resolver_no_examples(): '418': {} } - operation = Operation(method='GET', + operation = Operation(api=None, + method='GET', path='endpoint', path_parameters=[], operation={ @@ -74,7 +76,8 @@ def test_mock_resolver_no_examples(): def test_mock_resolver_notimplemented(): resolver = MockResolver(mock_all=False) - operation = Operation(method='GET', + operation = Operation(api=None, + method='GET', path='endpoint', path_parameters=[], operation={ diff --git a/tests/test_operation.py b/tests/test_operation.py index 72a7c6372..04d11ffb8 100644 --- a/tests/test_operation.py +++ b/tests/test_operation.py @@ -1,7 +1,10 @@ import pathlib import types +import mock import pytest + +from connexion.apis.flask_api import Jsonifier from connexion.decorators.security import security_passthrough, verify_oauth from connexion.exceptions import InvalidSpecification from connexion.operation import Operation @@ -234,8 +237,14 @@ 'scopes': {'myscope': 'can do stuff'}}} -def test_operation(): - operation = Operation(method='GET', +@pytest.fixture +def api(): + return mock.MagicMock(jsonifier=Jsonifier) + + +def test_operation(api): + operation = Operation(api=api, + method='GET', path='endpoint', path_parameters=[], operation=OPERATION1, @@ -264,8 +273,9 @@ def test_operation(): assert operation.body_schema == expected_body_schema -def test_operation_array(): - operation = Operation(method='GET', +def test_operation_array(api): + operation = Operation(api=api, + method='GET', path='endpoint', path_parameters=[], operation=OPERATION9, @@ -294,8 +304,9 @@ def test_operation_array(): assert operation.body_schema == expected_body_schema -def test_operation_composed_definition(): - operation = Operation(method='GET', +def test_operation_composed_definition(api): + operation = Operation(api=api, + method='GET', path='endpoint', path_parameters=[], operation=OPERATION10, @@ -323,9 +334,10 @@ def test_operation_composed_definition(): assert operation.body_schema == expected_body_schema -def test_non_existent_reference(): +def test_non_existent_reference(api): with pytest.raises(InvalidSpecification) as exc_info: # type: py.code.ExceptionInfo - operation = Operation(method='GET', + operation = Operation(api=api, + method='GET', path='endpoint', path_parameters=[], operation=OPERATION1, @@ -343,9 +355,10 @@ def test_non_existent_reference(): assert repr(exception) == "" -def test_multi_body(): +def test_multi_body(api): with pytest.raises(InvalidSpecification) as exc_info: # type: py.code.ExceptionInfo - operation = Operation(method='GET', + operation = Operation(api=api, + method='GET', path='endpoint', path_parameters=[], operation=OPERATION2, @@ -363,9 +376,10 @@ def test_multi_body(): assert repr(exception) == "" -def test_invalid_reference(): +def test_invalid_reference(api): with pytest.raises(InvalidSpecification) as exc_info: # type: py.code.ExceptionInfo - operation = Operation(method='GET', + operation = Operation(api=api, + method='GET', path='endpoint', path_parameters=[], operation=OPERATION3, @@ -383,8 +397,9 @@ def test_invalid_reference(): assert repr(exception).startswith("" - assert utils.flaskify_path("api/{test-path}") == "api/" - assert utils.flaskify_path("my-api/{test-path}") == "my-api/" - assert utils.flaskify_path("foo_bar/{a-b}/{c_d}") == "foo_bar//" - assert utils.flaskify_path("foo/{a}/{b}", {'a': 'integer'}) == "foo//" - assert utils.flaskify_path("foo/{a}/{b}", {'a': 'number'}) == "foo//" - assert utils.flaskify_path("foo/{a}/{b}", {'a': 'path'}) == "foo//" - -def test_flaskify_endpoint(): - assert utils.flaskify_endpoint("module.function") == "module_function" - assert utils.flaskify_endpoint("function") == "function" - - name = 'module.function' - randlen = 6 - res = utils.flaskify_endpoint(name, randlen) - assert res.startswith('module_function') - assert len(res) == len(name) + 1 + randlen +from mock import MagicMock def test_get_function_from_name(): @@ -47,8 +28,8 @@ def test_get_function_from_name_attr_error(monkeypatch): def test_get_function_from_name_for_class_method(): - function = utils.get_function_from_name('connexion.app.App.common_error_handler') - assert function == connexion.app.App.common_error_handler + function = utils.get_function_from_name('connexion.apps.FlaskApp.common_error_handler') + assert function == connexion.apps.FlaskApp.common_error_handler def test_boolean(): diff --git a/tests/test_validation.py b/tests/test_validation.py index c19fa1e02..8e97e3454 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,10 +1,13 @@ import json +import flask +from connexion.apis.flask_api import FlaskApi from connexion.decorators.validation import ParameterValidator # we are using "mock" module here for Py 2.7 support from mock import MagicMock + def test_parameter_validator(monkeypatch): request = MagicMock(name='request') request.args = {} @@ -12,13 +15,7 @@ def test_parameter_validator(monkeypatch): request.params = {} app = MagicMock(name='app') - def _response_class(data, mimetype, content_type, headers): - response = MagicMock(name='response') - response.detail = json.loads(''.join(data))['detail'] - response.headers = MagicMock() - return response - - app.response_class = _response_class + app.response_class = flask.Response monkeypatch.setattr('flask.request', request) monkeypatch.setattr('flask.current_app', app) @@ -30,34 +27,39 @@ def orig_handler(*args, **kwargs): {'name': 'q1', 'in': 'query', 'type': 'integer', 'maximum': 3}, {'name': 'a1', 'in': 'query', 'type': 'array', 'minItems': 2, 'maxItems': 3, 'items': {'type': 'integer', 'minimum': 0}}] - validator = ParameterValidator(params) + validator = ParameterValidator(params, FlaskApi) handler = validator(orig_handler) - assert handler().flask_response_object().detail == "Missing path parameter 'p1'" - assert handler(p1='123') == 'OK' - assert handler(p1='').flask_response_object().detail == "Wrong type, expected 'integer' for path parameter 'p1'" - assert handler(p1='foo').flask_response_object().detail == "Wrong type, expected 'integer' for path parameter 'p1'" - assert handler(p1='1.2').flask_response_object().detail == "Wrong type, expected 'integer' for path parameter 'p1'" - - request.args = {'q1': '4'} - assert handler(p1=1).flask_response_object().detail.startswith('4 is greater than the maximum of 3') - request.args = {'q1': '3'} - assert handler(p1=1) == 'OK' - - request.args = {'a1': "1,2"} - assert handler(p1=1) == "OK" - request.args = {'a1': "1,a"} - assert handler(p1=1).flask_response_object().detail.startswith("'a' is not of type 'integer'") - request.args = {'a1': "1,-1"} - assert handler(p1=1).flask_response_object().detail.startswith("-1 is less than the minimum of 0") - request.args = {'a1': "1"} - assert handler(p1=1).flask_response_object().detail.startswith("[1] is too short") - request.args = {'a1': "1,2,3,4"} - assert handler(p1=1).flask_response_object().detail.startswith("[1, 2, 3, 4] is too long") - del request.args['a1'] - - request.headers = {'h1': 'a'} - assert handler(p1='123') == 'OK' - - request.headers = {'h1': 'x'} - assert handler(p1='123').flask_response_object().detail.startswith("'x' is not one of ['a', 'b']") + kwargs = {'query': {}, 'headers': {}} + request = MagicMock(path_params={}, **kwargs) + assert json.loads(handler(request).data.decode())['detail'] == "Missing path parameter 'p1'" + request = MagicMock(path_params={'p1': '123'}, **kwargs) + assert handler(request) == 'OK' + request = MagicMock(path_params={'p1': ''}, **kwargs) + assert json.loads(handler(request).data.decode())['detail'] == "Wrong type, expected 'integer' for path parameter 'p1'" + request = MagicMock(path_params={'p1': 'foo'}, **kwargs) + assert json.loads(handler(request).data.decode())['detail'] == "Wrong type, expected 'integer' for path parameter 'p1'" + request = MagicMock(path_params={'p1': '1.2'}, **kwargs) + assert json.loads(handler(request).data.decode())['detail'] == "Wrong type, expected 'integer' for path parameter 'p1'" + + request = MagicMock(path_params={'p1': 1}, query={'q1': '4'}, headers={}) + assert json.loads(handler(request).data.decode())['detail'].startswith('4 is greater than the maximum of 3') + request = MagicMock(path_params={'p1': 1}, query={'q1': '3'}, headers={}) + assert handler(request) == 'OK' + + request = MagicMock(path_params={'p1': 1}, query={'a1': "1,2"}, headers={}) + assert handler(request) == "OK" + request = MagicMock(path_params={'p1': 1}, query={'a1': "1,a"}, headers={}) + assert json.loads(handler(request).data.decode())['detail'].startswith("'a' is not of type 'integer'") + request = MagicMock(path_params={'p1': 1}, query={'a1': "1,-1"}, headers={}) + assert json.loads(handler(request).data.decode())['detail'].startswith("-1 is less than the minimum of 0") + request = MagicMock(path_params={'p1': 1}, query={'a1': "1"}, headers={}) + assert json.loads(handler(request).data.decode())['detail'].startswith("[1] is too short") + request = MagicMock(path_params={'p1': 1}, query={'a1': "1,2,3,4"}, headers={}) + assert json.loads(handler(request).data.decode())['detail'].startswith("[1, 2, 3, 4] is too long") + + request = MagicMock(path_params={'p1': '123'}, query={}, headers={'h1': 'a'}) + assert handler(request) == 'OK' + + request = MagicMock(path_params={'p1': '123'}, query={}, headers={'h1': 'x'}) + assert json.loads(handler(request).data.decode())['detail'].startswith("'x' is not one of ['a', 'b']") diff --git a/tests/util.py b/tests/util.py index 805ddf52f..0bc9393b5 100644 --- a/tests/util.py +++ b/tests/util.py @@ -3,7 +3,8 @@ import pathlib import pytest -from connexion.app import App +from connexion.apis import FlaskApi +from connexion.apps import FlaskApp logging.basicConfig(level=logging.DEBUG) @@ -53,7 +54,7 @@ def fake_get(url, params=None, timeout=None): @pytest.fixture def app(): - app = App(__name__, 5001, SPEC_FOLDER, debug=True) + app = FlaskApp(__name__, 5001, SPEC_FOLDER, debug=True) app.add_api('api.yaml', validate_responses=True) return app @@ -70,12 +71,12 @@ def problem_api_spec_dir(): @pytest.fixture def simple_app(simple_api_spec_dir): - app = App(__name__, 5001, simple_api_spec_dir, debug=True) + app = FlaskApp(__name__, 5001, simple_api_spec_dir, debug=True) app.add_api('swagger.yaml', validate_responses=True) return app @pytest.fixture def problem_app(problem_api_spec_dir): - app = App(__name__, 5001, problem_api_spec_dir, debug=True) + app = FlaskApp(__name__, 5001, problem_api_spec_dir, debug=True) app.add_api('swagger.yaml', validate_responses=True)