Skip to content

Commit

Permalink
pythongh-126835: Move constant unaryop & binop folding to CFG (python…
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframAlph authored Feb 21, 2025
1 parent d88677a commit 38642bf
Show file tree
Hide file tree
Showing 6 changed files with 1,058 additions and 444 deletions.
198 changes: 80 additions & 118 deletions Lib/test/test_ast/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,17 @@ def test_optimization_levels__debug__(self):
self.assertEqual(res.body[0].value.id, expected)

def test_optimization_levels_const_folding(self):
folded = ('Expr', (1, 0, 1, 5), ('Constant', (1, 0, 1, 5), 3, None))
not_folded = ('Expr', (1, 0, 1, 5),
('BinOp', (1, 0, 1, 5),
('Constant', (1, 0, 1, 1), 1, None),
('Add',),
('Constant', (1, 4, 1, 5), 2, None)))
folded = ('Expr', (1, 0, 1, 6), ('Constant', (1, 0, 1, 6), (1, 2), None))
not_folded = ('Expr', (1, 0, 1, 6),
('Tuple', (1, 0, 1, 6),
[('Constant', (1, 1, 1, 2), 1, None),
('Constant', (1, 4, 1, 5), 2, None)], ('Load',)))

cases = [(-1, not_folded), (0, not_folded), (1, folded), (2, folded)]
for (optval, expected) in cases:
with self.subTest(optval=optval):
tree1 = ast.parse("1 + 2", optimize=optval)
tree2 = ast.parse(ast.parse("1 + 2"), optimize=optval)
tree1 = ast.parse("(1, 2)", optimize=optval)
tree2 = ast.parse(ast.parse("(1, 2)"), optimize=optval)
for tree in [tree1, tree2]:
res = to_tuple(tree.body[0])
self.assertEqual(res, expected)
Expand Down Expand Up @@ -3089,27 +3088,6 @@ def test_cli_file_input(self):


class ASTOptimiziationTests(unittest.TestCase):
binop = {
"+": ast.Add(),
"-": ast.Sub(),
"*": ast.Mult(),
"/": ast.Div(),
"%": ast.Mod(),
"<<": ast.LShift(),
">>": ast.RShift(),
"|": ast.BitOr(),
"^": ast.BitXor(),
"&": ast.BitAnd(),
"//": ast.FloorDiv(),
"**": ast.Pow(),
}

unaryop = {
"~": ast.Invert(),
"+": ast.UAdd(),
"-": ast.USub(),
}

def wrap_expr(self, expr):
return ast.Module(body=[ast.Expr(value=expr)])

Expand Down Expand Up @@ -3141,83 +3119,6 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
f"{ast.dump(optimized_tree)}",
)

def create_binop(self, operand, left=ast.Constant(1), right=ast.Constant(1)):
return ast.BinOp(left=left, op=self.binop[operand], right=right)

def test_folding_binop(self):
code = "1 %s 1"
operators = self.binop.keys()

for op in operators:
result_code = code % op
non_optimized_target = self.wrap_expr(self.create_binop(op))
optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))

with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)

# Multiplication of constant tuples must be folded
code = "(1,) * 3"
non_optimized_target = self.wrap_expr(self.create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
optimized_target = self.wrap_expr(ast.Constant(eval(code)))

self.assert_ast(code, non_optimized_target, optimized_target)

def test_folding_unaryop(self):
code = "%s1"
operators = self.unaryop.keys()

def create_unaryop(operand):
return ast.UnaryOp(op=self.unaryop[operand], operand=ast.Constant(1))

for op in operators:
result_code = code % op
non_optimized_target = self.wrap_expr(create_unaryop(op))
optimized_target = self.wrap_expr(ast.Constant(eval(result_code)))

with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_not(self):
code = "not (1 %s (1,))"
operators = {
"in": ast.In(),
"is": ast.Is(),
}
opt_operators = {
"is": ast.IsNot(),
"in": ast.NotIn(),
}

def create_notop(operand):
return ast.UnaryOp(op=ast.Not(), operand=ast.Compare(
left=ast.Constant(value=1),
ops=[operators[operand]],
comparators=[ast.Tuple(elts=[ast.Constant(value=1)])]
))

for op in operators.keys():
result_code = code % op
non_optimized_target = self.wrap_expr(create_notop(op))
optimized_target = self.wrap_expr(
ast.Compare(left=ast.Constant(1), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))])
)

