From 33bc95a0c922a7344fb00f2fab27a02d9d96ef1b Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 4 Mar 2023 20:29:46 -0500 Subject: [PATCH 01/14] [mypyc] Constant fold str multiplication --- mypy/constant_fold.py | 29 ++++++++++++++++++++-- mypyc/irbuild/callable_class.py | 2 +- mypyc/irbuild/constant_fold.py | 4 +++ mypyc/test-data/fixtures/ir.py | 2 ++ mypyc/test-data/irbuild-constant-fold.test | 10 +++++++- 5 files changed, 43 insertions(+), 4 deletions(-) diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index a22c1b9ba9e5..7002433596ae 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Union +from typing import Union, overload from typing_extensions import Final from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var @@ -60,6 +60,10 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non return constant_fold_binary_int_op(expr.op, left, right) elif isinstance(left, str) and isinstance(right, str): return constant_fold_binary_str_op(expr.op, left, right) + elif isinstance(left, str) and isinstance(right, int): + return constant_fold_binary_str_op(expr.op, left, right) + elif isinstance(left, int) and isinstance(right, str): + return constant_fold_binary_str_op(expr.op, left, right) elif isinstance(expr, UnaryExpr): value = constant_fold_expr(expr.expr, cur_mod_id) if isinstance(value, int): @@ -110,7 +114,28 @@ def constant_fold_unary_int_op(op: str, value: int) -> int | None: return None +@overload +def constant_fold_binary_str_op(op: str, left: int, right: str) -> str | None: + ... + + +@overload +def constant_fold_binary_str_op(op: str, left: str, right: int) -> str | None: + ... + + +@overload def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None: + ... + + +def constant_fold_binary_str_op(op: str, left: str | int, right: str | int) -> str | None: if op == "+": - return left + right + if isinstance(left, str) and isinstance(right, str): + return left + right + elif op == "*": + if isinstance(left, int) and isinstance(right, str): + return left * right + if isinstance(left, str) and isinstance(right, int): + return left * right return None diff --git a/mypyc/irbuild/callable_class.py b/mypyc/irbuild/callable_class.py index d3ee54a208cd..599dbb81f767 100644 --- a/mypyc/irbuild/callable_class.py +++ b/mypyc/irbuild/callable_class.py @@ -17,7 +17,7 @@ def setup_callable_class(builder: IRBuilder) -> None: - """Generate an (incomplete) callable class representing function. + """Generate an (incomplete) callable class representing a function. This can be a nested function or a function within a non-extension class. Also set up the 'self' variable for that class. diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index 4e9eb53b9222..31ca609adfac 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -56,6 +56,10 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | return constant_fold_binary_int_op(expr.op, left, right) elif isinstance(left, str) and isinstance(right, str): return constant_fold_binary_str_op(expr.op, left, right) + elif isinstance(left, str) and isinstance(right, int): + return constant_fold_binary_str_op(expr.op, left, right) + elif isinstance(left, int) and isinstance(right, str): + return constant_fold_binary_str_op(expr.op, left, right) elif isinstance(expr, UnaryExpr): value = constant_fold_expr(builder, expr.expr) if isinstance(value, int): diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 27e225f273bc..3d3562382f07 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -86,6 +86,8 @@ def __init__(self) -> None: pass @overload def __init__(self, x: object) -> None: pass def __add__(self, x: str) -> str: pass + def __mul__(self, x: int) -> str: pass + def __rmul__(self, x: int) -> str: pass def __eq__(self, x: object) -> bool: pass def __ne__(self, x: object) -> bool: pass def __lt__(self, x: str) -> bool: ... diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index 7d9127887aa6..088e27e83605 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -237,16 +237,24 @@ L0: from typing_extensions import Final S: Final = 'z' +N: Final = 2 def f() -> None: x = 'foo' + 'bar' y = 'x' + 'y' + S + mul = "foobar" * 2 + mul2 = N * "foobar" [out] def f(): - r0, x, r1, y :: str + r0, x, r1, y, r2, mul, r3, mul2 :: str L0: r0 = 'foobar' x = r0 r1 = 'xyz' y = r1 + r2 = 'foobarfoobar' + mul = r2 + r3 = 'foobarfoobar' + mul2 = r3 return 1 + From c80e95e4cd8fa445f32affd08f2e6e2811f04cad Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 4 Mar 2023 21:14:05 -0500 Subject: [PATCH 02/14] Refactor constant folding of binary ops --- mypy/constant_fold.py | 56 ++++++++++++---------------------- mypyc/irbuild/constant_fold.py | 33 +++++++++++--------- 2 files changed, 38 insertions(+), 51 deletions(-) diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 7002433596ae..294a27ffde85 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Union, overload +from typing import Union from typing_extensions import Final from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var @@ -56,14 +56,9 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non elif isinstance(expr, OpExpr): left = constant_fold_expr(expr.left, cur_mod_id) right = constant_fold_expr(expr.right, cur_mod_id) - if isinstance(left, int) and isinstance(right, int): - return constant_fold_binary_int_op(expr.op, left, right) - elif isinstance(left, str) and isinstance(right, str): - return constant_fold_binary_str_op(expr.op, left, right) - elif isinstance(left, str) and isinstance(right, int): - return constant_fold_binary_str_op(expr.op, left, right) - elif isinstance(left, int) and isinstance(right, str): - return constant_fold_binary_str_op(expr.op, left, right) + value = constant_fold_binary_op(expr.op, left, right) + if value is not None: + return value elif isinstance(expr, UnaryExpr): value = constant_fold_expr(expr.expr, cur_mod_id) if isinstance(value, int): @@ -71,6 +66,22 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non return None +def constant_fold_binary_op( + op: str, left: ConstantValue | None, right: ConstantValue | None +) -> ConstantValue | None: + if isinstance(left, int) and isinstance(right, int): + return constant_fold_binary_int_op(op, left, right) + + if op == "+" and isinstance(left, str) and isinstance(right, str): + return left + right + elif op == "*" and isinstance(left, str) and isinstance(right, int): + return left * right + elif op == "*" and isinstance(left, int) and isinstance(right, str): + return left * right + + return None + + def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None: if op == "+": return left + right @@ -112,30 +123,3 @@ def constant_fold_unary_int_op(op: str, value: int) -> int | None: elif op == "+": return value return None - - -@overload -def constant_fold_binary_str_op(op: str, left: int, right: str) -> str | None: - ... - - -@overload -def constant_fold_binary_str_op(op: str, left: str, right: int) -> str | None: - ... - - -@overload -def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None: - ... - - -def constant_fold_binary_str_op(op: str, left: str | int, right: str | int) -> str | None: - if op == "+": - if isinstance(left, str) and isinstance(right, str): - return left + right - elif op == "*": - if isinstance(left, int) and isinstance(right, str): - return left * right - if isinstance(left, str) and isinstance(right, int): - return left * right - return None diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index 31ca609adfac..cc9b1bf99d24 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -13,17 +13,23 @@ from typing import Union from typing_extensions import Final -from mypy.constant_fold import ( - constant_fold_binary_int_op, - constant_fold_binary_str_op, - constant_fold_unary_int_op, +from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_int_op +from mypy.nodes import ( + Expression, + FloatExpr, + IntExpr, + MemberExpr, + NameExpr, + OpExpr, + StrExpr, + UnaryExpr, + Var, ) -from mypy.nodes import Expression, IntExpr, MemberExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var from mypyc.irbuild.builder import IRBuilder # All possible result types of constant folding -ConstantValue = Union[int, str] -CONST_TYPES: Final = (int, str) +ConstantValue = Union[int, float, str] +CONST_TYPES: Final = (int, float, str) def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None: @@ -33,6 +39,8 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | """ if isinstance(expr, IntExpr): return expr.value + if isinstance(expr, FloatExpr): + return expr.value if isinstance(expr, StrExpr): return expr.value elif isinstance(expr, NameExpr): @@ -52,14 +60,9 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | elif isinstance(expr, OpExpr): left = constant_fold_expr(builder, expr.left) right = constant_fold_expr(builder, expr.right) - if isinstance(left, int) and isinstance(right, int): - return constant_fold_binary_int_op(expr.op, left, right) - elif isinstance(left, str) and isinstance(right, str): - return constant_fold_binary_str_op(expr.op, left, right) - elif isinstance(left, str) and isinstance(right, int): - return constant_fold_binary_str_op(expr.op, left, right) - elif isinstance(left, int) and isinstance(right, str): - return constant_fold_binary_str_op(expr.op, left, right) + value = constant_fold_binary_op(expr.op, left, right) + if value is not None: + return value elif isinstance(expr, UnaryExpr): value = constant_fold_expr(builder, expr.expr) if isinstance(value, int): From ced8a105956bbd4c213a5dd95aa6157cf22ffe84 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 4 Mar 2023 21:55:16 -0500 Subject: [PATCH 03/14] Support constant folding complex ops --- mypy/constant_fold.py | 23 +++++++++++++++++++--- mypy/nodes.py | 2 +- mypy/semanal.py | 2 +- mypyc/irbuild/builder.py | 12 +++++------ mypyc/irbuild/constant_fold.py | 7 +++++-- mypyc/irbuild/expression.py | 6 ++---- mypyc/test-data/fixtures/ir.py | 1 + mypyc/test-data/irbuild-constant-fold.test | 18 +++++++++++++++++ 8 files changed, 54 insertions(+), 17 deletions(-) diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 294a27ffde85..5f0306abc381 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -8,11 +8,21 @@ from typing import Union from typing_extensions import Final -from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var +from mypy.nodes import ( + ComplexExpr, + Expression, + FloatExpr, + IntExpr, + NameExpr, + OpExpr, + StrExpr, + UnaryExpr, + Var, +) # All possible result types of constant folding -ConstantValue = Union[int, bool, float, str] -CONST_TYPES: Final = (int, bool, float, str) +ConstantValue = Union[int, bool, float, complex, str] +CONST_TYPES: Final = (int, bool, float, complex, str) def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None: @@ -39,6 +49,8 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non return expr.value if isinstance(expr, FloatExpr): return expr.value + if isinstance(expr, ComplexExpr): + return expr.value elif isinstance(expr, NameExpr): if expr.name == "True": return True @@ -79,6 +91,11 @@ def constant_fold_binary_op( elif op == "*" and isinstance(left, int) and isinstance(right, str): return left * right + if op == "+" and isinstance(left, int) and isinstance(right, complex): + return left + right + elif op == "+" and isinstance(left, complex) and isinstance(right, int): + return left + right + return None diff --git a/mypy/nodes.py b/mypy/nodes.py index e4d8514ad6e2..cfb3ef3f0137 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -989,7 +989,7 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None: # If constant value is a simple literal, # store the literal value (unboxed) for the benefit of # tools like mypyc. - self.final_value: int | float | bool | str | None = None + self.final_value: int | float | complex | bool | str | None = None # Where the value was set (only for class attributes) self.final_unset_in_class = False self.final_set_in_init = False diff --git a/mypy/semanal.py b/mypy/semanal.py index 2720d2606e92..a63bdc4d54c2 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -3350,7 +3350,7 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Typ return None value = constant_fold_expr(rvalue, self.cur_mod_id) - if value is None: + if value is None or isinstance(value, complex): return None if isinstance(value, bool): diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index a49429f1c6ec..7860e3881766 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -548,16 +548,14 @@ def load_final_static( error_msg=f'value for final name "{error_name}" was not set', ) - def load_final_literal_value(self, val: int | str | bytes | float | bool, line: int) -> Value: - """Load value of a final name or class-level attribute.""" + def load_literal_value(self, val: int | str | bytes | float | complex | bool) -> Value: + """Load value of a final name, class-level attribute, or constant folded expression.""" if isinstance(val, bool): if val: return self.true() else: return self.false() elif isinstance(val, int): - # TODO: take care of negative integer initializers - # (probably easier to fix this in mypy itself). return self.builder.load_int(val) elif isinstance(val, float): return self.builder.load_float(val) @@ -565,8 +563,10 @@ def load_final_literal_value(self, val: int | str | bytes | float | bool, line: return self.builder.load_str(val) elif isinstance(val, bytes): return self.builder.load_bytes(val) + elif isinstance(val, complex): + return self.builder.load_complex(val) else: - assert False, "Unsupported final literal value" + assert False, "Unsupported literal value" def get_assignment_target( self, lvalue: Lvalue, line: int = -1, *, for_read: bool = False @@ -1010,7 +1010,7 @@ def emit_load_final( line: line number where loading occurs """ if final_var.final_value is not None: # this is safe even for non-native names - return self.load_final_literal_value(final_var.final_value, line) + return self.load_literal_value(final_var.final_value) elif native: return self.load_final_static(fullname, self.mapper.type_to_rtype(typ), line, name) else: diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index cc9b1bf99d24..2fc8ad909ec5 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -15,6 +15,7 @@ from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_int_op from mypy.nodes import ( + ComplexExpr, Expression, FloatExpr, IntExpr, @@ -28,8 +29,8 @@ from mypyc.irbuild.builder import IRBuilder # All possible result types of constant folding -ConstantValue = Union[int, float, str] -CONST_TYPES: Final = (int, float, str) +ConstantValue = Union[int, float, complex, str] +CONST_TYPES: Final = (int, float, complex, str) def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None: @@ -43,6 +44,8 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | return expr.value if isinstance(expr, StrExpr): return expr.value + if isinstance(expr, ComplexExpr): + return expr.value elif isinstance(expr, NameExpr): node = expr.node if isinstance(node, Var) and node.is_final: diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 5997bdbd0a43..ab35517287dd 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -562,10 +562,8 @@ def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None: Return None otherwise. """ value = constant_fold_expr(builder, expr) - if isinstance(value, int): - return builder.load_int(value) - elif isinstance(value, str): - return builder.load_str(value) + if value is not None: + return builder.load_literal_value(value) return None diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 3d3562382f07..e9fc38126352 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -125,6 +125,7 @@ def __invert__(self) -> float: pass class complex: def __init__(self, x: object, y: object = None) -> None: pass def __add__(self, n: complex) -> complex: pass + def __radd__(self, n: int) -> complex: pass def __sub__(self, n: complex) -> complex: pass def __mul__(self, n: complex) -> complex: pass def __truediv__(self, n: complex) -> complex: pass diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index 088e27e83605..9ebe99b64493 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -258,3 +258,21 @@ L0: mul2 = r3 return 1 +[case testComplexConstantFolding] +from typing_extensions import Final + +N: Final = 1 + +def f() -> None: + x = 1+2j + y = 2j+N +[out] +def f(): + r0, x, r1, y :: object +L0: + r0 = (1+2j) + x = r0 + r1 = (1+2j) + y = r1 + return 1 + From 54a892b6713dc1049729e2eab51f040b69c24649 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 4 Mar 2023 22:21:01 -0500 Subject: [PATCH 04/14] Support constant folding bytes ops Unfortunately mypy can't easily support storing the bytes value for final references. Other than preventing b"foo" + CONST_BYTES from being folded, it also means this commit is mypyc-only. --- mypyc/irbuild/constant_fold.py | 42 +++++++++++++++++----- mypyc/test-data/fixtures/ir.py | 2 ++ mypyc/test-data/irbuild-constant-fold.test | 22 ++++++++++++ 3 files changed, 57 insertions(+), 9 deletions(-) diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index 2fc8ad909ec5..523ddba02c6b 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -15,6 +15,7 @@ from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_int_op from mypy.nodes import ( + BytesExpr, ComplexExpr, Expression, FloatExpr, @@ -27,10 +28,11 @@ Var, ) from mypyc.irbuild.builder import IRBuilder +from mypyc.irbuild.util import bytes_from_str # All possible result types of constant folding -ConstantValue = Union[int, float, complex, str] -CONST_TYPES: Final = (int, float, complex, str) +ConstantValue = Union[int, float, complex, str, bytes] +CONST_TYPES: Final = (int, float, complex, str, bytes) def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None: @@ -44,26 +46,28 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | return expr.value if isinstance(expr, StrExpr): return expr.value + if isinstance(expr, BytesExpr): + return bytes_from_str(expr.value) if isinstance(expr, ComplexExpr): return expr.value elif isinstance(expr, NameExpr): node = expr.node if isinstance(node, Var) and node.is_final: - value = node.final_value - if isinstance(value, (CONST_TYPES)): - return value + final_value = node.final_value + if isinstance(final_value, (CONST_TYPES)): + return final_value elif isinstance(expr, MemberExpr): final = builder.get_final_ref(expr) if final is not None: fn, final_var, native = final if final_var.is_final: - value = final_var.final_value - if isinstance(value, (CONST_TYPES)): - return value + final_value = final_var.final_value + if isinstance(final_value, (CONST_TYPES)): + return final_value elif isinstance(expr, OpExpr): left = constant_fold_expr(builder, expr.left) right = constant_fold_expr(builder, expr.right) - value = constant_fold_binary_op(expr.op, left, right) + value = constant_fold_binary_op_extended(expr.op, left, right) if value is not None: return value elif isinstance(expr, UnaryExpr): @@ -71,3 +75,23 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | if isinstance(value, int): return constant_fold_unary_int_op(expr.op, value) return None + + +def constant_fold_binary_op_extended( + op: str, left: ConstantValue | None, right: ConstantValue | None +) -> ConstantValue | None: + """Like mypy's constant_fold_binary_op(), but includes bytes support. + + mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc. + """ + if not isinstance(left, bytes) and not isinstance(right, bytes): + return constant_fold_binary_op(op, left, right) + + if op == "+" and isinstance(left, bytes) and isinstance(right, bytes): + return left + right + elif op == "*" and isinstance(left, bytes) and isinstance(right, int): + return left * right + elif op == "*" and isinstance(left, int) and isinstance(right, bytes): + return left * right + + return None diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index e9fc38126352..a932c5be0e09 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -137,6 +137,8 @@ def __init__(self) -> None: ... @overload def __init__(self, x: object) -> None: ... def __add__(self, x: bytes) -> bytes: ... + def __mul__(self, x: int) -> bytes: ... + def __rmul__(self, x: int) -> bytes: ... def __eq__(self, x: object) -> bool: ... def __ne__(self, x: object) -> bool: ... @overload diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index 9ebe99b64493..e1962b3367b6 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -258,6 +258,28 @@ L0: mul2 = r3 return 1 +[case testBytesConstantFolding] +from typing_extensions import Final + +N: Final = 2 + +def f() -> None: + # Unfortunately, mypy doesn't store the bytes value of final refs. + x = b'foo' + b'bar' + mul = b"foobar" * 2 + mul2 = N * b"foobar" +[out] +def f(): + r0, x, r1, mul, r2, mul2 :: bytes +L0: + r0 = b'foobar' + x = r0 + r1 = b'foobarfoobar' + mul = r1 + r2 = b'foobarfoobar' + mul2 = r2 + return 1 + [case testComplexConstantFolding] from typing_extensions import Final From c057484d6d84d0a6b07292e2bb779eadeddb92c1 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 4 Mar 2023 22:23:45 -0500 Subject: [PATCH 05/14] Simplify irbuild.expression.set_literal_values() --- mypyc/irbuild/expression.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index ab35517287dd..d599dcf1314c 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -90,7 +90,6 @@ tokenizer_printf_style, ) from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization -from mypyc.irbuild.util import bytes_from_str from mypyc.primitives.bytes_ops import bytes_slice_op from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op from mypyc.primitives.generic_ops import iter_op @@ -645,10 +644,6 @@ def set_literal_values(builder: IRBuilder, items: Sequence[Expression]) -> list[ values.append(True) elif item.fullname == "builtins.False": values.append(False) - elif isinstance(item, (BytesExpr, FloatExpr, ComplexExpr)): - # constant_fold_expr() doesn't handle these (yet?) - v = bytes_from_str(item.value) if isinstance(item, BytesExpr) else item.value - values.append(v) elif isinstance(item, TupleExpr): tuple_values = set_literal_values(builder, item.items) if tuple_values is not None: @@ -668,7 +663,6 @@ def precompute_set_literal(builder: IRBuilder, s: SetExpr) -> Value | None: Supported items: - Anything supported by irbuild.constant_fold.constant_fold_expr() - None, True, and False - - Float, byte, and complex literals - Tuple literals with only items listed above """ values = set_literal_values(builder, s.items) From 3b3c4c752f20970bb791dbe5492cb23e9d9d3132 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sun, 5 Mar 2023 12:46:30 -0500 Subject: [PATCH 06/14] Support constant folding float ops --- mypy/constant_fold.py | 56 +++++- mypyc/irbuild/constant_fold.py | 8 +- mypyc/test-data/fixtures/ir.py | 5 +- mypyc/test-data/irbuild-constant-fold.test | 224 +++++++++++++++++---- 4 files changed, 253 insertions(+), 40 deletions(-) diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 5f0306abc381..0f2e00bdf0dd 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -75,6 +75,8 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non value = constant_fold_expr(expr.expr, cur_mod_id) if isinstance(value, int): return constant_fold_unary_int_op(expr.op, value) + if isinstance(value, float): + return constant_fold_unary_float_op(expr.op, value) return None @@ -84,6 +86,13 @@ def constant_fold_binary_op( if isinstance(left, int) and isinstance(right, int): return constant_fold_binary_int_op(op, left, right) + if isinstance(left, float) and isinstance(right, float): + return constant_fold_binary_float_op(op, left, right) + if isinstance(left, float) and isinstance(right, int): + return constant_fold_binary_float_op(op, left, right) + if isinstance(left, int) and isinstance(right, float): + return constant_fold_binary_float_op(op, left, right) + if op == "+" and isinstance(left, str) and isinstance(right, str): return left + right elif op == "*" and isinstance(left, str) and isinstance(right, int): @@ -91,21 +100,24 @@ def constant_fold_binary_op( elif op == "*" and isinstance(left, int) and isinstance(right, str): return left * right - if op == "+" and isinstance(left, int) and isinstance(right, complex): + if op == "+" and isinstance(left, (int, float)) and isinstance(right, complex): return left + right - elif op == "+" and isinstance(left, complex) and isinstance(right, int): + elif op == "+" and isinstance(left, complex) and isinstance(right, (int, float)): return left + right return None -def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None: +def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | float | None: if op == "+": return left + right if op == "-": return left - right elif op == "*": return left * right + elif op == "/": + if right != 0: + return left / right elif op == "//": if right != 0: return left // right @@ -132,6 +144,36 @@ def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None: return None +def constant_fold_binary_float_op(op: str, left: int | float, right: int | float) -> float | None: + assert not (isinstance(left, int) and isinstance(right, int)), (op, left, right) + if op == "+": + return left + right + if op == "-": + return left - right + elif op == "*": + return left * right + elif op == "/": + if right != 0: + return left / right + elif op == "//": + if right != 0: + return left // right + elif op == "%": + if right != 0: + return left % right + elif op == "**": + if (left < 0 and right >= 1 or right == 0) or (left >= 0 and right >= 0): + try: + ret = left**right + except OverflowError: + return None + else: + assert isinstance(ret, float) + return ret + + return None + + def constant_fold_unary_int_op(op: str, value: int) -> int | None: if op == "-": return -value @@ -140,3 +182,11 @@ def constant_fold_unary_int_op(op: str, value: int) -> int | None: elif op == "+": return value return None + + +def constant_fold_unary_float_op(op: str, value: float) -> float | None: + if op == "-": + return -value + elif op == "+": + return value + return None diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index 523ddba02c6b..d3e86fb06eba 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -13,7 +13,11 @@ from typing import Union from typing_extensions import Final -from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_int_op +from mypy.constant_fold import ( + constant_fold_binary_op, + constant_fold_unary_int_op, + constant_fold_unary_float_op, +) from mypy.nodes import ( BytesExpr, ComplexExpr, @@ -74,6 +78,8 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | value = constant_fold_expr(builder, expr.expr) if isinstance(value, int): return constant_fold_unary_int_op(expr.op, value) + if isinstance(value, float): + return constant_fold_unary_float_op(expr.op, value) return None diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index a932c5be0e09..3b75a1c995c9 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -113,9 +113,12 @@ def encode(self, x: str=..., y: str=...) -> bytes: ... class float: def __init__(self, x: object) -> None: pass def __add__(self, n: float) -> float: pass + def __radd__(self, n: int) -> float: pass def __sub__(self, n: float) -> float: pass + def __rsub__(self, n: int) -> int: pass def __mul__(self, n: float) -> float: pass def __truediv__(self, n: float) -> float: pass + def __floordiv__(self, n: float) -> float: pass def __pow__(self, n: float) -> float: pass def __neg__(self) -> float: pass def __pos__(self) -> float: pass @@ -125,7 +128,7 @@ def __invert__(self) -> float: pass class complex: def __init__(self, x: object, y: object = None) -> None: pass def __add__(self, n: complex) -> complex: pass - def __radd__(self, n: int) -> complex: pass + def __radd__(self, n: float) -> complex: pass def __sub__(self, n: complex) -> complex: pass def __mul__(self, n: complex) -> complex: pass def __truediv__(self, n: complex) -> complex: pass diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index e1962b3367b6..17d8afb22ee0 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -3,6 +3,7 @@ def bin_ops() -> None: add = 15 + 47 add_mul = (2 + 3) * 5 sub = 7 - 11 + div = 3 / 2 bit_and = 6 & 10 bit_or = 6 | 10 bit_xor = 6 ^ 10 @@ -25,11 +26,15 @@ def pow() -> None: p3 = 0**0 [out] def bin_ops(): - add, add_mul, sub, bit_and, bit_or, bit_xor, lshift, rshift, lshift0, rshift0 :: int + add, add_mul, sub :: int + r0, div :: float + bit_and, bit_or, bit_xor, lshift, rshift, lshift0, rshift0 :: int L0: add = 124 add_mul = 50 sub = -8 + r0 = 1.5 + div = r0 bit_and = 4 bit_or = 28 bit_xor = 24 @@ -117,44 +122,32 @@ L0: [case testIntConstantFoldingUnsupportedCases] def error_cases() -> None: - div_by_zero = 5 // 0 + div_by_zero = 5 / 0 + floor_div_by_zero = 5 // 0 mod_by_zero = 5 % 0 lshift_neg = 6 << -1 rshift_neg = 7 >> -1 -def unsupported_div() -> None: - x = 4 / 6 - y = 10 / 5 def unsupported_pow() -> None: p = 3 ** (-1) [out] def error_cases(): - r0, div_by_zero, r1, mod_by_zero, r2, lshift_neg, r3, rshift_neg :: int -L0: - r0 = CPyTagged_FloorDivide(10, 0) - div_by_zero = r0 - r1 = CPyTagged_Remainder(10, 0) - mod_by_zero = r1 - r2 = CPyTagged_Lshift(12, -2) - lshift_neg = r2 - r3 = CPyTagged_Rshift(14, -2) - rshift_neg = r3 - return 1 -def unsupported_div(): r0, r1, r2 :: object - r3, x :: float - r4, r5, r6 :: object - r7, y :: float + r3, div_by_zero :: float + r4, floor_div_by_zero, r5, mod_by_zero, r6, lshift_neg, r7, rshift_neg :: int L0: - r0 = object 4 - r1 = object 6 + r0 = object 5 + r1 = object 0 r2 = PyNumber_TrueDivide(r0, r1) r3 = cast(float, r2) - x = r3 - r4 = object 10 - r5 = object 5 - r6 = PyNumber_TrueDivide(r4, r5) - r7 = cast(float, r6) - y = r7 + div_by_zero = r3 + r4 = CPyTagged_FloorDivide(10, 0) + floor_div_by_zero = r4 + r5 = CPyTagged_Remainder(10, 0) + mod_by_zero = r5 + r6 = CPyTagged_Lshift(12, -2) + lshift_neg = r6 + r7 = CPyTagged_Rshift(14, -2) + rshift_neg = r7 return 1 def unsupported_pow(): r0, r1, r2 :: object @@ -233,6 +226,155 @@ L0: a = 12 return 1 +[case testFloatConstantFolding] +def bin_ops() -> None: + add = 0.5 + 0.5 + add_mul = (1.5 + 3.5) * 5.0 + sub = 7.0 - 7.5 + div = 3.0 / 2.0 + floor_div = 3.0 // 2.0 +def unary_ops() -> None: + neg1 = -5.5 + neg2 = --1.5 + neg3 = -0.0 + pos = +5.5 +def pow() -> None: + p0 = 16.0**0 + p1 = 16.0**0.5 + p2 = (-5.0)**3 + p3 = 0.0**0.0 +def error_cases() -> None: + div = 2.0 / 0.0 + floor_div = 2.0 // 0.0 + power_imag = (-2.0)**0.5 + power_overflow = 2.0**10000.0 +[out] +def bin_ops(): + r0, add, r1, add_mul, r2, sub, r3, div, r4, floor_div :: float +L0: + r0 = 1.0 + add = r0 + r1 = 25.0 + add_mul = r1 + r2 = -0.5 + sub = r2 + r3 = 1.5 + div = r3 + r4 = 1.0 + floor_div = r4 + return 1 +def unary_ops(): + r0, neg1, r1, neg2, r2, neg3, r3, pos :: float +L0: + r0 = -5.5 + neg1 = r0 + r1 = 1.5 + neg2 = r1 + r2 = -0.0 + neg3 = r2 + r3 = 5.5 + pos = r3 + return 1 +def pow(): + r0, p0, r1, p1, r2, p2, r3, p3 :: float +L0: + r0 = 1.0 + p0 = r0 + r1 = 4.0 + p1 = r1 + r2 = -125.0 + p2 = r2 + r3 = 1.0 + p3 = r3 + return 1 +def error_cases(): + r0, r1 :: float + r2 :: object + r3, div, r4, r5 :: float + r6 :: object + r7, floor_div, r8, r9 :: float + r10 :: object + r11, power_imag, r12, r13 :: float + r14 :: object + r15, power_overflow :: float +L0: + r0 = 2.0 + r1 = 0.0 + r2 = PyNumber_TrueDivide(r0, r1) + r3 = cast(float, r2) + div = r3 + r4 = 2.0 + r5 = 0.0 + r6 = PyNumber_FloorDivide(r4, r5) + r7 = cast(float, r6) + floor_div = r7 + r8 = -2.0 + r9 = 0.5 + r10 = CPyNumber_Power(r8, r9) + r11 = cast(float, r10) + power_imag = r11 + r12 = 2.0 + r13 = 10000.0 + r14 = CPyNumber_Power(r12, r13) + r15 = cast(float, r14) + power_overflow = r15 + return 1 + +[case testMixedFloatIntConstantFolding] +def bin_ops() -> None: + add = 1 + 0.5 + sub = 1 - 0.5 + mul = 0.5 * 5 + div = 5 / 0.5 + floor_div = 9.5 // 5 +def error_cases() -> None: + div = 2.0 / 0 + floor_div = 2.0 / 0 + power_overflow = 2.0**10000 +[out] +def bin_ops(): + r0, add, r1 :: float + r2, sub :: int + r3, mul, r4, div, r5, floor_div :: float +L0: + r0 = 1.5 + add = r0 + r1 = 0.5 + r2 = unbox(int, r1) + sub = r2 + r3 = 2.5 + mul = r3 + r4 = 10.0 + div = r4 + r5 = 1.0 + floor_div = r5 + return 1 +def error_cases(): + r0 :: float + r1, r2 :: object + r3, div, r4 :: float + r5, r6 :: object + r7, floor_div, r8 :: float + r9, r10 :: object + r11, power_overflow :: float +L0: + r0 = 2.0 + r1 = object 0 + r2 = PyNumber_TrueDivide(r0, r1) + r3 = cast(float, r2) + div = r3 + r4 = 2.0 + r5 = object 0 + r6 = PyNumber_TrueDivide(r4, r5) + r7 = cast(float, r6) + floor_div = r7 + r8 = 2.0 + r9 = object 10000 + r10 = CPyNumber_Power(r8, r9) + r11 = cast(float, r10) + power_overflow = r11 + return 1 + [case testStrConstantFolding] from typing_extensions import Final @@ -284,17 +426,29 @@ L0: from typing_extensions import Final N: Final = 1 +FLOAT_N: Final = 1.5 -def f() -> None: - x = 1+2j - y = 2j+N +def integral() -> None: + pos = 1+2j + pos_2 = 2j+N +def floating() -> None: + pos = 1.5+2j + pos_2 = 2j+FLOAT_N [out] -def f(): - r0, x, r1, y :: object +def integral(): + r0, pos, r1, pos_2 :: object L0: r0 = (1+2j) - x = r0 + pos = r0 r1 = (1+2j) - y = r1 + pos_2 = r1 + return 1 +def floating(): + r0, pos, r1, pos_2 :: object +L0: + r0 = (1.5+2j) + pos = r0 + r1 = (1.5+2j) + pos_2 = r1 return 1 From 3bae7d983118318b029940ea22c6cc3111f232ca Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sun, 5 Mar 2023 12:51:57 -0500 Subject: [PATCH 07/14] Also support folding complex with negative real --- mypy/constant_fold.py | 10 +++++++--- mypyc/test-data/fixtures/ir.py | 1 + mypyc/test-data/irbuild-constant-fold.test | 16 ++++++++++++++-- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 0f2e00bdf0dd..37d3370c22fc 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -88,9 +88,9 @@ def constant_fold_binary_op( if isinstance(left, float) and isinstance(right, float): return constant_fold_binary_float_op(op, left, right) - if isinstance(left, float) and isinstance(right, int): + elif isinstance(left, float) and isinstance(right, int): return constant_fold_binary_float_op(op, left, right) - if isinstance(left, int) and isinstance(right, float): + elif isinstance(left, int) and isinstance(right, float): return constant_fold_binary_float_op(op, left, right) if op == "+" and isinstance(left, str) and isinstance(right, str): @@ -104,6 +104,10 @@ def constant_fold_binary_op( return left + right elif op == "+" and isinstance(left, complex) and isinstance(right, (int, float)): return left + right + elif op == "-" and isinstance(left, (int, float)) and isinstance(right, complex): + return left - right + elif op == "-" and isinstance(left, complex) and isinstance(right, (int, float)): + return left - right return None @@ -148,7 +152,7 @@ def constant_fold_binary_float_op(op: str, left: int | float, right: int | float assert not (isinstance(left, int) and isinstance(right, int)), (op, left, right) if op == "+": return left + right - if op == "-": + elif op == "-": return left - right elif op == "*": return left * right diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 3b75a1c995c9..afae4b82b9c9 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -130,6 +130,7 @@ def __init__(self, x: object, y: object = None) -> None: pass def __add__(self, n: complex) -> complex: pass def __radd__(self, n: float) -> complex: pass def __sub__(self, n: complex) -> complex: pass + def __rsub__(self, n: float) -> complex: pass def __mul__(self, n: complex) -> complex: pass def __truediv__(self, n: complex) -> complex: pass def __neg__(self) -> complex: pass diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index 17d8afb22ee0..348369b6cadb 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -431,24 +431,36 @@ FLOAT_N: Final = 1.5 def integral() -> None: pos = 1+2j pos_2 = 2j+N + neg = 1-2j + neg_2 = 2j-N def floating() -> None: pos = 1.5+2j pos_2 = 2j+FLOAT_N + neg = 1.5-2j + neg_2 = 2j-FLOAT_N [out] def integral(): - r0, pos, r1, pos_2 :: object + r0, pos, r1, pos_2, r2, neg, r3, neg_2 :: object L0: r0 = (1+2j) pos = r0 r1 = (1+2j) pos_2 = r1 + r2 = (1-2j) + neg = r2 + r3 = (-1+2j) + neg_2 = r3 return 1 def floating(): - r0, pos, r1, pos_2 :: object + r0, pos, r1, pos_2, r2, neg, r3, neg_2 :: object L0: r0 = (1.5+2j) pos = r0 r1 = (1.5+2j) pos_2 = r1 + r2 = (1.5-2j) + neg = r2 + r3 = (-1.5+2j) + neg_2 = r3 return 1 From 3670d0d0efae81f80d42808f1cf46dfffedd534d Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sun, 5 Mar 2023 13:22:34 -0500 Subject: [PATCH 08/14] Avoid constant folding in tests where it's unwanted + fix stub --- mypyc/test-data/fixtures/ir.py | 2 +- mypyc/test-data/irbuild-basic.test | 11 ++++++----- mypyc/test-data/irbuild-constant-fold.test | 19 ++++++++----------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index afae4b82b9c9..3a92db79dee6 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -115,7 +115,7 @@ def __init__(self, x: object) -> None: pass def __add__(self, n: float) -> float: pass def __radd__(self, n: int) -> float: pass def __sub__(self, n: float) -> float: pass - def __rsub__(self, n: int) -> int: pass + def __rsub__(self, n: int) -> float: pass def __mul__(self, n: float) -> float: pass def __truediv__(self, n: float) -> float: pass def __floordiv__(self, n: float) -> float: pass diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index a06977d037b2..c1898984ccf4 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -1036,15 +1036,16 @@ L0: [case testLoadComplex] def load() -> complex: - return 5j+1.0 + real = 1 + return 5j+real [out] def load(): - r0 :: object - r1 :: float - r2 :: object + real :: int + r0, r1, r2 :: object L0: + real = 2 r0 = 5j - r1 = 1.0 + r1 = box(int, real) r2 = PyNumber_Add(r0, r1) return r2 diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index 348369b6cadb..480132ab902b 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -333,21 +333,18 @@ def error_cases() -> None: power_overflow = 2.0**10000 [out] def bin_ops(): - r0, add, r1 :: float - r2, sub :: int - r3, mul, r4, div, r5, floor_div :: float + r0, add, r1, sub, r2, mul, r3, div, r4, floor_div :: float L0: r0 = 1.5 add = r0 r1 = 0.5 - r2 = unbox(int, r1) - sub = r2 - r3 = 2.5 - mul = r3 - r4 = 10.0 - div = r4 - r5 = 1.0 - floor_div = r5 + sub = r1 + r2 = 2.5 + mul = r2 + r3 = 10.0 + div = r3 + r4 = 1.0 + floor_div = r4 return 1 def error_cases(): r0 :: float From 8d24649e4873a2600ba5632bd63b5728e7de4485 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sun, 5 Mar 2023 14:51:36 -0500 Subject: [PATCH 09/14] Refactor to reduce repetition and improve consistency --- mypy/constant_fold.py | 29 +++++++++-------------------- mypyc/irbuild/constant_fold.py | 19 ++++++------------- 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 37d3370c22fc..5d6ec7c7cf11 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -68,20 +68,17 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non elif isinstance(expr, OpExpr): left = constant_fold_expr(expr.left, cur_mod_id) right = constant_fold_expr(expr.right, cur_mod_id) - value = constant_fold_binary_op(expr.op, left, right) - if value is not None: - return value + if left is not None and right is not None: + return constant_fold_binary_op(expr.op, left, right) elif isinstance(expr, UnaryExpr): value = constant_fold_expr(expr.expr, cur_mod_id) - if isinstance(value, int): - return constant_fold_unary_int_op(expr.op, value) - if isinstance(value, float): - return constant_fold_unary_float_op(expr.op, value) + if value is not None: + return constant_fold_unary_op(expr.op, value) return None def constant_fold_binary_op( - op: str, left: ConstantValue | None, right: ConstantValue | None + op: str, left: ConstantValue, right: ConstantValue ) -> ConstantValue | None: if isinstance(left, int) and isinstance(right, int): return constant_fold_binary_int_op(op, left, right) @@ -178,19 +175,11 @@ def constant_fold_binary_float_op(op: str, left: int | float, right: int | float return None -def constant_fold_unary_int_op(op: str, value: int) -> int | None: - if op == "-": +def constant_fold_unary_op(op: str, value: ConstantValue) -> int | float | None: + if op == "-" and isinstance(value, (int, float)): return -value - elif op == "~": + elif op == "~" and isinstance(value, int): return ~value - elif op == "+": - return value - return None - - -def constant_fold_unary_float_op(op: str, value: float) -> float | None: - if op == "-": - return -value - elif op == "+": + elif op == "+" and isinstance(value, (int, float)): return value return None diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index d3e86fb06eba..dc21be4689e2 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -13,11 +13,7 @@ from typing import Union from typing_extensions import Final -from mypy.constant_fold import ( - constant_fold_binary_op, - constant_fold_unary_int_op, - constant_fold_unary_float_op, -) +from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op from mypy.nodes import ( BytesExpr, ComplexExpr, @@ -71,20 +67,17 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | elif isinstance(expr, OpExpr): left = constant_fold_expr(builder, expr.left) right = constant_fold_expr(builder, expr.right) - value = constant_fold_binary_op_extended(expr.op, left, right) - if value is not None: - return value + if left is not None and right is not None: + return constant_fold_binary_op_extended(expr.op, left, right) elif isinstance(expr, UnaryExpr): value = constant_fold_expr(builder, expr.expr) - if isinstance(value, int): - return constant_fold_unary_int_op(expr.op, value) - if isinstance(value, float): - return constant_fold_unary_float_op(expr.op, value) + if value is not None and not isinstance(value, bytes): + return constant_fold_unary_op(expr.op, value) return None def constant_fold_binary_op_extended( - op: str, left: ConstantValue | None, right: ConstantValue | None + op: str, left: ConstantValue, right: ConstantValue ) -> ConstantValue | None: """Like mypy's constant_fold_binary_op(), but includes bytes support. From 85a35c45dae3f5e72d241b852aff20362ebc34b1 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 6 May 2023 12:49:39 -0400 Subject: [PATCH 10/14] Fixup IR tests after native floats landed --- mypyc/test-data/irbuild-constant-fold.test | 200 ++++++++++----------- 1 file changed, 91 insertions(+), 109 deletions(-) diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index 1909b95c9818..e5b3d56eeb17 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -27,14 +27,13 @@ def pow() -> None: [out] def bin_ops(): add, add_mul, sub :: int - r0, div :: float + div :: float bit_and, bit_or, bit_xor, lshift, rshift, lshift0, rshift0 :: int L0: add = 124 add_mul = 50 sub = -8 - r0 = 1.5 - div = r0 + div = 1.5 bit_and = 4 bit_or = 28 bit_xor = 24 @@ -131,23 +130,19 @@ def unsupported_pow() -> None: p = 3 ** (-1) [out] def error_cases(): - r0, r1, r2 :: object - r3, div_by_zero :: float - r4, floor_div_by_zero, r5, mod_by_zero, r6, lshift_neg, r7, rshift_neg :: int + r0, div_by_zero :: float + r1, floor_div_by_zero, r2, mod_by_zero, r3, lshift_neg, r4, rshift_neg :: int L0: - r0 = object 5 - r1 = object 0 - r2 = PyNumber_TrueDivide(r0, r1) - r3 = cast(float, r2) - div_by_zero = r3 - r4 = CPyTagged_FloorDivide(10, 0) - floor_div_by_zero = r4 - r5 = CPyTagged_Remainder(10, 0) - mod_by_zero = r5 - r6 = CPyTagged_Lshift(12, -2) - lshift_neg = r6 - r7 = CPyTagged_Rshift(14, -2) - rshift_neg = r7 + r0 = CPyTagged_TrueDivide(10, 0) + div_by_zero = r0 + r1 = CPyTagged_FloorDivide(10, 0) + floor_div_by_zero = r1 + r2 = CPyTagged_Remainder(10, 0) + mod_by_zero = r2 + r3 = CPyTagged_Lshift(12, -2) + lshift_neg = r3 + r4 = CPyTagged_Rshift(14, -2) + rshift_neg = r4 return 1 def unsupported_pow(): r0, r1, r2 :: object @@ -250,74 +245,59 @@ def error_cases() -> None: power_overflow = 2.0**10000.0 [out] def bin_ops(): - r0, add, r1, add_mul, r2, sub, r3, div, r4, floor_div :: float + add, add_mul, sub, div, floor_div :: float L0: - r0 = 1.0 - add = r0 - r1 = 25.0 - add_mul = r1 - r2 = -0.5 - sub = r2 - r3 = 1.5 - div = r3 - r4 = 1.0 - floor_div = r4 + add = 1.0 + add_mul = 25.0 + sub = -0.5 + div = 1.5 + floor_div = 1.0 return 1 def unary_ops(): - r0, neg1, r1, neg2, r2, neg3, r3, pos :: float + neg1, neg2, neg3, pos :: float L0: - r0 = -5.5 - neg1 = r0 - r1 = 1.5 - neg2 = r1 - r2 = -0.0 - neg3 = r2 - r3 = 5.5 - pos = r3 + neg1 = -5.5 + neg2 = 1.5 + neg3 = -0.0 + pos = 5.5 return 1 def pow(): - r0, p0, r1, p1, r2, p2, r3, p3 :: float + p0, p1, p2, p3 :: float L0: - r0 = 1.0 - p0 = r0 - r1 = 4.0 - p1 = r1 - r2 = -125.0 - p2 = r2 - r3 = 1.0 - p3 = r3 + p0 = 1.0 + p1 = 4.0 + p2 = -125.0 + p3 = 1.0 return 1 def error_cases(): - r0, r1 :: float - r2 :: object - r3, div, r4, r5 :: float - r6 :: object - r7, floor_div, r8, r9 :: float - r10 :: object - r11, power_imag, r12, r13 :: float - r14 :: object - r15, power_overflow :: float + r0 :: bit + r1 :: bool + r2, div, r3, floor_div :: float + r4, r5, r6 :: object + r7, power_imag :: float + r8, r9, r10 :: object + r11, power_overflow :: float L0: - r0 = 2.0 - r1 = 0.0 - r2 = PyNumber_TrueDivide(r0, r1) - r3 = cast(float, r2) - div = r3 - r4 = 2.0 - r5 = 0.0 - r6 = PyNumber_FloorDivide(r4, r5) - r7 = cast(float, r6) - floor_div = r7 - r8 = -2.0 - r9 = 0.5 + r0 = 0.0 == 0.0 + if r0 goto L1 else goto L2 :: bool +L1: + r1 = raise ZeroDivisionError('float division by zero') + unreachable +L2: + r2 = 2.0 / 0.0 + div = r2 + r3 = CPyFloat_FloorDivide(2.0, 0.0) + floor_div = r3 + r4 = box(float, -2.0) + r5 = box(float, 0.5) + r6 = CPyNumber_Power(r4, r5) + r7 = unbox(float, r6) + power_imag = r7 + r8 = box(float, 2.0) + r9 = box(float, 10000.0) r10 = CPyNumber_Power(r8, r9) - r11 = cast(float, r10) - power_imag = r11 - r12 = 2.0 - r13 = 10000.0 - r14 = CPyNumber_Power(r12, r13) - r15 = cast(float, r14) - power_overflow = r15 + r11 = unbox(float, r10) + power_overflow = r11 return 1 [case testMixedFloatIntConstantFolding] @@ -333,43 +313,45 @@ def error_cases() -> None: power_overflow = 2.0**10000 [out] def bin_ops(): - r0, add, r1, sub, r2, mul, r3, div, r4, floor_div :: float + add, sub, mul, div, floor_div :: float L0: - r0 = 1.5 - add = r0 - r1 = 0.5 - sub = r1 - r2 = 2.5 - mul = r2 - r3 = 10.0 - div = r3 - r4 = 1.0 - floor_div = r4 + add = 1.5 + sub = 0.5 + mul = 2.5 + div = 10.0 + floor_div = 1.0 return 1 def error_cases(): - r0 :: float - r1, r2 :: object - r3, div, r4 :: float - r5, r6 :: object - r7, floor_div, r8 :: float - r9, r10 :: object - r11, power_overflow :: float + r0 :: bit + r1 :: bool + r2, div :: float + r3 :: bit + r4 :: bool + r5, floor_div :: float + r6, r7, r8 :: object + r9, power_overflow :: float L0: - r0 = 2.0 - r1 = object 0 - r2 = PyNumber_TrueDivide(r0, r1) - r3 = cast(float, r2) - div = r3 - r4 = 2.0 - r5 = object 0 - r6 = PyNumber_TrueDivide(r4, r5) - r7 = cast(float, r6) - floor_div = r7 - r8 = 2.0 - r9 = object 10000 - r10 = CPyNumber_Power(r8, r9) - r11 = cast(float, r10) - power_overflow = r11 + r0 = 0.0 == 0.0 + if r0 goto L1 else goto L2 :: bool +L1: + r1 = raise ZeroDivisionError('float division by zero') + unreachable +L2: + r2 = 2.0 / 0.0 + div = r2 + r3 = 0.0 == 0.0 + if r3 goto L3 else goto L4 :: bool +L3: + r4 = raise ZeroDivisionError('float division by zero') + unreachable +L4: + r5 = 2.0 / 0.0 + floor_div = r5 + r6 = box(float, 2.0) + r7 = box(float, 10000.0) + r8 = CPyNumber_Power(r6, r7) + r9 = unbox(float, r8) + power_overflow = r9 return 1 [case testStrConstantFolding] From 7fb8f8c024a52106b036992f86f9f2c5ccce15a4 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 6 May 2023 12:59:04 -0400 Subject: [PATCH 11/14] Moar comments --- mypy/constant_fold.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 5d6ec7c7cf11..005d67b8f961 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -83,6 +83,7 @@ def constant_fold_binary_op( if isinstance(left, int) and isinstance(right, int): return constant_fold_binary_int_op(op, left, right) + # Float and mixed int/float arithmetic. if isinstance(left, float) and isinstance(right, float): return constant_fold_binary_float_op(op, left, right) elif isinstance(left, float) and isinstance(right, int): @@ -90,6 +91,7 @@ def constant_fold_binary_op( elif isinstance(left, int) and isinstance(right, float): return constant_fold_binary_float_op(op, left, right) + # String concatenation and multiplication. if op == "+" and isinstance(left, str) and isinstance(right, str): return left + right elif op == "*" and isinstance(left, str) and isinstance(right, int): @@ -97,6 +99,7 @@ def constant_fold_binary_op( elif op == "*" and isinstance(left, int) and isinstance(right, str): return left * right + # Complex construction. if op == "+" and isinstance(left, (int, float)) and isinstance(right, complex): return left + right elif op == "+" and isinstance(left, complex) and isinstance(right, (int, float)): From 7542d9988a4d61fd0e922e3a2dd0663d4432f834 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 6 May 2023 13:03:34 -0400 Subject: [PATCH 12/14] Fix typo in IRbuild test + remove unused import --- mypyc/irbuild/expression.py | 1 - mypyc/test-data/irbuild-constant-fold.test | 31 ++++++++-------------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index b5b3651fd7b8..281cbb5cd726 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -54,7 +54,6 @@ Assign, BasicBlock, ComparisonOp, - Float, Integer, LoadAddress, LoadLiteral, diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index e5b3d56eeb17..c7c5c054e7ce 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -309,7 +309,7 @@ def bin_ops() -> None: floor_div = 9.5 // 5 def error_cases() -> None: div = 2.0 / 0 - floor_div = 2.0 / 0 + floor_div = 2.0 // 0 power_overflow = 2.0**10000 [out] def bin_ops(): @@ -324,12 +324,9 @@ L0: def error_cases(): r0 :: bit r1 :: bool - r2, div :: float - r3 :: bit - r4 :: bool - r5, floor_div :: float - r6, r7, r8 :: object - r9, power_overflow :: float + r2, div, r3, floor_div :: float + r4, r5, r6 :: object + r7, power_overflow :: float L0: r0 = 0.0 == 0.0 if r0 goto L1 else goto L2 :: bool @@ -339,19 +336,13 @@ L1: L2: r2 = 2.0 / 0.0 div = r2 - r3 = 0.0 == 0.0 - if r3 goto L3 else goto L4 :: bool -L3: - r4 = raise ZeroDivisionError('float division by zero') - unreachable -L4: - r5 = 2.0 / 0.0 - floor_div = r5 - r6 = box(float, 2.0) - r7 = box(float, 10000.0) - r8 = CPyNumber_Power(r6, r7) - r9 = unbox(float, r8) - power_overflow = r9 + r3 = CPyFloat_FloorDivide(2.0, 0.0) + floor_div = r3 + r4 = box(float, 2.0) + r5 = box(float, 10000.0) + r6 = CPyNumber_Power(r4, r5) + r7 = unbox(float, r6) + power_overflow = r7 return 1 [case testStrConstantFolding] From c32dc3b2a5a2929216fa445333a535130479f9f5 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 24 Jun 2023 18:42:17 -0400 Subject: [PATCH 13/14] Avoid stray complex results and expand tests --- mypy/constant_fold.py | 4 +- mypyc/test-data/irbuild-constant-fold.test | 57 +++++++++++++++++++--- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 005d67b8f961..6881ecae9e88 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -166,13 +166,13 @@ def constant_fold_binary_float_op(op: str, left: int | float, right: int | float if right != 0: return left % right elif op == "**": - if (left < 0 and right >= 1 or right == 0) or (left >= 0 and right >= 0): + if (left < 0 and isinstance(right, int)) or left > 0: try: ret = left**right except OverflowError: return None else: - assert isinstance(ret, float) + assert isinstance(ret, float), ret return ret return None diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index c7c5c054e7ce..32e72a78720d 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -222,12 +222,24 @@ L0: return 1 [case testFloatConstantFolding] +from typing_extensions import Final + +N: Final = 1.5 +N2: Final = 1.5 * 2 + def bin_ops() -> None: add = 0.5 + 0.5 add_mul = (1.5 + 3.5) * 5.0 sub = 7.0 - 7.5 div = 3.0 / 2.0 floor_div = 3.0 // 2.0 +def bin_ops_neg() -> None: + add = 0.5 + -0.5 + add_mul = (-1.5 + 3.5) * -5.0 + add_mul2 = (1.5 + -3.5) * -5.0 + sub = 7.0 - -7.5 + div = 3.0 / -2.0 + floor_div = 3.0 // -2.0 def unary_ops() -> None: neg1 = -5.5 neg2 = --1.5 @@ -237,12 +249,19 @@ def pow() -> None: p0 = 16.0**0 p1 = 16.0**0.5 p2 = (-5.0)**3 - p3 = 0.0**0.0 + p3 = 16.0**(-0) + p4 = 16.0**(-0.5) + p5 = (-2.0)**(-1) def error_cases() -> None: div = 2.0 / 0.0 floor_div = 2.0 // 0.0 power_imag = (-2.0)**0.5 + power_imag2 = (-2.0)**(-0.5) power_overflow = 2.0**10000.0 +def final_floats() -> None: + add1 = N + 1.2 + add2 = N + N2 + add3 = -1.2 + N2 [out] def bin_ops(): add, add_mul, sub, div, floor_div :: float @@ -253,6 +272,16 @@ L0: div = 1.5 floor_div = 1.0 return 1 +def bin_ops_neg(): + add, add_mul, add_mul2, sub, div, floor_div :: float +L0: + add = 0.0 + add_mul = -10.0 + add_mul2 = 10.0 + sub = 14.5 + div = -1.5 + floor_div = -2.0 + return 1 def unary_ops(): neg1, neg2, neg3, pos :: float L0: @@ -262,12 +291,14 @@ L0: pos = 5.5 return 1 def pow(): - p0, p1, p2, p3 :: float + p0, p1, p2, p3, p4, p5 :: float L0: p0 = 1.0 p1 = 4.0 p2 = -125.0 p3 = 1.0 + p4 = 0.25 + p5 = -0.5 return 1 def error_cases(): r0 :: bit @@ -276,7 +307,9 @@ def error_cases(): r4, r5, r6 :: object r7, power_imag :: float r8, r9, r10 :: object - r11, power_overflow :: float + r11, power_imag2 :: float + r12, r13, r14 :: object + r15, power_overflow :: float L0: r0 = 0.0 == 0.0 if r0 goto L1 else goto L2 :: bool @@ -293,11 +326,23 @@ L2: r6 = CPyNumber_Power(r4, r5) r7 = unbox(float, r6) power_imag = r7 - r8 = box(float, 2.0) - r9 = box(float, 10000.0) + r8 = box(float, -2.0) + r9 = box(float, -0.5) r10 = CPyNumber_Power(r8, r9) r11 = unbox(float, r10) - power_overflow = r11 + power_imag2 = r11 + r12 = box(float, 2.0) + r13 = box(float, 10000.0) + r14 = CPyNumber_Power(r12, r13) + r15 = unbox(float, r14) + power_overflow = r15 + return 1 +def final_floats(): + add1, add2, add3 :: float +L0: + add1 = 2.7 + add2 = 4.5 + add3 = 1.8 return 1 [case testMixedFloatIntConstantFolding] From be9e755ec2795dca8bdbe876c6f8be5bc072543c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 24 Jun 2023 22:43:03 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypyc/test-data/irbuild-constant-fold.test | 1 - 1 file changed, 1 deletion(-) diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index 32e72a78720d..97b13ab337c7 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -478,4 +478,3 @@ L0: r3 = (-1.5+2j) neg_2 = r3 return 1 -