Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: Initial work for first class pydantic usage #3268

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
181 changes: 181 additions & 0 deletions strawberry/experimental/pydantic/pydantic_first_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import dataclasses
from typing import Callable, Dict, List, Optional, Sequence, Type

from pydantic import BaseModel

from strawberry.annotation import StrawberryAnnotation
from strawberry.experimental.pydantic._compat import CompatModelField, get_model_fields
from strawberry.experimental.pydantic.conversion_types import PydanticModel
from strawberry.experimental.pydantic.fields import replace_types_recursively
from strawberry.experimental.pydantic.utils import get_default_factory_for_field
from strawberry.field import StrawberryField
from strawberry.object_type import _get_interfaces
from strawberry.types.types import StrawberryObjectDefinition
from strawberry.utils.deprecations import DEPRECATION_MESSAGES, DeprecatedDescriptor
from strawberry.utils.str_converters import to_camel_case


def _get_strawberry_fields_from_basemodel(
model: Type[BaseModel], is_input: bool, use_pydantic_alias: bool
) -> List[StrawberryField]:
"""Get all the strawberry fields off a pydantic BaseModel cls
This function returns a list of StrawberryFields (one for each field item), while
also paying attention the name and typing of the field.
model:
A pure pydantic field. Will not have a StrawberryField; one will need to
be created in this function. Type annotation is required.
"""
fields: list[StrawberryField] = []

# BaseModel already has fields, so we need to get them from there
model_fields: Dict[str, CompatModelField] = get_model_fields(model)
for name, field in model_fields.items():
converted_type = replace_types_recursively(field.outer_type_, is_input=is_input)
if field.allow_none:
converted_type = Optional[converted_type]
graphql_before_case = (
field.alias or field.name if use_pydantic_alias else field.name
)
camel_case_name = to_camel_case(graphql_before_case)
fields.append(
StrawberryField(
python_name=name,
graphql_name=camel_case_name,
# always unset because we use default_factory instead
default=dataclasses.MISSING,
default_factory=get_default_factory_for_field(field),
type_annotation=StrawberryAnnotation.from_annotation(converted_type),
description=field.description,
deprecation_reason=None,
permission_classes=[],
directives=[],
metadata={},
)
)

return fields


def first_class_process_basemodel(
model: Type[BaseModel],
*,
name: Optional[str] = None,
is_input: bool = False,
is_interface: bool = False,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
use_pydantic_alias: bool = True,
):
name = name or to_camel_case(model.__name__)

interfaces = _get_interfaces(model)
fields: List[StrawberryField] = _get_strawberry_fields_from_basemodel(
model, is_input=is_input, use_pydantic_alias=use_pydantic_alias
)
is_type_of = getattr(model, "is_type_of", None)
resolve_type = getattr(model, "resolve_type", None)

model.__strawberry_definition__ = StrawberryObjectDefinition(
name=name,
is_input=is_input,
is_interface=is_interface,
interfaces=interfaces,
description=description,
directives=directives,
origin=model,
extend=extend,
fields=fields,
is_type_of=is_type_of,
resolve_type=resolve_type,
)
# TODO: remove when deprecating _type_definition
DeprecatedDescriptor(
DEPRECATION_MESSAGES._TYPE_DEFINITION,
model.__strawberry_definition__,
"_type_definition",
).inject(model)

return model


def register_first_class(
model: Type[PydanticModel],
*,
name: Optional[str] = None,
is_input: bool = False,
is_interface: bool = False,
description: Optional[str] = None,
use_pydantic_alias: bool = True,
) -> Type[PydanticModel]:
"""A function for registering a pydantic model as a first class strawberry type.
This is useful when your pydantic model is some code that you can't edit
(e.g. from a third party library).

Example:
class User(BaseModel):
id: int
name: str

register_first_class(User)

@strawberry.type
class Query:
@strawberry.field
def user(self) -> User:
return User(id=1, name="Patrick")
"""

first_class_process_basemodel(
model,
name=name or to_camel_case(model.__name__),
is_input=is_input,
is_interface=is_interface,
description=description,
use_pydantic_alias=use_pydantic_alias,
)

if is_input:
# TODO: Probably should check if the name clashes with an existing type?
model._strawberry_input_type = model # type: ignore
else:
model._strawberry_type = model # type: ignore

return model


def first_class(
name: Optional[str] = None,
is_input: bool = False,
is_interface: bool = False,
description: Optional[str] = None,
use_pydantic_alias: bool = True,
) -> Callable[[Type[PydanticModel]], Type[PydanticModel]]:
"""A decorator to make a pydantic class work on strawberry without creating
a separate strawberry type.

Example:
@strawberry.experimental.pydantic.first_class()
class User(BaseModel):
id: int
name: str

@strawberry.type
class Query:
@strawberry.field
def user(self) -> User:
return User(id=1, name="Patrick")

"""

def wrap(model: Type[PydanticModel]) -> Type[PydanticModel]:
return register_first_class(
model,
name=name,
is_input=is_input,
is_interface=is_interface,
description=description,
use_pydantic_alias=use_pydantic_alias,
)

return wrap
216 changes: 216 additions & 0 deletions tests/experimental/pydantic_first_class/schema/test_mutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from typing import Dict, List, Union

import pydantic
import pytest

import strawberry
from strawberry.experimental.pydantic._compat import IS_PYDANTIC_V2
from strawberry.experimental.pydantic.pydantic_first_class import first_class


def test_mutation():
@first_class(is_input=True)
class CreateUserInput(pydantic.BaseModel):
name: pydantic.constr(min_length=2)

