Skip to content

Commit

Permalink
Make dataclass attr collection no longer worst-case quadratic (#13539)
Browse files Browse the repository at this point in the history
While working on #13531, I noticed that DataclassTransformer's
`collect_attributes` method was doing basically this:

```python
all_attrs = []
known_attrs = set()
for stmt in current_class:
    attr = convert_stmt_to_dataclass_attr(stmt)
    all_attrs.append(attr)
    known_attrs.add(attr.name)

for info in current_class.mro[1:-1]:
    if info is not a dataclass:
        continue

    super_attrs = []
    for attr in info.dataclass_attributes:
        # ...snip...
        if attr.name not in known_attrs:
            super_attrs.append(attr)
            known_attrs.add(attr.name)
        elif all_attrs:
            for other_attr in all_attrs:
                if other_attr.name == attr.name:
                    all_attrs.remove(attr)
                    super_attrs.append(attr)
                    break
    all_attrs = super_attrs + all_attrs
    all_attrs.sort(key=lambda a: a.kw_only)

validate all_attrs
```

Constantly searching through and removing items from `all_attrs`,
then pre-pending the superclass attrs will result in worst-case
quadratic behavior in the edge case where subtype is overriding
every attribute defined in the supertype.

This edge case is admittedly pretty unlikely to happen, but I wanted
to clean up the code a bit by reversing the order in which we process
everything so we naturally record attrs in the correct order.

One quirk of the old implementation I found was that we do not sort
the attrs list and move kw-only attrs to the end when none of the
supertypes are dataclasses. I tried changing this logic so we
unconditionally sort the list, but this actually broke a few of our
tests for some reason. I didn't want to get too deep in the weeds,
so opted to preserve this behavior.
  • Loading branch information
Michael0x2a authored Aug 31, 2022
1 parent 840a310 commit 2857736
Showing 1 changed file with 65 additions and 72 deletions.
137 changes: 65 additions & 72 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,51 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
Return None if some dataclass base class hasn't been processed
yet and thus we'll need to ask for another pass.
"""
# First, collect attributes belonging to the current class.
ctx = self._ctx
cls = self._ctx.cls
attrs: list[DataclassAttribute] = []
known_attrs: set[str] = set()

# First, collect attributes belonging to any class in the MRO, ignoring duplicates.
#
# We iterate through the MRO in reverse because attrs defined in the parent must appear
# earlier in the attributes list than attrs defined in the child. See:
# https://docs.python.org/3/library/dataclasses.html#inheritance
#
# However, we also want attributes defined in the subtype to override ones defined
# in the parent. We can implement this via a dict without disrupting the attr order
# because dicts preserve insertion order in Python 3.7+.
found_attrs: dict[str, DataclassAttribute] = {}
found_dataclass_supertype = False
for info in reversed(cls.info.mro[1:-1]):
if "dataclass_tag" in info.metadata and "dataclass" not in info.metadata:
# We haven't processed the base class yet. Need another pass.
return None
if "dataclass" not in info.metadata:
continue

# Each class depends on the set of attributes in its dataclass ancestors.
ctx.api.add_plugin_dependency(make_wildcard_trigger(info.fullname))
found_dataclass_supertype = True

for data in info.metadata["dataclass"]["attributes"]:
name: str = data["name"]

attr = DataclassAttribute.deserialize(info, data, ctx.api)
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
with state.strict_optional_set(ctx.api.options.strict_optional):
attr.expand_typevar_from_subtype(ctx.cls.info)
found_attrs[name] = attr

sym_node = cls.info.names.get(name)
if sym_node and sym_node.node and not isinstance(sym_node.node, Var):
ctx.api.fail(
"Dataclass attribute may only be overridden by another attribute",
sym_node.node,
)

# Second, collect attributes belonging to the current class.
current_attr_names: set[str] = set()
kw_only = _get_decorator_bool_argument(ctx, "kw_only", False)
for stmt in cls.defs.body:
# Any assignment that doesn't use the new type declaration
Expand Down Expand Up @@ -435,8 +475,6 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
if field_kw_only_param is not None:
is_kw_only = bool(ctx.api.parse_bool(field_kw_only_param))

known_attrs.add(lhs.name)

if sym.type is None and node.is_final and node.is_inferred:
# This is a special case, assignment like x: Final = 42 is classified
# annotated above, but mypy strips the `Final` turning it into x = 42.
Expand All @@ -453,75 +491,27 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
)
node.type = AnyType(TypeOfAny.from_error)

attrs.append(
DataclassAttribute(
name=lhs.name,
is_in_init=is_in_init,
is_init_var=is_init_var,
has_default=has_default,
line=stmt.line,
column=stmt.column,
type=sym.type,
info=cls.info,
kw_only=is_kw_only,
)
current_attr_names.add(lhs.name)
found_attrs[lhs.name] = DataclassAttribute(
name=lhs.name,
is_in_init=is_in_init,
is_init_var=is_init_var,
has_default=has_default,
line=stmt.line,
column=stmt.column,
type=sym.type,
info=cls.info,
kw_only=is_kw_only,
)

# Next, collect attributes belonging to any class in the MRO
# as long as those attributes weren't already collected. This
# makes it possible to overwrite attributes in subclasses.
# copy() because we potentially modify all_attrs below and if this code requires debugging
# we'll have unmodified attrs laying around.
all_attrs = attrs.copy()
known_super_attrs = set()
for info in cls.info.mro[1:-1]:
if "dataclass_tag" in info.metadata and "dataclass" not in info.metadata:
# We haven't processed the base class yet. Need another pass.
return None
if "dataclass" not in info.metadata:
continue

super_attrs = []
# Each class depends on the set of attributes in its dataclass ancestors.
ctx.api.add_plugin_dependency(make_wildcard_trigger(info.fullname))

for data in info.metadata["dataclass"]["attributes"]:
name: str = data["name"]
if name not in known_attrs:
attr = DataclassAttribute.deserialize(info, data, ctx.api)
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
with state.strict_optional_set(ctx.api.options.strict_optional):
attr.expand_typevar_from_subtype(ctx.cls.info)
known_attrs.add(name)
known_super_attrs.add(name)
super_attrs.append(attr)
elif all_attrs:
# How early in the attribute list an attribute appears is determined by the
# reverse MRO, not simply MRO.
# See https://docs.python.org/3/library/dataclasses.html#inheritance for
# details.
for attr in all_attrs:
if attr.name == name:
all_attrs.remove(attr)
super_attrs.append(attr)
break
all_attrs = super_attrs + all_attrs
all_attrs = list(found_attrs.values())
if found_dataclass_supertype:
all_attrs.sort(key=lambda a: a.kw_only)

for known_super_attr_name in known_super_attrs:
sym_node = cls.info.names.get(known_super_attr_name)
if sym_node and sym_node.node and not isinstance(sym_node.node, Var):
ctx.api.fail(
"Dataclass attribute may only be overridden by another attribute",
sym_node.node,
)

# Ensure that arguments without a default don't follow
# arguments that have a default.
# Third, ensure that arguments without a default don't follow
# arguments that have a default and that the KW_ONLY sentinel
# is only provided once.
found_default = False
# Ensure that the KW_ONLY sentinel is only provided once
found_kw_sentinel = False
for attr in all_attrs:
# If we find any attribute that is_in_init, not kw_only, and that
Expand All @@ -530,17 +520,20 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
if found_default and attr.is_in_init and not attr.has_default and not attr.kw_only:
# If the issue comes from merging different classes, report it
# at the class definition point.
context = Context(line=attr.line, column=attr.column) if attr in attrs else ctx.cls
context: Context = ctx.cls
if attr.name in current_attr_names:
context = Context(line=attr.line, column=attr.column)
ctx.api.fail(
"Attributes without a default cannot follow attributes with one", context
)

found_default = found_default or (attr.has_default and attr.is_in_init)
if found_kw_sentinel and self._is_kw_only_type(attr.type):
context = Context(line=attr.line, column=attr.column) if attr in attrs else ctx.cls
context = ctx.cls
if attr.name in current_attr_names:
context = Context(line=attr.line, column=attr.column)
ctx.api.fail("There may not be more than one field with the KW_ONLY type", context)
found_kw_sentinel = found_kw_sentinel or self._is_kw_only_type(attr.type)

return all_attrs

def _freeze(self, attributes: list[DataclassAttribute]) -> None:
Expand Down

0 comments on commit 2857736

Please sign in to comment.