Skip to content

Commit

Permalink
Merge pull request #1 from abravalheri/ast-annotation
Browse files Browse the repository at this point in the history
Refactor `StaticModule.__getattr__` and fix `ast.AnnAssign` edge case
  • Loading branch information
karlotness authored Jun 19, 2022
2 parents 15af535 + 0b1d090 commit 7e4cb12
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
25 changes: 12 additions & 13 deletions setuptools/config/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,23 @@ def __init__(self, name: str, spec: ModuleSpec):
vars(self).update(locals())
del self.self

def _find_assignments(self) -> Iterator[Tuple[ast.AST, ast.AST]]:
for statement in self.module.body:
if isinstance(statement, ast.Assign):
yield from ((target, statement.value) for target in statement.targets)
elif isinstance(statement, ast.AnnAssign) and statement.value:
yield (statement.target, statement.value)

def __getattr__(self, attr):
"""Attempt to load an attribute "statically", via :func:`ast.literal_eval`."""
try:
for statement in self.module.body:
if isinstance(statement, ast.Assign):
targets = statement.targets
value = statement.value
elif isinstance(statement, ast.AnnAssign):
targets = [statement.target]
value = statement.value
else:
continue
for target in targets:
if isinstance(target, ast.Name) and target.id == attr:
return ast.literal_eval(value)
return next(
ast.literal_eval(value)
for target, value in self._find_assignments()
if isinstance(target, ast.Name) and target.id == attr
)
except Exception as e:
raise AttributeError(f"{self.name} has no attribute {attr}") from e
raise AttributeError(f"{self.name} has no attribute {attr}")


def glob_relative(
Expand Down
14 changes: 9 additions & 5 deletions setuptools/tests/config/test_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,17 @@ def test_read_attr(self, tmp_path, monkeypatch):
values = expand.read_attr('lib.mod.VALUES', {'lib': 'pkg/sub'}, tmp_path)
assert values['c'] == (0, 1, 1)

def test_read_annotated_attr(self, tmp_path):
@pytest.mark.parametrize(
"example",
[
"VERSION: str\nVERSION = '0.1.1'\nraise SystemExit(1)\n",
"VERSION: str = '0.1.1'\nraise SystemExit(1)\n",
]
)
def test_read_annotated_attr(self, tmp_path, example):
files = {
"pkg/__init__.py": "",
"pkg/sub/__init__.py": (
"VERSION: str = '0.1.1'\n"
"raise SystemExit(1)\n"
),
"pkg/sub/__init__.py": example,
}
write_files(files, tmp_path)
# Make sure this attribute can be read statically
Expand Down

0 comments on commit 7e4cb12

Please sign in to comment.