with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_format(self):
code = "'%s' % (a,)"

Expand Down Expand Up @@ -3247,9 +3148,9 @@ def test_folding_tuple(self):
self.assert_ast(code, non_optimized_target, optimized_target)

def test_folding_type_param_in_function_def(self):
code = "def foo[%s = 1 + 1](): pass"
code = "def foo[%s = (1, 2)](): pass"

unoptimized_binop = self.create_binop("+")
unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
Expand All @@ -3263,23 +3164,23 @@ def test_folding_type_param_in_function_def(self):
name='foo',
args=ast.arguments(),
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=ast.Constant(2))]
type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))]
)
)
non_optimized_target = self.wrap_statement(
ast.FunctionDef(
name='foo',
args=ast.arguments(),
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=unoptimized_binop)]
type_params=[type_param(name=name, default_value=unoptimized_tuple)]
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_type_param_in_class_def(self):
code = "class foo[%s = 1 + 1]: pass"
code = "class foo[%s = (1, 2)]: pass"

unoptimized_binop = self.create_binop("+")
unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
Expand All @@ -3292,22 +3193,22 @@ def test_folding_type_param_in_class_def(self):
ast.ClassDef(
name='foo',
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=ast.Constant(2))]
type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))]
)
)
non_optimized_target = self.wrap_statement(
ast.ClassDef(
name='foo',
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=unoptimized_binop)]
type_params=[type_param(name=name, default_value=unoptimized_tuple)]
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_type_param_in_type_alias(self):
code = "type foo[%s = 1 + 1] = 1"
code = "type foo[%s = (1, 2)] = 1"

