3737)
3838from strawberry .types .arguments import StrawberryArgument , argument
3939from strawberry .types .base import StrawberryList , StrawberryOptional
40+ from strawberry .types .cast import cast as strawberry_cast
4041from strawberry .types .field import _RESOLVER_TYPE , StrawberryField , field
4142from strawberry .types .fields .resolver import StrawberryResolver
4243from strawberry .types .lazy_type import LazyType
@@ -88,12 +89,27 @@ def resolver(
8889 info : Info ,
8990 id : Annotated [GlobalID , argument (description = "The ID of the object." )],
9091 ) -> Union [Node , None , Awaitable [Union [Node , None ]]]:
91- return id .resolve_type (info ).resolve_node (
92+ node_type = id .resolve_type (info )
93+ resolved_node = node_type .resolve_node (
9294 id .node_id ,
9395 info = info ,
9496 required = not is_optional ,
9597 )
9698
99+ # We are using `strawberry_cast` here to cast the resolved node to make
100+ # sure `is_type_of` will not try to find its type again. Very important
101+ # when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
102+ # we could end up resolving to a different type in case more than one
103+ # are registered.
104+ if inspect .isawaitable (resolved_node ):
105+
106+ async def resolve () -> Any :
107+ return strawberry_cast (node_type , await resolved_node )
108+
109+ return resolve ()
110+
111+ return cast (Node , strawberry_cast (node_type , resolved_node ))
112+
97113 return resolver
98114
99115 def get_node_list_resolver (
@@ -139,6 +155,14 @@ def resolver(
139155 if inspect .isasyncgen (nodes )
140156 }
141157
158+ # We are using `strawberry_cast` here to cast the resolved node to make
159+ # sure `is_type_of` will not try to find its type again. Very important
160+ # when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
161+ # we could end up resolving to a different type in case more than one
162+ # are registered
163+ def cast_nodes (node_t : type [Node ], nodes : Iterable [Any ]) -> list [Node ]:
164+ return [cast (Node , strawberry_cast (node_t , node )) for node in nodes ]
165+
142166 if awaitable_nodes or asyncgen_nodes :
143167
144168 async def resolve (resolved : Any = resolved_nodes ) -> list [Node ]:
@@ -161,7 +185,8 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
161185
162186 # Resolve any generator to lists
163187 resolved = {
164- node_t : list (nodes ) for node_t , nodes in resolved .items ()
188+ node_t : cast_nodes (node_t , nodes )
189+ for node_t , nodes in resolved .items ()
165190 }
166191 return [
167192 resolved [index_map [gid ][0 ]][index_map [gid ][1 ]] for gid in ids
@@ -171,7 +196,7 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
171196
172197 # Resolve any generator to lists
173198 resolved = {
174- node_t : list ( cast (Iterator [Node ], nodes ))
199+ node_t : cast_nodes ( node_t , cast (Iterable [Node ], nodes ))
175200 for node_t , nodes in resolved_nodes .items ()
176201 }
177202 return [resolved [index_map [gid ][0 ]][index_map [gid ][1 ]] for gid in ids ]
0 commit comments