From 538577a77c34774875f474f093a536afa3a24b54 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 16 Apr 2024 17:06:16 -0700 Subject: [PATCH 01/13] Rewrite IREquality to use a more compact stack instead of deep recursion Deletes a bunch of code and speeds up lowering time of local laplacian with 20 pyramid levels by ~2.5% --- src/Bounds.cpp | 4 +- src/CSE.cpp | 18 +- src/IREquality.cpp | 1139 +++++++++++++++++-------------------- src/IREquality.h | 89 +-- src/ParallelRVar.cpp | 2 +- src/RDom.cpp | 2 +- src/ScheduleFunctions.cpp | 2 +- 7 files changed, 520 insertions(+), 736 deletions(-) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index d7d337dacfdf..a8ed2deba0d2 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -79,9 +79,9 @@ int static_sign(const Expr &x) { return -1; } else { Expr zero = make_zero(x.type()); - if (equal(const_true(), simplify(x > zero))) { + if (is_const_one(simplify(x > zero))) { return 1; - } else if (equal(const_true(), simplify(x < zero))) { + } else if (is_const_one(simplify(x < zero))) { return -1; } } diff --git a/src/CSE.cpp b/src/CSE.cpp index d8ecd619db81..0905562c4e63 100644 --- a/src/CSE.cpp +++ b/src/CSE.cpp @@ -76,7 +76,7 @@ class GVN : public IRMutator { Expr expr; int use_count = 0; // All consumer Exprs for which this is the last child Expr. - map uses; + map uses; Entry(const Expr &e) : expr(e) { } @@ -84,25 +84,15 @@ class GVN : public IRMutator { vector> entries; map shallow_numbering, output_numbering; - map leaves; + map leaves; - int number = -1; - - IRCompareCache cache; - - GVN() - : number(0), cache(8) { - } + int number = 0; Stmt mutate(const Stmt &s) override { internal_error << "Can't call GVN on a Stmt: " << s << "\n"; return Stmt(); } - ExprWithCompareCache with_cache(const Expr &e) { - return ExprWithCompareCache(e, &cache); - } - Expr mutate(const Expr &e) override { // Early out if we've already seen this exact Expr. { @@ -123,7 +113,7 @@ class GVN : public IRMutator { // that child has an identical parent to this one. auto &use_map = number == -1 ? leaves : entries[number]->uses; - auto p = use_map.emplace(with_cache(new_e), (int)entries.size()); + auto p = use_map.emplace(new_e, (int)entries.size()); auto iter = p.first; bool novel = p.second; if (novel) { diff --git a/src/IREquality.cpp b/src/IREquality.cpp index 0d21ca1e26b5..90ec38970f5d 100644 --- a/src/IREquality.cpp +++ b/src/IREquality.cpp @@ -10,713 +10,590 @@ using std::vector; namespace { -/** The class that does the work of comparing two IR nodes. */ -class IRComparer : public IRVisitor { -public: - /** Different possible results of a comparison. Unknown should - * only occur internally due to a cache miss. */ - enum CmpResult { Unknown, - Equal, - LessThan, - GreaterThan }; - - /** The result of the comparison. Should be Equal, LessThan, or GreaterThan. */ +enum CmpResult { Unknown, + Equal, + LessThan, + GreaterThan }; + +// A helper class for comparing two pieces of IR with the minimum amount of +// recursion. +template +struct Comparer { + + // Points to any cache in use for comparing Expr graphs. Will be non-null + // exactly when cache_size > 0 + const IRNode **cache; + + // The compare method below does the actual work, but it needs to call out + // to a variety of template helper functions to compare specific types. We + // make the syntax in the giant switch statement in the compare method much + // simpler if we just give these helper functions access to the state in the + // compare method: The stack pointers, the currently-considered piece of + // IR, and the result of the comparison so far. + const IRNode **stack_end = nullptr, **stack_ptr = nullptr; + const IRNode *next_a = nullptr, *next_b = nullptr; CmpResult result = Equal; - /** Compare two expressions or statements and return the - * result. Returns the result immediately if it is already - * non-zero. */ - // @{ - CmpResult compare_expr(const Expr &a, const Expr &b); - CmpResult compare_stmt(const Stmt &a, const Stmt &b); - // @} - - /** If the expressions you're comparing may contain many repeated - * subexpressions, it's worth passing in a cache to use. - * Currently this is only done in common-subexpression - * elimination. */ - IRComparer(IRCompareCache *c = nullptr) - : cache(c) { + Comparer(const IRNode **cache) + : cache(cache) { } -private: - Expr expr; - Stmt stmt; - IRCompareCache *cache; + // Compare the given member variable of next_a and next_b. If it's an Expr + // or Stmt, it's guaranteed to be defined. + template + HALIDE_ALWAYS_INLINE void cmp(MemberType Node::*member_ptr) { + if (result == Equal) { + cmp(((const Node *)next_a)->*member_ptr, ((const Node *)next_b)->*member_ptr); + } + } - CmpResult compare_names(const std::string &a, const std::string &b); - CmpResult compare_types(Type a, Type b); - CmpResult compare_expr_vector(const std::vector &a, const std::vector &b); + // The same as above, but with no guarantee. + template + HALIDE_ALWAYS_INLINE void cmp_if_defined(MemberType Node::*member_ptr) { + if (result == Equal) { + cmp_if_defined(((const Node *)next_a)->*member_ptr, ((const Node *)next_b)->*member_ptr); + } + } - // Compare two things that already have a well-defined operator< - template - CmpResult compare_scalar(T a, T b); - - void visit(const IntImm *) override; - void visit(const UIntImm *) override; - void visit(const FloatImm *) override; - void visit(const StringImm *) override; - void visit(const Cast *) override; - void visit(const Reinterpret *) override; - void visit(const Variable *) override; - void visit(const Add *) override; - void visit(const Sub *) override; - void visit(const Mul *) override; - void visit(const Div *) override; - void visit(const Mod *) override; - void visit(const Min *) override; - void visit(const Max *) override; - void visit(const EQ *) override; - void visit(const NE *) override; - void visit(const LT *) override; - void visit(const LE *) override; - void visit(const GT *) override; - void visit(const GE *) override; - void visit(const And *) override; - void visit(const Or *) override; - void visit(const Not *) override; - void visit(const Select *) override; - void visit(const Load *) override; - void visit(const Ramp *) override; - void visit(const Broadcast *) override; - void visit(const Call *) override; - void visit(const Let *) override; - void visit(const LetStmt *) override; - void visit(const AssertStmt *) override; - void visit(const ProducerConsumer *) override; - void visit(const For *) override; - void visit(const Acquire *) override; - void visit(const Store *) override; - void visit(const Provide *) override; - void visit(const Allocate *) override; - void visit(const Free *) override; - void visit(const Realize *) override; - void visit(const Block *) override; - void visit(const Fork *) override; - void visit(const IfThenElse *) override; - void visit(const Evaluate *) override; - void visit(const Shuffle *) override; - void visit(const Prefetch *) override; - void visit(const Atomic *) override; - void visit(const VectorReduce *) override; - void visit(const HoistedStorage *) override; -}; + size_t hash(const IRNode *a, const IRNode *b) { + uintptr_t pa = (uintptr_t)a; + uintptr_t pb = (uintptr_t)b; + uintptr_t h = (((pa * 17) ^ (pb * 13)) >> 4); + h ^= h >> 8; + h = h & (cache_size - 1); + return h; + } + + // See if we've already processed this pair of IR nodes + bool cache_contains(const IRNode *a, const IRNode *b) { + size_t h = hash(a, b); + const IRNode **c = cache + h * 2; + return (c[0] == a && c[1] == b); + } + + // Mark a pair of IR nodes as already processed. We don't do this until + // we're done processing their children, because there aren't going to be + // any queries to match a node with one of its children, because nodes can't + // be their own ancestors. Inserting it into the cache too soon just means + // it's going to be evicted before we need it. + void cache_insert(const IRNode *a, const IRNode *b) { + size_t h = hash(a, b); + const IRNode **c = cache + h * 2; + c[0] = a; + c[1] = b; + } + + // Compare two known-to-be-defined IR nodes. Well... don't actually compare + // them because that would be a recursive call. Just push them onto the + // pending tasks stack. + void cmp(const IRHandle &a, const IRHandle &b) { + if (cache_size > 0 && cache_contains(a.get(), b.get())) { + return; + } -template -IRComparer::CmpResult IRComparer::compare_scalar(T a, T b) { - if (result != Equal) { - return result; + if (a.get() == b.get()) { + } else if (stack_ptr == stack_end) { + // Out of stack space. Make a recursive call to buy some more stack. + Comparer sub_comparer(cache); + result = sub_comparer.compare(a.get(), b.get()); + } else { + *stack_ptr++ = a.get(); + *stack_ptr++ = b.get(); + } } - if constexpr (std::is_floating_point_v) { - // NaNs are equal to each other and less than non-nans - if (std::isnan(a) && std::isnan(b)) { - result = Equal; - return result; - } - if (std::isnan(a)) { + // Compare two IR nodes, which may or may not be defined. + HALIDE_ALWAYS_INLINE + void cmp_if_defined(const IRHandle &a, const IRHandle &b) { + if (a.defined() < b.defined()) { result = LessThan; - return result; - } - if (std::isnan(b)) { + } else if (a.defined() > b.defined()) { result = GreaterThan; - return result; + } else if (a.defined() && b.defined()) { + cmp(a, b); } } - if (a < b) { - result = LessThan; - } else if (a > b) { - result = GreaterThan; - } - - return result; -} - -IRComparer::CmpResult IRComparer::compare_expr(const Expr &a, const Expr &b) { - if (result != Equal) { - return result; + template + void cmp(const std::vector &a, const std::vector &b) { + if (a.size() < b.size()) { + result = LessThan; + } else if (a.size() > b.size()) { + result = GreaterThan; + } else { + for (size_t i = 0; i < a.size() && result == Equal; i++) { + cmp(a[i], b[i]); + } + } } - if (a.same_as(b)) { - result = Equal; - return result; + HALIDE_ALWAYS_INLINE + void cmp(const Range &a, const Range &b) { + cmp(a.min, b.min); + cmp(a.extent, b.extent); } - // Undefined values are equal to each other and less than defined values - if (!a.defined() && !b.defined()) { - result = Equal; - return result; + HALIDE_ALWAYS_INLINE + void cmp(const ModulusRemainder &a, const ModulusRemainder &b) { + cmp(a.modulus, b.modulus); + cmp(a.remainder, b.remainder); } - if (!a.defined()) { - result = LessThan; - return result; + void cmp(const halide_handle_cplusplus_type *ha, + const halide_handle_cplusplus_type *hb) { + if (ha == hb) { + return; + } else if (!ha) { + result = LessThan; + } else if (!hb) { + result = GreaterThan; + } else { + // They're both non-void handle types with distinct type info + // structs. We now need to distinguish between different C++ + // pointer types (e.g. char * vs const float *). If would be nice + // if the structs were unique per C++ type. Then comparing the + // pointers above would be sufficient. Unfortunately, different + // shared libraries in the same process each create a distinct + // struct for the same type. We therefore have to do a deep + // comparison of the type info fields. + cmp(ha->reference_type, hb->reference_type); + cmp(ha->inner_name.name, hb->inner_name.name); + cmp(ha->inner_name.cpp_type_type, hb->inner_name.cpp_type_type); + cmp(ha->namespaces, hb->namespaces); + cmp(ha->enclosing_types, hb->enclosing_types); + cmp(ha->cpp_type_modifiers, hb->cpp_type_modifiers); + } } - if (!b.defined()) { - result = GreaterThan; - return result; + HALIDE_ALWAYS_INLINE + void cmp(const Type &a, const Type &b) { + uint32_t ta = ((halide_type_t)a).as_u32(); + uint32_t tb = ((halide_type_t)b).as_u32(); + if (ta < tb) { + result = LessThan; + } else if (ta > tb) { + result = GreaterThan; + } else { + if (a.handle_type || b.handle_type) { + cmp(a.handle_type, b.handle_type); + } + } } - // If in the future we have hashes for Exprs, this is a good place - // to compare the hashes: - // if (compare_scalar(a.hash(), b.hash()) != Equal) { - // return result; - // } - - if (compare_scalar(a->node_type, b->node_type) != Equal) { - return result; + void cmp(const PrefetchDirective &a, const PrefetchDirective &b) { + cmp(a.name, b.name); + cmp(a.at, b.at); + cmp(a.from, b.from); + cmp(a.offset, b.offset); + cmp(a.strategy, b.strategy); } - if (compare_types(a.type(), b.type()) != Equal) { - return result; + HALIDE_ALWAYS_INLINE + void cmp(double a, double b) { + // Floating point scalars need special handling, due to NaNs. + if (std::isnan(a) && std::isnan(b)) { + } else if (std::isnan(a)) { + result = LessThan; + } else if (std::isnan(b)) { + result = GreaterThan; + } else if (a < b) { + result = LessThan; + } else if (b < a) { + result = GreaterThan; + } } - // Check the cache - perhaps these exprs have already been compared and found equal. - if (cache && cache->contains(a, b)) { - result = Equal; - return result; + HALIDE_ALWAYS_INLINE + void cmp(const std::string &a, const std::string &b) { + int r = a.compare(b); + if (r < 0) { + result = LessThan; + } else if (r > 0) { + result = GreaterThan; + } } - expr = a; - b.accept(this); - - if (cache && result == Equal) { - cache->insert(a, b); + // The method to use whenever we can just use operator< and get a bool. + template && + std::is_same_v() < std::declval()), bool>>> + HALIDE_NEVER_INLINE void cmp(const T &a, const T &b) { + if (a < b) { + result = LessThan; + } else if (b < a) { + result = GreaterThan; + } } - return result; -} - -IRComparer::CmpResult IRComparer::compare_stmt(const Stmt &a, const Stmt &b) { - if (result != Equal) { - return result; - } + CmpResult compare(const IRNode *root_a, const IRNode *root_b) { + constexpr size_t stack_size = 64; // 1 kb + const IRNode *stack_storage[stack_size * 2]; // Intentionally uninitialized - if (a.same_as(b)) { + stack_ptr = stack_storage; + stack_end = stack_storage + stack_size * 2; result = Equal; - return result; - } - if (!a.defined() && !b.defined()) { - result = Equal; - return result; - } - - if (!a.defined()) { - result = LessThan; - return result; - } - - if (!b.defined()) { - result = GreaterThan; - return result; - } - - if (compare_scalar(a->node_type, b->node_type) != Equal) { - return result; - } - - stmt = a; - b.accept(this); - - return result; -} - -IRComparer::CmpResult IRComparer::compare_types(Type a, Type b) { - if (result != Equal) { - return result; - } - - compare_scalar(a.code(), b.code()); - compare_scalar(a.bits(), b.bits()); - compare_scalar(a.lanes(), b.lanes()); - - if (result != Equal) { - return result; - } - - const halide_handle_cplusplus_type *ha = a.handle_type; - const halide_handle_cplusplus_type *hb = b.handle_type; - - if (ha == hb) { - // Same handle type, or both not handles, or both void * - return result; - } - - if (ha == nullptr) { - // void* < T* - result = LessThan; - return result; - } - - if (hb == nullptr) { - // T* > void* - result = GreaterThan; - return result; - } - - // They're both non-void handle types with distinct type info - // structs. We now need to distinguish between different C++ - // pointer types (e.g. char * vs const float *). If would be nice - // if the structs were unique per C++ type. Then comparing the - // pointers above would be sufficient. Unfortunately, different - // shared libraries in the same process each create a distinct - // struct for the same type. We therefore have to do a deep - // comparison of the type info fields. - - compare_scalar(ha->reference_type, hb->reference_type); - compare_names(ha->inner_name.name, hb->inner_name.name); - compare_scalar(ha->inner_name.cpp_type_type, hb->inner_name.cpp_type_type); - compare_scalar(ha->namespaces.size(), hb->namespaces.size()); - compare_scalar(ha->enclosing_types.size(), hb->enclosing_types.size()); - compare_scalar(ha->cpp_type_modifiers.size(), hb->cpp_type_modifiers.size()); - - if (result != Equal) { - return result; - } - - for (size_t i = 0; i < ha->namespaces.size(); i++) { - compare_names(ha->namespaces[i], hb->namespaces[i]); - } - - if (result != Equal) { - return result; - } - - for (size_t i = 0; i < ha->enclosing_types.size(); i++) { - compare_scalar(ha->enclosing_types[i].cpp_type_type, - hb->enclosing_types[i].cpp_type_type); - compare_names(ha->enclosing_types[i].name, - hb->enclosing_types[i].name); - } - - if (result != Equal) { - return result; - } - - for (size_t i = 0; i < ha->cpp_type_modifiers.size(); i++) { - compare_scalar(ha->cpp_type_modifiers[i], - hb->cpp_type_modifiers[i]); - } - - return result; -} - -IRComparer::CmpResult IRComparer::compare_names(const string &a, const string &b) { - if (result != Equal) { - return result; - } - - int string_cmp = a.compare(b); - if (string_cmp < 0) { - result = LessThan; - } else if (string_cmp > 0) { - result = GreaterThan; - } - - return result; -} + *stack_ptr++ = root_a; + *stack_ptr++ = root_b; + + while (result == Equal && stack_ptr > stack_storage) { + stack_ptr -= 2; + next_a = stack_ptr[0]; + next_b = stack_ptr[1]; + + if (next_a == next_b) { + continue; + } + + if (cache_size > 0 && (((uintptr_t)next_a) & 1)) { + // If we are using a cache, we want to keep the nodes on the + // stack while processing their children, but mark them with a + // tombstone. We'll flip the low bit to 1 for our tombstone. We + // want to insert them into the cache when the tombstone is + // handled. This if statement triggers if we just hit a + // tombstone. + cache_insert((const IRNode *)((uintptr_t)next_a ^ 1), next_b); + continue; + } + + cmp(next_a->node_type, next_b->node_type); + if (result != Equal) { + break; + } + + if (next_a->node_type < IRNodeType::LetStmt) { + cmp(&BaseExprNode::type); + } + + if (cache_size > 0) { + // Keep the parent nodes on the stack, but mark them with a + // tombstone bit. + stack_ptr[0] = (const IRNode *)(((uintptr_t)next_a) | 1); + stack_ptr += 2; + } + + switch (next_a->node_type) { + case IRNodeType::IntImm: + cmp(&IntImm::value); + break; + case IRNodeType::UIntImm: + cmp(&UIntImm::value); + break; + case IRNodeType::FloatImm: + cmp(&FloatImm::value); + break; + case IRNodeType::StringImm: + cmp(&StringImm::value); + break; + case IRNodeType::Broadcast: + cmp(&Broadcast::value); + break; + case IRNodeType::Cast: + cmp(&Cast::value); + break; + case IRNodeType::Reinterpret: + cmp(&Cast::value); + break; + case IRNodeType::Variable: + cmp(&Variable::name); + break; + case IRNodeType::Add: + cmp(&Add::a); + cmp(&Add::b); + break; + case IRNodeType::Sub: + cmp(&Sub::a); + cmp(&Sub::b); + break; + case IRNodeType::Mod: + cmp(&Mod::a); + cmp(&Mod::b); + break; + case IRNodeType::Mul: + cmp(&Mul::a); + cmp(&Mul::b); + break; + case IRNodeType::Div: + cmp(&Div::a); + cmp(&Div::b); + break; + case IRNodeType::Min: + cmp(&Min::a); + cmp(&Min::b); + break; + case IRNodeType::Max: + cmp(&Max::a); + cmp(&Max::b); + break; + case IRNodeType::EQ: + cmp(&EQ::a); + cmp(&EQ::b); + break; + case IRNodeType::NE: + cmp(&NE::a); + cmp(&NE::b); + break; + case IRNodeType::LT: + cmp(<::a); + cmp(<::b); + break; + case IRNodeType::LE: + cmp(&LE::a); + cmp(&LE::b); + break; + case IRNodeType::GT: + cmp(>::a); + cmp(>::b); + case IRNodeType::GE: + cmp(&GE::a); + cmp(&GE::b); + break; + case IRNodeType::And: + cmp(&And::a); + cmp(&And::b); + break; + case IRNodeType::Or: + cmp(&Or::a); + cmp(&Or::b); + break; + case IRNodeType::Not: + cmp(&Not::a); + break; + case IRNodeType::Select: + cmp(&Select::condition); + cmp(&Select::true_value); + cmp(&Select::false_value); + break; + case IRNodeType::Load: + cmp(&Load::name); + cmp(&Load::alignment); + cmp(&Load::index); + cmp(&Load::predicate); + break; + case IRNodeType::Ramp: + cmp(&Ramp::stride); + cmp(&Ramp::base); + break; + case IRNodeType::Call: + cmp(&Call::name); + cmp(&Call::call_type); + cmp(&Call::value_index); + cmp(&Call::args); + break; + case IRNodeType::Let: + cmp(&Let::name); + cmp(&Let::value); + cmp(&Let::body); + break; + case IRNodeType::Shuffle: + cmp(&Shuffle::indices); + cmp(&Shuffle::vectors); + break; + case IRNodeType::VectorReduce: + cmp(&VectorReduce::op); + cmp(&VectorReduce::value); + break; + case IRNodeType::LetStmt: + cmp(&LetStmt::name); + cmp(&LetStmt::value); + cmp(&LetStmt::body); + break; + case IRNodeType::AssertStmt: + cmp(&AssertStmt::condition); + cmp(&AssertStmt::message); + break; + case IRNodeType::ProducerConsumer: + cmp(&ProducerConsumer::name); + cmp(&ProducerConsumer::is_producer); + cmp(&ProducerConsumer::body); + break; + case IRNodeType::For: + cmp(&For::name); + cmp(&For::for_type); + cmp(&For::device_api); + cmp(&For::partition_policy); + cmp(&For::min); + cmp(&For::extent); + cmp(&For::body); + break; + case IRNodeType::Acquire: + cmp(&Acquire::semaphore); + cmp(&Acquire::count); + cmp(&Acquire::body); + break; + case IRNodeType::Store: + cmp(&Store::name); + cmp(&Store::alignment); + cmp(&Store::predicate); + cmp(&Store::value); + cmp(&Store::index); + break; + case IRNodeType::Provide: + cmp(&Provide::name); + cmp(&Provide::args); + cmp(&Provide::values); + break; + case IRNodeType::Allocate: + cmp(&Allocate::name); + cmp(&Allocate::type); + cmp(&Allocate::free_function); + cmp_if_defined(&Allocate::new_expr); + cmp(&Allocate::condition); + cmp(&Allocate::extents); + cmp(&Allocate::body); + break; + case IRNodeType::Free: + cmp(&Free::name); + break; + case IRNodeType::Realize: + cmp(&Realize::name); + cmp(&Realize::types); + cmp(&Realize::bounds); + cmp(&Realize::body); + cmp(&Realize::condition); + break; + case IRNodeType::Block: + cmp(&Block::first); + cmp(&Block::rest); + break; + case IRNodeType::Fork: + cmp(&Fork::first); + cmp(&Fork::rest); + break; + case IRNodeType::IfThenElse: + cmp(&IfThenElse::condition); + cmp(&IfThenElse::then_case); + cmp_if_defined(&IfThenElse::else_case); + break; + case IRNodeType::Evaluate: + cmp(&Evaluate::value); + break; + case IRNodeType::Prefetch: + cmp(&Prefetch::name); + cmp(&Prefetch::types); + cmp(&Prefetch::prefetch); + cmp(&Prefetch::bounds); + cmp(&Prefetch::condition); + cmp(&Prefetch::body); + break; + case IRNodeType::Atomic: + cmp(&Atomic::producer_name); + cmp(&Atomic::mutex_name); + cmp(&Atomic::body); + break; + case IRNodeType::HoistedStorage: + cmp(&HoistedStorage::name); + cmp(&HoistedStorage::body); + break; + } + } -IRComparer::CmpResult IRComparer::compare_expr_vector(const vector &a, const vector &b) { - if (result != Equal) { + // Don't hold onto pointers to this stack frame. + stack_ptr = stack_end = nullptr; return result; } +}; - compare_scalar(a.size(), b.size()); - for (size_t i = 0; (i < a.size()) && result == Equal; i++) { - compare_expr(a[i], b[i]); - } - - return result; -} - -void IRComparer::visit(const IntImm *op) { - const IntImm *e = expr.as(); - compare_scalar(e->value, op->value); -} - -void IRComparer::visit(const UIntImm *op) { - const UIntImm *e = expr.as(); - compare_scalar(e->value, op->value); -} - -void IRComparer::visit(const FloatImm *op) { - const FloatImm *e = expr.as(); - compare_scalar(e->value, op->value); -} - -void IRComparer::visit(const StringImm *op) { - const StringImm *e = expr.as(); - compare_names(e->value, op->value); -} - -void IRComparer::visit(const Cast *op) { - compare_expr(expr.as()->value, op->value); -} - -void IRComparer::visit(const Reinterpret *op) { - compare_expr(expr.as()->value, op->value); -} - -void IRComparer::visit(const Variable *op) { - const Variable *e = expr.as(); - compare_names(e->name, op->name); -} - -namespace { -template -void visit_binary_operator(IRComparer *cmp, const T *op, Expr expr) { - const T *e = expr.as(); - cmp->compare_expr(e->a, op->a); - cmp->compare_expr(e->b, op->b); -} -} // namespace - -void IRComparer::visit(const Add *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const Sub *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const Mul *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const Div *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const Mod *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const Min *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const Max *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const EQ *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const NE *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const LT *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const LE *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const GT *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const GE *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const And *op) { - visit_binary_operator(this, op, expr); -} -void IRComparer::visit(const Or *op) { - visit_binary_operator(this, op, expr); -} - -void IRComparer::visit(const Not *op) { - const Not *e = expr.as(); - compare_expr(e->a, op->a); -} - -void IRComparer::visit(const Select *op) { - const Select *e = expr.as