Skip to content

Commit 526eb82

Browse files
authored
fix: Prevent a possible security issue when resolving a relay node with multiple possibilities (#3749)
1 parent fc854f1 commit 526eb82

File tree

8 files changed

+203
-3
lines changed

8 files changed

+203
-3
lines changed

RELEASE.md

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
Release type: minor
2+
3+
The common `node: Node` used to resolve relay nodes means we will be relying on
4+
is_type_of to check if the returned object is in fact a subclass of the Node
5+
interface.
6+
7+
However, integrations such as Django, SQLAlchemy and Pydantic will not return
8+
the type itself, but instead an alike object that is later resolved to the
9+
expected type.
10+
11+
In case there are more than one possible type defined for that model that is
12+
being returned, the first one that replies True to `is_type_of` check would be
13+
used in the resolution, meaning that when asking for `"PublicUser:123"`,
14+
strawberry could end up returning `"User:123"`, which can lead to security
15+
issues (such as data leakage).
16+
17+
In here we are introducing a new `strawberry.cast`, which will be used to mark
18+
an object with the already known type by us, and when asking for is_type_of that
19+
mark will be used to check instead, ensuring we will return the correct type.
20+
21+
That `cast` is already in place for the relay node resolution and pydantic.

strawberry/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .schema_directive import schema_directive
1414
from .types.arguments import argument
1515
from .types.auto import auto
16+
from .types.cast import cast
1617
from .types.enum import enum, enum_value
1718
from .types.field import field
1819
from .types.info import Info
@@ -36,6 +37,7 @@
3637
"argument",
3738
"asdict",
3839
"auto",
40+
"cast",
3941
"directive",
4042
"directive_field",
4143
"enum",

strawberry/experimental/pydantic/object_type.py

+4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_private_fields,
3030
)
3131
from strawberry.types.auto import StrawberryAuto
32+
from strawberry.types.cast import get_strawberry_type_cast
3233
from strawberry.types.field import StrawberryField
3334
from strawberry.types.object_type import _process_type, _wrap_dataclass
3435
from strawberry.types.type_resolver import _get_fields
@@ -207,6 +208,9 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
207208
# pydantic objects (not the corresponding strawberry type)
208209
@classmethod # type: ignore
209210
def is_type_of(cls: builtins.type, obj: Any, _info: GraphQLResolveInfo) -> bool:
211+
if (type_cast := get_strawberry_type_cast(obj)) is not None:
212+
return type_cast is cls
213+
210214
return isinstance(obj, (cls, model))
211215

212216
namespace = {"is_type_of": is_type_of}

