Skip to content

Commit

Permalink
Merge pull request #56 from igieon/global_decorators
Browse files Browse the repository at this point in the history
Enable documented global decorators
  • Loading branch information
mahenzon committed Aug 17, 2021
2 parents 620a96f + d2f45b1 commit 3684d2e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
5 changes: 5 additions & 0 deletions flask_combo_jsonapi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def route(self, resource, view, *urls, **kwargs):
resource.view = view
url_rule_options = kwargs.get('url_rule_options') or dict()

if hasattr(resource, 'decorators'):
resource.decorators += self.decorators
else:
resource.decorators = self.decorators

view_func = resource.as_view(view)

if 'blueprint' in kwargs:
Expand Down
60 changes: 43 additions & 17 deletions tests/test_sqlalchemy_data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from sqlalchemy import create_engine, Column, Integer, DateTime, String, ForeignKey
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.ext.declarative import declarative_base
from flask import Blueprint, make_response, json
from flask import Blueprint, make_response, json, request
from marshmallow_jsonapi.flask import Schema, Relationship
from marshmallow import Schema as MarshmallowSchema
from marshmallow_jsonapi import fields
from marshmallow import ValidationError
from werkzeug.exceptions import Unauthorized

from flask_combo_jsonapi import Api, ResourceList, ResourceDetail, ResourceRelationship, JsonApiException
from flask_combo_jsonapi.pagination import add_pagination_links
Expand Down Expand Up @@ -231,9 +232,26 @@ def address(session, address_model):


@pytest.fixture(scope="module")
def dummy_decorator():
def custom_auth_decorator():
def deco(f):
def wrapper_f(*args, **kwargs):
auth = request.headers.get("auth", None)
if auth == '123':
raise Unauthorized()
return f(*args, **kwargs)

return wrapper_f

yield deco


@pytest.fixture(scope="module")
def custom_auth_decorator_2():
def deco(f):
def wrapper_f(*args, **kwargs):
auth = request.headers.get("auth", None)
if auth == '1234':
raise Unauthorized()
return f(*args, **kwargs)

return wrapper_f
Expand Down Expand Up @@ -405,16 +423,14 @@ def before_delete_object_(self, obj, view_kwargs):


@pytest.fixture(scope="module")
def person_list(session, person_model, dummy_decorator, person_schema, before_create_object):
def person_list(session, person_model, person_schema, before_create_object):
class PersonList(ResourceList):
schema = person_schema
data_layer = {
"model": person_model,
"session": session,
"methods": {"before_create_object": before_create_object},
}
get_decorators = [dummy_decorator]
post_decorators = [dummy_decorator]
get_schema_kwargs = dict()
post_schema_kwargs = dict()

Expand Down Expand Up @@ -459,7 +475,8 @@ class PersonList(ResourceList):


@pytest.fixture(scope="module")
def person_detail(session, person_model, dummy_decorator, person_schema, before_update_object, before_delete_object):
def person_detail(session, person_model, person_schema, before_update_object, before_delete_object,
custom_auth_decorator_2):
class PersonDetail(ResourceDetail):
schema = person_schema
data_layer = {
Expand All @@ -468,25 +485,19 @@ class PersonDetail(ResourceDetail):
"url_field": "person_id",
"methods": {"before_update_object": before_update_object, "before_delete_object": before_delete_object},
}
get_decorators = [dummy_decorator]
patch_decorators = [dummy_decorator]
delete_decorators = [dummy_decorator]
get_schema_kwargs = dict()
patch_schema_kwargs = dict()
delete_schema_kwargs = dict()
decorators = (custom_auth_decorator_2,)

yield PersonDetail


@pytest.fixture(scope="module")
def person_computers(session, person_model, dummy_decorator, person_schema):
def person_computers(session, person_model, person_schema):
class PersonComputersRelationship(ResourceRelationship):
schema = person_schema
data_layer = {"session": session, "model": person_model, "url_field": "person_id"}
get_decorators = [dummy_decorator]
post_decorators = [dummy_decorator]
patch_decorators = [dummy_decorator]
delete_decorators = [dummy_decorator]

yield PersonComputersRelationship

Expand Down Expand Up @@ -566,7 +577,7 @@ class ComputerList(ResourceList):


@pytest.fixture(scope="module")
def computer_detail(session, computer_model, dummy_decorator, computer_schema):
def computer_detail(session, computer_model, computer_schema):
class ComputerDetail(ResourceDetail):
schema = computer_schema
data_layer = {"model": computer_model, "session": session}
Expand All @@ -576,7 +587,7 @@ class ComputerDetail(ResourceDetail):


@pytest.fixture(scope="module")
def computer_owner(session, computer_model, dummy_decorator, computer_schema):
def computer_owner(session, computer_model, computer_schema):
class ComputerOwnerRelationship(ResourceRelationship):
schema = computer_schema
data_layer = {"session": session, "model": computer_model}
Expand Down Expand Up @@ -628,6 +639,7 @@ def register_routes(
client,
app,
api_blueprint,
custom_auth_decorator,
person_list,
person_detail,
person_computers,
Expand All @@ -643,7 +655,7 @@ def register_routes(
string_json_attribute_person_detail,
string_json_attribute_person_list,
):
api = Api(blueprint=api_blueprint)
api = Api(blueprint=api_blueprint, decorators=(custom_auth_decorator,))
api.route(person_list, "person_list", "/persons")
api.route(person_list_custom_qs_manager, "person_list_custom_qs_manager", "/persons_qs")
api.route(person_detail, "person_detail", "/persons/<int:person_id>")
Expand Down Expand Up @@ -941,6 +953,7 @@ def test_post_list_nested(client, register_routes, computer):
assert response.status_code == 201
assert json.loads(response.get_data())["data"]["attributes"]["tags"][0]["key"] == "k1"


def test_post_list_nested_field(client, register_routes):
"""
Test a schema contains a nested field is correctly serialized and deserialized
Expand Down Expand Up @@ -991,6 +1004,19 @@ def test_get_detail(client, register_routes, person):
assert response.status_code == 200


def test_get_detail_custom_auth_decorator_global(client, register_routes, person):
with client:
response = client.get("/persons/" + str(person.person_id), content_type="application/vnd.api+json",
headers={'auth': '123'})
assert response.status_code == 401

def test_get_detail_custom_auth_decorator_resource_level(client, register_routes, person):
with client:
response = client.get("/persons/" + str(person.person_id), content_type="application/vnd.api+json",
headers={'auth': '1234'})
assert response.status_code == 401


def test_patch_detail(client, register_routes, computer, person):
payload = {
"data": {
Expand Down

0 comments on commit 3684d2e

Please sign in to comment.