diff --git a/.codecov.yml b/.codecov.yml index 39c158742f..80201e96b8 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -16,3 +16,8 @@ comment: layout: "header, diff" behavior: default require_changes: no + +ignore: + - "strawberry/ext/mypy_plugin.py" + - "setup.py" + - "strawberry/experimental/pydantic/conversion_types.py" diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..d278980572 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,20 @@ +Release type: minor + +Adds `to_pydantic` and `from_pydantic` type hints for IDE support. + +Adds mypy extension support as well. + +```python +from pydantic import BaseModel +import strawberry + +class UserPydantic(BaseModel): + age: int + +@strawberry.experimental.pydantic.type(UserPydantic) +class UserStrawberry: + age: strawberry.auto + +reveal_type(UserStrawberry(age=123).to_pydantic()) +``` +Mypy will infer the type as "UserPydantic". Previously it would be "Any" diff --git a/strawberry/experimental/pydantic/conversion_types.py b/strawberry/experimental/pydantic/conversion_types.py new file mode 100644 index 0000000000..4db1afb46b --- /dev/null +++ b/strawberry/experimental/pydantic/conversion_types.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Any, Dict, TypeVar + +from pydantic import BaseModel +from typing_extensions import Protocol + +from strawberry.types.types import TypeDefinition + + +PydanticModel = TypeVar("PydanticModel", bound=BaseModel) + + +class StrawberryTypeFromPydantic(Protocol[PydanticModel]): + """This class does not exist in runtime. + It only makes the methods below visible for IDEs""" + + def __init__(self, **kwargs): + ... + + @staticmethod + def from_pydantic( + instance: PydanticModel, extra: Dict[str, Any] = None + ) -> StrawberryTypeFromPydantic[PydanticModel]: + ... + + def to_pydantic(self) -> PydanticModel: + ... + + @property + def _type_definition(self) -> TypeDefinition: + ... + + @property + def _pydantic_type(self) -> PydanticModel: + ... diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index 5964dba251..b365e433ac 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -1,8 +1,20 @@ +from __future__ import annotations + import builtins import dataclasses import warnings from functools import partial -from typing import Any, Dict, List, Optional, Sequence, Type, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Type, + cast, +) from pydantic import BaseModel from pydantic.fields import ModelField @@ -74,8 +86,15 @@ def get_type_for_field(field: ModelField): return type_ +if TYPE_CHECKING: + from strawberry.experimental.pydantic.conversion_types import ( + PydanticModel, + StrawberryTypeFromPydantic, + ) + + def type( - model: Type[BaseModel], + model: Type[PydanticModel], *, fields: Optional[List[str]] = None, name: Optional[str] = None, @@ -84,8 +103,8 @@ def type( description: Optional[str] = None, directives: Optional[Sequence[StrawberrySchemaDirective]] = (), all_fields: bool = False, -): - def wrap(cls): +) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: + def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: model_fields = model.__fields__ fields_set = set(fields) if fields else set([]) @@ -178,12 +197,14 @@ def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool: model._strawberry_type = cls # type: ignore cls._pydantic_type = model # type: ignore - def from_pydantic(instance: Any, extra: Dict[str, Any] = None) -> Any: + def from_pydantic( + instance: PydanticModel, extra: Dict[str, Any] = None + ) -> StrawberryTypeFromPydantic[PydanticModel]: return convert_pydantic_model_to_strawberry_class( cls=cls, model_instance=instance, extra=extra ) - def to_pydantic(self) -> Any: + def to_pydantic(self) -> PydanticModel: instance_kwargs = dataclasses.asdict(self) return model(**instance_kwargs) diff --git a/strawberry/ext/mypy_plugin.py b/strawberry/ext/mypy_plugin.py index 34e5e9573c..0249a8b97d 100644 --- a/strawberry/ext/mypy_plugin.py +++ b/strawberry/ext/mypy_plugin.py @@ -1,21 +1,27 @@ from decimal import Decimal -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast from typing_extensions import Final from mypy.nodes import ( + ARG_OPT, ARG_POS, + ARG_STAR2, GDEF, MDEF, Argument, AssignmentStmt, + Block, CallExpr, CastExpr, + ClassDef, Context, Expression, + FuncDef, IndexExpr, MemberExpr, NameExpr, + PassStmt, PlaceholderNode, RefExpr, SymbolTableNode, @@ -28,14 +34,16 @@ ) from mypy.plugin import ( AnalyzeTypeContext, + CheckerPluginInterface, ClassDefContext, DynamicClassDefContext, FunctionContext, Plugin, SemanticAnalyzerPluginInterface, ) -from mypy.plugins.common import _get_decorator_bool_argument, add_method +from mypy.plugins.common import _get_argument, _get_decorator_bool_argument, add_method from mypy.plugins.dataclasses import DataclassAttribute +from mypy.semanal_shared import set_callable_name from mypy.server.trigger import make_wildcard_trigger from mypy.types import ( AnyType, @@ -48,6 +56,8 @@ UnionType, get_proper_type, ) +from mypy.typevars import fill_typevars +from mypy.util import get_unique_redefinition_name # Backwards compatible with the removal of `TypeVarDef` in mypy 0.920. @@ -245,9 +255,71 @@ def enum_hook(ctx: DynamicClassDefContext) -> None: ) -def strawberry_pydantic_class_callback(ctx: ClassDefContext): +def add_static_method_to_class( + api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface], + cls: ClassDef, + name: str, + args: List[Argument], + return_type: Type, + tvar_def: Optional[TypeVarType] = None, +) -> None: + """Adds a static method + Edited add_method_to_class to incorporate static method logic + https://github.com/python/mypy/blob/9c05d3d19/mypy/plugins/common.py + """ + info = cls.info + + # First remove any previously generated methods with the same name + # to avoid clashes and problems in the semantic analyzer. + if name in info.names: + sym = info.names[name] + if sym.plugin_generated and isinstance(sym.node, FuncDef): + cls.defs.body.remove(sym.node) + + # For compat with mypy < 0.93 + if MypyVersion.VERSION < Decimal("0.93"): + function_type = api.named_type("__builtins__.function") # type: ignore + else: + if isinstance(api, SemanticAnalyzerPluginInterface): + function_type = api.named_type("builtins.function") + else: + function_type = api.named_generic_type("builtins.function", []) + + arg_types, arg_names, arg_kinds = [], [], [] + for arg in args: + assert arg.type_annotation, "All arguments must be fully typed." + arg_types.append(arg.type_annotation) + arg_names.append(arg.variable.name) + arg_kinds.append(arg.kind) + + signature = CallableType( + arg_types, arg_kinds, arg_names, return_type, function_type + ) + if tvar_def: + signature.variables = [tvar_def] + + func = FuncDef(name, args, Block([PassStmt()])) + + func.is_static = True + func.info = info + func.type = set_callable_name(signature, func) + func._fullname = f"{info.fullname}.{name}" + func.line = info.line + + # NOTE: we would like the plugin generated node to dominate, but we still + # need to keep any existing definitions so they get semantically analyzed. + if name in info.names: + # Get a nice unique name instead. + r_name = get_unique_redefinition_name(name, info.names) + info.names[r_name] = info.names[name] + + info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True) + info.defn.defs.body.append(func) + + +def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None: # in future we want to have a proper pydantic plugin, but for now - # let's fallback to any, some resources are here: + # let's fallback to **kwargs for __init__, some resources are here: # https://github.com/samuelcolvin/pydantic/blob/master/pydantic/mypy.py # >>> model_index = ctx.cls.decorators[0].arg_names.index("model") # >>> model_name = ctx.cls.decorators[0].args[model_index].name @@ -255,7 +327,42 @@ def strawberry_pydantic_class_callback(ctx: ClassDefContext): # >>> model_type = ctx.api.named_type("UserModel") # >>> model_type = ctx.api.lookup(model_name, Context()) - ctx.cls.info.fallback_to_any = True + model_expression = _get_argument(call=ctx.reason, name="model") # type: ignore + if model_expression is None: + ctx.api.fail("model argument in decorator failed to be parsed", ctx.reason) + + else: + # Add __init__ + init_args = [ + Argument(Var("kwargs"), AnyType(TypeOfAny.explicit), None, ARG_STAR2) + ] + add_method(ctx, "__init__", init_args, NoneType()) + + model_type = _get_type_for_expr(model_expression, ctx.api) + + # Add to_pydantic + add_method( + ctx, + "to_pydantic", + args=[], + return_type=model_type, + ) + + # Add from_pydantic + model_argument = Argument( + variable=Var(name="instance", type=model_type), + type_annotation=model_type, + initializer=None, + kind=ARG_OPT, + ) + + add_static_method_to_class( + ctx.api, + ctx.cls, + name="from_pydantic", + args=[model_argument], + return_type=fill_typevars(ctx.cls.info), + ) def is_dataclasses_field_or_strawberry_field(expr: Expression) -> bool: diff --git a/tests/mypy/test_pydantic.decorators.yml b/tests/mypy/test_pydantic.decorators.yml new file mode 100644 index 0000000000..f7dd24d3b5 --- /dev/null +++ b/tests/mypy/test_pydantic.decorators.yml @@ -0,0 +1,86 @@ + +- case: test_converted_pydantic_init_any_kwargs + main: | + from pydantic import BaseModel + import strawberry + + class UserPydantic(BaseModel): + age: int + + @strawberry.experimental.pydantic.type(UserPydantic) + class UserStrawberry: + age: strawberry.auto + + reveal_type(UserStrawberry) + reveal_type(UserStrawberry(age=123)) + out: | + main:11: note: Revealed type is "def (**kwargs: Any) -> main.UserStrawberry" + main:12: note: Revealed type is "main.UserStrawberry" + +- case: test_converted_to_pydantic + main: | + from pydantic import BaseModel + import strawberry + + class UserPydantic(BaseModel): + age: int + + @strawberry.experimental.pydantic.type(UserPydantic) + class UserStrawberry: + age: strawberry.auto + + reveal_type(UserStrawberry(age=123).to_pydantic()) + out: | + main:11: note: Revealed type is "main.UserPydantic" + +- case: test_converted_from_pydantic + main: | + from pydantic import BaseModel + import strawberry + + class UserPydantic(BaseModel): + age: int + + @strawberry.experimental.pydantic.type(UserPydantic) + class UserStrawberry: + age: strawberry.auto + + reveal_type(UserStrawberry.from_pydantic(UserPydantic(age=123))) + out: | + main:11: note: Revealed type is "main.UserStrawberry" + + +- case: test_converted_from_pydantic_raise_error_wrong_instance + main: | + from pydantic import BaseModel + import strawberry + + class UserPydantic(BaseModel): + age: int + + @strawberry.experimental.pydantic.type(UserPydantic) + class UserStrawberry: + age: strawberry.auto + + class AnotherModel(BaseModel): + age: int + + UserStrawberry.from_pydantic(AnotherModel(age=123)) + out: | + main:14: error: Argument 1 to "from_pydantic" of "UserStrawberry" has incompatible type "AnotherModel"; expected "UserPydantic" + +- case: test_converted_from_pydantic_chained + main: | + from pydantic import BaseModel + import strawberry + + class UserPydantic(BaseModel): + age: int + + @strawberry.experimental.pydantic.type(UserPydantic) + class UserStrawberry: + age: strawberry.auto + + reveal_type(UserStrawberry.from_pydantic(UserPydantic(age=123)).to_pydantic()) + out: | + main:11: note: Revealed type is "main.UserPydantic"