37
37
)
38
38
from strawberry .types .arguments import StrawberryArgument , argument
39
39
from strawberry .types .base import StrawberryList , StrawberryOptional
40
+ from strawberry .types .cast import cast as strawberry_cast
40
41
from strawberry .types .field import _RESOLVER_TYPE , StrawberryField , field
41
42
from strawberry .types .fields .resolver import StrawberryResolver
42
43
from strawberry .types .lazy_type import LazyType
@@ -88,12 +89,27 @@ def resolver(
88
89
info : Info ,
89
90
id : Annotated [GlobalID , argument (description = "The ID of the object." )],
90
91
) -> Union [Node , None , Awaitable [Union [Node , None ]]]:
91
- return id .resolve_type (info ).resolve_node (
92
+ node_type = id .resolve_type (info )
93
+ resolved_node = node_type .resolve_node (
92
94
id .node_id ,
93
95
info = info ,
94
96
required = not is_optional ,
95
97
)
96
98
99
+ # We are using `strawberry_cast` here to cast the resolved node to make
100
+ # sure `is_type_of` will not try to find its type again. Very important
101
+ # when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
102
+ # we could end up resolving to a different type in case more than one
103
+ # are registered.
104
+ if inspect .isawaitable (resolved_node ):
105
+
106
+ async def resolve () -> Any :
107
+ return strawberry_cast (node_type , await resolved_node )
108
+
109
+ return resolve ()
110
+
111
+ return cast (Node , strawberry_cast (node_type , resolved_node ))
112
+
97
113
return resolver
98
114
99
115
def get_node_list_resolver (
@@ -139,6 +155,14 @@ def resolver(
139
155
if inspect .isasyncgen (nodes )
140
156
}
141
157
158
+ # We are using `strawberry_cast` here to cast the resolved node to make
159
+ # sure `is_type_of` will not try to find its type again. Very important
160
+ # when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
161
+ # we could end up resolving to a different type in case more than one
162
+ # are registered
163
+ def cast_nodes (node_t : type [Node ], nodes : Iterable [Any ]) -> list [Node ]:
164
+ return [cast (Node , strawberry_cast (node_t , node )) for node in nodes ]
165
+
142
166
if awaitable_nodes or asyncgen_nodes :
143
167
144
168
async def resolve (resolved : Any = resolved_nodes ) -> list [Node ]:
@@ -161,7 +185,8 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
161
185
162
186
# Resolve any generator to lists
163
187
resolved = {
164
- node_t : list (nodes ) for node_t , nodes in resolved .items ()
188
+ node_t : cast_nodes (node_t , nodes )
189
+ for node_t , nodes in resolved .items ()
165
190
}
166
191
return [
167
192
resolved [index_map [gid ][0 ]][index_map [gid ][1 ]] for gid in ids
@@ -171,7 +196,7 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
171
196
172
197
# Resolve any generator to lists
173
198
resolved = {
174
- node_t : list ( cast (Iterator [Node ], nodes ))
199
+ node_t : cast_nodes ( node_t , cast (Iterable [Node ], nodes ))
175
200
for node_t , nodes in resolved_nodes .items ()
176
201
}
177
202
return [resolved [index_map [gid ][0 ]][index_map [gid ][1 ]] for gid in ids ]
0 commit comments