Skip to content

Commit

Permalink
Merge pull request #101 from iterait/dev
Browse files Browse the repository at this point in the history
Release 0.9.4
  • Loading branch information
Jan Buchar authored Oct 7, 2019
2 parents c9834c7 + f058eab commit fbe164e
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 47 deletions.
2 changes: 1 addition & 1 deletion apistrap/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.3"
__version__ = "0.9.4"
42 changes: 38 additions & 4 deletions apistrap/aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import inspect
import json
import logging
Expand All @@ -6,7 +8,7 @@
from functools import wraps
from os import path
from pathlib import Path
from typing import Any, Callable, Coroutine, Generator, List, Optional, Sequence, Tuple, Type
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, Tuple, Type

import jinja2
from aiohttp import StreamReader, web
Expand All @@ -16,17 +18,22 @@
from aiohttp.web_urldispatcher import AbstractRoute, DynamicResource, PlainResource

from apistrap.errors import ApiClientError
from apistrap.extension import Apistrap, ErrorHandler
from apistrap.extension import Apistrap, ErrorHandler, SecurityScheme
from apistrap.operation_wrapper import OperationWrapper
from apistrap.schemas import ErrorResponse
from apistrap.types import FileResponse
from apistrap.utils import format_exception, resolve_fw_decl

SecurityEnforcer = Callable[[BaseRequest, Sequence[str]], None]


class AioHTTPOperationWrapper(OperationWrapper):
def __init__(self, extension: Apistrap, function: Callable, decorators: Sequence[object], route: AbstractRoute):
def __init__(
self, extension: AioHTTPApistrap, function: Callable, decorators: Sequence[object], route: AbstractRoute
):
self.route = route
super().__init__(extension, function, decorators)
self._extension = extension

def process_metadata(self):
super().process_metadata()
Expand Down Expand Up @@ -58,10 +65,24 @@ def _get_aiohttp_request_param_name(self) -> Optional[str]:

return None

def _enforce_security(self, request):
error = None

for security_scheme, required_scopes in self._get_required_scopes():
try:
# If any enforcer passes without throwing, the user is authenticated
self._extension.security_enforcers[security_scheme](request, required_scopes)
return
except Exception as e:
error = e
else:
if error is not None:
raise error

def get_decorated_function(self):
@wraps(self._wrapped_function)
async def wrapper(request: Request):
self._check_security()
self._enforce_security(request)

kwargs = {}

Expand Down Expand Up @@ -235,6 +256,7 @@ def __init__(self):
super().__init__()
self.app: web.Application = None
self.error_middleware = ErrorHandlerMiddleware(self)
self.security_enforcers: Dict[SecurityScheme, SecurityEnforcer] = {}
self._jinja_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(path.join(path.dirname(__file__), "templates"))
)
Expand Down Expand Up @@ -266,6 +288,18 @@ def init_app(self, app: web.Application) -> None:
for redoc_url in (self.redoc_url, self.redoc_url + "/"):
app.router.add_route("get", redoc_url, self._get_redoc)

def add_security_scheme(self, scheme: SecurityScheme, enforcer: SecurityEnforcer):
"""
Add a security scheme to be used by the API.
:param scheme: a description of the security scheme
:param enforcer: a function that checks the requirements of the security scheme
"""

self.security_schemes.append(scheme)
self.spec.components.security_scheme(scheme.name, scheme.to_openapi_dict())
self.security_enforcers[scheme] = enforcer

def _handle_server_error(self, exception):
"""
Default handler for server errors (500-599).
Expand Down
8 changes: 7 additions & 1 deletion apistrap/decorators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence, Type, Union
from typing import TYPE_CHECKING, Optional, Sequence, Type, Union

from schematics import Model

from apistrap.tags import TagData

if TYPE_CHECKING:
from apistrap.extension import SecurityScheme


class IgnoreDecorator:
"""
Expand Down Expand Up @@ -77,3 +82,4 @@ class SecurityDecorator:
"""

scopes: Sequence[str]
security_scheme: Optional[SecurityScheme] = None
20 changes: 5 additions & 15 deletions apistrap/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ class SecurityScheme(metaclass=ABCMeta):
Description of an authentication method.
"""

