Skip to content

Commit

Permalink
feat(pydantic): Add mypy extension, typing stubs for conversion of py…
Browse files Browse the repository at this point in the history
…dantic models (strawberry-graphql#1544)

* fix(pydantic): Add mypy extension for to_pydantic and from_pydantic for converted models, add stub for IDEs

* fix(pydantic): Fix for mypy==0.93.-

* fix(pydantic): Backwards compat for mypy < 0.93

* fix(pydantic): Add RELEASE.md

* fix(pydantic): Refactor A to become PydanticModel

* fix(pydantic): Add type: ignore

* fix(pydantic): Rename add_method_to_class to add_static_method_to_class so future developers use it only for static methods

* Apply suggestions from code review

Co-authored-by: ignormies <bryce.beagle@gmail.com>

* fix(pydantic): Use Protocol[PydanticModel] instead of Generic[PydanticModel], use from __future__ import annotations

* retrigger checks

* fix(compat): import Protocol from typing_extensions instead of typing for Python <= 3.7 compat

* fix(codecov): ignore strawberry/ext/mypy_plugin.py in codecov

* fix(codecov): ignore setup.py

* Edit comment

* Update RELEASE.md

Co-authored-by: Patrick Arminio <patrick.arminio@gmail.com>

* fix(review): Move StrawberryTypeFromPydantic to new file. Add example to RELEASE.md.

* fix(codecov): Silence codecov up about moving StrawberryTypeFromPydantic type to new file.

* fix(release): Edit RELEASE.md

* fix(typehint): fix from_pydantic type hint

Co-authored-by: James Chua <james@leadiq.com>
Co-authored-by: ignormies <bryce.beagle@gmail.com>
Co-authored-by: Patrick Arminio <patrick.arminio@gmail.com>
  • Loading branch information
4 people authored Jan 22, 2022
1 parent b021481 commit 86e9e77
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 11 deletions.
5 changes: 5 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ comment:
layout: "header, diff"
behavior: default
require_changes: no

ignore:
- "strawberry/ext/mypy_plugin.py"
- "setup.py"
- "strawberry/experimental/pydantic/conversion_types.py"
20 changes: 20 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Release type: minor

Adds `to_pydantic` and `from_pydantic` type hints for IDE support.

Adds mypy extension support as well.

```python
from pydantic import BaseModel
import strawberry

class UserPydantic(BaseModel):
age: int

@strawberry.experimental.pydantic.type(UserPydantic)
class UserStrawberry:
age: strawberry.auto

reveal_type(UserStrawberry(age=123).to_pydantic())
```
Mypy will infer the type as "UserPydantic". Previously it would be "Any"
36 changes: 36 additions & 0 deletions strawberry/experimental/pydantic/conversion_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from typing import Any, Dict, TypeVar

from pydantic import BaseModel
from typing_extensions import Protocol

from strawberry.types.types import TypeDefinition


PydanticModel = TypeVar("PydanticModel", bound=BaseModel)


class StrawberryTypeFromPydantic(Protocol[PydanticModel]):
"""This class does not exist in runtime.
It only makes the methods below visible for IDEs"""

def __init__(self, **kwargs):
...

@staticmethod
def from_pydantic(
instance: PydanticModel, extra: Dict[str, Any] = None
) -> StrawberryTypeFromPydantic[PydanticModel]:
...

def to_pydantic(self) -> PydanticModel:
...

@property
def _type_definition(self) -> TypeDefinition:
...

@property
def _pydantic_type(self) -> PydanticModel:
...
33 changes: 27 additions & 6 deletions strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from __future__ import annotations

import builtins
import dataclasses
import warnings
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Type, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
cast,
)

from pydantic import BaseModel
from pydantic.fields import ModelField
Expand Down Expand Up @@ -74,8 +86,15 @@ def get_type_for_field(field: ModelField):
return type_


if TYPE_CHECKING:
from strawberry.experimental.pydantic.conversion_types import (
PydanticModel,
StrawberryTypeFromPydantic,
)


