Skip to content

Commit

Permalink
Use decorators argument
Browse files Browse the repository at this point in the history
  • Loading branch information
jkimbo authored and patrick91 committed Mar 12, 2021
1 parent 333f18c commit c440110
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 26 deletions.
4 changes: 3 additions & 1 deletion strawberry/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def field(
permission_classes: Optional[List[Type[BasePermission]]] = None,
federation: Optional[FederationFieldParams] = None,
deprecation_reason: Optional[str] = None,
) -> StrawberryField:
decorators: Optional[List[Callable]] = None,
):
"""Annotates a method or property as a GraphQL field.
This is normally used inside a type declaration:
Expand All @@ -94,6 +95,7 @@ def field(
arguments=[], # modified by resolver in __call__
federation=federation or FederationFieldParams(),
deprecation_reason=deprecation_reason,
decorators=decorators,
)

field_ = StrawberryField(field_definition)
Expand Down
28 changes: 5 additions & 23 deletions strawberry/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import enum
import functools
from inspect import iscoroutine
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union, cast

Expand All @@ -10,6 +9,10 @@
from .types.fields.resolver import StrawberryResolver


def default_field_resolver(field_name: str, source, info):
return getattr(source, field_name)


def is_default_resolver(func: Callable) -> bool:
"""Check whether the function is a default resolver or a user provided one."""
return getattr(func, "_is_default", False)
Expand Down Expand Up @@ -80,25 +83,7 @@ def get_result_for_field(
return getattr(source, origin_name)


def get_result_for_field(
field: FieldDefinition, kwargs: Dict[str, Any], source: Any, info: Any
) -> Union[Awaitable[Any], Any]:
"""
Calls the resolver defined for `field`. If field doesn't have a
resolver defined we default to using getattr on `source`.
"""

actual_resolver = field.base_resolver

if actual_resolver:
args, kwargs = get_arguments(field, kwargs, source=source, info=info)

return actual_resolver(*args, **kwargs)

origin_name = cast(str, field.origin_name)
return getattr(source, origin_name)


# TODO: fix and use this
def run_decorators(result: Any, field: FieldDefinition) -> Any:
if field.decorators:
result = "TODO"
Expand Down Expand Up @@ -132,16 +117,13 @@ async def _resolver_async(source, info: Info, **kwargs):

result = convert_enums_to_values(field, result)

result = run_decorators()

return result

def _resolver(source, info, **kwargs):
_check_permissions(source, info, **kwargs)

result = get_result_for_field(field, kwargs=kwargs, info=info, source=source)
result = convert_enums_to_values(field, result)
result = run_decorators()

return result

Expand Down
1 change: 1 addition & 0 deletions strawberry/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,4 @@ class FieldDefinition:
)
default_value: Any = undefined
deprecation_reason: Optional[str] = None
decorators: Optional[List[Callable]] = None
36 changes: 34 additions & 2 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def upper_case(resolver, source, info, **kwargs):
class Query:
@strawberry.field
@upper_case
def greeting(name: str) -> str:
def greeting(self, name: str) -> str:
return f"hi {name}"

schema = strawberry.Schema(query=Query)
Expand All @@ -107,11 +107,43 @@ def wrapped(resolver, source, info, **kwargs):
class Query:
@strawberry.field
@title_case_argument("name")
def greeting(name: str) -> str:
def greeting(self, name: str) -> str:
return f"hi {name}"

schema = strawberry.Schema(query=Query)
result = schema.execute_sync('query { greeting(name: "patrick") }')

assert not result.errors
assert result.data == {"greeting": "hi Patrick"}


def test_decorator_simple_field():
@make_strawberry_decorator
def upper_case(resolver, source, info, **kwargs):
result = resolver(**kwargs)
return result.upper()

def suffix(value):
@make_strawberry_decorator
def wrapper(resolver, source, info, **kwargs):
result = resolver(**kwargs)
return f"{result}{value}"

return wrapper

@strawberry.type
class Query:
name: str = strawberry.field(decorators=[upper_case, suffix(" 👋")])

schema = strawberry.Schema(query=Query)
result = schema.execute_sync(
"""
query {
name
}
""",
root_value=Query(name="patrick"),
)

assert not result.errors
assert result.data == {"name": "PATRICK 👋"}

0 comments on commit c440110

Please sign in to comment.