def __init__(self, name: str, enforcer: Callable[[List[str]], None]):
def __init__(self, name: str):
"""
:param name: Name of the scheme (used as the name in the OpenAPI specification)
:param enforcer: A function that takes a list of scopes and raises an error if the user doesn't have them
:param enforcer: An object invoked by an extension that takes a list of scopes and raises an error if the user
doesn't have them
"""
self.name = name
self.enforcer = enforcer

@abc.abstractmethod
def to_openapi_dict(self):
Expand Down Expand Up @@ -83,11 +83,11 @@ class OAuthSecurity(SecurityScheme):
A description of an OAuth security scheme with an arbitrary list of OAuth 2 flows
"""

def __init__(self, name: str, enforcer: Callable, *flows: OAuthFlowDefinition):
def __init__(self, name: str, *flows: OAuthFlowDefinition):
"""
:param flows: A list of OAuth 2 flows allowed by the security scheme
"""
super().__init__(name, enforcer)
super().__init__(name)
self.flows = flows

def to_openapi_dict(self):
Expand Down Expand Up @@ -362,16 +362,6 @@ def add_schema_definition(self, name: str, schema: dict):

return f"#/components/schemas/{name}"

def add_security_scheme(self, scheme: SecurityScheme):
"""
Add a security scheme to be used by the API.
:param scheme: a description of the security scheme
"""

self.security_schemes.append(scheme)
self.spec.components.security_scheme(scheme.name, scheme.to_openapi_dict())

def add_tag_data(self, tag: TagData) -> None:
"""
Add information about a tag to the specification.
Expand Down
35 changes: 32 additions & 3 deletions apistrap/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
import re
from functools import wraps
from os import path
from typing import Callable, Generator, List, Optional, Sequence, Tuple, Type
from typing import Callable, Dict, Generator, List, Optional, Sequence, Tuple, Type

from flask import Blueprint, Flask, Response, jsonify, render_template, request, send_file
from werkzeug.exceptions import HTTPException

from apistrap.errors import ApiClientError, ApiServerError
from apistrap.extension import Apistrap, ErrorHandler
from apistrap.extension import Apistrap, ErrorHandler, SecurityScheme
from apistrap.operation_wrapper import OperationWrapper
from apistrap.schemas import ErrorResponse
from apistrap.types import FileResponse
from apistrap.utils import format_exception, resolve_fw_decl

SecurityEnforcer = Callable[[Sequence[str]], None]


