Skip to content

Commit

Permalink
[mypyc] Generate faster code for bool comparisons and arithmetic (#14489
Browse files Browse the repository at this point in the history
)

Generate specialized, efficient IR for various operations on bools.
These are covered:
* Bool comparisons
* Mixed bool/integer comparisons
* Bool arithmetic (binary and unary)
* Mixed bool/integer arithmetic and bitwise ops

Mixed operations where the left operand is a `bool` and the right
operand is a native int still have some unnecessary conversions between
native int and `int`. This would be a bit trickier to fix and is seems
rare, so it doesn't seem urgent to fix this.

Fixes mypyc/mypyc#968.
  • Loading branch information
JukkaL authored Feb 5, 2023
1 parent 27f51fc commit 5614ffa
Show file tree
Hide file tree
Showing 6 changed files with 533 additions and 26 deletions.
6 changes: 5 additions & 1 deletion mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}"
)

def check_compatibility(self, op: Op, t: RType, s: RType) -> None:
if not can_coerce_to(t, s) or not can_coerce_to(s, t):
self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible")

def visit_goto(self, op: Goto) -> None:
self.check_control_op_targets(op)

Expand Down Expand Up @@ -375,7 +379,7 @@ def visit_int_op(self, op: IntOp) -> None:
pass

def visit_comparison_op(self, op: ComparisonOp) -> None:
pass
self.check_compatibility(op, op.lhs.type, op.rhs.type)

def visit_load_mem(self, op: LoadMem) -> None:
pass
Expand Down
82 changes: 57 additions & 25 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@
">>=",
}

# Binary operations on bools that are specialized and don't just promote operands to int
BOOL_BINARY_OPS: Final = {"&", "&=", "|", "|=", "^", "^=", "==", "!=", "<", "<=", ">", ">="}


class LowLevelIRBuilder:
def __init__(self, current_module: str, mapper: Mapper, options: CompilerOptions) -> None:
Expand Down Expand Up @@ -326,13 +329,13 @@ def coerce(
):
# Equivalent types
return src
elif (
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
) and is_int_rprimitive(target_type):
elif (is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)) and is_tagged(
target_type
):
shifted = self.int_op(
bool_rprimitive, src, Integer(1, bool_rprimitive), IntOp.LEFT_SHIFT
)
return self.add(Extend(shifted, int_rprimitive, signed=False))
return self.add(Extend(shifted, target_type, signed=False))
elif (
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
) and is_fixed_width_rtype(target_type):
Expand Down Expand Up @@ -1245,48 +1248,45 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
return self.compare_bytes(lreg, rreg, op, line)
if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping:
return self.compare_tagged(lreg, rreg, op, line)
if (
is_bool_rprimitive(ltype)
and is_bool_rprimitive(rtype)
and op in ("&", "&=", "|", "|=", "^", "^=")
):
return self.bool_bitwise_op(lreg, rreg, op[0], line)
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
if op in ComparisonOp.signed_ops:
return self.bool_comparison_op(lreg, rreg, op, line)
else:
return self.bool_bitwise_op(lreg, rreg, op[0], line)
if isinstance(rtype, RInstance) and op in ("in", "not in"):
return self.translate_instance_contains(rreg, lreg, op, line)
if is_fixed_width_rtype(ltype):
if op in FIXED_WIDTH_INT_BINARY_OPS:
if op.endswith("="):
op = op[:-1]
if op != "//":
op_id = int_op_to_id[op]
else:
op_id = IntOp.DIV
if is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
rreg = self.coerce(rreg, ltype, line)
rtype = ltype
if is_fixed_width_rtype(rtype) or is_tagged(rtype):
if op != "//":
op_id = int_op_to_id[op]
else:
op_id = IntOp.DIV
return self.fixed_width_int_op(ltype, lreg, rreg, op_id, line)
if isinstance(rreg, Integer):
# TODO: Check what kind of Integer
if op != "//":
op_id = int_op_to_id[op]
else:
op_id = IntOp.DIV
return self.fixed_width_int_op(
ltype, lreg, Integer(rreg.value >> 1, ltype), op_id, line
)
elif op in ComparisonOp.signed_ops:
if is_int_rprimitive(rtype):
rreg = self.coerce_int_to_fixed_width(rreg, ltype, line)
elif is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
rreg = self.coerce(rreg, ltype, line)
op_id = ComparisonOp.signed_ops[op]
if is_fixed_width_rtype(rreg.type):
return self.comparison_op(lreg, rreg, op_id, line)
if isinstance(rreg, Integer):
return self.comparison_op(lreg, Integer(rreg.value >> 1, ltype), op_id, line)
elif is_fixed_width_rtype(rtype):
if (
isinstance(lreg, Integer) or is_tagged(ltype)
) and op in FIXED_WIDTH_INT_BINARY_OPS:
if op in FIXED_WIDTH_INT_BINARY_OPS:
if op.endswith("="):
op = op[:-1]
# TODO: Support comparison ops (similar to above)
if op != "//":
op_id = int_op_to_id[op]
else:
Expand All @@ -1296,15 +1296,38 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
return self.fixed_width_int_op(
rtype, Integer(lreg.value >> 1, rtype), rreg, op_id, line
)
else:
if is_tagged(ltype):
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
if is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
lreg = self.coerce(lreg, rtype, line)
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
elif op in ComparisonOp.signed_ops:
if is_int_rprimitive(ltype):
lreg = self.coerce_int_to_fixed_width(lreg, rtype, line)
elif is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
lreg = self.coerce(lreg, rtype, line)
op_id = ComparisonOp.signed_ops[op]
if isinstance(lreg, Integer):
return self.comparison_op(Integer(lreg.value >> 1, rtype), rreg, op_id, line)
if is_fixed_width_rtype(lreg.type):
return self.comparison_op(lreg, rreg, op_id, line)

