From 6bcfec7b16a8410ac6b347378307b855a3f44b65 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 28 May 2024 12:22:29 +0200 Subject: [PATCH] Report useful error to user if the promise_clamp all fails to losslessly cast. --- src/IROperator.cpp | 29 +++++++++++++++++++--- src/IRPrinter.cpp | 62 +++++++++++++++++++++++++++++++++++++++++++++- src/IRPrinter.h | 5 ++++ 3 files changed, 91 insertions(+), 5 deletions(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 3492c9e828c3..2011fdfa06bf 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -1587,10 +1587,31 @@ Tuple mux(const Expr &id, const std::vector &values) { return Tuple{result}; } +namespace { +void cast_bounds_for_promise_clamped(const Expr &value, const Expr &min, const Expr &max, Expr &casted_min, Expr &casted_max, const char *call_name) { + { + Expr n_min_val = lossless_cast(value.type(), min); + if (min.defined()) { + user_assert(n_min_val.defined()) + << call_name << " min argument (type " << min.node_type() << " " << min.type() << ") could not be cast losslessly to " << value.type(); + } + casted_min = n_min_val.defined() ? n_min_val : value.type().min(); + } + { + Expr n_max_val = lossless_cast(value.type(), max); + if (max.defined()) { + user_assert(n_max_val.defined()) + << call_name << " max argument (type " << max.node_type() << " " << max.type() << ") could not be cast losslessly to " << value.type(); + } + casted_max = n_max_val.defined() ? n_max_val : value.type().max(); + } +} +} // namespace + Expr unsafe_promise_clamped(const Expr &value, const Expr &min, const Expr &max) { user_assert(value.defined()) << "unsafe_promise_clamped with undefined value.\n"; - Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min(); - Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max(); + Expr n_min_val, n_max_val; + cast_bounds_for_promise_clamped(value, min, max, n_min_val, n_max_val, "unsafe_promise_clamped"); // Min and max are allowed to be undefined with the meaning of no bound on that side. return Call::make(value.type(), @@ -1602,8 +1623,8 @@ Expr unsafe_promise_clamped(const Expr &value, const Expr &min, const Expr &max) namespace Internal { Expr promise_clamped(const Expr &value, const Expr &min, const Expr &max) { internal_assert(value.defined()) << "promise_clamped with undefined value.\n"; - Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min(); - Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max(); + Expr n_min_val, n_max_val; + cast_bounds_for_promise_clamped(value, min, max, n_min_val, n_max_val, "promise_clamped"); // Min and max are allowed to be undefined with the meaning of no bound on that side. return Call::make(value.type(), diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index f14e45a335f6..a42431f232d0 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -7,6 +7,7 @@ #include "Associativity.h" #include "Closure.h" #include "ConstantInterval.h" +#include "Expr.h" #include "IROperator.h" #include "Interval.h" #include "Module.h" @@ -48,7 +49,6 @@ ostream &operator<<(ostream &out, const Type &type) { } return out; } - ostream &operator<<(ostream &stream, const Expr &ir) { if (!ir.defined()) { stream << "(undefined)"; @@ -270,6 +270,66 @@ void IRPrinter::test() { std::cout << "IRPrinter test passed\n"; } +std::ostream &operator<<(std::ostream &stream, IRNodeType type) { +#define CASE(e) \ + case IRNodeType::e: \ + stream << #e; \ + break; + switch (type) { + CASE(IntImm) + CASE(UIntImm) + CASE(FloatImm) + CASE(StringImm) + CASE(Broadcast) + CASE(Cast) + CASE(Reinterpret) + CASE(Variable) + CASE(Add) + CASE(Sub) + CASE(Mod) + CASE(Mul) + CASE(Div) + CASE(Min) + CASE(Max) + CASE(EQ) + CASE(NE) + CASE(LT) + CASE(LE) + CASE(GT) + CASE(GE) + CASE(And) + CASE(Or) + CASE(Not) + CASE(Select) + CASE(Load) + CASE(Ramp) + CASE(Call) + CASE(Let) + CASE(Shuffle) + CASE(VectorReduce) + // Stmts + CASE(LetStmt) + CASE(AssertStmt) + CASE(ProducerConsumer) + CASE(For) + CASE(Acquire) + CASE(Store) + CASE(Provide) + CASE(Allocate) + CASE(Free) + CASE(Realize) + CASE(Block) + CASE(Fork) + CASE(IfThenElse) + CASE(Evaluate) + CASE(Prefetch) + CASE(Atomic) + CASE(HoistedStorage) + } +#undef CASE + return stream; +} + ostream &operator<<(ostream &stream, const AssociativePattern &p) { stream << "{\n"; for (size_t i = 0; i < p.ops.size(); ++i) { diff --git a/src/IRPrinter.h b/src/IRPrinter.h index 6addbbd7c771..48afef8603d3 100644 --- a/src/IRPrinter.h +++ b/src/IRPrinter.h @@ -61,6 +61,11 @@ class Closure; struct Interval; struct ConstantInterval; struct ModulusRemainder; +enum class IRNodeType; + +/** Emit a halide node type on an output stream (such as std::cout) in + * human-readable form */ +std::ostream &operator<<(std::ostream &stream, IRNodeType); /** Emit a halide associative pattern on an output stream (such as std::cout) * in a human-readable form */