class FlaskOperationWrapper(OperationWrapper):
URL_FILTER_MAP = {"string": str, "int": int, "float": float, "path": str}
Expand All @@ -27,10 +29,24 @@ def __init__(
self.method = method
super().__init__(extension, function, decorators)

def _enforce_security(self):
error = None

for security_scheme, required_scopes in self._get_required_scopes():
try:
# If any enforcer passes without throwing, the user is authenticated
self._extension.security_enforcers[security_scheme](required_scopes)
return
except Exception as e:
error = e
else:
if error is not None:
raise error

def get_decorated_function(self):
@wraps(self._wrapped_function)
def wrapper(*args, **kwargs):
self._check_security()
self._enforce_security()

if self.accepts_body:
self._check_request_content_type(request.content_type)
Expand Down Expand Up @@ -100,6 +116,7 @@ def __init__(self):
self._app: Flask = None
self._specs_extracted = False
self._operations: Optional[List[FlaskOperationWrapper]] = None
self.security_enforcers: Dict[SecurityScheme, SecurityEnforcer] = {}

self._default_error_handlers = (
ErrorHandler(HTTPException, lambda exc_type: exc_type.code, self.http_error_handler),
Expand Down Expand Up @@ -134,6 +151,18 @@ def init_app(self, app: Flask):

app.before_first_request_funcs.append(self._decorate_view_handlers)

def add_security_scheme(self, scheme: SecurityScheme, enforcer: SecurityEnforcer):
"""
Add a security scheme to be used by the API.
:param scheme: a description of the security scheme
:param enforcer: a function that checks the requirements of the security scheme
"""

self.security_schemes.append(scheme)
self.spec.components.security_scheme(scheme.name, scheme.to_openapi_dict())
self.security_enforcers[scheme] = enforcer

def _error_handler(self, exception: Exception):
response = self.exception_to_response(exception)

Expand Down
24 changes: 12 additions & 12 deletions apistrap/operation_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from apistrap.utils import resolve_fw_decl, snake_to_camel

if TYPE_CHECKING: # pragma: no cover
from apistrap.extension import Apistrap
from apistrap.extension import Apistrap, SecurityEnforcer, SecurityScheme


@dataclass(frozen=True)
Expand Down Expand Up @@ -249,18 +249,21 @@ def _load_request_body(self, body_primitive) -> Dict[str, Model]:

return {self._request_body_parameter: body}

def _check_security(self):
def _get_required_scopes(self) -> Generator[Tuple[SecurityScheme, Sequence[str]]]:
"""
Ensure security policies are met.
Get a list of scopes required by the endpoint.
"""
for security_decorator in self._find_decorators(SecurityDecorator):
if len(self._extension.security_schemes) > 1 and security_decorator.security_scheme is None:
raise TypeError(
"Multiple security schemes are defined - cannot use security decorator without an explicit scheme"
)

security_decorator = next(self._find_decorators(SecurityDecorator), None)
if len(self._extension.security_schemes) == 0:
raise TypeError("At least one security scheme must be defined in order to use the security decorator")

if security_decorator is None:
return

for scheme in self._extension.security_schemes:
scheme.enforcer(security_decorator.scopes)
scheme = security_decorator.security_scheme or self._extension.security_schemes[0]
yield scheme, security_decorator.scopes

def _postprocess_response(self, response: Union[Model, Tuple[Model, int]]) -> Tuple[Model, int, Optional[str]]:
"""
Expand Down Expand Up @@ -466,9 +469,6 @@ def _get_security_requirements(self) -> Generator[Dict[str, Sequence[str]], None
"""
decorators = [*self._find_decorators(SecurityDecorator)]

if len(decorators) > 1:
raise TypeError("Only one security decorator per view is allowed")

if len(decorators) == 0:
return # No security requirements

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import find_packages, setup

setup(name='apistrap',
version='0.9.3',
version='0.9.4',
description='Iterait REST API utilities',
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
12 changes: 6 additions & 6 deletions tests/test_security_enforcement.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Sequence

import pytest
from flask import jsonify
Expand All @@ -8,7 +8,7 @@
from apistrap.schemas import ErrorResponse


def enforcer(scopes: List[str]):
def enforcer(scopes: Sequence[str]):
user_scopes = ["read", "write"]

if not all((scope in user_scopes) for scope in scopes):
Expand All @@ -25,14 +25,14 @@ def app_with_oauth(app):
oapi.add_security_scheme(
OAuthSecurity(
"oauth",
enforcer,
OAuthFlowDefinition(
"authorization_code",
{"read": "Read stuff", "write": "Write stuff", "frobnicate": "Frobnicate stuff"},
"/auth",
"/token",
),
)
),
enforcer
)

oapi.add_error_handler(ForbiddenRequestError, 403, lambda _: ErrorResponse())
Expand Down Expand Up @@ -62,14 +62,14 @@ def app_with_oauth_and_unsecured_endpoint(app):
oapi.add_security_scheme(
OAuthSecurity(
"oauth",
enforcer,
OAuthFlowDefinition(
"authorization_code",
{"read": "Read stuff", "write": "Write stuff", "frobnicate": "Frobnicate stuff"},
"/auth",
"/token",
),
)
),
enforcer
)

oapi.add_error_handler(ForbiddenRequestError, 403, lambda _: ErrorResponse())
Expand Down
8 changes: 4 additions & 4 deletions tests/test_security_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ def app_with_oauth(app):
oapi.add_security_scheme(
OAuthSecurity(
"oauth",
lambda scopes: None,
OAuthFlowDefinition(
"authorization_code", {"read": "Read stuff", "write": "Write stuff"}, "/auth", "/token"
),
)
),
lambda scopes: None
)

@app.route("/secured", methods=["GET"])
Expand Down Expand Up @@ -61,11 +61,11 @@ def app_with_oauth_non_string_scopes(app):
oapi.add_security_scheme(
OAuthSecurity(
"oauth",
lambda scopes: None,
OAuthFlowDefinition(
"authorization_code", {"read": "Read stuff", "write": "Write stuff"}, "/auth", "/token"
),
)
),
lambda scopes: None
)

@app.route("/secured", methods=["GET"])
Expand Down

0 comments on commit fbe164e

Please sign in to comment.