-
-
Notifications
You must be signed in to change notification settings - Fork 542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
wip: Initial work for first class pydantic usage #3268
Open
thejaminator
wants to merge
12
commits into
strawberry-graphql:main
Choose a base branch
from
thejaminator:first-class-init
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
683c14a
initial commit
thejaminator b983039
add tests
thejaminator 1139e92
add more tests
thejaminator 4cf484c
stuck at both input and output
thejaminator 4811d74
xfail failing tests
thejaminator edd56bb
refactor
thejaminator 172c1d8
refactor
thejaminator 1318e86
fix issue with camelcase
thejaminator bc3de8e
add schema basic
thejaminator a4bc9ae
add xfail for mutation
thejaminator 11d5e4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b4f838e
comment with issue
thejaminator File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
181 changes: 181 additions & 0 deletions
181
strawberry/experimental/pydantic/pydantic_first_class.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import dataclasses | ||
from typing import Callable, Dict, List, Optional, Sequence, Type | ||
|
||
from pydantic import BaseModel | ||
|
||
from strawberry.annotation import StrawberryAnnotation | ||
from strawberry.experimental.pydantic._compat import CompatModelField, get_model_fields | ||
from strawberry.experimental.pydantic.conversion_types import PydanticModel | ||
from strawberry.experimental.pydantic.fields import replace_types_recursively | ||
from strawberry.experimental.pydantic.utils import get_default_factory_for_field | ||
from strawberry.field import StrawberryField | ||
from strawberry.object_type import _get_interfaces | ||
from strawberry.types.types import StrawberryObjectDefinition | ||
from strawberry.utils.deprecations import DEPRECATION_MESSAGES, DeprecatedDescriptor | ||
from strawberry.utils.str_converters import to_camel_case | ||
|
||
|
||
def _get_strawberry_fields_from_basemodel( | ||
model: Type[BaseModel], is_input: bool, use_pydantic_alias: bool | ||
) -> List[StrawberryField]: | ||
"""Get all the strawberry fields off a pydantic BaseModel cls | ||
This function returns a list of StrawberryFields (one for each field item), while | ||
also paying attention the name and typing of the field. | ||
model: | ||
A pure pydantic field. Will not have a StrawberryField; one will need to | ||
be created in this function. Type annotation is required. | ||
""" | ||
fields: list[StrawberryField] = [] | ||
|
||
# BaseModel already has fields, so we need to get them from there | ||
model_fields: Dict[str, CompatModelField] = get_model_fields(model) | ||
for name, field in model_fields.items(): | ||
converted_type = replace_types_recursively(field.outer_type_, is_input=is_input) | ||
if field.allow_none: | ||
converted_type = Optional[converted_type] | ||
graphql_before_case = ( | ||
field.alias or field.name if use_pydantic_alias else field.name | ||
) | ||
camel_case_name = to_camel_case(graphql_before_case) | ||
fields.append( | ||
StrawberryField( | ||
python_name=name, | ||
graphql_name=camel_case_name, | ||
# always unset because we use default_factory instead | ||
default=dataclasses.MISSING, | ||
default_factory=get_default_factory_for_field(field), | ||
type_annotation=StrawberryAnnotation.from_annotation(converted_type), | ||
description=field.description, | ||
deprecation_reason=None, | ||
permission_classes=[], | ||
directives=[], | ||
metadata={}, | ||
) | ||
) | ||
|
||
return fields | ||
|
||
|
||
def first_class_process_basemodel( | ||
model: Type[BaseModel], | ||
*, | ||
name: Optional[str] = None, | ||
is_input: bool = False, | ||
is_interface: bool = False, | ||
description: Optional[str] = None, | ||
directives: Optional[Sequence[object]] = (), | ||
extend: bool = False, | ||
use_pydantic_alias: bool = True, | ||
): | ||
name = name or to_camel_case(model.__name__) | ||
|
||
interfaces = _get_interfaces(model) | ||
fields: List[StrawberryField] = _get_strawberry_fields_from_basemodel( | ||
model, is_input=is_input, use_pydantic_alias=use_pydantic_alias | ||
) | ||
is_type_of = getattr(model, "is_type_of", None) | ||
resolve_type = getattr(model, "resolve_type", None) | ||
|
||
model.__strawberry_definition__ = StrawberryObjectDefinition( | ||
name=name, | ||
is_input=is_input, | ||
is_interface=is_interface, | ||
interfaces=interfaces, | ||
description=description, | ||
directives=directives, | ||
origin=model, | ||
extend=extend, | ||
fields=fields, | ||
is_type_of=is_type_of, | ||
resolve_type=resolve_type, | ||
) | ||
# TODO: remove when deprecating _type_definition | ||
DeprecatedDescriptor( | ||
DEPRECATION_MESSAGES._TYPE_DEFINITION, | ||
model.__strawberry_definition__, | ||
"_type_definition", | ||
).inject(model) | ||
|
||
return model | ||
|
||
|
||
def register_first_class( | ||
model: Type[PydanticModel], | ||
*, | ||
name: Optional[str] = None, | ||
is_input: bool = False, | ||
is_interface: bool = False, | ||
description: Optional[str] = None, | ||
use_pydantic_alias: bool = True, | ||
) -> Type[PydanticModel]: | ||
"""A function for registering a pydantic model as a first class strawberry type. | ||
This is useful when your pydantic model is some code that you can't edit | ||
(e.g. from a third party library). | ||
|
||
Example: | ||
class User(BaseModel): | ||
id: int | ||
name: str | ||
|
||
register_first_class(User) | ||
|
||
@strawberry.type | ||
class Query: | ||
@strawberry.field | ||
def user(self) -> User: | ||
return User(id=1, name="Patrick") | ||
""" | ||
|
||
first_class_process_basemodel( | ||
model, | ||
name=name or to_camel_case(model.__name__), | ||
is_input=is_input, | ||
is_interface=is_interface, | ||
description=description, | ||
use_pydantic_alias=use_pydantic_alias, | ||
) | ||
|
||
if is_input: | ||
# TODO: Probably should check if the name clashes with an existing type? | ||
model._strawberry_input_type = model # type: ignore | ||
else: | ||
model._strawberry_type = model # type: ignore | ||
|
||
return model | ||
|
||
|
||
def first_class( | ||
name: Optional[str] = None, | ||
is_input: bool = False, | ||
is_interface: bool = False, | ||
description: Optional[str] = None, | ||
use_pydantic_alias: bool = True, | ||
) -> Callable[[Type[PydanticModel]], Type[PydanticModel]]: | ||
"""A decorator to make a pydantic class work on strawberry without creating | ||
a separate strawberry type. | ||
|
||
Example: | ||
@strawberry.experimental.pydantic.first_class() | ||
class User(BaseModel): | ||
id: int | ||
name: str | ||
|
||
@strawberry.type | ||
class Query: | ||
@strawberry.field | ||
def user(self) -> User: | ||
return User(id=1, name="Patrick") | ||
|
||
""" | ||
|
||
def wrap(model: Type[PydanticModel]) -> Type[PydanticModel]: | ||
return register_first_class( | ||
model, | ||
name=name, | ||
is_input=is_input, | ||
is_interface=is_interface, | ||
description=description, | ||
use_pydantic_alias=use_pydantic_alias, | ||
) | ||
|
||
return wrap |
216 changes: 216 additions & 0 deletions
216
tests/experimental/pydantic_first_class/schema/test_mutation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
from typing import Dict, List, Union | ||
|
||
import pydantic | ||
import pytest | ||
|
||
import strawberry | ||
from strawberry.experimental.pydantic._compat import IS_PYDANTIC_V2 | ||
from strawberry.experimental.pydantic.pydantic_first_class import first_class | ||
|
||
|
||
def test_mutation(): | ||
@first_class(is_input=True) | ||
class CreateUserInput(pydantic.BaseModel): | ||
name: pydantic.constr(min_length=2) | ||
|
||
@first_class() | ||
class UserType(pydantic.BaseModel): | ||
name: str | ||
|
||
@strawberry.type | ||
class Query: | ||
h: str | ||
|
||
@strawberry.type | ||
class Mutation: | ||
@strawberry.mutation | ||
def create_user(self, input: CreateUserInput) -> UserType: | ||
return UserType(name=input.name) | ||
|
||
schema = strawberry.Schema(query=Query, mutation=Mutation) | ||
|
||
query = """ | ||
mutation { | ||
createUser(input: { name: "Patrick" }) { | ||
name | ||
} | ||
} | ||
""" | ||
|
||
result = schema.execute_sync(query) | ||
|
||
assert not result.errors | ||
assert result.data["createUser"]["name"] == "Patrick" | ||
|
||
|
||
def test_mutation_with_validation(): | ||
@first_class(is_input=True) | ||
class CreateUserInput(pydantic.BaseModel): | ||
name: pydantic.constr(min_length=2) | ||
|
||
@first_class() | ||
class UserType(pydantic.BaseModel): | ||
name: str | ||
|
||
@strawberry.type | ||
class Query: | ||
h: str | ||
|
||
@strawberry.type | ||
class Mutation: | ||
@strawberry.mutation | ||
def create_user(self, input: CreateUserInput) -> UserType: | ||
return UserType(name=input.name) | ||
|
||
schema = strawberry.Schema(query=Query, mutation=Mutation) | ||
|
||
query = """ | ||
mutation { | ||
createUser(input: { name: "P" }) { | ||
name | ||
} | ||
} | ||
""" | ||
|
||
result = schema.execute_sync(query) | ||
|
||
if IS_PYDANTIC_V2: | ||
assert result.errors[0].message == ( | ||
"1 validation error for CreateUserInput\n" | ||
"name\n" | ||
" String should have at least 2 characters [type=string_too_short, " | ||
"input_value='P', input_type=str]\n" | ||
" For further information visit " | ||
"https://errors.pydantic.dev/2.0.3/v/string_too_short" | ||
) | ||
else: | ||
assert result.errors[0].message == ( | ||
"1 validation error for CreateUserInput\nname\n ensure this value has at " | ||
"least 2 characters (type=value_error.any_str.min_length; limit_value=2)" | ||
) | ||
|
||
|
||
def test_mutation_with_validation_of_nested_model(): | ||
@first_class(is_input=True) | ||
class HobbyInput(pydantic.BaseModel): | ||
name: pydantic.constr(min_length=2) | ||
|
||
@first_class(is_input=True) | ||
class CreateUserInput(pydantic.BaseModel): | ||
hobby: HobbyInput | ||
|
||
class UserModel(pydantic.BaseModel): | ||
name: str | ||
|
||
@strawberry.experimental.pydantic.type(UserModel) | ||
class UserType: | ||
name: strawberry.auto | ||
|
||
@strawberry.type | ||
class Query: | ||
h: str | ||
|
||
@strawberry.type | ||
class Mutation: | ||
@strawberry.mutation | ||
def create_user(self, input: CreateUserInput) -> UserType: | ||
return UserType(name=input.hobby.name) | ||
|
||
schema = strawberry.Schema(query=Query, mutation=Mutation) | ||
|
||
query = """ | ||
mutation { | ||
createUser(input: { hobby: { name: "P" } }) { | ||
name | ||
} | ||
} | ||
""" | ||
|
||
result = schema.execute_sync(query) | ||
|
||
if IS_PYDANTIC_V2: | ||
assert result.errors[0].message == ( | ||
"1 validation error for HobbyInput\n" | ||
"name\n" | ||
" String should have at least 2 characters [type=string_too_short, " | ||
"input_value='P', input_type=str]\n" | ||
" For further information visit " | ||
"https://errors.pydantic.dev/2.0.3/v/string_too_short" | ||
) | ||
|
||
else: | ||
assert result.errors[0].message == ( | ||
"1 validation error for HobbyInput\nname\n" | ||
" ensure this value has at least 2 characters " | ||
"(type=value_error.any_str.min_length; limit_value=2)" | ||
) | ||
|
||
|
||
@pytest.mark.xfail( | ||
reason="""No way to manually handle errors | ||
the validation goes boom in convert_argument, not in the create_user resolver""" | ||
) | ||
def test_mutation_with_validation_and_error_type(): | ||
@first_class(is_input=True) | ||
class CreateUserInput(pydantic.BaseModel): | ||
name: pydantic.constr(min_length=2) | ||
|
||
@first_class() | ||
class UserType(pydantic.BaseModel): | ||
name: str | ||
|
||
@first_class() | ||
class UserError(pydantic.BaseModel): | ||
name: str | ||
|
||
@strawberry.type | ||
class Query: | ||
h: str | ||
|
||
@strawberry.type | ||
class Mutation: | ||
@strawberry.mutation | ||
def create_user(self, input: CreateUserInput) -> Union[UserType, UserError]: | ||
try: | ||
data = input | ||
except pydantic.ValidationError as e: | ||
# issue: the error will never be thrown here because the validation | ||
# happens in convert_argument | ||
args: Dict[str, List[str]] = {} | ||
for error in e.errors(): | ||
field = error["loc"][0] # currently doesn't support nested errors | ||
field_errors = args.get(field, []) | ||
field_errors.append(error["msg"]) | ||
args[field] = field_errors | ||
return UserError(**args) | ||
else: | ||
return UserType(name=data.name) | ||
|
||
schema = strawberry.Schema(query=Query, mutation=Mutation) | ||
|
||
query = """ | ||
mutation { | ||
createUser(input: { name: "P" }) { | ||
... on UserType { | ||
name | ||
} | ||
... on UserError { | ||
nameErrors: name | ||
} | ||
} | ||
} | ||
""" | ||
|
||
result = schema.execute_sync(query) | ||
|
||
assert result.errors is None | ||
assert result.data["createUser"].get("name") is None | ||
|
||
if IS_PYDANTIC_V2: | ||
assert result.data["createUser"]["nameErrors"] == [ | ||
("String should have at least 2 characters") | ||
] | ||
else: | ||
assert result.data["createUser"]["nameErrors"] == [ | ||
("ensure this value has at least 2 characters") | ||
] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue 2: no way to manually handle validation errors for now