Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract security to middleware #1514

Merged
merged 12 commits into from
Apr 27, 2022
8 changes: 0 additions & 8 deletions connexion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,16 @@
specified.
"""

import sys

import werkzeug.exceptions as exceptions # NOQA

from .apis import AbstractAPI # NOQA
from .apps import AbstractApp # NOQA
from .decorators.produces import NoContent # NOQA
from .exceptions import ProblemException # NOQA
# add operation for backwards compatibility
from .operations import compat
from .problem import problem # NOQA
from .resolver import Resolution, Resolver, RestyResolver # NOQA
from .utils import not_installed_error # NOQA

full_name = f'{__package__}.operation'
sys.modules[full_name] = sys.modules[compat.__name__]
Ruwann marked this conversation as resolved.
Show resolved Hide resolved


try:
from flask import request # NOQA

Expand Down
32 changes: 3 additions & 29 deletions connexion/apis/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,13 @@ def __init__(
self.debug = debug
self.resolver_error_handler = resolver_error_handler

logger.debug('Security Definitions: %s', self.specification.security_definitions)

self.resolver = resolver or Resolver()

logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
self.pass_context_arg_name = pass_context_arg_name

self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)

self.add_paths()

@staticmethod
@abc.abstractmethod
def make_security_handler_factory(pass_context_arg_name):
""" Create SecurityHandlerFactory to create all security check handlers """

def add_paths(self, paths: t.Optional[dict] = None) -> None:
"""
Adds the paths defined in the specification as endpoints
Expand Down Expand Up @@ -196,8 +187,6 @@ def _add_resolver_error_handler(self, method: str, path: str, err: ResolverError
"""
operation = self.resolver_error_handler(
err,
security=self.specification.security,
security_definitions=self.specification.security_definitions
)
self._add_operation_internal(method, path, operation)

Expand All @@ -221,13 +210,11 @@ class AbstractAPI(AbstractMinimalAPI, metaclass=AbstractAPIMeta):

def __init__(self, specification, base_path=None, arguments=None,
validate_responses=False, strict_validation=False, resolver=None,
auth_all_paths=False, debug=False, resolver_error_handler=None,
validator_map=None, pythonic_params=False, pass_context_arg_name=None, options=None,
):
debug=False, resolver_error_handler=None, validator_map=None,
pythonic_params=False, pass_context_arg_name=None, options=None, **kwargs):
"""
:type validate_responses: bool
:type strict_validation: bool
:type auth_all_paths: bool
:param validator_map: Custom validators for the types "parameter", "body" and "response".
:type validator_map: dict
:type resolver_error_handler: callable | None
Expand All @@ -247,22 +234,9 @@ def __init__(self, specification, base_path=None, arguments=None,
self.pythonic_params = pythonic_params

super().__init__(specification, base_path=base_path, arguments=arguments,
resolver=resolver, auth_all_paths=auth_all_paths,
resolver_error_handler=resolver_error_handler,
resolver=resolver, resolver_error_handler=resolver_error_handler,
debug=debug, pass_context_arg_name=pass_context_arg_name, options=options)

if auth_all_paths:
self.add_auth_on_not_found(
self.specification.security,
self.specification.security_definitions
)

@abc.abstractmethod
def add_auth_on_not_found(self, security, security_definitions):
"""
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
"""

def add_operation(self, path, method):
"""
Adds one operation to the api.
Expand Down
22 changes: 2 additions & 20 deletions connexion/apis/flask_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,19 @@
from typing import Any

import flask
import werkzeug.exceptions
from werkzeug.local import LocalProxy

from connexion.apis import flask_utils
from connexion.apis.abstract import AbstractAPI
from connexion.handlers import AuthErrorHandler
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.security import FlaskSecurityHandlerFactory
from connexion.utils import is_json_mimetype

logger = logging.getLogger('connexion.apis.flask_api')


class FlaskApi(AbstractAPI):

@staticmethod
def make_security_handler_factory(pass_context_arg_name):
""" Create default SecurityHandlerFactory to create all security check handlers """
return FlaskSecurityHandlerFactory(pass_context_arg_name)

def _set_base_path(self, base_path):
super()._set_base_path(base_path)
self._set_blueprint()
Expand All @@ -39,16 +31,6 @@ def _set_blueprint(self):
self.blueprint = flask.Blueprint(endpoint, __name__, url_prefix=self.base_path,
template_folder=str(self.options.openapi_console_ui_from_dir))

def add_auth_on_not_found(self, security, security_definitions):
RobbeSneyders marked this conversation as resolved.
Show resolved Hide resolved
"""
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
"""
logger.debug('Adding path not found authentication')
not_found_error = AuthErrorHandler(self, werkzeug.exceptions.NotFound(), security=security,
security_definitions=security_definitions)
endpoint_name = f"{self.blueprint.name}_not_found"
self.blueprint.add_url_rule('/<path:invalid_path>', endpoint_name, not_found_error.function)

def _add_operation_internal(self, method, path, operation):
operation_id = operation.operation_id
logger.debug('... Adding %s -> %s', method.upper(), operation_id,
Expand Down Expand Up @@ -156,9 +138,9 @@ def get_request(cls, *args, **params):

:rtype: ConnexionRequest
"""
context_dict = {}
setattr(flask._request_ctx_stack.top, 'connexion_context', context_dict)
flask_request = flask.request
context_dict = flask_request.environ['asgi.scope'].get('context', {})
Ruwann marked this conversation as resolved.
Show resolved Hide resolved
setattr(flask._request_ctx_stack.top, 'connexion_context', context_dict)
request = ConnexionRequest(
flask_request.url,
flask_request.method,
Expand Down
2 changes: 1 addition & 1 deletion connexion/apps/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def add_api(self, specification, base_path=None, arguments=None,

def _resolver_error_handler(self, *args, **kwargs):
from connexion.handlers import ResolverErrorHandler
return ResolverErrorHandler(self.api_cls, self.resolver_error, *args, **kwargs)
return ResolverErrorHandler(self.resolver_error, *args, **kwargs)

def add_url_rule(self, rule, endpoint=None, view_func=None, **options):
"""
Expand Down
29 changes: 28 additions & 1 deletion connexion/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings

from jsonschema.exceptions import ValidationError
from werkzeug.exceptions import Forbidden, Unauthorized
from starlette.exceptions import HTTPException

from .problem import problem

Expand Down Expand Up @@ -123,6 +123,20 @@ def __init__(self, message, reason="Response headers do not conform to specifica
super().__init__(reason=reason, message=message)


class Unauthorized(HTTPException):

description = (
"The server could not verify that you are authorized to access"
" the URL requested. You either supplied the wrong credentials"
" (e.g. a bad password), or your browser doesn't understand"
" how to supply the credentials required."
)

def __init__(self, **kwargs):
kwargs.setdefault('detail', self.description)
RobbeSneyders marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(401, **kwargs)


class OAuthProblem(Unauthorized):
pass

Expand All @@ -133,6 +147,19 @@ def __init__(self, token_response, **kwargs):
super().__init__(**kwargs)


class Forbidden(HTTPException):

description = (
"You don't have the permission to access the requested"
" resource. It is either read-protected or not readable by the"
" server."
)

def __init__(self, **kwargs):
kwargs.setdefault('detail', self.description)
super().__init__(403, **kwargs)


class OAuthScopeProblem(Forbidden):
def __init__(self, token_scopes, required_scopes, **kwargs):
self.required_scopes = required_scopes
Expand Down
52 changes: 3 additions & 49 deletions connexion/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,67 +4,21 @@

import logging

from .exceptions import AuthenticationProblem, ResolverProblem
from .operations.secure import SecureOperation
from .exceptions import ResolverProblem

logger = logging.getLogger('connexion.handlers')

RESOLVER_ERROR_ENDPOINT_RANDOM_DIGITS = 6


class AuthErrorHandler(SecureOperation):
"""
Wraps an error with authentication.
"""

def __init__(self, api, exception, security, security_definitions):
"""
This class uses the exception instance to produce the proper response problem in case the
request is authenticated.

:param exception: the exception to be wrapped with authentication
:type exception: werkzeug.exceptions.HTTPException
:param security: list of security rules the application uses by default
:type security: list
:param security_definitions: `Security Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
:type security_definitions: dict
"""
self.exception = exception
super().__init__(api, security, security_definitions)

@property
def function(self):
"""
Configured error auth handler.
"""
security_decorator = self.security_decorator
logger.debug('... Adding security decorator (%r)', security_decorator, extra=vars(self))
function = self.handle
function = security_decorator(function)
function = self._request_response_decorator(function)
return function

def handle(self, *args, **kwargs):
"""
Actual handler for the execution after authentication.
"""
raise AuthenticationProblem(
title=self.exception.name,
detail=self.exception.description,
status=self.exception.code
)


class ResolverErrorHandler(SecureOperation):
class ResolverErrorHandler:
"""
Handler for responding to ResolverError.
"""

def __init__(self, api, status_code, exception, security, security_definitions):
def __init__(self, status_code, exception):
self.status_code = status_code
self.exception = exception
super().__init__(api, security, security_definitions)

@property
def function(self):
Expand Down
27 changes: 26 additions & 1 deletion connexion/middleware/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,38 @@
from starlette.requests import Request
from starlette.responses import Response

from connexion.exceptions import problem
from connexion.exceptions import ProblemException, problem


class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
existing connexion behavior."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_exception_handler(ProblemException, self.problem_handler)

def problem_handler(self, _, exception: ProblemException):
"""
:type exception: Exception
"""
connexion_response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext
)

return Response(
RobbeSneyders marked this conversation as resolved.
Show resolved Hide resolved
content=json.dumps(connexion_response.body),
status_code=connexion_response.status_code,
media_type=connexion_response.mimetype,
headers=connexion_response.headers
)
RobbeSneyders marked this conversation as resolved.
Show resolved Hide resolved

def http_exception(self, request: Request, exc: HTTPException) -> Response:
try:
headers = exc.headers
Expand Down
2 changes: 2 additions & 0 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from connexion.middleware.abstract import AppMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware


Expand All @@ -15,6 +16,7 @@ class ConnexionMiddleware:
ExceptionMiddleware,
SwaggerUIMiddleware,
RoutingMiddleware,
SecurityMiddleware,
]

def __init__(
Expand Down
5 changes: 0 additions & 5 deletions connexion/middleware/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,3 @@ def patch_operation_function():

def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None:
self.router.add_route(path, operation.function, methods=[method])

@staticmethod
def make_security_handler_factory(pass_context_arg_name):
""" Create default SecurityHandlerFactory to create all security check handlers """
pass
Loading