Skip to content

Commit

Permalink
fix: Don't turn items annotated as InitVar into dataclass members
Browse files Browse the repository at this point in the history
PR-252: #252
  • Loading branch information
has2k1 authored Mar 12, 2024
1 parent c88b484 commit 6835ea3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/griffe/extensions/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,22 @@ def _set_dataclass_init(class_: Class) -> None:
class_.set_member("__init__", init)


def _del_members_annotated_as_initvar(class_: Class) -> None:
# Definitions annotated as InitVar are not class members
attributes = [member for member in class_.members.values() if isinstance(member, Attribute)]
for attribute in attributes:
if isinstance(attribute.annotation, Expr) and attribute.annotation.canonical_path == "dataclasses.InitVar":
class_.del_member(attribute.name)


def _apply_recursively(mod_cls: Module | Class, processed: set[str]) -> None:
if mod_cls.canonical_path in processed:
return
processed.add(mod_cls.canonical_path)
if isinstance(mod_cls, Class):
if "__init__" not in mod_cls.members:
_set_dataclass_init(mod_cls)
_del_members_annotated_as_initvar(mod_cls)
for member in mod_cls.members.values():
if not member.is_alias and member.is_class:
_apply_recursively(member, processed) # type: ignore[arg-type]
Expand Down
33 changes: 33 additions & 0 deletions tests/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,36 @@ class Reordered(Base):
assert [p.name for p in params_base] == ["self", "a", "b"]
assert [p.name for p in params_reordered] == ["self", "b", "c", "a"]
assert str(params_reordered["b"].annotation) == "float"


def test_parameters_annotated_as_initvar() -> None:
"""Don't return InitVar annotated fields as class members.
But if __init__ is defined, InitVar has no effect.
"""
code = """
from dataclasses import dataclass, InitVar
@dataclass
class PointA:
x: float
y: float
z: InitVar[float]
@dataclass
class PointB:
x: float
y: float
z: InitVar[float]
def __init__(self, r: float): ...
"""

with temporary_visited_package("package", {"__init__.py": code}) as module:
point_a = module["PointA"]
assert ["self", "x", "y", "z"] == [p.name for p in point_a.parameters]
assert ["x", "y", "__init__"] == list(point_a.members)

point_b = module["PointB"]
assert ["self", "r"] == [p.name for p in point_b.parameters]
assert ["x", "y", "z", "__init__"] == list(point_b.members)

0 comments on commit 6835ea3

Please sign in to comment.