@first_class()
class UserType(pydantic.BaseModel):
name: str

@strawberry.type
class Query:
h: str

@strawberry.type
class Mutation:
@strawberry.mutation
def create_user(self, input: CreateUserInput) -> UserType:
return UserType(name=input.name)

schema = strawberry.Schema(query=Query, mutation=Mutation)

query = """
mutation {
createUser(input: { name: "Patrick" }) {
name
}
}
"""

result = schema.execute_sync(query)

assert not result.errors
assert result.data["createUser"]["name"] == "Patrick"


def test_mutation_with_validation():
@first_class(is_input=True)
class CreateUserInput(pydantic.BaseModel):
name: pydantic.constr(min_length=2)

@first_class()
class UserType(pydantic.BaseModel):
name: str

@strawberry.type
class Query:
h: str

@strawberry.type
class Mutation:
@strawberry.mutation
def create_user(self, input: CreateUserInput) -> UserType:
return UserType(name=input.name)

Check warning on line 63 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L63

Added line #L63 was not covered by tests

schema = strawberry.Schema(query=Query, mutation=Mutation)

query = """
mutation {
createUser(input: { name: "P" }) {
name
}
}
"""

result = schema.execute_sync(query)

if IS_PYDANTIC_V2:
assert result.errors[0].message == (

Check warning on line 78 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L78

Added line #L78 was not covered by tests
"1 validation error for CreateUserInput\n"
"name\n"
" String should have at least 2 characters [type=string_too_short, "
"input_value='P', input_type=str]\n"
" For further information visit "
"https://errors.pydantic.dev/2.0.3/v/string_too_short"
)
else:
assert result.errors[0].message == (
"1 validation error for CreateUserInput\nname\n ensure this value has at "
"least 2 characters (type=value_error.any_str.min_length; limit_value=2)"
)


def test_mutation_with_validation_of_nested_model():
@first_class(is_input=True)
class HobbyInput(pydantic.BaseModel):
name: pydantic.constr(min_length=2)

@first_class(is_input=True)
class CreateUserInput(pydantic.BaseModel):
hobby: HobbyInput

class UserModel(pydantic.BaseModel):
name: str

@strawberry.experimental.pydantic.type(UserModel)
class UserType:
name: strawberry.auto

@strawberry.type
class Query:
h: str

@strawberry.type
class Mutation:
@strawberry.mutation
def create_user(self, input: CreateUserInput) -> UserType:
return UserType(name=input.hobby.name)

Check warning on line 117 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L117

Added line #L117 was not covered by tests

schema = strawberry.Schema(query=Query, mutation=Mutation)

query = """
mutation {
createUser(input: { hobby: { name: "P" } }) {
name
}
}
"""

result = schema.execute_sync(query)

if IS_PYDANTIC_V2:
assert result.errors[0].message == (

Check warning on line 132 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L132

Added line #L132 was not covered by tests
"1 validation error for HobbyInput\n"
"name\n"
" String should have at least 2 characters [type=string_too_short, "
"input_value='P', input_type=str]\n"
" For further information visit "
"https://errors.pydantic.dev/2.0.3/v/string_too_short"
)

else:
assert result.errors[0].message == (
"1 validation error for HobbyInput\nname\n"
" ensure this value has at least 2 characters "
"(type=value_error.any_str.min_length; limit_value=2)"
)


@pytest.mark.xfail(
reason="""No way to manually handle errors
the validation goes boom in convert_argument, not in the create_user resolver"""
)
def test_mutation_with_validation_and_error_type():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue 2: no way to manually handle validation errors for now

@first_class(is_input=True)
class CreateUserInput(pydantic.BaseModel):
name: pydantic.constr(min_length=2)

@first_class()
class UserType(pydantic.BaseModel):
name: str

@first_class()
class UserError(pydantic.BaseModel):
name: str

@strawberry.type
class Query:
h: str

@strawberry.type
class Mutation:
@strawberry.mutation
def create_user(self, input: CreateUserInput) -> Union[UserType, UserError]:
try:
data = input
except pydantic.ValidationError as e:

Check warning on line 176 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L174-L176

Added lines #L174 - L176 were not covered by tests
# issue: the error will never be thrown here because the validation
# happens in convert_argument
args: Dict[str, List[str]] = {}

Check warning on line 179 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L179

Added line #L179 was not covered by tests
for error in e.errors():
field = error["loc"][0] # currently doesn't support nested errors
field_errors = args.get(field, [])
field_errors.append(error["msg"])
args[field] = field_errors
return UserError(**args)

Check warning on line 185 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L181-L185

Added lines #L181 - L185 were not covered by tests
else:
return UserType(name=data.name)

Check warning on line 187 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L187

Added line #L187 was not covered by tests

schema = strawberry.Schema(query=Query, mutation=Mutation)

query = """
mutation {
createUser(input: { name: "P" }) {
... on UserType {
name
}
... on UserError {
nameErrors: name
}
}
}
"""

result = schema.execute_sync(query)

assert result.errors is None
assert result.data["createUser"].get("name") is None

Check warning on line 207 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L207

Added line #L207 was not covered by tests

if IS_PYDANTIC_V2:
assert result.data["createUser"]["nameErrors"] == [

Check warning on line 210 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L210

Added line #L210 was not covered by tests
("String should have at least 2 characters")
]
else:
assert result.data["createUser"]["nameErrors"] == [

Check warning on line 214 in tests/experimental/pydantic_first_class/schema/test_mutation.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic_first_class/schema/test_mutation.py#L214

Added line #L214 was not covered by tests
("ensure this value has at least 2 characters")
]
Loading
Loading