unoptimized_binop = self.create_binop("+")
unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
Expand All @@ -3319,19 +3220,80 @@ def test_folding_type_param_in_type_alias(self):
optimized_target = self.wrap_statement(
ast.TypeAlias(
name=ast.Name(id='foo', ctx=ast.Store()),
type_params=[type_param(name=name, default_value=ast.Constant(2))],
type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))],
value=ast.Constant(value=1),
)
)
non_optimized_target = self.wrap_statement(
ast.TypeAlias(
name=ast.Name(id='foo', ctx=ast.Store()),
type_params=[type_param(name=name, default_value=unoptimized_binop)],
type_params=[type_param(name=name, default_value=unoptimized_tuple)],
value=ast.Constant(value=1),
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_match_case_allowed_expressions(self):
def get_match_case_values(node):
result = []
if isinstance(node, ast.Constant):
result.append(node.value)
elif isinstance(node, ast.MatchValue):
result.extend(get_match_case_values(node.value))
elif isinstance(node, ast.MatchMapping):
for key in node.keys:
result.extend(get_match_case_values(key))
elif isinstance(node, ast.MatchSequence):
for pat in node.patterns:
result.extend(get_match_case_values(pat))
else:
self.fail(f"Unexpected node {node}")
return result

tests = [
("-0", [0]),
("-0.1", [-0.1]),
("-0j", [complex(0, 0)]),
("-0.1j", [complex(0, -0.1)]),
("1 + 2j", [complex(1, 2)]),
("1 - 2j", [complex(1, -2)]),
("1.1 + 2.1j", [complex(1.1, 2.1)]),
("1.1 - 2.1j", [complex(1.1, -2.1)]),
("-0 + 1j", [complex(0, 1)]),
("-0 - 1j", [complex(0, -1)]),
("-0.1 + 1.1j", [complex(-0.1, 1.1)]),
("-0.1 - 1.1j", [complex(-0.1, -1.1)]),
("{-0: 0}", [0]),
("{-0.1: 0}", [-0.1]),
("{-0j: 0}", [complex(0, 0)]),
("{-0.1j: 0}", [complex(0, -0.1)]),
("{1 + 2j: 0}", [complex(1, 2)]),
("{1 - 2j: 0}", [complex(1, -2)]),
("{1.1 + 2.1j: 0}", [complex(1.1, 2.1)]),
("{1.1 - 2.1j: 0}", [complex(1.1, -2.1)]),
("{-0 + 1j: 0}", [complex(0, 1)]),
("{-0 - 1j: 0}", [complex(0, -1)]),
("{-0.1 + 1.1j: 0}", [complex(-0.1, 1.1)]),
("{-0.1 - 1.1j: 0}", [complex(-0.1, -1.1)]),
("{-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}", [0, complex(0, 1), complex(0.1, 1)]),
("[-0, -0.1, -0j, -0.1j]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("[[[[-0, -0.1, -0j, -0.1j]]]]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("[[-0, -0.1], -0j, -0.1j]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("[[-0, -0.1], [-0j, -0.1j]]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("(-0, -0.1, -0j, -0.1j)", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("((((-0, -0.1, -0j, -0.1j))))", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("((-0, -0.1), -0j, -0.1j)", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
("((-0, -0.1), (-0j, -0.1j))", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
]
for match_expr, constants in tests:
with self.subTest(match_expr):
src = f"match 0:\n\t case {match_expr}: pass"
tree = ast.parse(src, optimize=1)
match_stmt = tree.body[0]
case = match_stmt.cases[0]
values = get_match_case_values(case.pattern)
self.assertListEqual(constants, values)


if __name__ == '__main__':
if len(sys.argv) > 1 and sys.argv[1] == '--snapshot-update':
Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_ast/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
def to_tuple(t):
if t is None or isinstance(t, (str, int, complex, float, bytes)) or t is Ellipsis:
if t is None or isinstance(t, (str, int, complex, float, bytes, tuple)) or t is Ellipsis:
return t
elif isinstance(t, list):
return [to_tuple(e) for e in t]
Expand Down
15 changes: 6 additions & 9 deletions Lib/test/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def test_compile_async_generator(self):
self.assertEqual(type(glob['ticker']()), AsyncGeneratorType)

def test_compile_ast(self):
args = ("a*(1+2)", "f.py", "exec")
args = ("a*(1,2)", "f.py", "exec")
raw = compile(*args, flags = ast.PyCF_ONLY_AST).body[0]
opt1 = compile(*args, flags = ast.PyCF_OPTIMIZED_AST).body[0]
opt2 = compile(ast.parse(args[0]), *args[1:], flags = ast.PyCF_OPTIMIZED_AST).body[0]
Expand All @@ -566,17 +566,14 @@ def test_compile_ast(self):
self.assertIsInstance(tree.value.left, ast.Name)
self.assertEqual(tree.value.left.id, 'a')

raw_right = raw.value.right # expect BinOp(1, '+', 2)
self.assertIsInstance(raw_right, ast.BinOp)
self.assertIsInstance(raw_right.left, ast.Constant)
self.assertEqual(raw_right.left.value, 1)
self.assertIsInstance(raw_right.right, ast.Constant)
self.assertEqual(raw_right.right.value, 2)
raw_right = raw.value.right # expect Tuple((1, 2))
self.assertIsInstance(raw_right, ast.Tuple)
self.assertListEqual([elt.value for elt in raw_right.elts], [1, 2])

for opt in [opt1, opt2]:
opt_right = opt.value.right # expect Constant(3)
opt_right = opt.value.right # expect Constant((1,2))
self.assertIsInstance(opt_right, ast.Constant)
self.assertEqual(opt_right.value, 3)
self.assertEqual(opt_right.value, (1, 2))

def test_delattr(self):
sys.spam = 1
Expand Down
Loading

0 comments on commit 38642bf

Please sign in to comment.