strawberry/relay/fields.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from strawberry.types.arguments import StrawberryArgument, argument
3939
from strawberry.types.base import StrawberryList, StrawberryOptional
40+
from strawberry.types.cast import cast as strawberry_cast
4041
from strawberry.types.field import _RESOLVER_TYPE, StrawberryField, field
4142
from strawberry.types.fields.resolver import StrawberryResolver
4243
from strawberry.types.lazy_type import LazyType
@@ -88,12 +89,27 @@ def resolver(
8889
info: Info,
8990
id: Annotated[GlobalID, argument(description="The ID of the object.")],
9091
) -> Union[Node, None, Awaitable[Union[Node, None]]]:
91-
return id.resolve_type(info).resolve_node(
92+
node_type = id.resolve_type(info)
93+
resolved_node = node_type.resolve_node(
9294
id.node_id,
9395
info=info,
9496
required=not is_optional,
9597
)
9698

99+
# We are using `strawberry_cast` here to cast the resolved node to make
100+
# sure `is_type_of` will not try to find its type again. Very important
101+
# when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
102+
# we could end up resolving to a different type in case more than one
103+
# are registered.
104+
if inspect.isawaitable(resolved_node):
105+
106+
async def resolve() -> Any:
107+
return strawberry_cast(node_type, await resolved_node)
108+
109+
return resolve()
110+
111+
return cast(Node, strawberry_cast(node_type, resolved_node))
112+
97113
return resolver
98114

99115
def get_node_list_resolver(
@@ -139,6 +155,14 @@ def resolver(
139155
if inspect.isasyncgen(nodes)
140156
}
141157

158+
# We are using `strawberry_cast` here to cast the resolved node to make
159+
# sure `is_type_of` will not try to find its type again. Very important
160+
# when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
161+
# we could end up resolving to a different type in case more than one
162+
# are registered
163+
def cast_nodes(node_t: type[Node], nodes: Iterable[Any]) -> list[Node]:
164+
return [cast(Node, strawberry_cast(node_t, node)) for node in nodes]
165+
142166
if awaitable_nodes or asyncgen_nodes:
143167

144168
async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
@@ -161,7 +185,8 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
161185

162186
# Resolve any generator to lists
163187
resolved = {
164-
node_t: list(nodes) for node_t, nodes in resolved.items()
188+
node_t: cast_nodes(node_t, nodes)
189+
for node_t, nodes in resolved.items()
165190
}
166191
return [
167192
resolved[index_map[gid][0]][index_map[gid][1]] for gid in ids
@@ -171,7 +196,7 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
171196

172197
# Resolve any generator to lists
173198
resolved = {
174-
node_t: list(cast(Iterator[Node], nodes))
199+
node_t: cast_nodes(node_t, cast(Iterable[Node], nodes))
175200
for node_t, nodes in resolved_nodes.items()
176201
}
177202
return [resolved[index_map[gid][0]][index_map[gid][1]] for gid in ids]

strawberry/schema/schema_converter.py

+7
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
get_object_definition,
5959
has_object_definition,
6060
)
61+
from strawberry.types.cast import get_strawberry_type_cast
6162
from strawberry.types.enum import EnumDefinition
6263
from strawberry.types.field import UNRESOLVED
6364
from strawberry.types.lazy_type import LazyType
@@ -619,6 +620,9 @@ def _get_is_type_of() -> Optional[Callable[[Any, GraphQLResolveInfo], bool]]:
619620
)
620621

621622
def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool:
623+
if (type_cast := get_strawberry_type_cast(obj)) is not None:
624+
return type_cast in possible_types
625+
622626
if object_type.concrete_of and (
623627
has_object_definition(obj)
624628
and obj.__strawberry_definition__.origin
@@ -898,6 +902,9 @@ def _get_is_type_of(
898902
if object_type.interfaces:
899903

900904
def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool:
905+
if (type_cast := get_strawberry_type_cast(obj)) is not None:
906+
return type_cast is object_type.origin
907+
901908
if object_type.concrete_of and (
902909
has_object_definition(obj)
903910
and obj.__strawberry_definition__.origin

strawberry/types/cast.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, TypeVar, overload
4+
5+
_T = TypeVar("_T", bound=object)
6+
7+
TYPE_CAST_ATTRIBUTE = "__as_strawberry_type__"
8+
9+
10+
@overload
11+
def cast(type_: type, obj: None) -> None: ...
12+
13+
14+
@overload
15+
def cast(type_: type, obj: _T) -> _T: ...
16+
17+
18+
def cast(type_: type, obj: _T | None) -> _T | None:
19+
"""Cast an object to given type.
20+
21+
This is used to mark an object as a cast object, so that the type can be
22+
picked up when resolving unions/interfaces in case of ambiguity, which can
23+
happen when returning an alike object instead of an instance of the type
24+
(e.g. returning a Django, Pydantic or SQLAlchemy object)
25+
"""
26+
if obj is None:
27+
return None
28+
29+
setattr(obj, TYPE_CAST_ATTRIBUTE, type_)
30+
return obj
31+
32+
33+
def get_strawberry_type_cast(obj: Any) -> type | None:
34+
"""Get the type of a cast object."""
35+
return getattr(obj, TYPE_CAST_ATTRIBUTE, None)

tests/relay/test_fields.py

+78
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import dataclasses
12
import textwrap
3+
from collections.abc import Iterable
4+
from typing import Optional, Union
5+
from typing_extensions import Self
26

37
import pytest
48
from pytest_mock import MockerFixture
@@ -1621,3 +1625,77 @@ def test_query_after_error():
16211625

16221626
assert result.errors is not None
16231627
assert "Argument 'after' contains a non-existing value" in str(result.errors)
1628+
1629+
1630+
@pytest.mark.parametrize(
1631+
("type_name", "should_have_name"),
1632+
[("Fruit", False), ("PublicFruit", True)],
1633+
)
1634+
@pytest.mark.django_db(transaction=True)
1635+
def test_correct_model_returned(type_name: str, should_have_name: bool):
1636+
@dataclasses.dataclass
1637+
class FruitModel:
1638+
id: str
1639+
name: str
1640+
1641+
fruits: dict[str, FruitModel] = {"1": FruitModel(id="1", name="Strawberry")}
1642+
1643+
@strawberry.type
1644+
class Fruit(relay.Node):
1645+
id: relay.NodeID[int]
1646+
1647+
@classmethod
1648+
def resolve_nodes(
1649+
cls,
1650+
*,
1651+
info: Optional[strawberry.Info] = None,
1652+
node_ids: Iterable[str],
1653+
required: bool = False,
1654+
) -> Iterable[Optional[Union[Self, FruitModel]]]:
1655+
return [fruits[nid] if required else fruits.get(nid) for nid in node_ids]
1656+
1657+
@strawberry.type
1658+
class PublicFruit(relay.Node):
1659+
id: relay.NodeID[int]
1660+
name: str
1661+
1662+
@classmethod
1663+
def resolve_nodes(
1664+
cls,
1665+
*,
1666+
info: Optional[strawberry.Info] = None,
1667+
node_ids: Iterable[str],
1668+
required: bool = False,
1669+
) -> Iterable[Optional[Union[Self, FruitModel]]]:
1670+
return [fruits[nid] if required else fruits.get(nid) for nid in node_ids]
1671+
1672+
@strawberry.type
1673+
class Query:
1674+
node: relay.Node = relay.node()
1675+
1676+
schema = strawberry.Schema(query=Query, types=[Fruit, PublicFruit])
1677+
1678+
node_id = relay.to_base64(type_name, "1")
1679+
result = schema.execute_sync(
1680+
"""
1681+
query NodeQuery($id: GlobalID!) {
1682+
node(id: $id) {
1683+
__typename
1684+
id
1685+
... on PublicFruit {
1686+
name
1687+
}
1688+
}
1689+
}
1690+
""",
1691+
{"id": node_id},
1692+
)
1693+
assert result.errors is None
1694+
assert isinstance(result.data, dict)
1695+
1696+
assert result.data["node"]["__typename"] == type_name
1697+
assert result.data["node"]["id"] == node_id
1698+
if should_have_name:
1699+
assert result.data["node"]["name"] == "Strawberry"
1700+
else:
1701+
assert "name" not in result.data["node"]

tests/types/test_cast.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import strawberry
2+
from strawberry.types.cast import get_strawberry_type_cast
3+
4+
5+
def test_cast():
6+
@strawberry.type
7+
class SomeType: ...
8+
9+
class OtherType: ...
10+
11+
obj = OtherType
12+
assert get_strawberry_type_cast(obj) is None
13+
14+
cast_obj = strawberry.cast(SomeType, obj)
15+
assert cast_obj is obj
16+
assert get_strawberry_type_cast(cast_obj) is SomeType
17+
18+
19+
def test_cast_none_obj():
20+
@strawberry.type
21+
class SomeType: ...
22+
23+
obj = None
24+
assert get_strawberry_type_cast(obj) is None
25+
26+
cast_obj = strawberry.cast(SomeType, obj)
27+
assert cast_obj is None
28+
assert get_strawberry_type_cast(obj) is None

0 commit comments

Comments
 (0)