# Mixed int comparisons
if op in ("==", "!="):
op_id = ComparisonOp.signed_ops[op]
if is_tagged(ltype) and is_subtype(rtype, ltype):
rreg = self.coerce(rreg, int_rprimitive, line)
return self.comparison_op(lreg, rreg, op_id, line)
if is_tagged(rtype) and is_subtype(ltype, rtype):
lreg = self.coerce(lreg, int_rprimitive, line)
return self.comparison_op(lreg, rreg, op_id, line)
elif op in op in int_comparison_op_mapping:
if is_tagged(ltype) and is_subtype(rtype, ltype):
rreg = self.coerce(rreg, short_int_rprimitive, line)
return self.compare_tagged(lreg, rreg, op, line)
if is_tagged(rtype) and is_subtype(ltype, rtype):
lreg = self.coerce(lreg, short_int_rprimitive, line)
return self.compare_tagged(lreg, rreg, op, line)

call_c_ops_candidates = binary_ops.get(op, [])
target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line)
Expand Down Expand Up @@ -1509,14 +1532,21 @@ def bool_bitwise_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value
assert False, op
return self.add(IntOp(bool_rprimitive, lreg, rreg, code, line))

def bool_comparison_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
op_id = ComparisonOp.signed_ops[op]
return self.comparison_op(lreg, rreg, op_id, line)

def unary_not(self, value: Value, line: int) -> Value:
mask = Integer(1, value.type, line)
return self.int_op(value.type, value, mask, IntOp.XOR, line)

def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
typ = value.type
if (is_bool_rprimitive(typ) or is_bit_rprimitive(typ)) and expr_op == "not":
return self.unary_not(value, line)
if is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
if expr_op == "not":
return self.unary_not(value, line)
if expr_op == "+":
return value
if is_fixed_width_rtype(typ):
if expr_op == "-":
# Translate to '0 - x'
Expand All @@ -1532,6 +1562,8 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
if is_short_int_rprimitive(typ):
num >>= 1
return Integer(-num, typ, value.line)
if is_tagged(typ) and expr_op == "+":
return value
if isinstance(typ, RInstance):
if expr_op == "-":
method = "__neg__"
Expand Down
Loading

0 comments on commit 5614ffa

Please sign in to comment.