From c591c891f7c5c35c3546ae6b4709ee97ef9e1136 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Sat, 16 Mar 2024 12:54:54 +0000 Subject: [PATCH] [mypyc] Implement lowering pass and add primitives for int (in)equality (#17027) Add a new `PrimitiveOp` op which can be transformed into lower-level ops in a lowering pass after reference counting op insertion pass. Higher-level ops in IR make it easier to implement various optimizations, and the output of irbuild test cases will be more compact and readable. Implement the lowering pass. Currently it's pretty minimal, and I will add additional primitives and the direct transformation of various primitives to `CallC` ops in follow-up PRs. Currently primitives that map to C calls generate `CallC` ops in the main irbuild pass, but the long-term goal is to only/mostly generate `PrimitiveOp`s instead of `CallC` ops during the main irbuild pass. Also implement primitives for tagged integer equality and inequality as examples. Lowering of primitives is implemented using decorated handler functions in `mypyc.lower` that are found based on the name of the primitive. The name has no other significance, though it's also used in pretty-printed IR output. Work on mypyc/mypyc#854. The issue describes the motivation in more detail. --- mypyc/analysis/dataflow.py | 4 + mypyc/analysis/ircheck.py | 4 + mypyc/analysis/selfleaks.py | 4 + mypyc/codegen/emitfunc.py | 6 + mypyc/codegen/emitmodule.py | 10 +- mypyc/ir/ops.py | 79 ++++++++- mypyc/ir/pprint.py | 17 ++ mypyc/irbuild/ast_helpers.py | 7 +- mypyc/irbuild/expression.py | 4 +- mypyc/irbuild/ll_builder.py | 126 +++++++++++-- mypyc/lower/__init__.py | 0 mypyc/lower/int_ops.py | 15 ++ mypyc/lower/registry.py | 26 +++ mypyc/primitives/int_ops.py | 25 ++- mypyc/primitives/registry.py | 40 +++-- mypyc/test-data/analysis.test | 70 +++----- mypyc/test-data/irbuild-basic.test | 206 +++++++--------------- mypyc/test-data/irbuild-bool.test | 6 +- mypyc/test-data/irbuild-classes.test | 2 +- mypyc/test-data/irbuild-int.test | 28 +-- mypyc/test-data/irbuild-match.test | 45 +++-- mypyc/test-data/irbuild-nested.test | 4 +- mypyc/test-data/irbuild-optional.test | 2 +- mypyc/test-data/irbuild-tuple.test | 87 ++------- mypyc/test-data/lowering-int.test | 126 +++++++++++++ mypyc/test-data/opt-flag-elimination.test | 18 +- mypyc/test-data/refcount.test | 172 ++++++------------ mypyc/test/test_cheader.py | 16 +- mypyc/test/test_emitfunc.py | 2 + mypyc/test/test_lowering.py | 54 ++++++ mypyc/transform/ir_transform.py | 17 +- mypyc/transform/lower.py | 33 ++++ 32 files changed, 772 insertions(+), 483 deletions(-) create mode 100644 mypyc/lower/__init__.py create mode 100644 mypyc/lower/int_ops.py create mode 100644 mypyc/lower/registry.py create mode 100644 mypyc/test-data/lowering-int.test create mode 100644 mypyc/test/test_lowering.py create mode 100644 mypyc/transform/lower.py diff --git a/mypyc/analysis/dataflow.py b/mypyc/analysis/dataflow.py index 57ad2b17fcc5..9babf860fb31 100644 --- a/mypyc/analysis/dataflow.py +++ b/mypyc/analysis/dataflow.py @@ -38,6 +38,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, RegisterOp, Return, @@ -234,6 +235,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]: def visit_call_c(self, op: CallC) -> GenAndKill[T]: return self.visit_register_op(op) + def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill[T]: + return self.visit_register_op(op) + def visit_truncate(self, op: Truncate) -> GenAndKill[T]: return self.visit_register_op(op) diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index 127047e02ff5..88737ac208de 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -37,6 +37,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, Register, Return, @@ -381,6 +382,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None: def visit_call_c(self, op: CallC) -> None: pass + def visit_primitive_op(self, op: PrimitiveOp) -> None: + pass + def visit_truncate(self, op: Truncate) -> None: pass diff --git a/mypyc/analysis/selfleaks.py b/mypyc/analysis/selfleaks.py index 80c2bc348bc2..5d89a9bfc7c6 100644 --- a/mypyc/analysis/selfleaks.py +++ b/mypyc/analysis/selfleaks.py @@ -31,6 +31,7 @@ LoadStatic, MethodCall, OpVisitor, + PrimitiveOp, RaiseStandardError, Register, RegisterOp, @@ -149,6 +150,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill: def visit_call_c(self, op: CallC) -> GenAndKill: return self.check_register_op(op) + def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill: + return self.check_register_op(op) + def visit_truncate(self, op: Truncate) -> GenAndKill: return CLEAN diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index c08f1f840fa4..12f57b9cee6f 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -47,6 +47,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, Register, Return, @@ -629,6 +630,11 @@ def visit_call_c(self, op: CallC) -> None: args = ", ".join(self.reg(arg) for arg in op.args) self.emitter.emit_line(f"{dest}{op.function_name}({args});") + def visit_primitive_op(self, op: PrimitiveOp) -> None: + raise RuntimeError( + f"unexpected PrimitiveOp {op.desc.name}: they must be lowered before codegen" + ) + def visit_truncate(self, op: Truncate) -> None: dest = self.reg(op) value = self.reg(op.src) diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 9466bc2cea79..6c8f5ac91335 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -59,6 +59,7 @@ from mypyc.transform.copy_propagation import do_copy_propagation from mypyc.transform.exceptions import insert_exception_handling from mypyc.transform.flag_elimination import do_flag_elimination +from mypyc.transform.lower import lower_ir from mypyc.transform.refcount import insert_ref_count_opcodes from mypyc.transform.uninit import insert_uninit_checks @@ -235,6 +236,8 @@ def compile_scc_to_ir( insert_exception_handling(fn) # Insert refcount handling. insert_ref_count_opcodes(fn) + # Switch to lower abstraction level IR. + lower_ir(fn, compiler_options) # Perform optimizations. do_copy_propagation(fn, compiler_options) do_flag_elimination(fn, compiler_options) @@ -423,10 +426,11 @@ def compile_modules_to_c( ) modules = compile_modules_to_ir(result, mapper, compiler_options, errors) - ctext = compile_ir_to_c(groups, modules, result, mapper, compiler_options) + if errors.num_errors > 0: + return {}, [] - if errors.num_errors == 0: - write_cache(modules, result, group_map, ctext) + ctext = compile_ir_to_c(groups, modules, result, mapper, compiler_options) + write_cache(modules, result, group_map, ctext) return modules, [ctext[name] for _, name in groups] diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 04c50d1e2841..3acfb0933e5a 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -576,6 +576,78 @@ def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_method_call(self) +class PrimitiveDescription: + """Description of a primitive op. + + Primitives get lowered into lower-level ops before code generation. + + If c_function_name is provided, a primitive will be lowered into a CallC op. + Otherwise custom logic will need to be implemented to transform the + primitive into lower-level ops. + """ + + def __init__( + self, + name: str, + arg_types: list[RType], + return_type: RType, # TODO: What about generic? + var_arg_type: RType | None, + truncated_type: RType | None, + c_function_name: str | None, + error_kind: int, + steals: StealsDescription, + is_borrowed: bool, + ordering: list[int] | None, + extra_int_constants: list[tuple[int, RType]], + priority: int, + ) -> None: + # Each primitive much have a distinct name, but otherwise they are arbitrary. + self.name: Final = name + self.arg_types: Final = arg_types + self.return_type: Final = return_type + self.var_arg_type: Final = var_arg_type + self.truncated_type: Final = truncated_type + # If non-None, this will map to a call of a C helper function; if None, + # there must be a custom handler function that gets invoked during the lowering + # pass to generate low-level IR for the primitive (in the mypyc.lower package) + self.c_function_name: Final = c_function_name + self.error_kind: Final = error_kind + self.steals: Final = steals + self.is_borrowed: Final = is_borrowed + self.ordering: Final = ordering + self.extra_int_constants: Final = extra_int_constants + self.priority: Final = priority + + def __repr__(self) -> str: + return f"" + + +class PrimitiveOp(RegisterOp): + """A higher-level primitive operation. + + Some of these have special compiler support. These will be lowered + (transformed) into lower-level IR ops before code generation, and after + reference counting op insertion. Others will be transformed into CallC + ops. + + Tagged integer equality is a typical primitive op with non-trivial + lowering. It gets transformed into a tag check, followed by different + code paths for short and long representations. + """ + + def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1) -> None: + self.args = args + self.type = desc.return_type + self.error_kind = desc.error_kind + self.desc = desc + + def sources(self) -> list[Value]: + return self.args + + def accept(self, visitor: OpVisitor[T]) -> T: + return visitor.visit_primitive_op(self) + + class LoadErrorValue(RegisterOp): """Load an error value. @@ -1446,7 +1518,8 @@ class Unborrow(RegisterOp): error_kind = ERR_NEVER - def __init__(self, src: Value) -> None: + def __init__(self, src: Value, line: int = -1) -> None: + super().__init__(line) assert src.is_borrowed self.src = src self.type = src.type @@ -1555,6 +1628,10 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> T: def visit_call_c(self, op: CallC) -> T: raise NotImplementedError + @abstractmethod + def visit_primitive_op(self, op: PrimitiveOp) -> T: + raise NotImplementedError + @abstractmethod def visit_truncate(self, op: Truncate) -> T: raise NotImplementedError diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index 5578049256f1..8d6723917ea0 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -43,6 +43,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, Register, Return, @@ -217,6 +218,22 @@ def visit_call_c(self, op: CallC) -> str: else: return self.format("%r = %s(%s)", op, op.function_name, args_str) + def visit_primitive_op(self, op: PrimitiveOp) -> str: + args = [] + arg_index = 0 + type_arg_index = 0 + for arg_type in zip(op.desc.arg_types): + if arg_type: + args.append(self.format("%r", op.args[arg_index])) + arg_index += 1 + else: + assert op.type_args + args.append(self.format("%r", op.type_args[type_arg_index])) + type_arg_index += 1 + + args_str = ", ".join(args) + return self.format("%r = %s %s ", op, op.desc.name, args_str) + def visit_truncate(self, op: Truncate) -> str: return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type) diff --git a/mypyc/irbuild/ast_helpers.py b/mypyc/irbuild/ast_helpers.py index 1af1ad611a89..8490eaa03477 100644 --- a/mypyc/irbuild/ast_helpers.py +++ b/mypyc/irbuild/ast_helpers.py @@ -93,7 +93,12 @@ def maybe_process_conditional_comparison( self.add_bool_branch(reg, true, false) else: # "left op right" for two tagged integers - self.builder.compare_tagged_condition(left, right, op, true, false, e.line) + if op in ("==", "!="): + reg = self.builder.binary_op(left, right, op, e.line) + self.flush_keep_alives() + self.add_bool_branch(reg, true, false) + else: + self.builder.compare_tagged_condition(left, right, op, true, false, e.line) return True diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 81e37953809f..021b7a1dbe90 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -756,7 +756,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: set_literal = precompute_set_literal(builder, e.operands[1]) if set_literal is not None: lhs = e.operands[0] - result = builder.builder.call_c( + result = builder.builder.primitive_op( set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive ) if first_op == "not in": @@ -778,7 +778,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: borrow_left = is_borrow_friendly_expr(builder, right_expr) left = builder.accept(left_expr, can_borrow=borrow_left) right = builder.accept(right_expr, can_borrow=True) - return builder.compare_tagged(left, right, first_op, e.line) + return builder.binary_op(left, right, first_op, e.line) # TODO: Don't produce an expression when used in conditional context # All of the trickiness here is due to support for chained conditionals diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 45c06e11befd..f9bacb43bc3e 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -63,6 +63,8 @@ LoadStatic, MethodCall, Op, + PrimitiveDescription, + PrimitiveOp, RaiseStandardError, Register, SetMem, @@ -1313,7 +1315,12 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: return self.compare_strings(lreg, rreg, op, line) if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="): return self.compare_bytes(lreg, rreg, op, line) - if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping: + if ( + is_tagged(ltype) + and is_tagged(rtype) + and op in int_comparison_op_mapping + and op not in ("==", "!=") + ): return self.compare_tagged(lreg, rreg, op, line) if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS: if op in ComparisonOp.signed_ops: @@ -1379,13 +1386,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: # Mixed int comparisons if op in ("==", "!="): - op_id = ComparisonOp.signed_ops[op] - if is_tagged(ltype) and is_subtype(rtype, ltype): - rreg = self.coerce(rreg, int_rprimitive, line) - return self.comparison_op(lreg, rreg, op_id, line) - if is_tagged(rtype) and is_subtype(ltype, rtype): - lreg = self.coerce(lreg, int_rprimitive, line) - return self.comparison_op(lreg, rreg, op_id, line) + pass # TODO: Do we need anything here? elif op in op in int_comparison_op_mapping: if is_tagged(ltype) and is_subtype(rtype, ltype): rreg = self.coerce(rreg, short_int_rprimitive, line) @@ -1412,8 +1413,8 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: if base_op in float_op_to_id: return self.float_op(lreg, rreg, base_op, line) - call_c_ops_candidates = binary_ops.get(op, []) - target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line) + primitive_ops_candidates = binary_ops.get(op, []) + target = self.matching_primitive_op(primitive_ops_candidates, [lreg, rreg], line) assert target, "Unsupported binary operation: %s" % op return target @@ -1432,7 +1433,14 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) - def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: """Compare two tagged integers using given operator (value context).""" # generate fast binary logic ops on short ints - if is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type): + if (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type)) and op in ( + "==", + "!=", + ): + quick = True + else: + quick = is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type) + if quick: return self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line) op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op] result = Register(bool_rprimitive) @@ -1986,6 +1994,102 @@ def matching_call_c( return target return None + def primitive_op( + self, + desc: PrimitiveDescription, + args: list[Value], + line: int, + result_type: RType | None = None, + ) -> Value: + """Add a primitive op.""" + # Does this primitive map into calling a Python C API + # or an internal mypyc C API function? + if desc.c_function_name: + # TODO: Generate PrimitiOps here and transform them into CallC + # ops only later in the lowering pass + c_desc = CFunctionDescription( + desc.name, + desc.arg_types, + desc.return_type, + desc.var_arg_type, + desc.truncated_type, + desc.c_function_name, + desc.error_kind, + desc.steals, + desc.is_borrowed, + desc.ordering, + desc.extra_int_constants, + desc.priority, + ) + return self.call_c(c_desc, args, line, result_type) + + # This primitve gets transformed in a lowering pass to + # lower-level IR ops using a custom transform function. + + coerced = [] + # Coerce fixed number arguments + for i in range(min(len(args), len(desc.arg_types))): + formal_type = desc.arg_types[i] + arg = args[i] + assert formal_type is not None # TODO + arg = self.coerce(arg, formal_type, line) + coerced.append(arg) + assert desc.ordering is None + assert desc.var_arg_type is None + assert not desc.extra_int_constants + target = self.add(PrimitiveOp(coerced, desc, line=line)) + if desc.is_borrowed: + # If the result is borrowed, force the arguments to be + # kept alive afterwards, as otherwise the result might be + # immediately freed, at the risk of a dangling pointer. + for arg in coerced: + if not isinstance(arg, (Integer, LoadLiteral)): + self.keep_alives.append(arg) + if desc.error_kind == ERR_NEG_INT: + comp = ComparisonOp(target, Integer(0, desc.return_type, line), ComparisonOp.SGE, line) + comp.error_kind = ERR_FALSE + self.add(comp) + + assert desc.truncated_type is None + result = target + if result_type and not is_runtime_subtype(result.type, result_type): + if is_none_rprimitive(result_type): + # Special case None return. The actual result may actually be a bool + # and so we can't just coerce it. + result = self.none() + else: + result = self.coerce(result, result_type, line, can_borrow=desc.is_borrowed) + return result + + def matching_primitive_op( + self, + candidates: list[PrimitiveDescription], + args: list[Value], + line: int, + result_type: RType | None = None, + can_borrow: bool = False, + ) -> Value | None: + matching: PrimitiveDescription | None = None + for desc in candidates: + if len(desc.arg_types) != len(args): + continue + if all( + # formal is not None and # TODO + is_subtype(actual.type, formal) + for actual, formal in zip(args, desc.arg_types) + ) and (not desc.is_borrowed or can_borrow): + if matching: + assert matching.priority != desc.priority, "Ambiguous:\n1) {}\n2) {}".format( + matching, desc + ) + if desc.priority > matching.priority: + matching = desc + else: + matching = desc + if matching: + return self.primitive_op(matching, args, line=line) + return None + def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int = -1) -> Value: """Generate a native integer binary op. diff --git a/mypyc/lower/__init__.py b/mypyc/lower/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/mypyc/lower/int_ops.py b/mypyc/lower/int_ops.py new file mode 100644 index 000000000000..40fba7af4f4d --- /dev/null +++ b/mypyc/lower/int_ops.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from mypyc.ir.ops import Value +from mypyc.irbuild.ll_builder import LowLevelIRBuilder +from mypyc.lower.registry import lower_binary_op + + +@lower_binary_op("int_eq") +def lower_int_eq(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value: + return builder.compare_tagged(args[0], args[1], "==", line) + + +@lower_binary_op("int_ne") +def lower_int_ne(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value: + return builder.compare_tagged(args[0], args[1], "!=", line) diff --git a/mypyc/lower/registry.py b/mypyc/lower/registry.py new file mode 100644 index 000000000000..cc53eb93f4dd --- /dev/null +++ b/mypyc/lower/registry.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Callable, Final, List + +from mypyc.ir.ops import Value +from mypyc.irbuild.ll_builder import LowLevelIRBuilder + +LowerFunc = Callable[[LowLevelIRBuilder, List[Value], int], Value] + + +lowering_registry: Final[dict[str, LowerFunc]] = {} + + +def lower_binary_op(name: str) -> Callable[[LowerFunc], LowerFunc]: + """Register a handler that generates low-level IR for a primitive binary op.""" + + def wrapper(f: LowerFunc) -> LowerFunc: + assert name not in lowering_registry + lowering_registry[name] = f + return f + + return wrapper + + +# Import various modules that set up global state. +import mypyc.lower.int_ops # noqa: F401 diff --git a/mypyc/primitives/int_ops.py b/mypyc/primitives/int_ops.py index 95f9cc5ff43f..4103fe349a74 100644 --- a/mypyc/primitives/int_ops.py +++ b/mypyc/primitives/int_ops.py @@ -12,7 +12,14 @@ from typing import NamedTuple -from mypyc.ir.ops import ERR_ALWAYS, ERR_MAGIC, ERR_MAGIC_OVERLAPPING, ERR_NEVER, ComparisonOp +from mypyc.ir.ops import ( + ERR_ALWAYS, + ERR_MAGIC, + ERR_MAGIC_OVERLAPPING, + ERR_NEVER, + ComparisonOp, + PrimitiveDescription, +) from mypyc.ir.rtypes import ( RType, bit_rprimitive, @@ -101,6 +108,22 @@ ) +def int_binary_primitive( + op: str, primitive_name: str, return_type: RType = int_rprimitive, error_kind: int = ERR_NEVER +) -> PrimitiveDescription: + return binary_op( + name=op, + arg_types=[int_rprimitive, int_rprimitive], + return_type=return_type, + primitive_name=primitive_name, + error_kind=error_kind, + ) + + +int_eq = int_binary_primitive(op="==", primitive_name="int_eq", return_type=bit_rprimitive) +int_ne = int_binary_primitive(op="!=", primitive_name="int_ne", return_type=bit_rprimitive) + + def int_binary_op( name: str, c_function_name: str, diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index 11fca7dc2c70..d4768b4df532 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -39,7 +39,7 @@ from typing import Final, NamedTuple -from mypyc.ir.ops import StealsDescription +from mypyc.ir.ops import PrimitiveDescription, StealsDescription from mypyc.ir.rtypes import RType # Error kind for functions that return negative integer on exception. This @@ -76,7 +76,7 @@ class LoadAddressDescription(NamedTuple): function_ops: dict[str, list[CFunctionDescription]] = {} # CallC op for binary ops -binary_ops: dict[str, list[CFunctionDescription]] = {} +binary_ops: dict[str, list[PrimitiveDescription]] = {} # CallC op for unary ops unary_ops: dict[str, list[CFunctionDescription]] = {} @@ -192,8 +192,9 @@ def binary_op( name: str, arg_types: list[RType], return_type: RType, - c_function_name: str, error_kind: int, + c_function_name: str | None = None, + primitive_name: str | None = None, var_arg_type: RType | None = None, truncated_type: RType | None = None, ordering: list[int] | None = None, @@ -201,7 +202,7 @@ def binary_op( steals: StealsDescription = False, is_borrowed: bool = False, priority: int = 1, -) -> CFunctionDescription: +) -> PrimitiveDescription: """Define a c function call op for a binary operation. This will be automatically generated by matching against the AST. @@ -209,22 +210,24 @@ def binary_op( Most arguments are similar to method_op(), but exactly two argument types are expected. """ + assert c_function_name is not None or primitive_name is not None + assert not (c_function_name is not None and primitive_name is not None) if extra_int_constants is None: extra_int_constants = [] ops = binary_ops.setdefault(name, []) - desc = CFunctionDescription( - name, - arg_types, - return_type, - var_arg_type, - truncated_type, - c_function_name, - error_kind, - steals, - is_borrowed, - ordering, - extra_int_constants, - priority, + desc = PrimitiveDescription( + name=primitive_name or name, + arg_types=arg_types, + return_type=return_type, + var_arg_type=var_arg_type, + truncated_type=truncated_type, + c_function_name=c_function_name, + error_kind=error_kind, + steals=steals, + is_borrowed=is_borrowed, + ordering=ordering, + extra_int_constants=extra_int_constants, + priority=priority, ) ops.append(desc) return desc @@ -311,11 +314,10 @@ def load_address_op(name: str, type: RType, src: str) -> LoadAddressDescription: return LoadAddressDescription(name, type, src) +# Import various modules that set up global state. import mypyc.primitives.bytes_ops import mypyc.primitives.dict_ops import mypyc.primitives.float_ops - -# Import various modules that set up global state. import mypyc.primitives.int_ops import mypyc.primitives.list_ops import mypyc.primitives.misc_ops diff --git a/mypyc/test-data/analysis.test b/mypyc/test-data/analysis.test index efd219cc222a..8e067aed4d79 100644 --- a/mypyc/test-data/analysis.test +++ b/mypyc/test-data/analysis.test @@ -10,40 +10,27 @@ def f(a: int) -> None: [out] def f(a): a, x :: int - r0 :: native_int - r1, r2, r3 :: bit + r0 :: bit y, z :: int L0: x = 2 - r0 = x & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq x, a + if r0 goto L1 else goto L2 :: bool L1: - r2 = CPyTagged_IsEq_(x, a) - if r2 goto L3 else goto L4 :: bool -L2: - r3 = x == a - if r3 goto L3 else goto L4 :: bool -L3: y = 2 - goto L5 -L4: + goto L3 +L2: z = 2 -L5: +L3: return 1 (0, 0) {a} {a, x} (0, 1) {a, x} {a, x} (0, 2) {a, x} {a, x} -(0, 3) {a, x} {a, x} -(1, 0) {a, x} {a, x} -(1, 1) {a, x} {a, x} -(2, 0) {a, x} {a, x} -(2, 1) {a, x} {a, x} -(3, 0) {a, x} {a, x, y} -(3, 1) {a, x, y} {a, x, y} -(4, 0) {a, x} {a, x, z} -(4, 1) {a, x, z} {a, x, z} -(5, 0) {a, x, y, z} {a, x, y, z} +(1, 0) {a, x} {a, x, y} +(1, 1) {a, x, y} {a, x, y} +(2, 0) {a, x} {a, x, z} +(2, 1) {a, x, z} {a, x, z} +(3, 0) {a, x, y, z} {a, x, y, z} [case testSimple_Liveness] def f(a: int) -> int: @@ -58,7 +45,7 @@ def f(a): r0 :: bit L0: x = 2 - r0 = x == 2 + r0 = int_eq x, 2 if r0 goto L1 else goto L2 :: bool L1: return a @@ -124,7 +111,7 @@ def f(a): r0 :: bit y, x :: int L0: - r0 = a == 2 + r0 = int_eq a, 2 if r0 goto L1 else goto L2 :: bool L1: y = 2 @@ -421,40 +408,27 @@ def f(a: int) -> int: [out] def f(a): a :: int - r0 :: native_int - r1, r2, r3 :: bit + r0 :: bit x :: int L0: - r0 = a & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq a, a + if r0 goto L1 else goto L2 :: bool L1: - r2 = CPyTagged_IsEq_(a, a) - if r2 goto L3 else goto L4 :: bool -L2: - r3 = a == a - if r3 goto L3 else goto L4 :: bool -L3: x = 4 a = 2 - goto L5 -L4: + goto L3 +L2: x = 2 -L5: +L3: return x (0, 0) {a} {a} (0, 1) {a} {a} -(0, 2) {a} {a} (1, 0) {a} {a} -(1, 1) {a} {a} +(1, 1) {a} {} +(1, 2) {} {} (2, 0) {a} {a} (2, 1) {a} {a} -(3, 0) {a} {a} -(3, 1) {a} {} -(3, 2) {} {} -(4, 0) {a} {a} -(4, 1) {a} {a} -(5, 0) {} {} +(3, 0) {} {} [case testLoop_BorrowedArgument] def f(a: int) -> int: diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index cd952ef2ebfd..981460dae371 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -568,7 +568,7 @@ L3: x = 2 goto L8 L4: - r4 = n == 0 + r4 = int_eq n, 0 if r4 goto L5 else goto L6 :: bool L5: x = 2 @@ -598,7 +598,7 @@ def f(n): r0 :: bit r1 :: int L0: - r0 = n == 0 + r0 = int_eq n, 0 if r0 goto L1 else goto L2 :: bool L1: r1 = 0 @@ -1462,7 +1462,7 @@ L0: r1 = load_mem r0 :: native_int* keep_alive x r2 = r1 << 1 - r3 = r2 != 0 + r3 = int_ne r2, 0 if r3 goto L1 else goto L2 :: bool L1: return 2 @@ -2052,19 +2052,12 @@ def f(): r13 :: bit r14 :: object r15, x :: int - r16 :: native_int - r17, r18 :: bit - r19 :: bool - r20, r21 :: bit - r22 :: native_int - r23, r24 :: bit - r25 :: bool - r26, r27 :: bit - r28 :: int - r29 :: object - r30 :: i32 - r31 :: bit - r32 :: short_int + r16, r17 :: bit + r18 :: int + r19 :: object + r20 :: i32 + r21 :: bit + r22 :: short_int L0: r0 = PyList_New(0) r1 = PyList_New(3) @@ -2086,52 +2079,30 @@ L1: keep_alive r1 r12 = r11 << 1 r13 = r9 < r12 :: signed - if r13 goto L2 else goto L14 :: bool + if r13 goto L2 else goto L8 :: bool L2: r14 = CPyList_GetItemUnsafe(r1, r9) r15 = unbox(int, r14) x = r15 - r16 = x & 1 - r17 = r16 == 0 - if r17 goto L3 else goto L4 :: bool + r16 = int_ne x, 4 + if r16 goto L4 else goto L3 :: bool L3: - r18 = x != 4 - r19 = r18 - goto L5 + goto L7 L4: - r20 = CPyTagged_IsEq_(x, 4) - r21 = r20 ^ 1 - r19 = r21 + r17 = int_ne x, 6 + if r17 goto L6 else goto L5 :: bool L5: - if r19 goto L7 else goto L6 :: bool + goto L7 L6: - goto L13 + r18 = CPyTagged_Multiply(x, x) + r19 = box(int, r18) + r20 = PyList_Append(r0, r19) + r21 = r20 >= 0 :: signed L7: - r22 = x & 1 - r23 = r22 == 0 - if r23 goto L8 else goto L9 :: bool -L8: - r24 = x != 6 - r25 = r24 - goto L10 -L9: - r26 = CPyTagged_IsEq_(x, 6) - r27 = r26 ^ 1 - r25 = r27 -L10: - if r25 goto L12 else goto L11 :: bool -L11: - goto L13 -L12: - r28 = CPyTagged_Multiply(x, x) - r29 = box(int, r28) - r30 = PyList_Append(r0, r29) - r31 = r30 >= 0 :: signed -L13: - r32 = r9 + 2 - r9 = r32 + r22 = r9 + 2 + r9 = r22 goto L1 -L14: +L8: return r0 [case testDictComprehension] @@ -2151,19 +2122,12 @@ def f(): r13 :: bit r14 :: object r15, x :: int - r16 :: native_int - r17, r18 :: bit - r19 :: bool - r20, r21 :: bit - r22 :: native_int - r23, r24 :: bit - r25 :: bool - r26, r27 :: bit - r28 :: int - r29, r30 :: object - r31 :: i32 - r32 :: bit - r33 :: short_int + r16, r17 :: bit + r18 :: int + r19, r20 :: object + r21 :: i32 + r22 :: bit + r23 :: short_int L0: r0 = PyDict_New() r1 = PyList_New(3) @@ -2185,53 +2149,31 @@ L1: keep_alive r1 r12 = r11 << 1 r13 = r9 < r12 :: signed - if r13 goto L2 else goto L14 :: bool + if r13 goto L2 else goto L8 :: bool L2: r14 = CPyList_GetItemUnsafe(r1, r9) r15 = unbox(int, r14) x = r15 - r16 = x & 1 - r17 = r16 == 0 - if r17 goto L3 else goto L4 :: bool + r16 = int_ne x, 4 + if r16 goto L4 else goto L3 :: bool L3: - r18 = x != 4 - r19 = r18 - goto L5 + goto L7 L4: - r20 = CPyTagged_IsEq_(x, 4) - r21 = r20 ^ 1 - r19 = r21 + r17 = int_ne x, 6 + if r17 goto L6 else goto L5 :: bool L5: - if r19 goto L7 else goto L6 :: bool + goto L7 L6: - goto L13 + r18 = CPyTagged_Multiply(x, x) + r19 = box(int, x) + r20 = box(int, r18) + r21 = CPyDict_SetItem(r0, r19, r20) + r22 = r21 >= 0 :: signed L7: - r22 = x & 1 - r23 = r22 == 0 - if r23 goto L8 else goto L9 :: bool -L8: - r24 = x != 6 - r25 = r24 - goto L10 -L9: - r26 = CPyTagged_IsEq_(x, 6) - r27 = r26 ^ 1 - r25 = r27 -L10: - if r25 goto L12 else goto L11 :: bool -L11: - goto L13 -L12: - r28 = CPyTagged_Multiply(x, x) - r29 = box(int, x) - r30 = box(int, r28) - r31 = CPyDict_SetItem(r0, r29, r30) - r32 = r31 >= 0 :: signed -L13: - r33 = r9 + 2 - r9 = r33 + r23 = r9 + 2 + r9 = r23 goto L1 -L14: +L8: return r0 [case testLoopsMultipleAssign] @@ -3011,85 +2953,57 @@ def call_any(l): r0 :: bool r1, r2 :: object r3, i :: int - r4 :: native_int - r5, r6 :: bit - r7 :: bool - r8, r9 :: bit + r4, r5 :: bit L0: r0 = 0 r1 = PyObject_GetIter(l) L1: r2 = PyIter_Next(r1) - if is_error(r2) goto L9 else goto L2 + if is_error(r2) goto L6 else goto L2 L2: r3 = unbox(int, r2) i = r3 - r4 = i & 1 - r5 = r4 == 0 - if r5 goto L3 else goto L4 :: bool + r4 = int_eq i, 0 + if r4 goto L3 else goto L4 :: bool L3: - r6 = i == 0 - r7 = r6 - goto L5 + r0 = 1 + goto L8 L4: - r8 = CPyTagged_IsEq_(i, 0) - r7 = r8 L5: - if r7 goto L6 else goto L7 :: bool + goto L1 L6: - r0 = 1 - goto L11 + r5 = CPy_NoErrOccured() L7: L8: - goto L1 -L9: - r9 = CPy_NoErrOccured() -L10: -L11: return r0 def call_all(l): l :: object r0 :: bool r1, r2 :: object r3, i :: int - r4 :: native_int - r5, r6 :: bit - r7 :: bool - r8 :: bit - r9 :: bool - r10 :: bit + r4, r5, r6 :: bit L0: r0 = 1 r1 = PyObject_GetIter(l) L1: r2 = PyIter_Next(r1) - if is_error(r2) goto L9 else goto L2 + if is_error(r2) goto L6 else goto L2 L2: r3 = unbox(int, r2) i = r3 - r4 = i & 1 - r5 = r4 == 0 + r4 = int_eq i, 0 + r5 = r4 ^ 1 if r5 goto L3 else goto L4 :: bool L3: - r6 = i == 0 - r7 = r6 - goto L5 + r0 = 0 + goto L8 L4: - r8 = CPyTagged_IsEq_(i, 0) - r7 = r8 L5: - r9 = r7 ^ 1 - if r9 goto L6 else goto L7 :: bool + goto L1 L6: - r0 = 0 - goto L11 + r6 = CPy_NoErrOccured() L7: L8: - goto L1 -L9: - r10 = CPy_NoErrOccured() -L10: -L11: return r0 [case testSum] diff --git a/mypyc/test-data/irbuild-bool.test b/mypyc/test-data/irbuild-bool.test index 731d393d69ab..f0b0b480bc0d 100644 --- a/mypyc/test-data/irbuild-bool.test +++ b/mypyc/test-data/irbuild-bool.test @@ -96,7 +96,7 @@ L0: r1 = load_mem r0 :: native_int* keep_alive l r2 = r1 << 1 - r3 = r2 != 0 + r3 = int_ne r2, 0 return r3 def always_truthy_instance_to_bool(o): o :: __main__.C @@ -222,7 +222,7 @@ def eq1(x, y): L0: r0 = y << 1 r1 = extend r0: builtins.bool to builtins.int - r2 = x == r1 + r2 = int_eq x, r1 return r2 def eq2(x, y): x :: bool @@ -233,7 +233,7 @@ def eq2(x, y): L0: r0 = x << 1 r1 = extend r0: builtins.bool to builtins.int - r2 = r1 == y + r2 = int_eq r1, y return r2 def neq1(x, y): x :: i64 diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 55e55dbf3286..8c4743c6a47f 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1249,7 +1249,7 @@ L0: r0 = x.__getitem__(2) r1 = CPyList_GetItemShortBorrow(r0, 0) r2 = unbox(int, r1) - r3 = r2 == 4 + r3 = int_eq r2, 4 keep_alive r0 if r3 goto L1 else goto L2 :: bool L1: diff --git a/mypyc/test-data/irbuild-int.test b/mypyc/test-data/irbuild-int.test index fbe00aff4040..1489f2f470dd 100644 --- a/mypyc/test-data/irbuild-int.test +++ b/mypyc/test-data/irbuild-int.test @@ -4,24 +4,10 @@ def f(x: int, y: int) -> bool: [out] def f(x, y): x, y :: int - r0 :: native_int - r1, r2 :: bit - r3 :: bool - r4, r5 :: bit + r0 :: bit L0: - r0 = x & 1 - r1 = r0 == 0 - if r1 goto L1 else goto L2 :: bool -L1: - r2 = x != y - r3 = r2 - goto L3 -L2: - r4 = CPyTagged_IsEq_(x, y) - r5 = r4 ^ 1 - r3 = r5 -L3: - return r3 + r0 = int_ne x, y + return r0 [case testShortIntComparisons] def f(x: int) -> int: @@ -43,22 +29,22 @@ def f(x): r4 :: native_int r5, r6, r7 :: bit L0: - r0 = x == 6 + r0 = int_eq x, 6 if r0 goto L1 else goto L2 :: bool L1: return 2 L2: - r1 = x != 8 + r1 = int_ne x, 8 if r1 goto L3 else goto L4 :: bool L3: return 4 L4: - r2 = 10 == x + r2 = int_eq 10, x if r2 goto L5 else goto L6 :: bool L5: return 6 L6: - r3 = 12 != x + r3 = int_ne 12, x if r3 goto L7 else goto L8 :: bool L7: return 8 diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index a078ae0defdb..ab5a19624ba6 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -14,7 +14,7 @@ def f(): r6 :: object_ptr r7, r8 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 if r0 goto L1 else goto L2 :: bool L1: r1 = 'matched' @@ -30,6 +30,7 @@ L2: L3: r8 = box(None, 1) return r8 + [case testMatchOrPattern_python3_10] def f(): match 123: @@ -46,10 +47,10 @@ def f(): r7 :: object_ptr r8, r9 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 if r0 goto L3 else goto L1 :: bool L1: - r1 = 246 == 912 + r1 = int_eq 246, 912 if r1 goto L3 else goto L2 :: bool L2: goto L4 @@ -67,6 +68,7 @@ L4: L5: r9 = box(None, 1) return r9 + [case testMatchOrPatternManyPatterns_python3_10] def f(): match 1: @@ -83,16 +85,16 @@ def f(): r9 :: object_ptr r10, r11 :: object L0: - r0 = 2 == 2 + r0 = int_eq 2, 2 if r0 goto L5 else goto L1 :: bool L1: - r1 = 2 == 4 + r1 = int_eq 2, 4 if r1 goto L5 else goto L2 :: bool L2: - r2 = 2 == 6 + r2 = int_eq 2, 6 if r2 goto L5 else goto L3 :: bool L3: - r3 = 2 == 8 + r3 = int_eq 2, 8 if r3 goto L5 else goto L4 :: bool L4: goto L6 @@ -110,6 +112,7 @@ L6: L7: r11 = box(None, 1) return r11 + [case testMatchClassPattern_python3_10] def f(): match 123: @@ -200,7 +203,7 @@ def f(): r14 :: object_ptr r15, r16 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 if r0 goto L1 else goto L2 :: bool L1: r1 = 'matched' @@ -213,7 +216,7 @@ L1: keep_alive r1 goto L5 L2: - r8 = 246 == 912 + r8 = int_eq 246, 912 if r8 goto L3 else goto L4 :: bool L3: r9 = 'no match' @@ -229,6 +232,7 @@ L4: L5: r16 = box(None, 1) return r16 + [case testMatchMultiBodyAndComplexOr_python3_10] def f(): match 123: @@ -265,7 +269,7 @@ def f(): r23 :: object_ptr r24, r25 :: object L0: - r0 = 246 == 2 + r0 = int_eq 246, 2 if r0 goto L1 else goto L2 :: bool L1: r1 = 'here 1' @@ -278,10 +282,10 @@ L1: keep_alive r1 goto L9 L2: - r8 = 246 == 4 + r8 = int_eq 246, 4 if r8 goto L5 else goto L3 :: bool L3: - r9 = 246 == 6 + r9 = int_eq 246, 6 if r9 goto L5 else goto L4 :: bool L4: goto L6 @@ -296,7 +300,7 @@ L5: keep_alive r10 goto L9 L6: - r17 = 246 == 246 + r17 = int_eq 246, 246 if r17 goto L7 else goto L8 :: bool L7: r18 = 'here 123' @@ -312,6 +316,7 @@ L8: L9: r25 = box(None, 1) return r25 + [case testMatchWithGuard_python3_10] def f(): match 123: @@ -328,7 +333,7 @@ def f(): r6 :: object_ptr r7, r8 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 if r0 goto L1 else goto L3 :: bool L1: if 1 goto L2 else goto L3 :: bool @@ -346,6 +351,7 @@ L3: L4: r8 = box(None, 1) return r8 + [case testMatchSingleton_python3_10] def f(): match 123: @@ -449,7 +455,7 @@ def f(): r9 :: object_ptr r10, r11 :: object L0: - r0 = 2 == 2 + r0 = int_eq 2, 2 if r0 goto L3 else goto L1 :: bool L1: r1 = load_address PyLong_Type @@ -472,6 +478,7 @@ L4: L5: r11 = box(None, 1) return r11 + [case testMatchAsPattern_python3_10] def f(): match 123: @@ -487,7 +494,7 @@ def f(): r6 :: object_ptr r7, r8 :: object L0: - r0 = 246 == 246 + r0 = int_eq 246, 246 r1 = object 123 x = r1 if r0 goto L1 else goto L2 :: bool @@ -504,6 +511,7 @@ L2: L3: r8 = box(None, 1) return r8 + [case testMatchAsPatternOnOrPattern_python3_10] def f(): match 1: @@ -521,12 +529,12 @@ def f(): r8 :: object_ptr r9, r10 :: object L0: - r0 = 2 == 2 + r0 = int_eq 2, 2 r1 = object 1 x = r1 if r0 goto L3 else goto L1 :: bool L1: - r2 = 2 == 4 + r2 = int_eq 2, 4 r3 = object 2 x = r3 if r2 goto L3 else goto L2 :: bool @@ -545,6 +553,7 @@ L4: L5: r10 = box(None, 1) return r10 + [case testMatchAsPatternOnClassPattern_python3_10] def f(): match 123: diff --git a/mypyc/test-data/irbuild-nested.test b/mypyc/test-data/irbuild-nested.test index b2b884705366..62ae6eb9ee35 100644 --- a/mypyc/test-data/irbuild-nested.test +++ b/mypyc/test-data/irbuild-nested.test @@ -658,7 +658,7 @@ def baz_f_obj.__call__(__mypyc_self__, n): r6, r7 :: int L0: r0 = __mypyc_self__.__mypyc_env__ - r1 = n == 0 + r1 = int_eq n, 0 if r1 goto L1 else goto L2 :: bool L1: return 0 @@ -796,7 +796,7 @@ def baz(n): r0 :: bit r1, r2, r3 :: int L0: - r0 = n == 0 + r0 = int_eq n, 0 if r0 goto L1 else goto L2 :: bool L1: return 0 diff --git a/mypyc/test-data/irbuild-optional.test b/mypyc/test-data/irbuild-optional.test index e89018a727da..75c008586999 100644 --- a/mypyc/test-data/irbuild-optional.test +++ b/mypyc/test-data/irbuild-optional.test @@ -222,7 +222,7 @@ def f(y): L0: r0 = box(None, 1) x = r0 - r1 = y == 2 + r1 = int_eq y, 2 if r1 goto L1 else goto L2 :: bool L1: r2 = box(int, y) diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index a47f3db6a725..ab0e2fa09a9d 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -195,70 +195,30 @@ def f(i: int) -> bool: [out] def f(i): i :: int - r0 :: native_int - r1, r2 :: bit + r0 :: bit + r1 :: bool + r2 :: bit r3 :: bool r4 :: bit - r5 :: bool - r6 :: native_int - r7, r8 :: bit - r9 :: bool - r10 :: bit - r11 :: bool - r12 :: native_int - r13, r14 :: bit - r15 :: bool - r16 :: bit L0: - r0 = i & 1 - r1 = r0 == 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq i, 2 + if r0 goto L1 else goto L2 :: bool L1: - r2 = i == 2 - r3 = r2 + r1 = r0 goto L3 L2: - r4 = CPyTagged_IsEq_(i, 2) - r3 = r4 + r2 = int_eq i, 4 + r1 = r2 L3: - if r3 goto L4 else goto L5 :: bool + if r1 goto L4 else goto L5 :: bool L4: - r5 = r3 - goto L9 + r3 = r1 + goto L6 L5: - r6 = i & 1 - r7 = r6 == 0 - if r7 goto L6 else goto L7 :: bool + r4 = int_eq i, 6 + r3 = r4 L6: - r8 = i == 4 - r9 = r8 - goto L8 -L7: - r10 = CPyTagged_IsEq_(i, 4) - r9 = r10 -L8: - r5 = r9 -L9: - if r5 goto L10 else goto L11 :: bool -L10: - r11 = r5 - goto L15 -L11: - r12 = i & 1 - r13 = r12 == 0 - if r13 goto L12 else goto L13 :: bool -L12: - r14 = i == 6 - r15 = r14 - goto L14 -L13: - r16 = CPyTagged_IsEq_(i, 6) - r15 = r16 -L14: - r11 = r15 -L15: - return r11 - + return r3 [case testTupleBuiltFromList] def f(val: int) -> bool: @@ -270,24 +230,11 @@ def test() -> None: [out] def f(val): val, r0 :: int - r1 :: native_int - r2, r3 :: bit - r4 :: bool - r5 :: bit + r1 :: bit L0: r0 = CPyTagged_Remainder(val, 4) - r1 = r0 & 1 - r2 = r1 == 0 - if r2 goto L1 else goto L2 :: bool -L1: - r3 = r0 == 0 - r4 = r3 - goto L3 -L2: - r5 = CPyTagged_IsEq_(r0, 0) - r4 = r5 -L3: - return r4 + r1 = int_eq r0, 0 + return r1 def test(): r0 :: list r1, r2, r3 :: object diff --git a/mypyc/test-data/lowering-int.test b/mypyc/test-data/lowering-int.test new file mode 100644 index 000000000000..8c813563d0e6 --- /dev/null +++ b/mypyc/test-data/lowering-int.test @@ -0,0 +1,126 @@ +-- Test cases for converting high-level IR to lower-level IR (lowering). + +[case testLowerIntEq] +def f(x: int, y: int) -> int: + if x == y: + return 1 + else: + return 2 +[out] +def f(x, y): + x, y :: int + r0 :: native_int + r1, r2, r3 :: bit +L0: + r0 = x & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = x == y + if r2 goto L3 else goto L4 :: bool +L2: + r3 = CPyTagged_IsEq_(x, y) + if r3 goto L3 else goto L4 :: bool +L3: + return 2 +L4: + return 4 + +[case testLowerIntNe] +def f(x: int, y: int) -> int: + if x != y: + return 1 + else: + return 2 +[out] +def f(x, y): + x, y :: int + r0 :: native_int + r1, r2, r3, r4 :: bit +L0: + r0 = x & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = x != y + if r2 goto L3 else goto L4 :: bool +L2: + r3 = CPyTagged_IsEq_(x, y) + r4 = r3 ^ 1 + if r4 goto L3 else goto L4 :: bool +L3: + return 2 +L4: + return 4 + +[case testLowerIntEqWithConstant] +def f(x: int, y: int) -> int: + if x == 2: + return 1 + elif -1 == x: + return 2 + return 3 +[out] +def f(x, y): + x, y :: int + r0, r1 :: bit +L0: + r0 = x == 4 + if r0 goto L1 else goto L2 :: bool +L1: + return 2 +L2: + r1 = -2 == x + if r1 goto L3 else goto L4 :: bool +L3: + return 4 +L4: + return 6 + +[case testLowerIntNeWithConstant] +def f(x: int, y: int) -> int: + if x != 2: + return 1 + elif -1 != x: + return 2 + return 3 +[out] +def f(x, y): + x, y :: int + r0, r1 :: bit +L0: + r0 = x != 4 + if r0 goto L1 else goto L2 :: bool +L1: + return 2 +L2: + r1 = -2 != x + if r1 goto L3 else goto L4 :: bool +L3: + return 4 +L4: + return 6 + +[case testLowerIntEqValueContext] +def f(x: int, y: int) -> bool: + return x == y +[out] +def f(x, y): + x, y :: int + r0 :: native_int + r1, r2 :: bit + r3 :: bool + r4 :: bit +L0: + r0 = x & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = x == y + r3 = r2 + goto L3 +L2: + r4 = CPyTagged_IsEq_(x, y) + r3 = r4 +L3: + return r3 diff --git a/mypyc/test-data/opt-flag-elimination.test b/mypyc/test-data/opt-flag-elimination.test index f047a87dc3fa..337ced70a355 100644 --- a/mypyc/test-data/opt-flag-elimination.test +++ b/mypyc/test-data/opt-flag-elimination.test @@ -29,15 +29,13 @@ L0: if x goto L1 else goto L2 :: bool L1: r0 = c() - if r0 goto L4 else goto L5 :: bool + if r0 goto L3 else goto L4 :: bool L2: r1 = d() - if r1 goto L4 else goto L5 :: bool + if r1 goto L3 else goto L4 :: bool L3: - unreachable -L4: return 2 -L5: +L4: return 4 [case testFlagEliminationOneAssignment] @@ -92,20 +90,18 @@ L0: if x goto L1 else goto L2 :: bool L1: r0 = c(2) - if r0 goto L6 else goto L7 :: bool + if r0 goto L5 else goto L6 :: bool L2: if y goto L3 else goto L4 :: bool L3: r1 = c(4) - if r1 goto L6 else goto L7 :: bool + if r1 goto L5 else goto L6 :: bool L4: r2 = c(6) - if r2 goto L6 else goto L7 :: bool + if r2 goto L5 else goto L6 :: bool L5: - unreachable -L6: return 2 -L7: +L6: return 4 [case testFlagEliminationAssignmentNotLastOp] diff --git a/mypyc/test-data/refcount.test b/mypyc/test-data/refcount.test index 0f2c134ae21e..df980af8a7c7 100644 --- a/mypyc/test-data/refcount.test +++ b/mypyc/test-data/refcount.test @@ -67,7 +67,7 @@ def f(): L0: x = 2 y = 4 - r0 = x == 2 + r0 = int_eq x, 2 if r0 goto L3 else goto L4 :: bool L1: return x @@ -185,34 +185,26 @@ def f(a: int) -> int: [out] def f(a): a :: int - r0 :: native_int - r1, r2, r3 :: bit - x, r4, y :: int + r0 :: bit + x, r1, y :: int L0: - r0 = a & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq a, a + if r0 goto L1 else goto L2 :: bool L1: - r2 = CPyTagged_IsEq_(a, a) - if r2 goto L3 else goto L4 :: bool -L2: - r3 = a == a - if r3 goto L3 else goto L4 :: bool -L3: a = 2 - goto L5 -L4: + goto L3 +L2: x = 4 dec_ref x :: int - goto L6 -L5: - r4 = CPyTagged_Add(a, 2) + goto L4 +L3: + r1 = CPyTagged_Add(a, 2) dec_ref a :: int - y = r4 + y = r1 return y -L6: +L4: inc_ref a :: int - goto L5 + goto L3 [case testConditionalAssignToArgument2] def f(a: int) -> int: @@ -225,33 +217,25 @@ def f(a: int) -> int: [out] def f(a): a :: int - r0 :: native_int - r1, r2, r3 :: bit - x, r4, y :: int + r0 :: bit + x, r1, y :: int L0: - r0 = a & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq a, a + if r0 goto L1 else goto L2 :: bool L1: - r2 = CPyTagged_IsEq_(a, a) - if r2 goto L3 else goto L4 :: bool -L2: - r3 = a == a - if r3 goto L3 else goto L4 :: bool -L3: x = 4 dec_ref x :: int - goto L6 -L4: + goto L4 +L2: a = 2 -L5: - r4 = CPyTagged_Add(a, 2) +L3: + r1 = CPyTagged_Add(a, 2) dec_ref a :: int - y = r4 + y = r1 return y -L6: +L4: inc_ref a :: int - goto L5 + goto L3 [case testConditionalAssignToArgument3] def f(a: int) -> int: @@ -261,25 +245,17 @@ def f(a: int) -> int: [out] def f(a): a :: int - r0 :: native_int - r1, r2, r3 :: bit + r0 :: bit L0: - r0 = a & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq a, a + if r0 goto L1 else goto L3 :: bool L1: - r2 = CPyTagged_IsEq_(a, a) - if r2 goto L3 else goto L5 :: bool -L2: - r3 = a == a - if r3 goto L3 else goto L5 :: bool -L3: a = 2 -L4: +L2: return a -L5: +L3: inc_ref a :: int - goto L4 + goto L2 [case testAssignRegisterToItself] def f(a: int) -> int: @@ -438,40 +414,32 @@ def f() -> int: [out] def f(): x, y, z :: int - r0 :: native_int - r1, r2, r3 :: bit - a, r4, r5 :: int + r0 :: bit + a, r1, r2 :: int L0: x = 2 y = 4 z = 6 - r0 = z & 1 - r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + r0 = int_eq z, z + if r0 goto L3 else goto L4 :: bool L1: - r2 = CPyTagged_IsEq_(z, z) - if r2 goto L5 else goto L6 :: bool -L2: - r3 = z == z - if r3 goto L5 else goto L6 :: bool -L3: return z -L4: +L2: a = 2 - r4 = CPyTagged_Add(x, y) + r1 = CPyTagged_Add(x, y) dec_ref x :: int dec_ref y :: int - r5 = CPyTagged_Subtract(r4, a) - dec_ref r4 :: int + r2 = CPyTagged_Subtract(r1, a) + dec_ref r1 :: int dec_ref a :: int - return r5 -L5: + return r2 +L3: dec_ref x :: int dec_ref y :: int - goto L3 -L6: + goto L1 +L4: dec_ref z :: int - goto L4 + goto L2 [case testLoop] def f(a: int) -> int: @@ -1371,25 +1339,12 @@ class C: def add(c): c :: __main__.C r0, r1 :: int - r2 :: native_int - r3, r4 :: bit - r5 :: bool - r6 :: bit + r2 :: bit L0: r0 = borrow c.x r1 = borrow c.y - r2 = r0 & 1 - r3 = r2 == 0 - if r3 goto L1 else goto L2 :: bool -L1: - r4 = r0 == r1 - r5 = r4 - goto L3 -L2: - r6 = CPyTagged_IsEq_(r0, r1) - r5 = r6 -L3: - return r5 + r2 = int_eq r0, r1 + return r2 [case testBorrowIntLessThan] def add(c: C) -> bool: @@ -1441,24 +1396,11 @@ class C: def add(c): c :: __main__.C r0 :: int - r1 :: native_int - r2, r3 :: bit - r4 :: bool - r5 :: bit + r1 :: bit L0: r0 = borrow c.x - r1 = r0 & 1 - r2 = r1 == 0 - if r2 goto L1 else goto L2 :: bool -L1: - r3 = r0 == 20 - r4 = r3 - goto L3 -L2: - r5 = CPyTagged_IsEq_(r0, 20) - r4 = r5 -L3: - return r4 + r1 = int_eq r0, 20 + return r1 [case testBorrowIntArithmetic] def add(c: C) -> int: @@ -1501,23 +1443,15 @@ class C: def add(c, n): c :: __main__.C n, r0, r1 :: int - r2 :: native_int - r3, r4, r5 :: bit + r2 :: bit L0: r0 = borrow c.x r1 = borrow c.y - r2 = r0 & 1 - r3 = r2 != 0 - if r3 goto L1 else goto L2 :: bool + r2 = int_eq r0, r1 + if r2 goto L1 else goto L2 :: bool L1: - r4 = CPyTagged_IsEq_(r0, r1) - if r4 goto L3 else goto L4 :: bool -L2: - r5 = r0 == r1 - if r5 goto L3 else goto L4 :: bool -L3: return 1 -L4: +L2: return 0 [case testBorrowIntInPlaceOp] diff --git a/mypyc/test/test_cheader.py b/mypyc/test/test_cheader.py index cc0fd9df2b34..f2af41c22ea9 100644 --- a/mypyc/test/test_cheader.py +++ b/mypyc/test/test_cheader.py @@ -7,6 +7,7 @@ import re import unittest +from mypyc.ir.ops import PrimitiveDescription from mypyc.primitives import registry from mypyc.primitives.registry import CFunctionDescription @@ -25,17 +26,24 @@ def check_name(name: str) -> None: rf"\b{name}\b", header ), f'"{name}" is used in mypyc.primitives but not declared in CPy.h' - for values in [ + for old_values in [ registry.method_call_ops.values(), registry.function_ops.values(), - registry.binary_ops.values(), registry.unary_ops.values(), ]: + for old_ops in old_values: + if isinstance(old_ops, CFunctionDescription): + old_ops = [old_ops] + for old_op in old_ops: + check_name(old_op.c_function_name) + + for values in [registry.binary_ops.values()]: for ops in values: - if isinstance(ops, CFunctionDescription): + if isinstance(ops, PrimitiveDescription): ops = [ops] for op in ops: - check_name(op.c_function_name) + if op.c_function_name is not None: + check_name(op.c_function_name) primitives_path = os.path.join(os.path.dirname(__file__), "..", "primitives") for fnam in glob.glob(f"{primitives_path}/*.py"): diff --git a/mypyc/test/test_emitfunc.py b/mypyc/test/test_emitfunc.py index ab1586bb22a8..b16387aa40af 100644 --- a/mypyc/test/test_emitfunc.py +++ b/mypyc/test/test_emitfunc.py @@ -859,6 +859,8 @@ def assert_emit_binary_op( args = [left, right] if desc.ordering is not None: args = [args[i] for i in desc.ordering] + # This only supports primitives that map to C calls + assert desc.c_function_name is not None self.assert_emit( CallC( desc.c_function_name, diff --git a/mypyc/test/test_lowering.py b/mypyc/test/test_lowering.py new file mode 100644 index 000000000000..e32dba2e1021 --- /dev/null +++ b/mypyc/test/test_lowering.py @@ -0,0 +1,54 @@ +"""Runner for lowering transform tests.""" + +from __future__ import annotations + +import os.path + +from mypy.errors import CompileError +from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase +from mypyc.common import TOP_LEVEL_NAME +from mypyc.ir.pprint import format_func +from mypyc.options import CompilerOptions +from mypyc.test.testutil import ( + ICODE_GEN_BUILTINS, + MypycDataSuite, + assert_test_output, + build_ir_for_single_file, + remove_comment_lines, + use_custom_builtins, +) +from mypyc.transform.exceptions import insert_exception_handling +from mypyc.transform.flag_elimination import do_flag_elimination +from mypyc.transform.lower import lower_ir +from mypyc.transform.refcount import insert_ref_count_opcodes +from mypyc.transform.uninit import insert_uninit_checks + + +class TestLowering(MypycDataSuite): + files = ["lowering-int.test"] + base_path = test_temp_dir + + def run_case(self, testcase: DataDrivenTestCase) -> None: + with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase): + expected_output = remove_comment_lines(testcase.output) + try: + ir = build_ir_for_single_file(testcase.input) + except CompileError as e: + actual = e.messages + else: + actual = [] + for fn in ir: + if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"): + continue + options = CompilerOptions() + # Lowering happens after exception handling and ref count opcodes have + # been added. Any changes must maintain reference counting semantics. + insert_uninit_checks(fn) + insert_exception_handling(fn) + insert_ref_count_opcodes(fn) + lower_ir(fn, options) + do_flag_elimination(fn, options) + actual.extend(format_func(fn)) + + assert_test_output(testcase, actual, "Invalid source code output", expected_output) diff --git a/mypyc/transform/ir_transform.py b/mypyc/transform/ir_transform.py index 254fe3f7771d..a631bd7352b5 100644 --- a/mypyc/transform/ir_transform.py +++ b/mypyc/transform/ir_transform.py @@ -35,6 +35,7 @@ MethodCall, Op, OpVisitor, + PrimitiveOp, RaiseStandardError, Return, SetAttr, @@ -80,6 +81,7 @@ def transform_blocks(self, blocks: list[BasicBlock]) -> None: """ block_map: dict[BasicBlock, BasicBlock] = {} op_map = self.op_map + empties = set() for block in blocks: new_block = BasicBlock() block_map[block] = new_block @@ -89,7 +91,10 @@ def transform_blocks(self, blocks: list[BasicBlock]) -> None: new_op = op.accept(self) if new_op is not op: op_map[op] = new_op - + # A transform can produce empty blocks which can be removed. + if is_empty_block(new_block) and not is_empty_block(block): + empties.add(new_block) + self.builder.blocks = [block for block in self.builder.blocks if block not in empties] # Update all op/block references to point to the transformed ones. patcher = PatchVisitor(op_map, block_map) for block in self.builder.blocks: @@ -170,6 +175,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> Value | None: def visit_call_c(self, op: CallC) -> Value | None: return self.add(op) + def visit_primitive_op(self, op: PrimitiveOp) -> Value | None: + return self.add(op) + def visit_truncate(self, op: Truncate) -> Value | None: return self.add(op) @@ -302,6 +310,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None: def visit_call_c(self, op: CallC) -> None: op.args = [self.fix_op(arg) for arg in op.args] + def visit_primitive_op(self, op: PrimitiveOp) -> None: + op.args = [self.fix_op(arg) for arg in op.args] + def visit_truncate(self, op: Truncate) -> None: op.src = self.fix_op(op.src) @@ -351,3 +362,7 @@ def visit_keep_alive(self, op: KeepAlive) -> None: def visit_unborrow(self, op: Unborrow) -> None: op.src = self.fix_op(op.src) + + +def is_empty_block(block: BasicBlock) -> bool: + return len(block.ops) == 1 and isinstance(block.ops[0], Unreachable) diff --git a/mypyc/transform/lower.py b/mypyc/transform/lower.py new file mode 100644 index 000000000000..b717657095f9 --- /dev/null +++ b/mypyc/transform/lower.py @@ -0,0 +1,33 @@ +"""Transform IR to lower-level ops. + +Higher-level ops are used in earlier compiler passes, as they make +various analyses, optimizations and transforms easier to implement. +Later passes use lower-level ops, as they are easier to generate code +from, and they help with lower-level optimizations. + +Lowering of various primitive ops is implemented in the mypyc.lower +package. +""" + +from mypyc.ir.func_ir import FuncIR +from mypyc.ir.ops import PrimitiveOp, Value +from mypyc.irbuild.ll_builder import LowLevelIRBuilder +from mypyc.lower.registry import lowering_registry +from mypyc.options import CompilerOptions +from mypyc.transform.ir_transform import IRTransform + + +def lower_ir(ir: FuncIR, options: CompilerOptions) -> None: + builder = LowLevelIRBuilder(None, options) + visitor = LoweringVisitor(builder) + visitor.transform_blocks(ir.blocks) + ir.blocks = builder.blocks + + +class LoweringVisitor(IRTransform): + def visit_primitive_op(self, op: PrimitiveOp) -> Value: + # The lowering implementation functions of various primitive ops are stored + # in a registry, which is populated using function decorators. The name + # of op (such as "int_eq") is used as the key. + lower_fn = lowering_registry[op.desc.name] + return lower_fn(self.builder, op.args, op.line)