Skip to content

Commit 7e4cb12

Browse files
authored
Merge pull request #1 from abravalheri/ast-annotation
Refactor `StaticModule.__getattr__` and fix `ast.AnnAssign` edge case
2 parents 15af535 + 0b1d090 commit 7e4cb12

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

setuptools/config/expand.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,24 +66,23 @@ def __init__(self, name: str, spec: ModuleSpec):
6666
vars(self).update(locals())
6767
del self.self
6868

69+
def _find_assignments(self) -> Iterator[Tuple[ast.AST, ast.AST]]:
70+
for statement in self.module.body:
71+
if isinstance(statement, ast.Assign):
72+
yield from ((target, statement.value) for target in statement.targets)
73+
elif isinstance(statement, ast.AnnAssign) and statement.value:
74+
yield (statement.target, statement.value)
75+
6976
def __getattr__(self, attr):
7077
"""Attempt to load an attribute "statically", via :func:`ast.literal_eval`."""
7178
try:
72-
for statement in self.module.body:
73-
if isinstance(statement, ast.Assign):
74-
targets = statement.targets
75-
value = statement.value
76-
elif isinstance(statement, ast.AnnAssign):
77-
targets = [statement.target]
78-
value = statement.value
79-
else:
80-
continue
81-
for target in targets:
82-
if isinstance(target, ast.Name) and target.id == attr:
83-
return ast.literal_eval(value)
79+
return next(
80+
ast.literal_eval(value)
81+
for target, value in self._find_assignments()
82+
if isinstance(target, ast.Name) and target.id == attr
83+
)
8484
except Exception as e:
8585
raise AttributeError(f"{self.name} has no attribute {attr}") from e
86-
raise AttributeError(f"{self.name} has no attribute {attr}")
8786

8887

8988
def glob_relative(

setuptools/tests/config/test_expand.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,17 @@ def test_read_attr(self, tmp_path, monkeypatch):
8585
values = expand.read_attr('lib.mod.VALUES', {'lib': 'pkg/sub'}, tmp_path)
8686
assert values['c'] == (0, 1, 1)
8787

88-
def test_read_annotated_attr(self, tmp_path):
88+
@pytest.mark.parametrize(
89+
"example",
90+
[
91+
"VERSION: str\nVERSION = '0.1.1'\nraise SystemExit(1)\n",
92+
"VERSION: str = '0.1.1'\nraise SystemExit(1)\n",
93+
]
94+
)
95+
def test_read_annotated_attr(self, tmp_path, example):
8996
files = {
9097
"pkg/__init__.py": "",
91-
"pkg/sub/__init__.py": (
92-
"VERSION: str = '0.1.1'\n"
93-
"raise SystemExit(1)\n"
94-
),
98+
"pkg/sub/__init__.py": example,
9599
}
96100
write_files(files, tmp_path)
97101
# Make sure this attribute can be read statically

0 commit comments

Comments
 (0)