Skip to content

Commit

Permalink
Fix AST safety check false negative (#4270)
Browse files Browse the repository at this point in the history
Fixes #4268

Previously we would allow whitespace changes in all strings, now
only in docstrings.

Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
  • Loading branch information
JelleZijlstra and hauntsaninja authored Mar 10, 2024
1 parent f03ee11 commit 6af7d11
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
<!-- Changes that affect Black's stable style -->

- Don't move comments along with delimiters, which could cause crashes (#4248)
- Strengthen AST safety check to catch more unsafe changes to strings. Previous versions
of Black would incorrectly format the contents of certain unusual f-strings containing
nested strings with the same quote type. Now, Black will crash on such strings until
support for the new f-string syntax is implemented. (#4270)

### Preview style

Expand Down
15 changes: 10 additions & 5 deletions src/black/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,13 @@
syms,
)
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
from black.parsing import InvalidInput # noqa F401
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
from black.parsing import ( # noqa F401
ASTSafetyError,
InvalidInput,
lib2to3_parse,
parse_ast,
stringify_ast,
)
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
from black.report import Changed, NothingChanged, Report
from black.trans import iter_fexpr_spans
Expand Down Expand Up @@ -1511,7 +1516,7 @@ def assert_equivalent(src: str, dst: str) -> None:
try:
src_ast = parse_ast(src)
except Exception as exc:
raise AssertionError(
raise ASTSafetyError(
"cannot use --safe with this file; failed to parse source file AST: "
f"{exc}\n"
"This could be caused by running Black with an older Python version "
Expand All @@ -1522,7 +1527,7 @@ def assert_equivalent(src: str, dst: str) -> None:
dst_ast = parse_ast(dst)
except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError(
raise ASTSafetyError(
f"INTERNAL ERROR: Black produced invalid code: {exc}. "
"Please report a bug on https://github.com/psf/black/issues. "
f"This invalid output might be helpful: {log}"
Expand All @@ -1532,7 +1537,7 @@ def assert_equivalent(src: str, dst: str) -> None:
dst_ast_str = "\n".join(stringify_ast(dst_ast))
if src_ast_str != dst_ast_str:
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
raise ASTSafetyError(
"INTERNAL ERROR: Black produced code that is not equivalent to the"
" source. Please report a bug on "
f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
Expand Down
42 changes: 34 additions & 8 deletions src/black/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def lib2to3_unparse(node: Node) -> str:
return code


class ASTSafetyError(Exception):

This comment has been minimized.

Copy link
@spyoungtech

spyoungtech May 11, 2024

This should probably inherit from AssertionError since that is what was being raised where this replaces it.

"""Raised when Black's generated code is not equivalent to the old AST."""


def _parse_single_version(
src: str, version: Tuple[int, int], *, type_comments: bool
) -> ast.AST:
Expand Down Expand Up @@ -154,9 +158,20 @@ def _normalize(lineend: str, value: str) -> str:
return normalized.strip()


def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
def stringify_ast(node: ast.AST) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""
return _stringify_ast(node, [])


def _stringify_ast_with_new_parent(
node: ast.AST, parent_stack: List[ast.AST], new_parent: ast.AST
) -> Iterator[str]:
parent_stack.append(new_parent)
yield from _stringify_ast(node, parent_stack)
parent_stack.pop()


def _stringify_ast(node: ast.AST, parent_stack: List[ast.AST]) -> Iterator[str]:
if (
isinstance(node, ast.Constant)
and isinstance(node.value, str)
Expand All @@ -167,7 +182,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
# over the kind
node.kind = None

yield f"{' ' * depth}{node.__class__.__name__}("
yield f"{' ' * len(parent_stack)}{node.__class__.__name__}("

for field in sorted(node._fields): # noqa: F402
# TypeIgnore has only one field 'lineno' which breaks this comparison
Expand All @@ -179,7 +194,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
except AttributeError:
continue

yield f"{' ' * (depth + 1)}{field}="
yield f"{' ' * (len(parent_stack) + 1)}{field}="

if isinstance(value, list):
for item in value:
Expand All @@ -191,20 +206,28 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
and isinstance(item, ast.Tuple)
):
for elt in item.elts:
yield from stringify_ast(elt, depth + 2)
yield from _stringify_ast_with_new_parent(
elt, parent_stack, node
)

elif isinstance(item, ast.AST):
yield from stringify_ast(item, depth + 2)
yield from _stringify_ast_with_new_parent(item, parent_stack, node)

elif isinstance(value, ast.AST):
yield from stringify_ast(value, depth + 2)
yield from _stringify_ast_with_new_parent(value, parent_stack, node)

else:
normalized: object
if (
isinstance(node, ast.Constant)
and field == "value"
and isinstance(value, str)
and len(parent_stack) >= 2
and isinstance(parent_stack[-1], ast.Expr)
and isinstance(
parent_stack[-2],
(ast.FunctionDef, ast.AsyncFunctionDef, ast.Module, ast.ClassDef),
)
):
# Constant strings may be indented across newlines, if they are
# docstrings; fold spaces after newlines when comparing. Similarly,
Expand All @@ -215,6 +238,9 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
normalized = value.rstrip()
else:
normalized = value
yield f"{' ' * (depth + 2)}{normalized!r}, # {value.__class__.__name__}"
yield (
f"{' ' * (len(parent_stack) + 1)}{normalized!r}, #"
f" {value.__class__.__name__}"
)

yield f"{' ' * depth}) # /{node.__class__.__name__}"
yield f"{' ' * len(parent_stack)}) # /{node.__class__.__name__}"
122 changes: 108 additions & 14 deletions tests/test_black.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from black.debug import DebugVisitor
from black.mode import Mode, Preview
from black.output import color_diff, diff
from black.parsing import ASTSafetyError
from black.report import Report

# Import other test classes
Expand Down Expand Up @@ -1473,10 +1474,6 @@ def test_normalize_line_endings(self) -> None:
ff(test_file, write_back=black.WriteBack.YES)
self.assertEqual(test_file.read_bytes(), expected)

def test_assert_equivalent_different_asts(self) -> None:
with self.assertRaises(AssertionError):
black.assert_equivalent("{}", "None")

def test_root_logger_not_used_directly(self) -> None:
def fail(*args: Any, **kwargs: Any) -> None:
self.fail("Record created with root logger")
Expand Down Expand Up @@ -1962,16 +1959,6 @@ def test_for_handled_unexpected_eof_error(self) -> None:

exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")

def test_equivalency_ast_parse_failure_includes_error(self) -> None:
with pytest.raises(AssertionError) as err:
black.assert_equivalent("a«»a = 1", "a«»a = 1")

err.match("--safe")
# Unfortunately the SyntaxError message has changed in newer versions so we
# can't match it directly.
err.match("invalid character")
err.match(r"\(<unknown>, line 1\)")

def test_line_ranges_with_code_option(self) -> None:
code = textwrap.dedent("""\
if a == b:
Expand Down Expand Up @@ -2822,6 +2809,113 @@ def test_format_file_contents(self) -> None:
black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())


class TestASTSafety(BlackBaseTestCase):
def check_ast_equivalence(
self, source: str, dest: str, *, should_fail: bool = False
) -> None:
# If we get a failure, make sure it's not because the code itself
# is invalid, since that will also cause assert_equivalent() to throw
# ASTSafetyError.
source = textwrap.dedent(source)
dest = textwrap.dedent(dest)
black.parse_ast(source)
black.parse_ast(dest)
if should_fail:
with self.assertRaises(ASTSafetyError):
black.assert_equivalent(source, dest)
else:
black.assert_equivalent(source, dest)

def test_assert_equivalent_basic(self) -> None:
self.check_ast_equivalence("{}", "None", should_fail=True)
self.check_ast_equivalence("1+2", "1 + 2")
self.check_ast_equivalence("hi # comment", "hi")

def test_assert_equivalent_del(self) -> None:
self.check_ast_equivalence("del (a, b)", "del a, b")

def test_assert_equivalent_strings(self) -> None:
self.check_ast_equivalence('x = "x"', 'x = " x "', should_fail=True)
self.check_ast_equivalence(
'''
"""docstring """
''',
'''
"""docstring"""
''',
)
self.check_ast_equivalence(
'''
"""docstring """
''',
'''
"""ddocstring"""
''',
should_fail=True,
)
self.check_ast_equivalence(
'''
class A:
"""
docstring
"""
''',
'''
class A:
"""docstring"""
''',
)
self.check_ast_equivalence(
"""
def f():
" docstring "
""",
'''
def f():
"""docstring"""
''',
)
self.check_ast_equivalence(
"""
async def f():
" docstring "
""",
'''
async def f():
"""docstring"""
''',
)

def test_assert_equivalent_fstring(self) -> None:
major, minor = sys.version_info[:2]
if major < 3 or (major == 3 and minor < 12):
pytest.skip("relies on 3.12+ syntax")
# https://github.com/psf/black/issues/4268
self.check_ast_equivalence(
"""print(f"{"|".join([a,b,c])}")""",
"""print(f"{" | ".join([a,b,c])}")""",
should_fail=True,
)
self.check_ast_equivalence(
"""print(f"{"|".join(['a','b','c'])}")""",
"""print(f"{" | ".join(['a','b','c'])}")""",
should_fail=True,
)

def test_equivalency_ast_parse_failure_includes_error(self) -> None:
with pytest.raises(ASTSafetyError) as err:
black.assert_equivalent("a«»a = 1", "a«»a = 1")

err.match("--safe")
# Unfortunately the SyntaxError message has changed in newer versions so we
# can't match it directly.
err.match("invalid character")
err.match(r"\(<unknown>, line 1\)")


try:
with open(black.__file__, "r", encoding="utf-8") as _bf:
black_source_lines = _bf.readlines()
Expand Down

0 comments on commit 6af7d11

Please sign in to comment.