diff --git a/Makefile b/Makefile index 8eb4a82a9d63..b347d16c842b 100644 --- a/Makefile +++ b/Makefile @@ -532,6 +532,7 @@ SOURCE_FILES = \ Simplify_And.cpp \ Simplify_Call.cpp \ Simplify_Cast.cpp \ + Simplify_Reinterpret.cpp \ Simplify_Div.cpp \ Simplify_EQ.cpp \ Simplify_Exprs.cpp \ diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 51465515a80e..eb2f9138268b 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -243,6 +243,20 @@ class Bounds : public IRVisitor { interval = Interval::single_point(op); } + void visit(const Reinterpret *op) override { + TRACK_BOUNDS_INTERVAL; + + Type t = op->type.element_of(); + + if (t.is_handle()) { + interval = Interval::everything(); + return; + } + + // Just use the bounds of the type + bounds_of_type(t); + } + void visit(const Cast *op) override { TRACK_BOUNDS_INTERVAL; op->value.accept(this); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 56035abe7391..e0015a65551a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -299,6 +299,7 @@ set(SOURCE_FILES Simplify_Add.cpp Simplify_And.cpp Simplify_Call.cpp + Simplify_Reinterpret.cpp Simplify_Cast.cpp Simplify_Div.cpp Simplify_EQ.cpp diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index e9ab248757fa..de4cd010a7d3 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -2071,6 +2071,10 @@ void CodeGen_C::visit(const Cast *op) { id = print_cast_expr(op->type, op->value); } +void CodeGen_C::visit(const Reinterpret *op) { + id = print_assignment(op->type, print_reinterpret(op->type, op->value)); +} + void CodeGen_C::visit_binop(Type t, const Expr &a, const Expr &b, const char *op) { string sa = print_expr(a); string sb = print_expr(b); @@ -2294,9 +2298,6 @@ void CodeGen_C::visit(const Call *op) { } else if (op->is_intrinsic(Call::bitwise_not)) { internal_assert(op->args.size() == 1); rhs << "~" << print_expr(op->args[0]); - } else if (op->is_intrinsic(Call::reinterpret)) { - internal_assert(op->args.size() == 1); - rhs << print_reinterpret(op->type, op->args[0]); } else if (op->is_intrinsic(Call::shift_left)) { internal_assert(op->args.size() == 2); if (op->args[1].type().is_uint()) { diff --git a/src/CodeGen_C.h b/src/CodeGen_C.h index 401c5e753f61..9c06d4bb5630 100644 --- a/src/CodeGen_C.h +++ b/src/CodeGen_C.h @@ -196,6 +196,7 @@ class CodeGen_C : public IRPrinter { void visit(const StringImm *) override; void visit(const FloatImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Add *) override; void visit(const Sub *) override; void visit(const Mul *) override; diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 60b53f60b1f8..74dda7e5fd34 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1476,6 +1476,62 @@ void CodeGen_LLVM::visit(const Cast *op) { } } +void CodeGen_LLVM::visit(const Reinterpret *op) { + Type dst = op->type; + Type src = op->value.type(); + llvm::Type *llvm_dst = llvm_type_of(dst); + value = codegen(op->value); + if (src.is_handle() && !dst.is_handle()) { + internal_assert(dst.is_uint() && dst.bits() == 64); + + // Handle -> UInt64 + llvm::DataLayout d(module.get()); + if (d.getPointerSize() == 4) { + llvm::Type *intermediate = llvm_type_of(UInt(32, dst.lanes())); + value = builder->CreatePtrToInt(value, intermediate); + value = builder->CreateZExt(value, llvm_dst); + } else if (d.getPointerSize() == 8) { + value = builder->CreatePtrToInt(value, llvm_dst); + } else { + internal_error << "Pointer size is neither 4 nor 8 bytes\n"; + } + + } else if (dst.is_handle() && !src.is_handle()) { + internal_assert(src.is_uint() && src.bits() == 64); + + // UInt64 -> Handle + llvm::DataLayout d(module.get()); + if (d.getPointerSize() == 4) { + llvm::Type *intermediate = llvm_type_of(UInt(32, src.lanes())); + value = builder->CreateTrunc(value, intermediate); + value = builder->CreateIntToPtr(value, llvm_dst); + } else if (d.getPointerSize() == 8) { + value = builder->CreateIntToPtr(value, llvm_dst); + } else { + internal_error << "Pointer size is neither 4 nor 8 bytes\n"; + } + + } else { + if (src.is_scalar() && dst.is_vector()) { + // If the source type is a scalar, we promote it to an + // equivalent vector of width one before doing the + // bitcast, because llvm's bitcast operator doesn't + // want to convert between scalars and vectors. + value = create_broadcast(value, 1); + } + if (src.is_vector() && dst.is_scalar()) { + // Similarly, if we're converting from a vector to a + // scalar, convert to a vector of width 1 first, and + // then extract the first lane. + llvm_dst = get_vector_type(llvm_dst, 1); + } + value = builder->CreateBitCast(value, llvm_dst); + if (src.is_vector() && dst.is_scalar()) { + value = builder->CreateExtractElement(value, (uint64_t)0); + } + } +} + void CodeGen_LLVM::visit(const Variable *op) { value = sym_get(op->name); } @@ -2593,61 +2649,6 @@ void CodeGen_LLVM::visit(const Call *op) { internal_assert(op->args.size() == 1); Value *a = codegen(op->args[0]); value = builder->CreateNot(a); - } else if (op->is_intrinsic(Call::reinterpret)) { - internal_assert(op->args.size() == 1); - Type dst = op->type; - Type src = op->args[0].type(); - llvm::Type *llvm_dst = llvm_type_of(dst); - value = codegen(op->args[0]); - if (src.is_handle() && !dst.is_handle()) { - internal_assert(dst.is_uint() && dst.bits() == 64); - - // Handle -> UInt64 - llvm::DataLayout d(module.get()); - if (d.getPointerSize() == 4) { - llvm::Type *intermediate = llvm_type_of(UInt(32, dst.lanes())); - value = builder->CreatePtrToInt(value, intermediate); - value = builder->CreateZExt(value, llvm_dst); - } else if (d.getPointerSize() == 8) { - value = builder->CreatePtrToInt(value, llvm_dst); - } else { - internal_error << "Pointer size is neither 4 nor 8 bytes\n"; - } - - } else if (dst.is_handle() && !src.is_handle()) { - internal_assert(src.is_uint() && src.bits() == 64); - - // UInt64 -> Handle - llvm::DataLayout d(module.get()); - if (d.getPointerSize() == 4) { - llvm::Type *intermediate = llvm_type_of(UInt(32, src.lanes())); - value = builder->CreateTrunc(value, intermediate); - value = builder->CreateIntToPtr(value, llvm_dst); - } else if (d.getPointerSize() == 8) { - value = builder->CreateIntToPtr(value, llvm_dst); - } else { - internal_error << "Pointer size is neither 4 nor 8 bytes\n"; - } - - } else { - if (src.is_scalar() && dst.is_vector()) { - // If the source type is a scalar, we promote it to an - // equivalent vector of width one before doing the - // bitcast, because llvm's bitcast operator doesn't - // want to convert between scalars and vectors. - value = create_broadcast(value, 1); - } - if (src.is_vector() && dst.is_scalar()) { - // Similarly, if we're converting from a vector to a - // scalar, convert to a vector of width 1 first, and - // then extract the first lane. - llvm_dst = get_vector_type(llvm_dst, 1); - } - value = builder->CreateBitCast(value, llvm_dst); - if (src.is_vector() && dst.is_scalar()) { - value = builder->CreateExtractElement(value, (uint64_t)0); - } - } } else if (op->is_intrinsic(Call::shift_left)) { internal_assert(op->args.size() == 2); if (op->args[1].type().is_uint()) { diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index da31f5a1cd19..5982aa1672fd 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -329,6 +329,7 @@ class CodeGen_LLVM : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 5d46d60bf09e..e368d851d615 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -322,6 +322,15 @@ class Deinterleaver : public IRGraphMutator { } } + Expr visit(const Reinterpret *op) override { + if (op->type.is_scalar()) { + return op; + } else { + Type t = op->type.with_lanes(new_lanes); + return Reinterpret::make(t, mutate(op->value)); + } + } + Expr visit(const Call *op) override { Type t = op->type.with_lanes(new_lanes); diff --git a/src/Derivative.cpp b/src/Derivative.cpp index db914e36e143..c536eeea92ae 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -55,6 +55,7 @@ class ReverseAccumulationVisitor : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *op) override; + void visit(const Reinterpret *op) override; void visit(const Variable *op) override; void visit(const Add *op) override; void visit(const Sub *op) override; @@ -836,6 +837,14 @@ void ReverseAccumulationVisitor::visit(const Cast *op) { } } +void ReverseAccumulationVisitor::visit(const Reinterpret *op) { + internal_assert(expr_adjoints.find(op) != expr_adjoints.end()); + Expr adjoint = expr_adjoints[op]; + + // bit manipulation -- has zero derivative. + accumulate(op->value, make_zero(op->type)); +} + void ReverseAccumulationVisitor::visit(const Variable *op) { internal_assert(expr_adjoints.find(op) != expr_adjoints.end()); Expr adjoint = expr_adjoints[op]; @@ -1169,8 +1178,7 @@ void ReverseAccumulationVisitor::visit(const Call *op) { accumulate(op->args[1], adjoint); } else if (op->is_intrinsic(Call::undef)) { // do nothing - } else if (op->is_intrinsic(Call::reinterpret) || - op->is_intrinsic(Call::bitwise_and) || + } else if (op->is_intrinsic(Call::bitwise_and) || op->is_intrinsic(Call::bitwise_not) || op->is_intrinsic(Call::bitwise_or) || op->is_intrinsic(Call::bitwise_xor) || diff --git a/src/EliminateBoolVectors.cpp b/src/EliminateBoolVectors.cpp index 2e63382f644c..cebfe0f0019b 100644 --- a/src/EliminateBoolVectors.cpp +++ b/src/EliminateBoolVectors.cpp @@ -136,6 +136,8 @@ class EliminateBoolVectors : public IRMutator { } } + // FIXME: what about Reinterpret? + Stmt visit(const Store *op) override { Expr predicate = op->predicate; if (!is_const_one(predicate)) { diff --git a/src/Expr.h b/src/Expr.h index b70d608d290b..ac0ec6521d68 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -33,6 +33,7 @@ enum class IRNodeType { StringImm, Broadcast, Cast, + Reinterpret, Variable, Add, Sub, diff --git a/src/IR.cpp b/src/IR.cpp index fc448466b66b..740234b8e31f 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -20,6 +20,22 @@ Expr Cast::make(Type t, Expr v) { return node; } +Expr Reinterpret::make(Type t, Expr v) { + user_assert(v.defined()) << "reinterpret of undefined Expr\n"; + int from_bits = v.type().bits() * v.type().lanes(); + int to_bits = t.bits() * t.lanes(); + user_assert(from_bits == to_bits) + << "Reinterpret cast from type " << v.type() + << " which has " << from_bits + << " bits, to type " << t + << " which has " << to_bits << " bits\n"; + + Reinterpret *node = new Reinterpret; + node->type = t; + node->value = std::move(v); + return node; +} + Expr Add::make(Expr a, Expr b) { internal_assert(a.defined()) << "Add of undefined\n"; internal_assert(b.defined()) << "Add of undefined\n"; @@ -628,7 +644,6 @@ const char *const intrinsic_op_names[] = { "promise_clamped", "random", "register_destructor", - "reinterpret", "require", "require_mask", "return_second", @@ -970,6 +985,10 @@ void ExprNode::accept(IRVisitor *v) const { v->visit((const Cast *)this); } template<> +void ExprNode::accept(IRVisitor *v) const { + v->visit((const Reinterpret *)this); +} +template<> void ExprNode::accept(IRVisitor *v) const { v->visit((const Variable *)this); } @@ -1155,6 +1174,10 @@ Expr ExprNode::mutate_expr(IRMutator *v) const { return v->visit((const Cast *)this); } template<> +Expr ExprNode::mutate_expr(IRMutator *v) const { + return v->visit((const Reinterpret *)this); +} +template<> Expr ExprNode::mutate_expr(IRMutator *v) const { return v->visit((const Variable *)this); } diff --git a/src/IR.h b/src/IR.h index 45d4913b9078..c6085614b59d 100644 --- a/src/IR.h +++ b/src/IR.h @@ -34,6 +34,15 @@ struct Cast : public ExprNode { static const IRNodeType _node_type = IRNodeType::Cast; }; +/** Reinterpret a node as another type, without affecting any of the bits. */ +struct Reinterpret : public ExprNode { + Expr value; + + static Expr make(Type t, Expr v); + + static const IRNodeType _node_type = IRNodeType::Reinterpret; +}; + /** The sum of two expressions */ struct Add : public ExprNode { Expr a, b; @@ -537,7 +546,6 @@ struct Call : public ExprNode { promise_clamped, random, register_destructor, - reinterpret, require, require_mask, return_second, diff --git a/src/IREquality.cpp b/src/IREquality.cpp index 9e87b6950553..20cb616d2c32 100644 --- a/src/IREquality.cpp +++ b/src/IREquality.cpp @@ -57,6 +57,7 @@ class IRComparer : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; @@ -354,6 +355,10 @@ void IRComparer::visit(const Cast *op) { compare_expr(expr.as()->value, op->value); } +void IRComparer::visit(const Reinterpret *op) { + compare_expr(expr.as()->value, op->value); +} + void IRComparer::visit(const Variable *op) { const Variable *e = expr.as(); compare_names(e->name, op->name); diff --git a/src/IRMatch.cpp b/src/IRMatch.cpp index e769ae65d038..6aba3155777f 100644 --- a/src/IRMatch.cpp +++ b/src/IRMatch.cpp @@ -116,6 +116,16 @@ class IRMatch : public IRVisitor { } } + void visit(const Reinterpret *op) override { + const Reinterpret *e = expr.as(); + if (result && e && types_match(op->type, e->type)) { + expr = e->value; + op->value.accept(this); + } else { + result = false; + } + } + void visit(const Variable *op) override { if (!result) { return; @@ -432,6 +442,11 @@ bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept { // that the types of the values match, so use equal rather // than equal_helper. return equal(((const Cast &)a).value, ((const Cast &)b).value); + case IRNodeType::Reinterpret: + // While we know a and b have matching type, we don't know + // that the types of the values match, so use equal rather + // than equal_helper. + return equal(((const Reinterpret &)a).value, ((const Reinterpret &)b).value); case IRNodeType::Variable: return ((const Variable &)a).name == ((const Variable &)b).name; case IRNodeType::Add: diff --git a/src/IRMutator.cpp b/src/IRMutator.cpp index 5272f3051577..005937a17008 100644 --- a/src/IRMutator.cpp +++ b/src/IRMutator.cpp @@ -37,6 +37,14 @@ Expr IRMutator::visit(const Cast *op) { return Cast::make(op->type, std::move(value)); } +Expr IRMutator::visit(const Reinterpret *op) { + Expr value = mutate(op->value); + if (value.same_as(op->value)) { + return op; + } + return Reinterpret::make(op->type, std::move(value)); +} + namespace { template Expr mutate_binary_operator(IRMutator *mutator, const T *op) { diff --git a/src/IRMutator.h b/src/IRMutator.h index 04613a495930..c7a1984269d3 100644 --- a/src/IRMutator.h +++ b/src/IRMutator.h @@ -56,6 +56,7 @@ class IRMutator { virtual Expr visit(const FloatImm *); virtual Expr visit(const StringImm *); virtual Expr visit(const Cast *); + virtual Expr visit(const Reinterpret *); virtual Expr visit(const Variable *); virtual Expr visit(const Add *); virtual Expr visit(const Sub *); diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 977913f063c7..4693060a8d45 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -2361,15 +2361,7 @@ Expr fract(const Expr &x) { } Expr reinterpret(Type t, Expr e) { - user_assert(e.defined()) << "reinterpret of undefined Expr\n"; - int from_bits = e.type().bits() * e.type().lanes(); - int to_bits = t.bits() * t.lanes(); - user_assert(from_bits == to_bits) - << "Reinterpret cast from type " << e.type() - << " which has " << from_bits - << " bits, to type " << t - << " which has " << to_bits << " bits\n"; - return Internal::Call::make(t, Internal::Call::reinterpret, {std::move(e)}, Internal::Call::PureIntrinsic); + return Internal::Reinterpret::make(t, std::move(e)); } Expr operator&(Expr x, Expr y) { diff --git a/src/IROperator.h b/src/IROperator.h index e22f90c62cb1..ed0b11bb4fef 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -1059,7 +1059,7 @@ Expr reinterpret(Type t, Expr e); template Expr reinterpret(Expr e) { - return reinterpret(type_of(), e); + return reinterpret(type_of(), std::move(e)); } /** Return the bitwise and of two expressions (which need not have the diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index dec52fc28f7f..38f57e46649e 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -519,6 +519,12 @@ void IRPrinter::visit(const Cast *op) { stream << ")"; } +void IRPrinter::visit(const Reinterpret *op) { + stream << "reinterpret<" << op->type << ">("; + print(op->value); + stream << ")"; +} + void IRPrinter::visit(const Variable *op) { if (!known_type.contains(op->name) && (op->type != Int(32))) { diff --git a/src/IRPrinter.h b/src/IRPrinter.h index e0cb4cab5968..666235988cd7 100644 --- a/src/IRPrinter.h +++ b/src/IRPrinter.h @@ -155,6 +155,7 @@ class IRPrinter : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; diff --git a/src/IRVisitor.cpp b/src/IRVisitor.cpp index bde0799bdcee..7f9993987200 100644 --- a/src/IRVisitor.cpp +++ b/src/IRVisitor.cpp @@ -22,6 +22,10 @@ void IRVisitor::visit(const Cast *op) { op->value.accept(this); } +void IRVisitor::visit(const Reinterpret *op) { + op->value.accept(this); +} + void IRVisitor::visit(const Variable *) { } @@ -293,6 +297,10 @@ void IRGraphVisitor::visit(const Cast *op) { include(op->value); } +void IRGraphVisitor::visit(const Reinterpret *op) { + include(op->value); +} + void IRGraphVisitor::visit(const Variable *op) { } diff --git a/src/IRVisitor.h b/src/IRVisitor.h index f29bedc182bc..4e1650ff22be 100644 --- a/src/IRVisitor.h +++ b/src/IRVisitor.h @@ -34,6 +34,7 @@ class IRVisitor { virtual void visit(const FloatImm *); virtual void visit(const StringImm *); virtual void visit(const Cast *); + virtual void visit(const Reinterpret *); virtual void visit(const Variable *); virtual void visit(const Add *); virtual void visit(const Sub *); @@ -104,6 +105,7 @@ class IRGraphVisitor : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; @@ -174,6 +176,8 @@ class VariadicVisitor { return ((T *)this)->visit((const Broadcast *)node, std::forward(args)...); case IRNodeType::Cast: return ((T *)this)->visit((const Cast *)node, std::forward(args)...); + case IRNodeType::Reinterpret: + return ((T *)this)->visit((const Reinterpret *)node, std::forward(args)...); case IRNodeType::Variable: return ((T *)this)->visit((const Variable *)node, std::forward(args)...); case IRNodeType::Add: @@ -258,6 +262,7 @@ class VariadicVisitor { case IRNodeType::StringImm: case IRNodeType::Broadcast: case IRNodeType::Cast: + case IRNodeType::Reinterpret: case IRNodeType::Variable: case IRNodeType::Add: case IRNodeType::Sub: diff --git a/src/LICM.cpp b/src/LICM.cpp index 5cfbdfedb0bc..386a05bd1808 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -89,6 +89,10 @@ class LiftLoopInvariants : public IRMutator { return false; } } + if (const Reinterpret *reinterpret = e.as()) { + // Don't lift Reinterpret nodes. They're free. + return should_lift(reinterpret->value); + } if (const Add *add = e.as()) { if (add->type == Int(32) && is_const(add->b)) { @@ -97,8 +101,7 @@ class LiftLoopInvariants : public IRMutator { } } if (const Call *call = e.as()) { - if (Call::as_tag(call) || - call->is_intrinsic(Call::reinterpret)) { + if (Call::as_tag(call)) { // Don't lift these intrinsics. They're free. return should_lift(call->args[0]); } @@ -209,6 +212,8 @@ class LICM : public IRMutator { int cost(const Expr &e, const set &vars) { if (is_const(e)) { return 0; + } else if (const Reinterpret *reinterpret = e.as()) { + return cost(reinterpret->value, vars); } else if (const Variable *var = e.as()) { if (vars.count(var->name)) { // We're loading this already @@ -223,13 +228,6 @@ class LICM : public IRMutator { return cost(sub->a, vars) + cost(sub->b, vars) + 1; } else if (const Mul *mul = e.as()) { return cost(mul->a, vars) + cost(mul->b, vars) + 1; - } else if (const Call *call = e.as()) { - if (call->is_intrinsic(Call::reinterpret)) { - internal_assert(call->args.size() == 1); - return cost(call->args[0], vars); - } else { - return 100; - } } else { return 100; } diff --git a/src/ModulusRemainder.cpp b/src/ModulusRemainder.cpp index 1e7d49aa3e04..34a598e4c7e3 100644 --- a/src/ModulusRemainder.cpp +++ b/src/ModulusRemainder.cpp @@ -35,6 +35,7 @@ class ComputeModulusRemainder : public IRVisitor { void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; + void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; @@ -103,6 +104,10 @@ void ComputeModulusRemainder::visit(const Cast *) { result = ModulusRemainder{}; } +void ComputeModulusRemainder::visit(const Reinterpret *) { + result = ModulusRemainder{}; +} + void ComputeModulusRemainder::visit(const Variable *op) { if (scope.contains(op->name)) { result = scope.get(op->name); diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index 62910355f5ff..ae8978b2cb57 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -259,6 +259,10 @@ class DerivativeBounds : public IRVisitor { } } + void visit(const Reinterpret *op) override { + result = ConstantInterval::everything(); + } + void visit(const Variable *op) override { if (op->name == var) { result = ConstantInterval::single_point(1); diff --git a/src/RegionCosts.cpp b/src/RegionCosts.cpp index ddc6b45c4b2d..b7760b54bb8d 100644 --- a/src/RegionCosts.cpp +++ b/src/RegionCosts.cpp @@ -80,6 +80,11 @@ class ExprCost : public IRVisitor { arith += 1; } + void visit(const Reinterpret *op) override { + op->value.accept(this); + // `Reinterpret` is a no-op and does *not* incur any cost. + } + template void visit_binary_operator(const T *op, int op_cost) { op->a.accept(this); @@ -219,7 +224,7 @@ class ExprCost : public IRVisitor { // TODO: Improve the cost model. In some architectures (e.g. ARM or // NEON), count_leading_zeros should be as cheap as bitwise ops. // div_round_to_zero and mod_round_to_zero can also get fairly expensive. - if (call->is_intrinsic(Call::reinterpret) || call->is_intrinsic(Call::bitwise_and) || + if (call->is_intrinsic(Call::bitwise_and) || call->is_intrinsic(Call::bitwise_not) || call->is_intrinsic(Call::bitwise_xor) || call->is_intrinsic(Call::bitwise_or) || call->is_intrinsic(Call::shift_left) || call->is_intrinsic(Call::shift_right) || call->is_intrinsic(Call::div_round_to_zero) || diff --git a/src/RemoveUndef.cpp b/src/RemoveUndef.cpp index d96c503b7085..a4889f6cc3b5 100644 --- a/src/RemoveUndef.cpp +++ b/src/RemoveUndef.cpp @@ -59,6 +59,18 @@ class RemoveUndef : public IRMutator { } } + Expr visit(const Reinterpret *op) override { + Expr value = mutate(op->value); + if (!value.defined()) { + return Expr(); + } + if (value.same_as(op->value)) { + return op; + } else { + return Reinterpret::make(op->type, std::move(value)); + } + } + Expr visit(const Add *op) override { return mutate_binary_operator(op); } diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index 573d4381a52b..a1ff4c5130fe 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -282,25 +282,6 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } else { return a ^ b; } - } else if (op->is_intrinsic(Call::reinterpret)) { - Expr a = mutate(op->args[0], nullptr); - - int64_t ia; - uint64_t ua; - bool vector = op->type.is_vector() || a.type().is_vector(); - if (op->type == a.type()) { - return a; - } else if (const_int(a, &ia) && op->type.is_uint() && !vector) { - // int -> uint - return make_const(op->type, (uint64_t)ia); - } else if (const_uint(a, &ua) && op->type.is_int() && !vector) { - // uint -> int - return make_const(op->type, (int64_t)ua); - } else if (a.same_as(op->args[0])) { - return op; - } else { - return reinterpret(op->type, a); - } } else if (op->is_intrinsic(Call::abs)) { // Constant evaluate abs(x). ExprInfo a_bounds; diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 0be084d61154..a510e5c51f64 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -309,6 +309,7 @@ class Simplify : public VariadicVisitor { Expr visit(const StringImm *op, ExprInfo *bounds); Expr visit(const Broadcast *op, ExprInfo *bounds); Expr visit(const Cast *op, ExprInfo *bounds); + Expr visit(const Reinterpret *op, ExprInfo *bounds); Expr visit(const Variable *op, ExprInfo *bounds); Expr visit(const Add *op, ExprInfo *bounds); Expr visit(const Sub *op, ExprInfo *bounds); diff --git a/src/Simplify_Reinterpret.cpp b/src/Simplify_Reinterpret.cpp new file mode 100644 index 000000000000..c5d8d07ce233 --- /dev/null +++ b/src/Simplify_Reinterpret.cpp @@ -0,0 +1,28 @@ +#include "Simplify_Internal.h" + +namespace Halide { +namespace Internal { + +Expr Simplify::visit(const Reinterpret *op, ExprInfo *bounds) { + Expr a = mutate(op->value, nullptr); + + int64_t ia; + uint64_t ua; + bool vector = op->type.is_vector() || a.type().is_vector(); + if (op->type == a.type()) { + return a; + } else if (const_int(a, &ia) && op->type.is_uint() && !vector) { + // int -> uint + return make_const(op->type, (uint64_t)ia); + } else if (const_uint(a, &ua) && op->type.is_int() && !vector) { + // uint -> int + return make_const(op->type, (int64_t)ua); + } else if (a.same_as(op->value)) { + return op; + } else { + return reinterpret(op->type, a); + } +} + +} // namespace Internal +} // namespace Halide diff --git a/src/Solve.cpp b/src/Solve.cpp index 0b632ac52e45..d8ff919bb56c 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -1124,6 +1124,10 @@ class SolveForInterval : public IRVisitor { fail(); } + void visit(const Reinterpret *op) override { + fail(); + } + void visit(const Load *op) override { fail(); } diff --git a/src/StmtToHtml.cpp b/src/StmtToHtml.cpp index a77183e31d11..21bc74dd20ac 100644 --- a/src/StmtToHtml.cpp +++ b/src/StmtToHtml.cpp @@ -222,6 +222,19 @@ class StmtToHtml : public IRVisitor { stream << close_span(); } + void visit(const Reinterpret *op) override { + stream << open_span("Reinterpret"); + + stream << open_span("Matched"); + stream << open_span("Type") << op->type << close_span(); + stream << "("; + stream << close_span(); + print(op->value); + stream << matched(")"); + + stream << close_span(); + } + void visit_binary_op(const Expr &a, const Expr &b, const char *op) { stream << open_span("BinaryOp"); diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index e2960739ed86..7dcd79d24664 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -527,6 +527,16 @@ class VectorSubs : public IRMutator { } } + Expr visit(const Reinterpret *op) override { + Expr value = mutate(op->value); + if (value.same_as(op->value)) { + return op; + } else { + Type t = op->type.with_lanes(value.type().lanes()); + return Reinterpret::make(t, value); + } + } + string get_widened_var_name(const string &name) { return name + ".widened." + vectorized_vars.back().name; } diff --git a/src/autoschedulers/adams2019/FunctionDAG.cpp b/src/autoschedulers/adams2019/FunctionDAG.cpp index 7f4c41045804..0c1126b3c6c8 100644 --- a/src/autoschedulers/adams2019/FunctionDAG.cpp +++ b/src/autoschedulers/adams2019/FunctionDAG.cpp @@ -816,6 +816,11 @@ FunctionDAG::FunctionDAG(const vector &outputs, const MachineParams &p check_type(op->type); } + void visit(const Reinterpret *op) override { + IRVisitor::visit(op); + check_type(op->type); + } + void check_type(Type t) { if (t.bits() > 1 && (!narrowest_type.bits() || diff --git a/test/correctness/host_alignment.cpp b/test/correctness/host_alignment.cpp index 4dc7bf40a376..20c6644ff968 100644 --- a/test/correctness/host_alignment.cpp +++ b/test/correctness/host_alignment.cpp @@ -74,10 +74,9 @@ class CountHostAlignmentAsserts : public IRVisitor { left = call->args[0]; right = call->args[1]; } - const Call *reinterpret_call = left.as(); - if (!reinterpret_call || - !reinterpret_call->is_intrinsic(Call::reinterpret)) return; - Expr name = reinterpret_call->args[0]; + const Reinterpret *reinterpret = left.as(); + if (!reinterpret) return; + Expr name = reinterpret->value; const Variable *V = name.as(); string name_host_ptr = V->name; int expected_alignment = alignments_needed[name_host_ptr];