-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add signature for dataclasses.replace #14849
Changes from 38 commits
6dbaf9c
4dcbe44
7ed3741
89257b5
32b1d47
1f08816
c456a5f
8118d29
789cb2b
9f0974c
a37e406
367c0e9
0e84c4f
9cfc081
7b907cf
985db60
15dbb7b
3227fde
2dbf249
b32881c
26056a4
40315b7
c005895
d71bc21
d914b94
2cb6dee
5735726
d38897e
04f0ee3
f402b86
306c3f3
9b491f5
283fe3d
29780e9
9c43ab6
94024ba
70240c4
99bf973
8fe75a7
bea50e8
cd35951
d71a7c0
957744e
6abf6ff
4c4fc94
be4a290
ee0ae21
3f656f8
65d6e89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ | |
|
||
from mypy import errorcodes, message_registry | ||
from mypy.expandtype import expand_type, expand_type_by_instance | ||
from mypy.meet import meet_types | ||
from mypy.messages import format_type_bare | ||
from mypy.nodes import ( | ||
ARG_NAMED, | ||
ARG_NAMED_OPT, | ||
|
@@ -38,7 +40,7 @@ | |
TypeVarExpr, | ||
Var, | ||
) | ||
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface | ||
from mypy.plugin import ClassDefContext, FunctionSigContext, SemanticAnalyzerPluginInterface | ||
from mypy.plugins.common import ( | ||
_get_callee_type, | ||
_get_decorator_bool_argument, | ||
|
@@ -56,10 +58,13 @@ | |
Instance, | ||
LiteralType, | ||
NoneType, | ||
ProperType, | ||
TupleType, | ||
Type, | ||
TypeOfAny, | ||
TypeVarType, | ||
UninhabitedType, | ||
UnionType, | ||
get_proper_type, | ||
) | ||
from mypy.typevars import fill_typevars | ||
|
@@ -76,6 +81,7 @@ | |
frozen_default=False, | ||
field_specifiers=("dataclasses.Field", "dataclasses.field"), | ||
) | ||
_INTERNAL_REPLACE_SYM_NAME = "__mypy-replace" | ||
|
||
|
||
class DataclassAttribute: | ||
|
@@ -335,13 +341,47 @@ def transform(self) -> bool: | |
|
||
self._add_dataclass_fields_magic_attribute() | ||
|
||
if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES: | ||
self._add_internal_replace_method(attributes) | ||
|
||
info.metadata["dataclass"] = { | ||
"attributes": [attr.serialize() for attr in attributes], | ||
"frozen": decorator_arguments["frozen"], | ||
} | ||
|
||
return True | ||
|
||
def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) -> None: | ||
""" | ||
Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass | ||
to be used later whenever 'dataclasses.replace' is called for this dataclass. | ||
""" | ||
arg_types: list[Type] = [] | ||
arg_kinds = [] | ||
arg_names: list[str | None] = [] | ||
|
||
info = self._cls.info | ||
for attr in attributes: | ||
attr_type = attr.expand_type(info) | ||
assert attr_type is not None | ||
arg_types.append(attr_type) | ||
arg_kinds.append( | ||
ARG_NAMED if attr.is_init_var and not attr.has_default else ARG_NAMED_OPT | ||
) | ||
arg_names.append(attr.name) | ||
|
||
signature = CallableType( | ||
arg_types=arg_types, | ||
arg_kinds=arg_kinds, | ||
arg_names=arg_names, | ||
ret_type=NoneType(), | ||
fallback=self._api.named_type("builtins.function"), | ||
) | ||
|
||
self._cls.info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode( | ||
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True | ||
) | ||
|
||
def add_slots( | ||
self, info: TypeInfo, attributes: list[DataclassAttribute], *, correct_version: bool | ||
) -> None: | ||
|
@@ -884,3 +924,128 @@ def _has_direct_dataclass_transform_metaclass(info: TypeInfo) -> bool: | |
info.declared_metaclass is not None | ||
and info.declared_metaclass.type.dataclass_transform_spec is not None | ||
) | ||
|
||
|
||
def _fail_not_dataclass(ctx: FunctionSigContext, t: Type, parent_t: Type) -> None: | ||
t_name = format_type_bare(t, ctx.api.options) | ||
if parent_t is t: | ||
msg = ( | ||
f'Argument 1 to "replace" has a variable type "{t_name}" not bound to a dataclass' | ||
if isinstance(t, TypeVarType) | ||
else f'Argument 1 to "replace" has incompatible type "{t_name}"; expected a dataclass' | ||
) | ||
else: | ||
pt_name = format_type_bare(parent_t, ctx.api.options) | ||
msg = ( | ||
f'Argument 1 to "replace" has type "{pt_name}" whose item "{t_name}" is not bound to a dataclass' | ||
if isinstance(t, TypeVarType) | ||
else f'Argument 1 to "replace" has incompatible type "{pt_name}" whose item "{t_name}" is not a dataclass' | ||
) | ||
|
||
ctx.api.fail(msg, ctx.context) | ||
|
||
|
||
def _get_expanded_dataclasses_fields( | ||
ctx: FunctionSigContext, typ: ProperType, display_typ: ProperType, parent_typ: ProperType | ||
) -> list[CallableType] | None: | ||
""" | ||
For a given type, determine what dataclasses it can be: for each class, return the field types. | ||
For generic classes, the field types are expanded. | ||
If the type contains Any or a non-dataclass, returns None; in the latter case, also reports an error. | ||
""" | ||
if isinstance(typ, AnyType): | ||
return None | ||
elif isinstance(typ, UnionType): | ||
ret: list[CallableType] | None = [] | ||
for item in typ.relevant_items(): | ||
item = get_proper_type(item) | ||
item_types = _get_expanded_dataclasses_fields(ctx, item, item, parent_typ) | ||
if ret is not None and item_types is not None: | ||
ret += item_types | ||
else: | ||
ret = None # but keep iterating to emit all errors | ||
return ret | ||
elif isinstance(typ, TypeVarType): | ||
return _get_expanded_dataclasses_fields( | ||
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ | ||
) | ||
elif isinstance(typ, Instance): | ||
replace_sym = typ.type.get_method(_INTERNAL_REPLACE_SYM_NAME) | ||
if replace_sym is None: | ||
_fail_not_dataclass(ctx, display_typ, parent_typ) | ||
return None | ||
replace_sig = get_proper_type(replace_sym.type) | ||
assert isinstance(replace_sig, CallableType) | ||
return [expand_type_by_instance(replace_sig, typ)] | ||
else: | ||
_fail_not_dataclass(ctx, display_typ, parent_typ) | ||
return None | ||
|
||
|
||
def _meet_replace_sigs(sigs: list[CallableType]) -> CallableType: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like hack. Is this because the plugin systems expects a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooh, that's an interesting idea. It didn't cross my mind I could return a union of callables and let mypy handle it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. p.s. I'll be happy to simplify this code, but do you reckon we should first change the plugin API and then rework this PR, or merge as-is, change plugin API and then refactor? (A bit worried the devil's in the details and that maybe it's not as easy as it sounds, or not equivalent.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is better to do it in a follow up PR, but leave a TODO or open an issue (or both). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
""" | ||
Produces the lowest bound of the 'replace' signatures of multiple dataclasses. | ||
""" | ||
args = { | ||
name: (typ, kind) | ||
for name, typ, kind in zip(sigs[0].arg_names, sigs[0].arg_types, sigs[0].arg_kinds) | ||
} | ||
|
||
for sig in sigs[1:]: | ||
sig_args = { | ||
name: (typ, kind) | ||
for name, typ, kind in zip(sig.arg_names, sig.arg_types, sig.arg_kinds) | ||
} | ||
for name in (*args.keys(), *sig_args.keys()): | ||
sig_typ, sig_kind = args.get(name, (UninhabitedType(), ARG_NAMED_OPT)) | ||
sig2_typ, sig2_kind = sig_args.get(name, (UninhabitedType(), ARG_NAMED_OPT)) | ||
args[name] = ( | ||
meet_types(sig_typ, sig2_typ), | ||
ARG_NAMED_OPT if sig_kind == sig2_kind == ARG_NAMED_OPT else ARG_NAMED, | ||
) | ||
|
||
return sigs[0].copy_modified( | ||
arg_names=list(args.keys()), | ||
arg_types=[typ for typ, _ in args.values()], | ||
arg_kinds=[kind for _, kind in args.values()], | ||
) | ||
|
||
|
||
def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType: | ||
""" | ||
Returns a signature for the 'dataclasses.replace' function that's dependent on the type | ||
of the first positional argument. | ||
""" | ||
if len(ctx.args) != 2: | ||
# Ideally the name and context should be callee's, but we don't have it in FunctionSigContext. | ||
ctx.api.fail(f'"{ctx.default_signature.name}" has unexpected type annotation', ctx.context) | ||
return ctx.default_signature | ||
|
||
if len(ctx.args[0]) != 1: | ||
return ctx.default_signature # leave it to the type checker to complain | ||
|
||
obj_arg = ctx.args[0][0] | ||
|
||
# <hack> | ||
from mypy.checker import TypeChecker | ||
|
||
assert isinstance(ctx.api, TypeChecker) | ||
obj_type = ctx.api.expr_checker.accept(obj_arg) | ||
# </hack> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW this kind of hack is quite common (I myself used it several times, for old sqlalchemy plugin, and internally). I would propose to instead add another plugin hook, e.g. cc @JukkaL what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This proposal still stands. Do we have an issue for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
obj_type = get_proper_type(obj_type) | ||
inst_type_str = format_type_bare(obj_type, ctx.api.options) | ||
|
||
replace_sigs = _get_expanded_dataclasses_fields(ctx, obj_type, obj_type, obj_type) | ||
if replace_sigs is None: | ||
return ctx.default_signature | ||
replace_sig = _meet_replace_sigs(replace_sigs) | ||
|
||
return replace_sig.copy_modified( | ||
arg_names=[None, *replace_sig.arg_names], | ||
arg_kinds=[ARG_POS, *replace_sig.arg_kinds], | ||
arg_types=[obj_type, *replace_sig.arg_types], | ||
ret_type=obj_type, | ||
fallback=ctx.default_signature.fallback, | ||
name=f"{ctx.default_signature.name} of {inst_type_str}", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to call
get+proper_type()
here, since we know we added it, you can just addassert isinstance(..., ProperType)
instead.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It ends up looking awkward like this:
Isn't it simpler to do what I did before, even if taking a few more cycles?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it actually looks fine to have two asserts. It has two additional benefits:
get_proper_type()
in the code, so people we hopefully cargo cult them less (we had a bunch of weird crashes caused by blindly copying someget_proper_type()
calls).