Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick91 committed Mar 12, 2021
1 parent c440110 commit fed79be
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 38 deletions.
52 changes: 33 additions & 19 deletions strawberry/decorator.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,47 @@
import functools
from strawberry.types.fields.resolver import StrawberryResolver

from .resolvers import get_resolver_arguments
from .resolvers import get_arguments
from .utils.inspect import get_func_args


def make_strawberry_decorator(func):
def decorator(resolver):
if hasattr(resolver, "_field_definition"):
raise Exception("Can't apply decorator after strawberry.field") # TODO
##################################################################
### aaaaaaaaaaaaaaaaaaaa
return StrawberryResolver(resolver, decorators=[func])

function_args = get_func_args(resolver)
return decorator

@functools.wraps(resolver)
def wrapped_resolver(source, info, **kwargs):
def wrapped(**kwargs):
# If resolver is another strawberry decorator then pass all the
# arguments to it
if getattr(resolver, "_strawberry_decorator", False):
return resolver(source, info, **kwargs)
...

args, extra_kwargs = get_resolver_arguments(function_args, source, info)
return resolver(*args, **extra_kwargs, **kwargs)
# def decorator(resolver):
# if hasattr(resolver, "_field_definition"):
# raise Exception("Can't apply decorator after strawberry.field") # TODO

# Call the decorator body with the original set of kwargs so that it
# has the opportunity to modify them
return func(wrapped, source, info=info, **kwargs)
# function_args = get_func_args(resolver)

wrapped_resolver._strawberry_decorator = True
# @functools.wraps(resolver)
# def wrapped_resolver(root, info, **kwargs):
# def wrapped(**kwargs):
# # If resolver is another strawberry decorator then pass all the
# # arguments to it
# if getattr(resolver, "_strawberry_decorator", False):
# return resolver(root, info, **kwargs)

return wrapped_resolver
# breakpoint()

return decorator
# args, kwargs = get_arguments(
# function_args, kwargs=kwargs, source=root, info=info
# )
# return resolver(*args, **kwargs)

# # Call the decorator body with the original set of kwargs so that it
# # has the opportunity to modify them
# return func(wrapped, root, info=info, **kwargs)

# wrapped_resolver._strawberry_decorator = True

# return wrapped_resolver

# return decorator
12 changes: 3 additions & 9 deletions strawberry/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_arguments(
field: FieldDefinition, kwargs: Dict[str, Any], source: Any, info: Any
) -> Tuple[List[Any], Dict[str, Any]]:
actual_resolver = cast(StrawberryResolver, field.base_resolver)
is_decorator = getattr(actual_resolver.wrapped_func, "_strawberry_decorator", False)

kwargs = convert_arguments(kwargs, field.arguments)

Expand All @@ -55,10 +56,10 @@ def get_arguments(
if actual_resolver.has_self_arg:
args.append(source)

if actual_resolver.has_root_arg:
if actual_resolver.has_root_arg or is_decorator:
kwargs["root"] = source

if actual_resolver.has_info_arg:
if actual_resolver.has_info_arg or is_decorator:
kwargs["info"] = info

return args, kwargs
Expand All @@ -83,13 +84,6 @@ def get_result_for_field(
return getattr(source, origin_name)


# TODO: fix and use this
def run_decorators(result: Any, field: FieldDefinition) -> Any:
if field.decorators:
result = "TODO"

return result


def get_resolver(field: FieldDefinition) -> Callable:
# TODO: make sure that info is of type Info, currently it
Expand Down
33 changes: 31 additions & 2 deletions strawberry/types/fields/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,38 @@


class StrawberryResolver(Generic[T]):
def __init__(self, func: Callable[..., T], *, description: Optional[str] = None):
self.wrapped_func = func
def __init__(
self,
func: Callable[..., T],
*,
description: Optional[str] = None,
decorators: List[Callable[..., T]] = None
):
self._description = description
self.decorators = decorators or []

if isinstance(func, StrawberryResolver):
self.wrapped_func = func.wrapped_func
self.decorators = func.decorators + self.decorators

else:
# self.wrapped_func = func

def yad(decorators):
def wrap(resolver):
def decorator(root=resolver):
for d in reversed(decorators):
f = d(root)
return f

return decorator

return wrap

# if func() == "hi":
# breakpoint()

self.wrapped_func = yad(self.decorators)(func)

# TODO: Use this when doing the actual resolving? How to deal with async resolvers?
def __call__(self, *args, **kwargs) -> T:
Expand Down
17 changes: 9 additions & 8 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from strawberry.types.fields.resolver import StrawberryResolver
import strawberry
from strawberry.decorator import make_strawberry_decorator


def test_basic_decorator():
@make_strawberry_decorator
def upper_case(resolver, source, info, **kwargs):
def upper_case(resolver, **kwargs):
result = resolver(**kwargs)
return result.upper()

Expand All @@ -25,7 +26,7 @@ def greeting() -> str:
def test_decorator_with_arguments():
def suffix(value):
@make_strawberry_decorator
def wrapper(resolver, source, info, **kwargs):
def wrapper(resolver, **kwargs):
result = resolver(**kwargs)
return f"{result}{value}"

Expand All @@ -47,13 +48,13 @@ def greeting() -> str:

def test_multiple_decorators():
@make_strawberry_decorator
def upper_case(resolver, source, info, **kwargs):
def upper_case(resolver, **kwargs):
result = resolver(**kwargs)
return result.upper()

def suffix(value):
@make_strawberry_decorator
def wrapper(resolver, source, info, **kwargs):
def wrapper(resolver, **kwargs):
result = resolver(**kwargs)
return f"{result}{value}"

Expand All @@ -76,7 +77,7 @@ def greeting() -> str:

def test_decorator_with_graphql_arguments():
@make_strawberry_decorator
def upper_case(resolver, source, info, **kwargs):
def upper_case(resolver, **kwargs):
result = resolver(**kwargs)
return result.upper()

Expand All @@ -97,7 +98,7 @@ def greeting(self, name: str) -> str:
def test_decorator_modify_argument():
def title_case_argument(argument_name):
@make_strawberry_decorator
def wrapped(resolver, source, info, **kwargs):
def wrapped(resolver, **kwargs):
kwargs[argument_name] = kwargs[argument_name].title()
return resolver(**kwargs)

Expand All @@ -119,13 +120,13 @@ def greeting(self, name: str) -> str:

def test_decorator_simple_field():
@make_strawberry_decorator
def upper_case(resolver, source, info, **kwargs):
def upper_case(resolver, **kwargs):
result = resolver(**kwargs)
return result.upper()

def suffix(value):
@make_strawberry_decorator
def wrapper(resolver, source, info, **kwargs):
def wrapper(resolver, **kwargs):
result = resolver(**kwargs)
return f"{result}{value}"

Expand Down

0 comments on commit fed79be

Please sign in to comment.