Skip to content

Commit

Permalink
feat: Better support for dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Feb 29, 2024
1 parent b065d15 commit e44ebf5
Show file tree
Hide file tree
Showing 5 changed files with 391 additions and 12 deletions.
18 changes: 10 additions & 8 deletions src/griffe/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"Parameter(name={self.name!r}, annotation={self.annotation!r}, kind={self.kind!r}, default={self.default!r})"

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Parameter):
return NotImplemented
return (
self.name == __value.name
and self.annotation == __value.annotation
and self.kind == __value.kind
and self.default == __value.default
)

@property
def required(self) -> bool:
"""Whether this parameter is required."""
Expand Down Expand Up @@ -1561,14 +1571,6 @@ def parameters(self) -> Parameters:
try:
return self.all_members["__init__"].parameters # type: ignore[union-attr]
except KeyError:
if "dataclass" in self.labels:
return Parameters(
*[
Parameter(attr.name, annotation=attr.annotation, default=attr.value)
for attr in self.attributes.values()
if "property" not in attr.labels
],
)
return Parameters()

@cached_property
Expand Down
10 changes: 8 additions & 2 deletions src/griffe/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def _expr_as_dict(expression: Expr, **kwargs: Any) -> dict[str, Any]:


# TODO: Merge in decorators once Python 3.9 is dropped.
dataclass_opts: dict[str, bool] = {}
if sys.version_info >= (3, 10):
dataclass_opts["slots"] = True
dataclass_opts = {"slots": True}
else:
dataclass_opts = {}


@dataclass
Expand Down Expand Up @@ -243,6 +244,11 @@ class ExprCall(Expr):
arguments: Sequence[str | Expr]
"""Passed arguments."""

@property
def canonical_path(self) -> str:
"""The canonical path of this subscript's left part."""
return self.function.canonical_path

def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]: # noqa: D102
yield from _yield(self.function, flat=flat)
yield "("
Expand Down
4 changes: 4 additions & 0 deletions src/griffe/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def __init__(self, *extensions: ExtensionType) -> None:
self._extensions: list[Extension] = []
self.add(*extensions)

# TODO: Deprecate and remove at some point?
self.add(*_load_extension("dataclasses"))

def add(self, *extensions: ExtensionType) -> None:
"""Add extensions to this container.
Expand Down Expand Up @@ -347,6 +350,7 @@ def call(self, event: str, **kwargs: Any) -> None:

builtin_extensions: set[str] = {
"hybrid",
"dataclasses",
}


Expand Down
201 changes: 201 additions & 0 deletions src/griffe/extensions/dataclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Built-in extension adding support for dataclasses.
This extension re-creates `__init__` methods of dataclasses
during static analysis.
"""

from __future__ import annotations

import ast
from functools import lru_cache
from typing import Any, cast

from griffe.dataclasses import Attribute, Class, Decorator, Function, Module, Parameter, Parameters
from griffe.enumerations import ParameterKind
from griffe.expressions import (
Expr,
ExprAttribute,
ExprCall,
ExprDict,
)
from griffe.extensions.base import Extension


def _dataclass_decorator(decorators: list[Decorator]) -> Expr | None:
for decorator in decorators:
if isinstance(decorator.value, Expr) and decorator.value.canonical_path == "dataclasses.dataclass":
return decorator.value
return None


def _expr_args(expr: Expr) -> dict[str, str | Expr]:
args = {}
if isinstance(expr, ExprCall):
for argument in expr.arguments:
try:
args[argument.name] = argument.value
except AttributeError:
# Argument is a unpacked variable.
var = expr.function.parent.modules_collection[argument.value.canonical_path]
args.update(_expr_args(var.value))
elif isinstance(expr, ExprDict):
args.update({ast.literal_eval(key): value for key, value in zip(expr.keys, expr.values)})
return args # type: ignore[union-attr]


def _dataclass_arguments(decorators: list[Decorator]) -> dict[str, Any]:
if (expr := _dataclass_decorator(decorators)) and isinstance(expr, ExprCall):
return _expr_args(expr)
return {}


