Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 42 additions & 49 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,24 +703,17 @@ def run(self, mod: ast.Module) -> None:
return
pos = 0
for item in mod.body:
if (
expect_docstring
and isinstance(item, ast.Expr)
and isinstance(item.value, ast.Constant)
and isinstance(item.value.value, str)
):
doc = item.value.value
if self.is_rewrite_disabled(doc):
return
expect_docstring = False
elif (
isinstance(item, ast.ImportFrom)
and item.level == 0
and item.module == "__future__"
):
pass
else:
break
match item:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks interesting - whats the performance impact - i would expect match to be faster but that may be wishful thinking

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we had mixed results in pylint regarding performance. What do you suggest as a benchmark ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need code fragments that trigger the different cases and then we need a codegen that given a number creates that many instances

Then we can observe/profile

Im not sure whether we should synthesize a ast or just text

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A preliminary benchmark (done by claude) is not very encouraging : https://claude.ai/public/artifacts/3d4158ea-0594-4442-8b4c-975bc5a54ce1 (14% slower on my side, python 3.12 / ubuntu)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems worse than i hoped

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The examples I took where specifically the one where the diff was very favorable (There's a lot of match on raw string that could be done but the performance is worse in this case as there's no index lookup creation), so I also expected better. From what I read it should be a neutral change performance wise except if isinstance were smartly grouped together with and / or to do less checks. Maybe the regression is due to microbenchmarking, maybe match is only a readability change. It's hard to find information that is not slop about this topic. (found this for example : https://discuss.python.org/t/pattern-matching-optimization-comparison-of-values-specified-through/20791)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But comparing bytecode seems like something to potentially use fir understanding our cases

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marc Mueller dug into the CPython implementation here (pylint specific but still relevant) : pylint-dev/pylint#10544 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR: the faster you want it to be the shittier it have to look:

- case ast.Expr(value=ast.Constant(value=str(doc))) if expect_docstring:
+ case ast.Expr(value=ast.Constant(value=str() as doc)) if expect_docstring:

... should be faster, etc. (isinstance is very well optimized)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue for CPython: python/cpython#138912

case ast.Expr(value=ast.Constant(value=str() as doc)) if (
expect_docstring
):
if self.is_rewrite_disabled(doc):
return
expect_docstring = False
case ast.ImportFrom(level=0, module="__future__"):
pass
case _:
break
pos += 1
# Special case: for a decorated function, set the lineno to that of the
# first decorator, not the `def`. Issue #4984.
Expand Down Expand Up @@ -1017,20 +1010,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821
self.expl_stmts = fail_inner
# Check if the left operand is a ast.NamedExpr and the value has already been visited
if (
isinstance(v, ast.Compare)
and isinstance(v.left, ast.NamedExpr)
and v.left.target.id
in [
ast_expr.id
for ast_expr in boolop.values[:i]
if hasattr(ast_expr, "id")
]
):
pytest_temp = self.variable()
self.variables_overwrite[self.scope][v.left.target.id] = v.left # type:ignore[assignment]
v.left.target.id = pytest_temp
match v:
# Check if the left operand is an ast.NamedExpr and the value has already been visited
case ast.Compare(
left=ast.NamedExpr(target=ast.Name(id=target_id))
) if target_id in [
e.id for e in boolop.values[:i] if hasattr(e, "id")
]:
pytest_temp = self.variable()
self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment]
# mypy's false positive, we're checking that the 'target' attribute exists.
v.left.target.id = pytest_temp # type:ignore[attr-defined]
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
Expand Down Expand Up @@ -1080,10 +1070,11 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
arg_expls.append(expl)
new_args.append(res)
for keyword in call.keywords:
if isinstance(
keyword.value, ast.Name
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
keyword.value = self.variables_overwrite[self.scope][keyword.value.id] # type:ignore[assignment]
match keyword.value:
case ast.Name(id=id) if id in self.variables_overwrite.get(
self.scope, {}
):
keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment]
res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res))
if keyword.arg:
Expand Down Expand Up @@ -1119,12 +1110,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
if isinstance(
comp.left, ast.Name
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
comp.left = self.variables_overwrite[self.scope][comp.left.id] # type:ignore[assignment]
if isinstance(comp.left, ast.NamedExpr):
self.variables_overwrite[self.scope][comp.left.target.id] = comp.left # type:ignore[assignment]
match comp.left:
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
self.scope, {}
):
comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment]
case ast.NamedExpr(target=ast.Name(id=target_id)):
self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, ast.Compare | ast.BoolOp):
left_expl = f"({left_expl})"
Expand All @@ -1136,13 +1128,14 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms: list[ast.expr] = []
results = [left_res]
for i, op, next_operand in it:
if (
isinstance(next_operand, ast.NamedExpr)
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
next_operand.target.id = self.variable()
self.variables_overwrite[self.scope][left_res.id] = next_operand # type:ignore[assignment]
match (next_operand, left_res):
case (
ast.NamedExpr(target=ast.Name(id=target_id)),
ast.Name(id=name_id),
) if target_id == name_id:
next_operand.target.id = self.variable()
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]

next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, ast.Compare | ast.BoolOp):
next_expl = f"({next_expl})"
Expand Down
20 changes: 19 additions & 1 deletion testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,7 +1552,9 @@ def test_simple_failure():
result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"])


class TestIssue10743:
class TestAssertionRewriteWalrusOperator:
"""See #10743"""

def test_assertion_walrus_operator(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
Expand Down Expand Up @@ -1719,6 +1721,22 @@ def test_walrus_operator_not_override_value():
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_namedexpr_compare_left_overwrite(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_namedexpr_compare_left_overwrite():
a = "Hello"
b = "World"
c = "Test"
assert (a := b) == c and (a := "Test") == "Test"
"""
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert ('World' == 'Test'*"])


class TestIssue11028:
def test_assertion_walrus_operator_in_operand(self, pytester: Pytester) -> None:
Expand Down
Loading