Skip to content

Commit

Permalink
Remove decorator function
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick91 committed Mar 15, 2021
1 parent 25f7419 commit eb6f014
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 141 deletions.
47 changes: 0 additions & 47 deletions strawberry/decorator.py

This file was deleted.

4 changes: 1 addition & 3 deletions strawberry/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def field(
permission_classes: Optional[List[Type[BasePermission]]] = None,
federation: Optional[FederationFieldParams] = None,
deprecation_reason: Optional[str] = None,
decorators: Optional[List[Callable]] = None,
):
) -> StrawberryField:
"""Annotates a method or property as a GraphQL field.
This is normally used inside a type declaration:
Expand All @@ -95,7 +94,6 @@ def field(
arguments=[], # modified by resolver in __call__
federation=federation or FederationFieldParams(),
deprecation_reason=deprecation_reason,
decorators=decorators,
)

field_ = StrawberryField(field_definition)
Expand Down
9 changes: 2 additions & 7 deletions strawberry/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
from .types.fields.resolver import StrawberryResolver


def default_field_resolver(field_name: str, source, info):
return getattr(source, field_name)


def is_default_resolver(func: Callable) -> bool:
"""Check whether the function is a default resolver or a user provided one."""
return getattr(func, "_is_default", False)
Expand Down Expand Up @@ -41,7 +37,6 @@ def get_arguments(
field: FieldDefinition, kwargs: Dict[str, Any], source: Any, info: Any
) -> Tuple[List[Any], Dict[str, Any]]:
actual_resolver = cast(StrawberryResolver, field.base_resolver)
is_decorator = getattr(actual_resolver.wrapped_func, "_strawberry_decorator", False)

kwargs = convert_arguments(kwargs, field.arguments)

Expand All @@ -56,10 +51,10 @@ def get_arguments(
if actual_resolver.has_self_arg:
args.append(source)

if actual_resolver.has_root_arg or is_decorator:
if actual_resolver.has_root_arg:
kwargs["root"] = source

if actual_resolver.has_info_arg or is_decorator:
if actual_resolver.has_info_arg:
kwargs["info"] = info

return args, kwargs
Expand Down
33 changes: 2 additions & 31 deletions strawberry/types/fields/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,9 @@


class StrawberryResolver(Generic[T]):
def __init__(
self,
func: Callable[..., T],
*,
description: Optional[str] = None,
decorators: List[Callable[..., T]] = None
):
def __init__(self, func: Callable[..., T], *, description: Optional[str] = None):
self.wrapped_func = func
self._description = description
self.decorators = decorators or []

if isinstance(func, StrawberryResolver):
self.wrapped_func = func.wrapped_func
self.decorators = func.decorators + self.decorators

else:
# self.wrapped_func = func

def yad(decorators):
def wrap(resolver):
def decorator(root=resolver):
for d in reversed(decorators):
f = d(root)
return f

return decorator

return wrap

# if func() == "hi":
# breakpoint()

self.wrapped_func = yad(self.decorators)(func)

# TODO: Use this when doing the actual resolving? How to deal with async resolvers?
def __call__(self, *args, **kwargs) -> T:
Expand Down
1 change: 0 additions & 1 deletion strawberry/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,3 @@ class FieldDefinition:
)
default_value: Any = undefined
deprecation_reason: Optional[str] = None
decorators: Optional[List[Callable]] = None
115 changes: 63 additions & 52 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from strawberry.types.fields.resolver import StrawberryResolver
import functools
from functools import wraps

import strawberry
from strawberry.decorator import make_strawberry_decorator
from strawberry.types import Info


def test_basic_decorator():
@make_strawberry_decorator
def upper_case(resolver, **kwargs):
result = resolver(**kwargs)
return result.upper()
def upper_case(resolver):
@wraps(resolver)
def wrapped(*args, **kwargs):
return resolver(*args, **kwargs).upper()

return wrapped

@strawberry.type
class Query:
@strawberry.field
@upper_case
def greeting() -> str:
def greeting(self) -> str:
return "hi"

schema = strawberry.Schema(query=Query)
Expand All @@ -25,12 +29,16 @@ def greeting() -> str:

def test_decorator_with_arguments():
def suffix(value):
@make_strawberry_decorator
def wrapper(resolver, **kwargs):
result = resolver(**kwargs)
return f"{result}{value}"
def decorator(resolver):
@wraps(resolver)
def wrapper(*args, **kwargs):
result = resolver(*args, **kwargs)

return f"{result}{value}"

return wrapper
return wrapper

return decorator

@strawberry.type
class Query:
Expand All @@ -47,18 +55,25 @@ def greeting() -> str:


def test_multiple_decorators():
@make_strawberry_decorator
def upper_case(resolver, **kwargs):
result = resolver(**kwargs)
return result.upper()
def upper_case(resolver):
@functools.wraps(resolver)
def wrap(*args, **kwargs):
result = resolver(*args, **kwargs)
return result.upper()

return wrap

def suffix(value):
@make_strawberry_decorator
def wrapper(resolver, **kwargs):
result = resolver(**kwargs)
return f"{result}{value}"
def decorator(resolver):
@wraps(resolver)
def wrapper(*args, **kwargs):
result = resolver(*args, **kwargs)

return f"{result}{value}"

return wrapper

return wrapper
return decorator

@strawberry.type
class Query:
Expand All @@ -76,10 +91,13 @@ def greeting() -> str:


def test_decorator_with_graphql_arguments():
@make_strawberry_decorator
def upper_case(resolver, **kwargs):
result = resolver(**kwargs)
return result.upper()
def upper_case(resolver):
@functools.wraps(resolver)
def wrap(*args, **kwargs):
result = resolver(*args, **kwargs)
return result.upper()

return wrap

@strawberry.type
class Query:
Expand All @@ -97,12 +115,15 @@ def greeting(self, name: str) -> str:

def test_decorator_modify_argument():
def title_case_argument(argument_name):
@make_strawberry_decorator
def wrapped(resolver, **kwargs):
kwargs[argument_name] = kwargs[argument_name].title()
return resolver(**kwargs)
def decorator(resolver):
@functools.wraps(resolver)
def wrapped(*args, **kwargs):
kwargs[argument_name] = kwargs[argument_name].title()
return resolver(*args, **kwargs)

return wrapped
return wrapped

return decorator

@strawberry.type
class Query:
Expand All @@ -118,33 +139,23 @@ def greeting(self, name: str) -> str:
assert result.data == {"greeting": "hi Patrick"}


def test_decorator_simple_field():
@make_strawberry_decorator
def upper_case(resolver, **kwargs):
result = resolver(**kwargs)
return result.upper()

def suffix(value):
@make_strawberry_decorator
def wrapper(resolver, **kwargs):
result = resolver(**kwargs)
return f"{result}{value}"
def test_decorator_with_info():
def upper_case(resolver):
@wraps(resolver)
def wrapped(*args, **kwargs):
return resolver(*args, **kwargs).upper()

return wrapper
return wrapped

@strawberry.type
class Query:
name: str = strawberry.field(decorators=[upper_case, suffix(" 👋")])
@strawberry.field
@upper_case
def greeting(self, info: Info) -> str:
return str(info.context)

schema = strawberry.Schema(query=Query)
result = schema.execute_sync(
"""
query {
name
}
""",
root_value=Query(name="patrick"),
)
result = schema.execute_sync("query { greeting }")

assert not result.errors
assert result.data == {"name": "PATRICK 👋"}
assert result.data == {"greeting": "NONE"}

0 comments on commit eb6f014

Please sign in to comment.