def _field_arguments(attribute: Attribute) -> dict[str, Any]:
if attribute.value:
value = attribute.value
if isinstance(value, ExprAttribute):
value = value.last
if isinstance(value, ExprCall) and value.canonical_path == "dataclasses.field":
return _expr_args(value)
return {}


@lru_cache(maxsize=None)
def _dataclass_parameters(class_: Class) -> list[Parameter]:
# Fetch `@dataclass` arguments if any.
dec_args = _dataclass_arguments(class_.decorators)

# Parameters not added to `__init__`, return empty list.
if dec_args.get("init") == "False":
return []

# All parameters marked as keyword-only.
kw_only = dec_args.get("kw_only") == "True"

# Iterate on current attributes to find parameters.
parameters = []
for member in class_.members.values():
if member.is_attribute:
member = cast(Attribute, member)

# Start of keyword-only parameters.
if isinstance(member.annotation, Expr) and member.annotation.canonical_path == "dataclasses.KW_ONLY":
kw_only = True
continue

# Fetch `field` arguments if any.
field_args = _field_arguments(member)

# Parameter not added to `__init__`, skip it.
if field_args.get("init") == "False":
continue

# Determine parameter kind.
kind = (
ParameterKind.keyword_only
if kw_only or field_args.get("kw_only") == "True"
else ParameterKind.positional_or_keyword
)

# Determine parameter default.
if "default_factory" in field_args:
default = ExprCall(function=field_args["default_factory"], arguments=[])
else:
default = field_args.get("default", None if field_args else member.value)

# Add parameter to the list.
parameters.append(
Parameter(
member.name,
annotation=member.annotation,
kind=kind,
default=default,
),
)

return parameters


def _reorder_parameters(parameters: list[Parameter]) -> list[Parameter]:
# De-duplicate, overwriting previous parameters.
params_dict = {param.name: param for param in parameters}

# Re-order, putting positional-only in front and keyword-only at the end.
pos_only = []
pos_kw = []
kw_only = []
for param in params_dict.values():
if param.kind is ParameterKind.positional_only:
pos_only.append(param)
elif param.kind is ParameterKind.keyword_only:
kw_only.append(param)
else:
pos_kw.append(param)
return pos_only + pos_kw + kw_only


def _set_dataclass_init(class_: Class) -> None:
parameters = []

# If the class already has an `__init__` method, skip it.
if "__init__" in class_.members:
return

# Retrieve parameters from all parent dataclasses.
for parent in reversed(class_.mro()):
if _dataclass_decorator(parent.decorators):
parameters.extend(_dataclass_parameters(parent))
# At least one parent dataclass makes the current class a dataclass:
# that's how `dataclasses.is_dataclass` works.
class_.labels.add("dataclass")

# If the class is not decorated with `@dataclass`, skip it.
if not _dataclass_decorator(class_.decorators):
return

# Add current class parameters.
parameters.extend(_dataclass_parameters(class_))

# Create `__init__` method with re-ordered parameters.
init = Function(
"__init__",
lineno=0,
endlineno=0,
parent=class_,
parameters=Parameters(
Parameter(name="self", annotation=None, kind=ParameterKind.positional_or_keyword, default=None),
*_reorder_parameters(parameters),
),
returns="None",
)
class_.set_member("__init__", init)


def _apply_recursively(mod_cls: Module | Class, processed: set[str]) -> None:
if mod_cls.canonical_path in processed:
return
processed.add(mod_cls.canonical_path)
if isinstance(mod_cls, Class):
_set_dataclass_init(mod_cls)
for member in mod_cls.members.values():
if not member.is_alias and member.is_class:
_apply_recursively(member, processed)
elif isinstance(mod_cls, Module):
for member in mod_cls.members.values():
if not member.is_alias and (member.is_module or member.is_class):
_apply_recursively(member, processed)


class DataclassesExtension(Extension):
"""Built-in extension adding support for dataclasses.
This extension creates `__init__` methods of dataclasses
if they don't already exist.
"""

def on_package_loaded(self, *, pkg: Module) -> None:
"""Hook for loaded packages.
Parameters:
pkg: The loaded package.
"""
_apply_recursively(pkg, set())
Loading

0 comments on commit e44ebf5

Please sign in to comment.