Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement explicit interface type resolution #1406

Merged
9 changes: 9 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Release type: minor

This releases fixes an issue where you were not allowed
to return a non-strawberry type for fields that return
an interface. Now this works as long as each type
implementing the interface implements an `is_type_of`
classmethod. Previous automatic duck typing on types
that implement an interface now requires explicit
resolution using this classmethod.
9 changes: 9 additions & 0 deletions strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pydantic import BaseModel
from pydantic.fields import ModelField

from graphql import GraphQLResolveInfo

import strawberry
from strawberry.arguments import UNSET
from strawberry.experimental.pydantic.conversion import (
Expand Down Expand Up @@ -147,10 +149,17 @@ def wrap(cls):

sorted_fields = missing_default + has_default

# Implicitly define `is_type_of` to support interfaces/unions that use
# pydantic objects (not the corresponding strawberry type)
@classmethod # type: ignore
patrick91 marked this conversation as resolved.
Show resolved Hide resolved
def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool:
return isinstance(obj, (cls, model))

cls = dataclasses.make_dataclass(
cls.__name__,
sorted_fields,
bases=cls.__bases__,
namespace={"is_type_of": is_type_of},
)

_process_type(
Expand Down
2 changes: 2 additions & 0 deletions strawberry/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _process_type(

interfaces = _get_interfaces(cls)
fields = _get_fields(cls)
is_type_of = getattr(cls, "is_type_of", None)

cls._type_definition = TypeDefinition(
name=name,
Expand All @@ -117,6 +118,7 @@ def _process_type(
origin=cls,
extend=extend,
_fields=fields,
is_type_of=is_type_of,
patrick91 marked this conversation as resolved.
Show resolved Hide resolved
)

# dataclasses removes attributes from the class here:
Expand Down
30 changes: 13 additions & 17 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Callable, Dict, List, Tuple, Type, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

from graphql import (
GraphQLArgument,
Expand Down Expand Up @@ -229,27 +229,11 @@ def get_graphql_fields() -> Dict[str, GraphQLField]:

return graphql_fields

def resolve_type(
obj: Any,
info: GraphQLResolveInfo,
type_: Union[GraphQLInterfaceType, GraphQLUnionType],
) -> GraphQLObjectType:
# TODO: this will probably break when passing dicts
# or even non strawberry types
resolved_type = self.type_map[
obj.__class__._type_definition.name
].implementation

assert isinstance(resolved_type, GraphQLObjectType)

return resolved_type

graphql_interface = GraphQLInterfaceType(
name=interface_name,
fields=get_graphql_fields,
interfaces=list(map(self.from_interface, interface.interfaces)),
description=interface.description,
resolve_type=resolve_type,
)

self.type_map[interface_name] = ConcreteType(
Expand Down Expand Up @@ -295,11 +279,23 @@ def get_graphql_fields() -> Dict[str, GraphQLField]:

return graphql_fields

is_type_of: Optional[Callable[[Any, GraphQLResolveInfo], bool]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GraphQLResolveInfo <- we should use StrawberryInfo, not sure we can do this easily here though, maybe we can do it after we do this issue: #1425

if object_type.is_type_of:
is_type_of = object_type.is_type_of
elif object_type.interfaces:

AlecRosenbaum marked this conversation as resolved.
Show resolved Hide resolved
def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool:
return isinstance(obj, object_type.origin)

AlecRosenbaum marked this conversation as resolved.
Show resolved Hide resolved
else:
is_type_of = None

graphql_object_type = GraphQLObjectType(
name=object_type_name,
fields=get_graphql_fields,
interfaces=list(map(self.from_interface, object_type.interfaces)),
description=object_type.description,
is_type_of=is_type_of,
)

self.type_map[object_type_name] = ConcreteType(
Expand Down
6 changes: 6 additions & 0 deletions strawberry/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import dataclasses
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Mapping,
Optional,
Expand All @@ -17,6 +19,8 @@


if TYPE_CHECKING:
from graphql import GraphQLResolveInfo

from strawberry.field import StrawberryField
from strawberry.schema_directive import StrawberrySchemaDirective

Expand All @@ -31,6 +35,7 @@ class TypeDefinition(StrawberryType):
interfaces: List["TypeDefinition"]
extend: bool
directives: Optional[Sequence[StrawberrySchemaDirective]]
is_type_of: Optional[Callable[[Any, GraphQLResolveInfo], bool]]

_fields: List["StrawberryField"]

Expand Down Expand Up @@ -85,6 +90,7 @@ def copy_with(
interfaces=self.interfaces,
description=self.description,
extend=self.extend,
is_type_of=self.is_type_of,
_fields=fields,
concrete_of=self,
type_var_map=type_var_map,
Expand Down
11 changes: 9 additions & 2 deletions strawberry/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,16 @@ def _resolve_union_type(

from strawberry.types.types import TypeDefinition

# Make sure that the type that's passed in is an Object type
# If the type given is not an Object type, try resolving using `is_type_of`
# defined on the union's inner types
if not hasattr(root, "_type_definition"):
# TODO: If root=python dict, this won't work
for inner_type in type_.types:
if inner_type.is_type_of is not None and inner_type.is_type_of(
root, info
):
return inner_type.name

# Couldn't resolve using `is_type_of``
raise WrongReturnTypeForUnion(info.field_name, str(type(root)))

return_type: Optional[GraphQLType]
Expand Down
39 changes: 39 additions & 0 deletions tests/experimental/pydantic/schema/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,45 @@ def user(self) -> UserType:
assert result.data["user"]["unionField"]["fieldB"] == 10


def test_basic_type_with_union_pydantic_types():
class BranchA(pydantic.BaseModel):
field_a: str

class BranchB(pydantic.BaseModel):
field_b: int

class User(pydantic.BaseModel):
union_field: Union[BranchA, BranchB]

@strawberry.experimental.pydantic.type(BranchA, fields=["field_a"])
class BranchAType:
pass

@strawberry.experimental.pydantic.type(BranchB, fields=["field_b"])
class BranchBType:
pass

@strawberry.experimental.pydantic.type(User, fields=["age", "union_field"])
class UserType:
pass

@strawberry.type
class Query:
@strawberry.field
def user(self) -> UserType:
# note that BranchB is a pydantic type, not a strawberry type
return UserType(union_field=BranchB(field_b=10))

schema = strawberry.Schema(query=Query)

query = "{ user { unionField { ... on BranchBType { fieldB } } } }"

result = schema.execute_sync(query)

assert not result.errors
assert result.data["user"]["unionField"]["fieldB"] == 10


def test_basic_type_with_enum():
@strawberry.enum
class UserKind(Enum):
Expand Down
37 changes: 37 additions & 0 deletions tests/schema/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ class Entity:
class Anime(Entity):
name: str

@classmethod
def is_type_of(cls, obj, _info) -> bool:
return isinstance(obj, AnimeORM)
patrick91 marked this conversation as resolved.
Show resolved Hide resolved

@dataclass
class AnimeORM:
id: int
Expand All @@ -130,6 +134,39 @@ def anime(self) -> Anime:
assert result.data == {"anime": {"name": "One Piece"}}


def test_interface_explicit_type_resolution():
@dataclass
class AnimeORM:
id: int
name: str

@strawberry.interface
class Node:
id: int

@strawberry.type
class Anime(Node):
name: str

@classmethod
def is_type_of(cls, obj, _info) -> bool:
return isinstance(obj, AnimeORM)

@strawberry.type
class Query:
@strawberry.field
def node(self) -> Node:
return AnimeORM(id=1, name="One Piece") # type: ignore

schema = strawberry.Schema(query=Query, types=[Anime])

query = "{ node { __typename, id } }"
result = schema.execute_sync(query)

assert not result.errors
assert result.data == {"node": {"__typename": "Anime", "id": 1}}


@pytest.mark.xfail(reason="We don't support returning dictionaries yet")
def test_interface_duck_typing_returning_dict():
@strawberry.interface
Expand Down
35 changes: 35 additions & 0 deletions tests/schema/test_union.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from textwrap import dedent
from typing import Optional, Union

Expand Down Expand Up @@ -380,3 +381,37 @@ class Query:
field2: MyUnion!
}"""
)


def test_union_explicit_type_resolution():
@dataclass
class ADataclass:
a: int

@strawberry.type
class A:
a: int

@classmethod
def is_type_of(cls, obj, _info) -> bool:
return isinstance(obj, ADataclass)

@strawberry.type
class B:
b: int

MyUnion = strawberry.union("MyUnion", types=(A, B))

@strawberry.type
class Query:
@strawberry.field
def my_field(self) -> MyUnion:
return ADataclass(a=1) # type: ignore

schema = strawberry.Schema(query=Query)

query = "{ myField { __typename, ... on A { a }, ... on B { b } } }"
result = schema.execute_sync(query)

assert not result.errors
assert result.data == {"myField": {"__typename": "A", "a": 1}}