Skip to content

Commit

Permalink
Constant fold more unary and binary expressions (#15202)
Browse files Browse the repository at this point in the history
Now mypy can constant fold these additional operations:

- Float arithmetic
- Mixed int and float arithmetic
- String multiplication
- Complex plus or minus a literal real (eg. 1+j2)

While this can be useful with literal types, the main goal is to improve
constant folding in mypyc (mypyc/mypyc#772).

mypyc can also fold bytes addition and multiplication, but mypy cannot
as byte values can't be easily stored anywhere.
  • Loading branch information
ichard26 authored Jun 25, 2023
1 parent 2bb7078 commit cee0030
Show file tree
Hide file tree
Showing 10 changed files with 412 additions and 97 deletions.
118 changes: 90 additions & 28 deletions mypy/constant_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@
from typing import Union
from typing_extensions import Final

from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var
from mypy.nodes import (
ComplexExpr,
Expression,
FloatExpr,
IntExpr,
NameExpr,
OpExpr,
StrExpr,
UnaryExpr,
Var,
)

# All possible result types of constant folding
ConstantValue = Union[int, bool, float, str]
CONST_TYPES: Final = (int, bool, float, str)
ConstantValue = Union[int, bool, float, complex, str]
CONST_TYPES: Final = (int, bool, float, complex, str)


def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None:
Expand All @@ -39,6 +49,8 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non
return expr.value
if isinstance(expr, FloatExpr):
return expr.value
if isinstance(expr, ComplexExpr):
return expr.value
elif isinstance(expr, NameExpr):
if expr.name == "True":
return True
Expand All @@ -56,26 +68,60 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non
elif isinstance(expr, OpExpr):
left = constant_fold_expr(expr.left, cur_mod_id)
right = constant_fold_expr(expr.right, cur_mod_id)
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(expr.op, left, right)
elif isinstance(left, str) and isinstance(right, str):
return constant_fold_binary_str_op(expr.op, left, right)
if left is not None and right is not None:
return constant_fold_binary_op(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(expr.expr, cur_mod_id)
if isinstance(value, int):
return constant_fold_unary_int_op(expr.op, value)
if isinstance(value, float):
return constant_fold_unary_float_op(expr.op, value)
if value is not None:
return constant_fold_unary_op(expr.op, value)
return None


def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
def constant_fold_binary_op(
op: str, left: ConstantValue, right: ConstantValue
) -> ConstantValue | None:
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(op, left, right)

# Float and mixed int/float arithmetic.
if isinstance(left, float) and isinstance(right, float):
return constant_fold_binary_float_op(op, left, right)
elif isinstance(left, float) and isinstance(right, int):
return constant_fold_binary_float_op(op, left, right)
elif isinstance(left, int) and isinstance(right, float):
return constant_fold_binary_float_op(op, left, right)

# String concatenation and multiplication.
if op == "+" and isinstance(left, str) and isinstance(right, str):
return left + right
elif op == "*" and isinstance(left, str) and isinstance(right, int):
return left * right
elif op == "*" and isinstance(left, int) and isinstance(right, str):
return left * right

# Complex construction.
if op == "+" and isinstance(left, (int, float)) and isinstance(right, complex):
return left + right
elif op == "+" and isinstance(left, complex) and isinstance(right, (int, float)):
return left + right
elif op == "-" and isinstance(left, (int, float)) and isinstance(right, complex):
return left - right
elif op == "-" and isinstance(left, complex) and isinstance(right, (int, float)):
return left - right

return None


def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | float | None:
if op == "+":
return left + right
if op == "-":
return left - right
elif op == "*":
return left * right
elif op == "/":
if right != 0:
return left / right
elif op == "//":
if right != 0:
return left // right
Expand All @@ -102,25 +148,41 @@ def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
return None


def constant_fold_unary_int_op(op: str, value: int) -> int | None:
if op == "-":
return -value
elif op == "~":
return ~value
elif op == "+":
return value
def constant_fold_binary_float_op(op: str, left: int | float, right: int | float) -> float | None:
assert not (isinstance(left, int) and isinstance(right, int)), (op, left, right)
if op == "+":
return left + right
elif op == "-":
return left - right
elif op == "*":
return left * right
elif op == "/":
if right != 0:
return left / right
elif op == "//":
if right != 0:
return left // right
elif op == "%":
if right != 0:
return left % right
elif op == "**":
if (left < 0 and isinstance(right, int)) or left > 0:
try:
ret = left**right
except OverflowError:
return None
else:
assert isinstance(ret, float), ret
return ret

return None


def constant_fold_unary_float_op(op: str, value: float) -> float | None:
if op == "-":
def constant_fold_unary_op(op: str, value: ConstantValue) -> int | float | None:
if op == "-" and isinstance(value, (int, float)):
return -value
elif op == "+":
elif op == "~" and isinstance(value, int):
return ~value
elif op == "+" and isinstance(value, (int, float)):
return value
return None


def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None:
if op == "+":
return left + right
return None
2 changes: 1 addition & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
# If constant value is a simple literal,
# store the literal value (unboxed) for the benefit of
# tools like mypyc.
self.final_value: int | float | bool | str | None = None
self.final_value: int | float | complex | bool | str | None = None
# Where the value was set (only for class attributes)
self.final_unset_in_class = False
self.final_set_in_init = False
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3394,7 +3394,7 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Typ
return None

value = constant_fold_expr(rvalue, self.cur_mod_id)
if value is None:
if value is None or isinstance(value, complex):
return None

if isinstance(value, bool):
Expand Down
12 changes: 6 additions & 6 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,25 +535,25 @@ def load_final_static(
error_msg=f'value for final name "{error_name}" was not set',
)

def load_final_literal_value(self, val: int | str | bytes | float | bool, line: int) -> Value:
"""Load value of a final name or class-level attribute."""
def load_literal_value(self, val: int | str | bytes | float | complex | bool) -> Value:
"""Load value of a final name, class-level attribute, or constant folded expression."""
if isinstance(val, bool):
if val:
return self.true()
else:
return self.false()
elif isinstance(val, int):
# TODO: take care of negative integer initializers
# (probably easier to fix this in mypy itself).
return self.builder.load_int(val)
elif isinstance(val, float):
return self.builder.load_float(val)
elif isinstance(val, str):
return self.builder.load_str(val)
elif isinstance(val, bytes):
return self.builder.load_bytes(val)
elif isinstance(val, complex):
return self.builder.load_complex(val)
else:
assert False, "Unsupported final literal value"
assert False, "Unsupported literal value"

def get_assignment_target(
self, lvalue: Lvalue, line: int = -1, *, for_read: bool = False
Expand Down Expand Up @@ -1013,7 +1013,7 @@ def emit_load_final(
line: line number where loading occurs
"""
if final_var.final_value is not None: # this is safe even for non-native names
return self.load_final_literal_value(final_var.final_value, line)
return self.load_literal_value(final_var.final_value)
elif native:
return self.load_final_static(fullname, self.mapper.type_to_rtype(typ), line, name)
else:
Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/callable_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def setup_callable_class(builder: IRBuilder) -> None:
"""Generate an (incomplete) callable class representing function.
"""Generate an (incomplete) callable class representing a function.
This can be a nested function or a function within a non-extension
class. Also set up the 'self' variable for that class.
Expand Down
64 changes: 41 additions & 23 deletions mypyc/irbuild/constant_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
from typing import Union
from typing_extensions import Final

from mypy.constant_fold import (
constant_fold_binary_int_op,
constant_fold_binary_str_op,
constant_fold_unary_float_op,
constant_fold_unary_int_op,
)
from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
from mypy.nodes import (
BytesExpr,
ComplexExpr,
Expression,
FloatExpr,
IntExpr,
Expand All @@ -31,10 +28,11 @@
Var,
)
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.util import bytes_from_str

# All possible result types of constant folding
ConstantValue = Union[int, str, float]
CONST_TYPES: Final = (int, str, float)
ConstantValue = Union[int, float, complex, str, bytes]
CONST_TYPES: Final = (int, float, complex, str, bytes)


def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
Expand All @@ -44,35 +42,55 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue |
"""
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, FloatExpr):
return expr.value
if isinstance(expr, StrExpr):
return expr.value
if isinstance(expr, FloatExpr):
if isinstance(expr, BytesExpr):
return bytes_from_str(expr.value)
if isinstance(expr, ComplexExpr):
return expr.value
elif isinstance(expr, NameExpr):
node = expr.node
if isinstance(node, Var) and node.is_final:
value = node.final_value
if isinstance(value, (CONST_TYPES)):
return value
final_value = node.final_value
if isinstance(final_value, (CONST_TYPES)):
return final_value
elif isinstance(expr, MemberExpr):
final = builder.get_final_ref(expr)
if final is not None:
fn, final_var, native = final
if final_var.is_final:
value = final_var.final_value
if isinstance(value, (CONST_TYPES)):
return value
final_value = final_var.final_value
if isinstance(final_value, (CONST_TYPES)):
return final_value
elif isinstance(expr, OpExpr):
left = constant_fold_expr(builder, expr.left)
right = constant_fold_expr(builder, expr.right)
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(expr.op, left, right)
elif isinstance(left, str) and isinstance(right, str):
return constant_fold_binary_str_op(expr.op, left, right)
if left is not None and right is not None:
return constant_fold_binary_op_extended(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(builder, expr.expr)
if isinstance(value, int):
return constant_fold_unary_int_op(expr.op, value)
if isinstance(value, float):
return constant_fold_unary_float_op(expr.op, value)
if value is not None and not isinstance(value, bytes):
return constant_fold_unary_op(expr.op, value)
return None


def constant_fold_binary_op_extended(
op: str, left: ConstantValue, right: ConstantValue
) -> ConstantValue | None:
"""Like mypy's constant_fold_binary_op(), but includes bytes support.
mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
"""
if not isinstance(left, bytes) and not isinstance(right, bytes):
return constant_fold_binary_op(op, left, right)

if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
return left + right
elif op == "*" and isinstance(left, bytes) and isinstance(right, int):
return left * right
elif op == "*" and isinstance(left, int) and isinstance(right, bytes):
return left * right

return None
15 changes: 2 additions & 13 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
Assign,
BasicBlock,
ComparisonOp,
Float,
Integer,
LoadAddress,
LoadLiteral,
Expand Down Expand Up @@ -92,7 +91,6 @@
tokenizer_printf_style,
)
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
from mypyc.irbuild.util import bytes_from_str
from mypyc.primitives.bytes_ops import bytes_slice_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op
from mypyc.primitives.generic_ops import iter_op
Expand Down Expand Up @@ -575,12 +573,8 @@ def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None:
Return None otherwise.
"""
value = constant_fold_expr(builder, expr)
if isinstance(value, int):
return builder.load_int(value)
elif isinstance(value, str):
return builder.load_str(value)
elif isinstance(value, float):
return Float(value)
if value is not None:
return builder.load_literal_value(value)
return None


Expand Down Expand Up @@ -662,10 +656,6 @@ def set_literal_values(builder: IRBuilder, items: Sequence[Expression]) -> list[
values.append(True)
elif item.fullname == "builtins.False":
values.append(False)
elif isinstance(item, (BytesExpr, FloatExpr, ComplexExpr)):
# constant_fold_expr() doesn't handle these (yet?)
v = bytes_from_str(item.value) if isinstance(item, BytesExpr) else item.value
values.append(v)
elif isinstance(item, TupleExpr):
tuple_values = set_literal_values(builder, item.items)
if tuple_values is not None:
Expand All @@ -685,7 +675,6 @@ def precompute_set_literal(builder: IRBuilder, s: SetExpr) -> Value | None:
Supported items:
- Anything supported by irbuild.constant_fold.constant_fold_expr()
- None, True, and False
- Float, byte, and complex literals
- Tuple literals with only items listed above
"""
values = set_literal_values(builder, s.items)
Expand Down
Loading

0 comments on commit cee0030

Please sign in to comment.