def type(
model: Type[BaseModel],
model: Type[PydanticModel],
*,
fields: Optional[List[str]] = None,
name: Optional[str] = None,
Expand All @@ -84,8 +103,8 @@ def type(
description: Optional[str] = None,
directives: Optional[Sequence[StrawberrySchemaDirective]] = (),
all_fields: bool = False,
):
def wrap(cls):
) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]:
def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]:
model_fields = model.__fields__
fields_set = set(fields) if fields else set([])

Expand Down Expand Up @@ -178,12 +197,14 @@ def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool:
model._strawberry_type = cls # type: ignore
cls._pydantic_type = model # type: ignore

def from_pydantic(instance: Any, extra: Dict[str, Any] = None) -> Any:
def from_pydantic(
instance: PydanticModel, extra: Dict[str, Any] = None
) -> StrawberryTypeFromPydantic[PydanticModel]:
return convert_pydantic_model_to_strawberry_class(
cls=cls, model_instance=instance, extra=extra
)

def to_pydantic(self) -> Any:
def to_pydantic(self) -> PydanticModel:
instance_kwargs = dataclasses.asdict(self)

return model(**instance_kwargs)
Expand Down
117 changes: 112 additions & 5 deletions strawberry/ext/mypy_plugin.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
from decimal import Decimal
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast

from typing_extensions import Final

from mypy.nodes import (
ARG_OPT,
ARG_POS,
ARG_STAR2,
GDEF,
MDEF,
Argument,
AssignmentStmt,
Block,
CallExpr,
CastExpr,
ClassDef,
Context,
Expression,
FuncDef,
IndexExpr,
MemberExpr,
NameExpr,
PassStmt,
PlaceholderNode,
RefExpr,
SymbolTableNode,
Expand All @@ -28,14 +34,16 @@
)
from mypy.plugin import (
AnalyzeTypeContext,
CheckerPluginInterface,
ClassDefContext,
DynamicClassDefContext,
FunctionContext,
Plugin,
SemanticAnalyzerPluginInterface,
)
from mypy.plugins.common import _get_decorator_bool_argument, add_method
from mypy.plugins.common import _get_argument, _get_decorator_bool_argument, add_method
from mypy.plugins.dataclasses import DataclassAttribute
from mypy.semanal_shared import set_callable_name
from mypy.server.trigger import make_wildcard_trigger
from mypy.types import (
AnyType,
Expand All @@ -48,6 +56,8 @@
UnionType,
get_proper_type,
)
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name


# Backwards compatible with the removal of `TypeVarDef` in mypy 0.920.
Expand Down Expand Up @@ -245,17 +255,114 @@ def enum_hook(ctx: DynamicClassDefContext) -> None:
)


def strawberry_pydantic_class_callback(ctx: ClassDefContext):
def add_static_method_to_class(
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface],
cls: ClassDef,
name: str,
args: List[Argument],
return_type: Type,
tvar_def: Optional[TypeVarType] = None,
) -> None:
"""Adds a static method
Edited add_method_to_class to incorporate static method logic
https://github.com/python/mypy/blob/9c05d3d19/mypy/plugins/common.py
"""
info = cls.info

# First remove any previously generated methods with the same name
# to avoid clashes and problems in the semantic analyzer.
if name in info.names:
sym = info.names[name]
if sym.plugin_generated and isinstance(sym.node, FuncDef):
cls.defs.body.remove(sym.node)

# For compat with mypy < 0.93
if MypyVersion.VERSION < Decimal("0.93"):
function_type = api.named_type("__builtins__.function") # type: ignore
else:
if isinstance(api, SemanticAnalyzerPluginInterface):
function_type = api.named_type("builtins.function")
else:
function_type = api.named_generic_type("builtins.function", [])

arg_types, arg_names, arg_kinds = [], [], []
for arg in args:
assert arg.type_annotation, "All arguments must be fully typed."
arg_types.append(arg.type_annotation)
arg_names.append(arg.variable.name)
arg_kinds.append(arg.kind)

signature = CallableType(
arg_types, arg_kinds, arg_names, return_type, function_type
)
if tvar_def:
signature.variables = [tvar_def]

func = FuncDef(name, args, Block([PassStmt()]))

func.is_static = True
func.info = info
func.type = set_callable_name(signature, func)
func._fullname = f"{info.fullname}.{name}"
func.line = info.line

# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]

info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True)
info.defn.defs.body.append(func)


def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None:
# in future we want to have a proper pydantic plugin, but for now
# let's fallback to any, some resources are here:
# let's fallback to **kwargs for __init__, some resources are here:
# https://github.com/samuelcolvin/pydantic/blob/master/pydantic/mypy.py
# >>> model_index = ctx.cls.decorators[0].arg_names.index("model")
# >>> model_name = ctx.cls.decorators[0].args[model_index].name

# >>> model_type = ctx.api.named_type("UserModel")
# >>> model_type = ctx.api.lookup(model_name, Context())

ctx.cls.info.fallback_to_any = True
model_expression = _get_argument(call=ctx.reason, name="model") # type: ignore
if model_expression is None:
ctx.api.fail("model argument in decorator failed to be parsed", ctx.reason)

else:
# Add __init__
init_args = [
Argument(Var("kwargs"), AnyType(TypeOfAny.explicit), None, ARG_STAR2)
]
add_method(ctx, "__init__", init_args, NoneType())

model_type = _get_type_for_expr(model_expression, ctx.api)

# Add to_pydantic
add_method(
ctx,
"to_pydantic",
args=[],
return_type=model_type,
)

# Add from_pydantic
model_argument = Argument(
variable=Var(name="instance", type=model_type),
type_annotation=model_type,
initializer=None,
kind=ARG_OPT,
)

add_static_method_to_class(
ctx.api,
ctx.cls,
name="from_pydantic",
args=[model_argument],
return_type=fill_typevars(ctx.cls.info),
)


def is_dataclasses_field_or_strawberry_field(expr: Expression) -> bool:
Expand Down
86 changes: 86 additions & 0 deletions tests/mypy/test_pydantic.decorators.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

- case: test_converted_pydantic_init_any_kwargs
main: |
from pydantic import BaseModel
import strawberry
class UserPydantic(BaseModel):
age: int
@strawberry.experimental.pydantic.type(UserPydantic)
class UserStrawberry:
age: strawberry.auto
reveal_type(UserStrawberry)
reveal_type(UserStrawberry(age=123))
out: |
main:11: note: Revealed type is "def (**kwargs: Any) -> main.UserStrawberry"
main:12: note: Revealed type is "main.UserStrawberry"
- case: test_converted_to_pydantic
main: |
from pydantic import BaseModel
import strawberry
class UserPydantic(BaseModel):
age: int
@strawberry.experimental.pydantic.type(UserPydantic)
class UserStrawberry:
age: strawberry.auto
reveal_type(UserStrawberry(age=123).to_pydantic())
out: |
main:11: note: Revealed type is "main.UserPydantic"
- case: test_converted_from_pydantic
main: |
from pydantic import BaseModel
import strawberry
class UserPydantic(BaseModel):
age: int
@strawberry.experimental.pydantic.type(UserPydantic)
class UserStrawberry:
age: strawberry.auto
reveal_type(UserStrawberry.from_pydantic(UserPydantic(age=123)))
out: |
main:11: note: Revealed type is "main.UserStrawberry"
- case: test_converted_from_pydantic_raise_error_wrong_instance
main: |
from pydantic import BaseModel
import strawberry
class UserPydantic(BaseModel):
age: int
@strawberry.experimental.pydantic.type(UserPydantic)
class UserStrawberry:
age: strawberry.auto
class AnotherModel(BaseModel):
age: int
UserStrawberry.from_pydantic(AnotherModel(age=123))
out: |
main:14: error: Argument 1 to "from_pydantic" of "UserStrawberry" has incompatible type "AnotherModel"; expected "UserPydantic"
- case: test_converted_from_pydantic_chained
main: |
from pydantic import BaseModel
import strawberry
class UserPydantic(BaseModel):
age: int
@strawberry.experimental.pydantic.type(UserPydantic)
class UserStrawberry:
age: strawberry.auto
reveal_type(UserStrawberry.from_pydantic(UserPydantic(age=123)).to_pydantic())
out: |
main:11: note: Revealed type is "main.UserPydantic"

0 comments on commit 86e9e77

Please sign in to comment.