From e60003c20095f1e40d2e492fd06267f7293c6764 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 1 Apr 2020 17:14:49 -0700 Subject: [PATCH] [REFACTOR][TIR] Introduce ExprDeepEqual, Remove IRDeepCompare (#5206) * [REFACTOR][TIR] Introduce ExprDeepEqual, Remove IRDeepCompare This PR introduces ExprDeepEqual which reuses the StructuralEqual infra. We migrated the usecases of ir_pass::Equal to ExprDeepEqual and StructuralEqual. * Address comments --- docs/api/python/tir.rst | 9 +- include/tvm/tir/analysis.h | 54 ++ include/tvm/tir/expr.h | 11 + include/tvm/tir/ir_pass.h | 29 -- python/tvm/hybrid/calls.py | 3 +- python/tvm/hybrid/parser.py | 6 +- python/tvm/ir/base.py | 4 + python/tvm/tir/__init__.py | 1 + python/tvm/tir/analysis/__init__.py | 20 + python/tvm/tir/analysis/_ffi_api.py | 21 + python/tvm/tir/analysis/analysis.py | 57 +++ src/arith/canonical_simplify.cc | 4 +- src/arith/const_int_bound.cc | 3 +- src/arith/pattern_match.h | 3 +- src/arith/rewrite_simplify.cc | 9 +- src/arith/stmt_simplify.cc | 4 +- src/node/structural_equal.cc | 1 - src/relay/op/nn/convolution.h | 9 +- src/relay/qnn/op/convolution.cc | 7 +- src/te/operation/hybrid_op.cc | 5 +- src/te/operation/tensorize.cc | 4 +- src/tir/analysis/deep_equal.cc | 75 +++ src/tir/ir/buffer.cc | 10 +- src/tir/pass/ffi_api.cc | 9 - src/tir/pass/ir_deep_compare.cc | 460 ------------------ src/tir/pass/make_api.cc | 3 +- src/tir/pass/storage_rewrite.cc | 3 +- src/tir/pass/storage_sync.cc | 5 +- src/tir/transforms/combine_context_call.cc | 21 +- tests/cpp/pattern_match_test.cc | 12 +- .../unittest/test_arith_canonical_simplify.py | 6 +- .../test_arith_detect_linear_equation.py | 2 +- tests/python/unittest/test_arith_intset.py | 2 +- .../unittest/test_arith_rewrite_simplify.py | 2 +- tests/python/unittest/test_hybrid_script.py | 4 +- .../unittest/test_te_schedule_tensorize.py | 55 ++- ...y => test_tir_analysis_expr_deep_equal.py} | 36 +- tests/python/unittest/test_tir_buffer.py | 18 +- tests/python/unittest/test_tir_ops.py | 12 +- tests/python/unittest/test_tir_pass_basic.py | 6 +- .../unittest/test_tir_pass_loop_partition.py | 2 +- .../test_tir_structural_equal_hash.py | 25 + .../test_tir_transform_prim_func_pass.py | 2 +- topi/include/topi/detail/constant_utils.h | 6 +- vta/python/vta/ir_pass.py | 20 +- 45 files changed, 419 insertions(+), 641 deletions(-) create mode 100644 include/tvm/tir/analysis.h create mode 100644 python/tvm/tir/analysis/__init__.py create mode 100644 python/tvm/tir/analysis/_ffi_api.py create mode 100644 python/tvm/tir/analysis/analysis.py create mode 100644 src/tir/analysis/deep_equal.cc delete mode 100644 src/tir/pass/ir_deep_compare.cc rename tests/python/unittest/{test_tir_pass_equal.py => test_tir_analysis_expr_deep_equal.py} (51%) diff --git a/docs/api/python/tir.rst b/docs/api/python/tir.rst index dd08758e92ce..8ef247aff2f7 100644 --- a/docs/api/python/tir.rst +++ b/docs/api/python/tir.rst @@ -24,10 +24,17 @@ tvm.tir :autosummary: - tvm.tir.transform ----------------- .. automodule:: tvm.tir.transform :members: :imported-members: :autosummary: + + +tvm.tir.analysis +---------------- +.. automodule:: tvm.tir.analysis + :members: + :imported-members: + :autosummary: diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h new file mode 100644 index 000000000000..6bab44e2355b --- /dev/null +++ b/include/tvm/tir/analysis.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/analysis.h + * \brief Analysis utilitie and passes for TIR. + */ +#ifndef TVM_TIR_ANALYSIS_H_ +#define TVM_TIR_ANALYSIS_H_ + +#include +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Compare two expressions recursively and check if they are equal + * to each other without var remapping. + * + * This function does not remap variable bindings, it will not + * return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y). + * + * Use StructuralEqual for such cases. + * + * Due to the restriction of not remapping variables, this function can run + * faster than StructuralEqual and can be used as a utility function during arithmetic + * simplifications. + * + * \sa StructuralEqual + */ +struct ExprDeepEqual { + public: + TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const; +}; +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_ANALYSIS_H_ diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 985c67137385..7b8ab44036fd 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -920,6 +920,17 @@ class FunctionBaseNode : public Object { virtual const std::string& func_name() const = 0; /*! \return the number of outputs of this function */ virtual int num_outputs() const = 0; + + // fall back to pointer equality now before refactor. + bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const { + return this == other; + } + + void SHashReduce(SHashReducer hash_reduce) const { + } + + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; }; /*! \brief reference to a function */ diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 6e9a631fab4d..d54e094afe71 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -76,35 +76,6 @@ Stmt CanonicalSimplify(Stmt stmt, TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr, Map vrange = Map()); -/*! - * \brief Deep compare lhs and rhs - * \param lhs The left operand - * \param rhs The right operand - * \return The comparison result. - */ -TVM_DLL bool Equal(const PrimExpr& lhs, const PrimExpr& rhs); - -/*! - * \brief Deep compare lhs and rhs - * \param lhs The left operand - * \param rhs The right operand - * \return The comparison result. - */ -bool Equal(const Stmt& lhs, const Stmt& rhs); - -/*! - * \brief Deep compare lhs and rhs. - * - * If you only want equality comparison, use Equal - * which will also tie definitions. The compare mode - * will give order of expression in total order. - * - * \param lhs The left operand - * \param rhs The right operand - * \return The comparison result. - */ -int Compare(const PrimExpr& lhs, const PrimExpr& rhs); - /*! * \brief verifies whether the IR stmt or Expr is in SSA form. * That is: each VarExpr is defined and assigned once(in Let/For) diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 5b5c34d5cb0f..dfbb185a7eb4 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -22,7 +22,6 @@ from tvm.ir.container import Array from tvm import target as _tgt from tvm.tir import expr as _expr -from tvm.tir import ir_pass from tvm.tir import call_pure_intrin from tvm.tir.stmt import For @@ -47,7 +46,7 @@ def _range(annotation, args): else: _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!") low, ext = args[0], args[1] - if not ir_pass.Equal(low, const(0, dtype='int32')): + if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype='int32')): ext = ext - low for_type = LOOP_INTRIN[annotation] iter_var = None diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 0f8f3dd2ad01..107f51b8bbcc 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -56,7 +56,7 @@ def concat_list_to_block(lst): def visit_list_to_block(visit, lst): """Visit and concatenate a list of Python IR nodes to HalideIR Block""" lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)] - lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())] + lst = [stmt for stmt in lst if not tvm.ir.structural_equal(stmt, util.make_nop())] if not lst: return util.make_nop() return concat_list_to_block(lst) @@ -178,7 +178,7 @@ def add_symbol(self, key, ty, val): #pylint: disable=invalid-name self.binds[val.var.name] = val return val_ = self.binds[val.var.name] - _internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent), + _internal_assert(tvm.tir.analysis.expr_deep_equal(val_.dom.extent, val.dom.extent), "Thread extents should be uniform!") self.symbols[key] = ty, val_ @@ -525,7 +525,7 @@ def visit_For(self, node): if iter_var is None: _internal_assert(for_type is not None, "The loop iterating function parse error!") offset = iter_var = tvm.te.var(_name) - if not _ir_pass.Equal(low, tvm.runtime.const(0, 'int32')): + if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, 'int32')): offset = iter_var + low self.add_symbol(_name, Symbol.LoopVar, offset) _body = visit_list_to_block(self.visit, node.body) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 3c6e6ff1515f..bab98382e713 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -198,6 +198,8 @@ def structural_equal(lhs, rhs, map_free_vars=False): structural_hash assert_strucural_equal """ + lhs = tvm.runtime.convert(lhs) + rhs = tvm.runtime.convert(rhs) return bool(tvm.runtime._ffi_node_api.StructuralEqual( lhs, rhs, False, map_free_vars)) @@ -225,6 +227,8 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): -------- structural_equal """ + lhs = tvm.runtime.convert(lhs) + rhs = tvm.runtime.convert(rhs) tvm.runtime._ffi_node_api.StructuralEqual( lhs, rhs, True, map_free_vars) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index f0d4d931a370..653c3954f489 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -46,3 +46,4 @@ from . import ir_builder from . import ir_pass from . import transform +from . import analysis diff --git a/python/tvm/tir/analysis/__init__.py b/python/tvm/tir/analysis/__init__.py new file mode 100644 index 000000000000..c142485c5307 --- /dev/null +++ b/python/tvm/tir/analysis/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Namespace of all TIR analysis utils.""" +# pylint: disable=wildcard-import, invalid-name + +from .analysis import * diff --git a/python/tvm/tir/analysis/_ffi_api.py b/python/tvm/tir/analysis/_ffi_api.py new file mode 100644 index 000000000000..6c1687e8a520 --- /dev/null +++ b/python/tvm/tir/analysis/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.analysis""" +import tvm._ffi + + +tvm._ffi._init_api("tir.analysis", __name__) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py new file mode 100644 index 000000000000..84eeaac370c2 --- /dev/null +++ b/python/tvm/tir/analysis/analysis.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Wrapping existing analysis utils.""" +# pylint: disable=invalid-name + +from . import _ffi_api + + +def expr_deep_equal(lhs, rhs): + """Deeply compare two nested expressions. + + Parameters + ---------- + lhs : PrimExpr + The left operand. + + rhs : PrimExpr + The right operand. + + Returns + ------- + result : bool + The comparison result + + Note + ---- + + This function does not remap variable bindings, it will not + return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y). + Use py:func:`tvm.ir.structural_equal` to handle structural variable remapping. + + Due to the restriction of not remapping variables, this function can run + faster than StructuralEqual and can be used as a utility function during arithmetic + simplifications. + + Always consider py:func:`tvm.ir.structural_equal` first, which handles + the structural remapping. + + See Also + -------- + tvm.ir.structural_equal + """ + return _ffi_api.expr_deep_equal(lhs, rhs) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 3580cddf8d2e..7a6e772c2935 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -23,6 +23,8 @@ */ #include #include +#include + #include "const_fold.h" #include "pattern_match.h" #include "rewrite_simplify.h" @@ -157,7 +159,7 @@ class SplitExpr : public PrimExpr { inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { if (index.same_as(other->index)) return true; - return tir::Equal(index, other->index); + return tir::ExprDeepEqual()(index, other->index); } inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const { diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 9ef5723e153e..702e77532d95 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -138,10 +138,11 @@ class ConstIntBoundAnalyzer::Impl : Entry VisitExpr(const PrimExpr& expr) final { Entry res = ExprFunctor::VisitExpr(expr); + tir::ExprDeepEqual equal; // a linear search over additional info // assume we won't have a lot of conditions for (const BoundInfo& info : additional_info_) { - if (tir::Equal(expr, info.expr)) { + if (equal(expr, info.expr)) { res = Intersect(res, info.bound); } } diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 8a2df5043ffb..e81b0881f927 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -66,6 +66,7 @@ #define TVM_ARITH_PATTERN_MATCH_H_ #include +#include #include #include "const_fold.h" @@ -135,7 +136,7 @@ class PEqualChecker { public: bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { if (lhs.same_as(rhs)) return true; - return tir::Equal(lhs, rhs); + return tir::ExprDeepEqual()(lhs, rhs); } }; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 39b87ef1b056..126310813cc4 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -101,11 +101,11 @@ TryCompare(const PrimExpr& x, int64_t val) { } void RewriteSimplifier::Impl:: -Update(const Var& var, const PrimExpr& info, bool override) { - if (!override) { +Update(const Var& var, const PrimExpr& info, bool can_override) { + if (!can_override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(Equal(it->second, info)) + CHECK(ExprDeepEqual()(it->second, info)) << "Trying to update var \'" << var << "\'" << " with a different value: " << "original=" << it->second @@ -1716,10 +1716,11 @@ VisitExpr_(const CallNode* op) { return op->args[0] & op->args[1]; } } + ExprDeepEqual expr_equal; if (op->is_intrinsic(CallNode::likely)) { for (const auto& constraint : literal_constraints_) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } - if (Equal(constraint, op->args[0])) { + if (expr_equal(constraint, op->args[0])) { return make_const(op->dtype, true); } } diff --git a/src/arith/stmt_simplify.cc b/src/arith/stmt_simplify.cc index c0bc0c4787f1..6c3dd022565c 100644 --- a/src/arith/stmt_simplify.cc +++ b/src/arith/stmt_simplify.cc @@ -23,7 +23,9 @@ */ #include #include +#include #include + #include #include #include "ir_mutator_with_analyzer.h" @@ -83,7 +85,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { op = stmt.as(); if (const LoadNode* load = op->value.as()) { if (load->buffer_var.same_as(op->buffer_var) && - Equal(load->index, op->index)) { + tir::ExprDeepEqual()(load->index, op->index)) { return EvaluateNode::make(0); } } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index b2191c1f890c..0078781ebf4f 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -225,7 +225,6 @@ class RemapVarSEqualHandler : std::unordered_map equal_map_rhs_; }; - TVM_REGISTER_GLOBAL("node.StructuralEqual") .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 6a69178f49b1..05c11719c320 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -25,6 +25,8 @@ #define TVM_RELAY_OP_NN_CONVOLUTION_H_ #include +#include + #include #include @@ -158,8 +160,8 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(weight && weight->shape.defined()) << "Weight shape must be specified when groups is greater than 1."; Array wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape); - if (tvm::tir::Equal(param->groups, dshape_nchw[1]) && - tvm::tir::Equal(param->groups, wshape_oihw[0])) { + if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) && + tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) { is_depthwise = true; } } @@ -279,8 +281,9 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(param->kernel_size.size(), 3); CHECK_EQ(param->dilation.size(), 3); Array wshape; + tvm::tir::ExprDeepEqual expr_equal; - if (tvm::tir::Equal(param->channels, param->groups) && !tvm::tir::Equal(param->channels, 1)) { + if (expr_equal(param->channels, param->groups) && !expr_equal(param->channels, 1)) { // infer weight's shape for depthwise convolution wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}}; diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index de0aae3195f8..37186283ba51 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -27,6 +27,8 @@ #include #include #include +#include + #include "../../op/nn/convolution.h" #include "../../transforms/pattern_util.h" #include "../util.h" @@ -86,8 +88,9 @@ Array> QnnConvInferCorrectLayout(const Attrs& attrs, } bool is_depthwise(const Conv2DAttrs* param) { - return param->channels.defined() && tvm::tir::Equal(param->channels, param->groups) && - param->groups != 1; + return param->channels.defined() && + tvm::tir::ExprDeepEqual()(param->channels, param->groups) && + param->groups != 1; } // Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index dcd09f9f1fa8..4da127ea0a85 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -338,12 +339,14 @@ Stmt ApplyLoopAnnotations(const Stage &stage, LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {} Stmt VisitStmt_(const ForNode *op) final { + tir::ExprDeepEqual expr_equal; + if (op->loop_var.get() == var) { if (attr->bind_thread.defined()) { const auto &iter_var = attr->bind_thread; if (iter_var->dom.defined()) { CHECK(is_const_int(iter_var->dom->min, 0)); - CHECK(Equal(iter_var->dom->extent, op->extent)) + CHECK(expr_equal(iter_var->dom->extent, op->extent)) << "Thread extent and loop extent mismatch!\n"; } std::unordered_map rmap; diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index ba84a9088553..6064f5c4e008 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include "op_util.h" @@ -330,6 +331,7 @@ void VerifyTensorizeBody( const std::unordered_map& out_dom, const std::unordered_map >& in_region, const TensorIntrin& intrin) { + StructuralEqual expr_equal; Map compute_intrin_iter_space; Array body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin, &compute_intrin_iter_space); @@ -349,7 +351,7 @@ void VerifyTensorizeBody( << " provided=" << lhs.dtype() << ", intrin=" << rhs.dtype(); } - CHECK(Equal(lhs, rhs)) + CHECK(expr_equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " << intrin->name << "'s declaration " << " provided= " << lhs diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc new file mode 100644 index 000000000000..763e3eb7cdae --- /dev/null +++ b/src/tir/analysis/deep_equal.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/deep_equal.cc + * \brief Deep equality checking. + */ +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class DeepCmpSEqualHandler : + public SEqualReducer::Handler { + public: + // use direct recursion. + bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() && rhs.defined()) return false; + if (!rhs.defined() && lhs.defined()) return false; + if (lhs->type_index() != rhs->type_index()) return false; + return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, false)); + } + + ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { + return ObjectRef(nullptr); + } + + void MarkGraphNode() final { + } + + private: + // reflection vtable + ReflectionVTable* vtable_ = ReflectionVTable::Global(); +}; + +bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { + // quick path + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() && rhs.defined()) return false; + if (!rhs.defined() && lhs.defined()) return false; + if (lhs->type_index() != rhs->type_index()) return false; + if (auto* plhs = lhs.as()) { + auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && plhs->value == prhs->value; + } + return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false); +} + +TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") +.set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { + return ExprDeepEqual()(lhs, rhs); +}); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 19e32d6681ae..d663c30abb8c 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -24,7 +24,9 @@ #include #include #include +#include #include + #include #include #include "../../arith/compute_expr.h" @@ -112,6 +114,8 @@ inline std::pair MergeMulModInner(const PrimExpr &mult_expr, const PrimExpr* search_ptr = inner; PrimExpr mult_inner; // The inner multiplication factor PrimExpr no_opt_sum; // Sum of the exprs that cannot be optimized + tir::ExprDeepEqual expr_equal; + while (true) { auto inner_div_ptr = search_ptr->as(); auto inner_mult_ptr = search_ptr->as(); @@ -120,9 +124,9 @@ inline std::pair MergeMulModInner(const PrimExpr &mult_expr, return std::make_pair(false, PrimExpr()); } else if (inner_div_ptr) { PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; - if (Equal(overall_mult, inner_div_ptr->b) - && Equal(overall_mult, mod_r_expr) - && Equal(inner_div_ptr->a, mod_l_expr)) { + if (expr_equal(overall_mult, inner_div_ptr->b) + && expr_equal(overall_mult, mod_r_expr) + && expr_equal(inner_div_ptr->a, mod_l_expr)) { // Found! PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; return std::make_pair(true, ret); diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 46d0f67c6d51..f4d8193ac483 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -75,15 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute") } }); -TVM_REGISTER_GLOBAL("ir_pass.Equal") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); - } else { - *ret = Equal(args[0].operator PrimExpr(), args[1].operator PrimExpr()); - } - }); - TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() <= 3) { diff --git a/src/tir/pass/ir_deep_compare.cc b/src/tir/pass/ir_deep_compare.cc deleted file mode 100644 index e45251fe8a4a..000000000000 --- a/src/tir/pass/ir_deep_compare.cc +++ /dev/null @@ -1,460 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ir_deep_compare.cc - */ -#include -#include -#include - -namespace tvm { -namespace tir { - -using ExprComparator = ExprFunctor; -using StmtComparator = StmtFunctor; - -#define DEFINE_BIOP_EXPR_CMP_(OP) \ - void VisitExpr_(const OP* op, const PrimExpr& other) final { \ - const OP* rhs = other.as(); \ - if (CompareExpr(op->a, rhs->a) != 0) return; \ - if (CompareExpr(op->b, rhs->b) != 0) return; \ - } - -// Deep comparison to check if two IR graph are equivalent -class IRDeepCompare : - public ExprComparator, public StmtComparator { - public: - // Equality comparison - bool Equal(const Stmt& lhs, const Stmt& rhs) { - tie_def_ = true; - VisitStmt(lhs, rhs); - return order_ == 0; - } - - bool Equal(const PrimExpr& lhs, const PrimExpr& rhs) { - tie_def_ = true; - VisitExpr(lhs, rhs); - return order_ == 0; - } - - int Compare(const PrimExpr& lhs, const PrimExpr& rhs) { - tie_def_ = false; - VisitExpr(lhs, rhs); - return order_; - } - - void VisitExpr(const PrimExpr& n, const PrimExpr& other) override { - if (order_ != 0) return; - if (n.same_as(other)) return; - if (CompareValue(n->type_index(), other->type_index()) != 0) return; - if (CompareType(n.dtype(), other.dtype()) != 0) return; - ExprComparator::VisitExpr(n, other); - } - - void VisitStmt(const Stmt& n, const Stmt& other) override { - if (order_ != 0) return; - if (n.same_as(other)) return; - if (CompareValue(n->type_index(), other->type_index()) != 0) return; - StmtComparator::VisitStmt(n, other); - } - // Stmt - void VisitStmt_(const LetStmtNode* op, const Stmt& other) final { - const LetStmtNode* rhs = other.as(); - if (CompareExpr(op->value, rhs->value) != 0) return; - if (tie_def_) { - vmap_[op->var.get()] = rhs->var.get(); - } else { - if (CompareExpr(op->var, rhs->var) != 0) return; - } - if (CompareStmt(op->body, rhs->body) != 0) return; - } - - void VisitStmt_(const AttrStmtNode* op, const Stmt& other) final { - const AttrStmtNode* rhs = other.as(); - if (CompareString(op->attr_key, rhs->attr_key) != 0) return; - if (CompareNodeRef(op->node, rhs->node) != 0) return; - if (CompareExpr(op->value, rhs->value) != 0) return; - if (CompareStmt(op->body, rhs->body) != 0) return; - } - - void VisitStmt_(const IfThenElseNode* op, const Stmt& other) final { - const IfThenElseNode* rhs = other.as(); - if (CompareExpr(op->condition, rhs->condition) != 0) return; - if (CompareStmt(op->then_case, rhs->then_case) != 0) return; - if (CompareStmt(op->else_case, rhs->else_case) != 0) return; - } - - void VisitStmt_(const ForNode* op, const Stmt& other) final { - const ForNode* rhs = other.as(); - if (CompareExpr(op->min, rhs->min) != 0) return; - if (CompareExpr(op->extent, rhs->extent) != 0) return; - if (tie_def_) { - vmap_[op->loop_var.get()] = rhs->loop_var.get(); - } else { - if (CompareExpr(op->loop_var, rhs->loop_var) != 0) return; - } - if (CompareStmt(op->body, rhs->body) != 0) return; - } - - void VisitStmt_(const AllocateNode* op, const Stmt& other) final { - const AllocateNode* rhs = other.as(); - if (tie_def_) { - vmap_[op->buffer_var.get()] = rhs->buffer_var.get(); - } else { - if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; - } - if (CompareType(op->dtype, rhs->dtype) != 0) return; - if (CompareArray(op->extents, rhs->extents) != 0) return; - if (CompareExpr(op->condition, rhs->condition) != 0) return; - if (CompareStmt(op->body, rhs->body) != 0) return; - if (CompareExpr(op->new_expr, rhs->new_expr) != 0) return; - if (CompareString(op->free_function, rhs->free_function) != 0) return; - } - - void VisitStmt_(const StoreNode* op, const Stmt& other) final { - const StoreNode* rhs = other.as(); - if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; - if (CompareExpr(op->value, rhs->value) != 0) return; - if (CompareExpr(op->index, rhs->index) != 0) return; - if (CompareExpr(op->predicate, rhs->predicate) != 0) return; - } - - void VisitStmt_(const FreeNode* op, const Stmt& other) final { - const FreeNode* rhs = other.as(); - if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; - } - - void VisitStmt_(const AssertStmtNode* op, const Stmt& other) final { - const AssertStmtNode* rhs = other.as(); - if (CompareExpr(op->condition, rhs->condition) != 0) return; - if (CompareExpr(op->message, rhs->message) != 0) return; - if (CompareStmt(op->body, rhs->body) != 0) return; - } - - void VisitStmt_(const ProducerConsumerNode* op, const Stmt& other) final { - const ProducerConsumerNode* rhs = other.as(); - if (CompareNodeRef(op->func, rhs->func) != 0) return; - if (CompareValue(op->is_producer, rhs->is_producer) != 0) return; - if (CompareStmt(op->body, rhs->body) != 0) return; - } - - void VisitStmt_(const ProvideNode* op, const Stmt& other) final { - const ProvideNode* rhs = other.as(); - if (CompareNodeRef(op->func, rhs->func) != 0) return; - if (CompareValue(op->value_index, rhs->value_index) != 0) return; - if (CompareExpr(op->value, rhs->value) != 0) return; - if (CompareArray(op->args, rhs->args) != 0) return; - } - - void VisitStmt_(const RealizeNode* op, const Stmt& other) final { - const RealizeNode* rhs = other.as(); - if (CompareNodeRef(op->func, rhs->func) != 0) return; - if (CompareValue(op->value_index, rhs->value_index) != 0) return; - if (CompareType(op->dtype, rhs->dtype) != 0) return; - if (CompareRegion(op->bounds, rhs->bounds) != 0) return; - if (CompareStmt(op->body, rhs->body) != 0) return; - } - - void VisitStmt_(const PrefetchNode* op, const Stmt& other) final { - const PrefetchNode* rhs = other.as(); - if (CompareNodeRef(op->func, rhs->func) != 0) return; - if (CompareValue(op->value_index, rhs->value_index) != 0) return; - if (CompareType(op->dtype, rhs->dtype) != 0) return; - if (CompareRegion(op->bounds, rhs->bounds) != 0) return; - } - - void VisitStmt_(const SeqStmtNode* op, const Stmt& other) final { - const SeqStmtNode* rhs = other.as(); - if (CompareValue(op->size(), rhs->size()) != 0) return; - for (size_t i = 0; i < op->size(); ++i) { - if (CompareStmt(op->seq[i], rhs->seq[i]) != 0) return; - } - } - - void VisitStmt_(const EvaluateNode* op, const Stmt& other) final { - const EvaluateNode* rhs = other.as(); - CompareExpr(op->value, rhs->value); - } - - // Exprs - void VisitExpr_(const VarNode* op, const PrimExpr& other) final { - const VarNode* rhs = other.as(); - auto it = vmap_.find(op); - if (it != vmap_.end()) op = it->second; - if (op < rhs) { - order_ = -1; - } else if (op > rhs) { - order_ = +1; - } - } - void VisitExpr_(const LoadNode* op, const PrimExpr& other) final { - const LoadNode* rhs = other.as(); - if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; - if (CompareExpr(op->index, rhs->index) != 0) return; - if (CompareExpr(op->predicate, rhs->predicate) != 0) return; - } - - void VisitExpr_(const LetNode* op, const PrimExpr& other) final { - const LetNode* rhs = other.as(); - if (tie_def_) { - vmap_[op->var.get()] = rhs->var.get(); - } else { - if (CompareExpr(op->var, rhs->var) != 0) return; - } - if (CompareExpr(op->value, rhs->value) != 0) return; - if (CompareExpr(op->body, rhs->body) != 0) return; - } - - void VisitExpr_(const CallNode* op, const PrimExpr& other) final { - const CallNode* rhs = other.as(); - if (CompareString(op->name, rhs->name)) return; - if (CompareArray(op->args, rhs->args)) return; - if (CompareValue(op->call_type, rhs->call_type) != 0) return; - if (CompareNodeRef(op->func, rhs->func) != 0) return; - if (CompareValue(op->value_index, rhs->value_index) != 0) return; - } - - void VisitExpr_(const ReduceNode *op, const PrimExpr& other) final { - const ReduceNode* rhs = other.as(); - if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return; - if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return; - if (CompareValue(op->value_index, rhs->value_index) != 0) return; - for (size_t i = 0; i < op->axis.size(); ++i) { - if (CompareExpr(op->axis[i]->dom->min, rhs->axis[i]->dom->min) != 0) return; - if (CompareExpr(op->axis[i]->dom->extent, rhs->axis[i]->dom->extent) != 0) return; - if (tie_def_) { - vmap_[op->axis[i]->var.get()] = rhs->axis[i]->var.get(); - } else { - if (CompareExpr(op->axis[i]->var, rhs->axis[i]->var) != 0) return; - } - } - if (CompareExpr(op->condition, rhs->condition) != 0) return; - if (CompareArray(op->source, rhs->source) != 0) return; - } - - void VisitExpr_(const IntImmNode *op, const PrimExpr& other) final { - CompareValue(op->value, other.as()->value); - } - - void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final { - CompareValue(op->value, other.as()->value); - } - - void VisitExpr_(const StringImmNode *op, const PrimExpr& other) final { - CompareString(op->value, other.as()->value); - } - - void VisitExpr_(const CastNode *op, const PrimExpr& other) final { - CompareExpr(op->value, other.as()->value); - } - - void VisitExpr_(const NotNode *op, const PrimExpr& other) final { - CompareExpr(op->a, other.as()->a); - } - - void VisitExpr_(const SelectNode *op, const PrimExpr& other) final { - const SelectNode* rhs = other.as(); - if (CompareExpr(op->condition, rhs->condition) != 0) return; - if (CompareExpr(op->true_value, rhs->true_value) != 0) return; - if (CompareExpr(op->false_value, rhs->false_value) != 0) return; - } - - void VisitExpr_(const RampNode *op, const PrimExpr& other) final { - const RampNode* rhs = other.as(); - if (CompareExpr(op->base, rhs->base) != 0) return; - if (CompareExpr(op->stride, rhs->stride) != 0) return; - if (CompareValue(op->lanes, rhs->lanes) != 0) return; - } - - void VisitExpr_(const BroadcastNode *op, const PrimExpr& other) final { - const BroadcastNode* rhs = other.as(); - if (CompareExpr(op->value, rhs->value) != 0) return; - if (CompareValue(op->lanes, rhs->lanes) != 0) return; - } - - void VisitExpr_(const ShuffleNode *op, const PrimExpr& other) final { - const ShuffleNode* rhs = other.as(); - if (CompareArray(op->vectors, rhs->vectors) != 0) return; - if (CompareArray(op->indices, rhs->indices) != 0) return; - } - - DEFINE_BIOP_EXPR_CMP_(AddNode) - DEFINE_BIOP_EXPR_CMP_(SubNode) - DEFINE_BIOP_EXPR_CMP_(MulNode) - DEFINE_BIOP_EXPR_CMP_(DivNode) - DEFINE_BIOP_EXPR_CMP_(ModNode) - DEFINE_BIOP_EXPR_CMP_(FloorDivNode) - DEFINE_BIOP_EXPR_CMP_(FloorModNode) - DEFINE_BIOP_EXPR_CMP_(MinNode) - DEFINE_BIOP_EXPR_CMP_(MaxNode) - DEFINE_BIOP_EXPR_CMP_(EQNode) - DEFINE_BIOP_EXPR_CMP_(NENode) - DEFINE_BIOP_EXPR_CMP_(LTNode) - DEFINE_BIOP_EXPR_CMP_(LENode) - DEFINE_BIOP_EXPR_CMP_(GTNode) - DEFINE_BIOP_EXPR_CMP_(GENode) - DEFINE_BIOP_EXPR_CMP_(AndNode) - DEFINE_BIOP_EXPR_CMP_(OrNode) - - private: - int CompareExpr(const PrimExpr& lhs, const PrimExpr& rhs) { - if (order_ != 0) return order_; - if (!lhs.defined() && rhs.defined()) { - order_ = -1; return order_; - } - if (!rhs.defined() && lhs.defined()) { - order_ = +1; return order_; - } - VisitExpr(lhs, rhs); - return order_; - } - - int CompareStmt(const Stmt& lhs, const Stmt& rhs) { - if (order_ != 0) return order_; - if (!lhs.defined() && rhs.defined()) { - order_ = -1; return order_; - } - if (!rhs.defined() && lhs.defined()) { - order_ = +1; return order_; - } - VisitStmt(lhs, rhs); - return order_; - } - - int CompareArray(const Array& lhs, const Array& rhs) { - if (order_ != 0) return order_; - if (CompareValue(lhs.size(), rhs.size()) != 0) return order_; - for (size_t i = 0; i < lhs.size(); ++i) { - if (CompareExpr(lhs[i], rhs[i]) != 0) return order_; - } - return order_; - } - - int CompareRegion(const Region& lhs, const Region& rhs) { - if (order_ != 0) return order_; - if (CompareValue(lhs.size(), rhs.size()) != 0) return order_; - for (size_t i = 0; i < lhs.size(); ++i) { - if (CompareExpr(lhs[i]->min, rhs[i]->min) != 0) return order_; - if (CompareExpr(lhs[i]->extent, rhs[i]->extent) != 0) return order_; - } - return order_; - } - - int CompareNodeRef(const ObjectRef& lhs, const ObjectRef& rhs) { - if (order_ != 0) return order_; - if (lhs.get() < rhs.get()) { - order_ = -1; return order_; - } - if (lhs.get() > rhs.get()) { - order_ = +1; return order_; - } - return order_; - } - - int CompareType(const DataType& lhs, const DataType& rhs) { - if (order_ != 0) return order_; - if (lhs == rhs) return order_; - if (CompareValue(lhs.code(), rhs.code()) != 0) return order_; - if (CompareValue(lhs.bits(), rhs.bits()) != 0) return order_; - if (CompareValue(lhs.lanes(), rhs.lanes()) != 0) return order_; - return order_; - } - - int CompareString(const std::string& lhs, const std::string& rhs) { - if (order_ != 0) return order_; - order_ = lhs.compare(rhs); - return order_; - } - - template - int CompareValue(const T& lhs, const T& rhs) { - if (order_ != 0) return order_; - if (lhs < rhs) { - order_ = -1; return order_; - } else if (lhs > rhs) { - order_ = +1; return order_; - } - return order_; - } - - int CompareCommReducer(const CommReducer& lhs, const CommReducer& rhs) { - if (order_ != 0) return order_; - if (lhs == rhs) return order_; - if (CompareValue(lhs->lhs.size(), rhs->lhs.size()) != 0) return order_; - if (CompareValue(lhs->rhs.size(), rhs->rhs.size()) != 0) return order_; - IRDeepCompare cmp; - if (tie_def_) { - for (size_t i = 0; i < lhs->lhs.size(); ++i) { - cmp.vmap_[lhs->lhs[i].get()] = rhs->lhs[i].get(); - } - for (size_t i = 0; i < lhs->rhs.size(); ++i) { - cmp.vmap_[lhs->rhs[i].get()] = rhs->rhs[i].get(); - } - } else { - for (size_t i = 0; i < lhs->lhs.size(); ++i) { - if (CompareExpr(lhs->lhs[i], rhs->lhs[i]) != 0) return order_; - } - for (size_t i = 0; i < lhs->lhs.size(); ++i) { - if (CompareExpr(lhs->rhs[i], rhs->rhs[i]) != 0) return order_; - } - } - order_ = cmp.CompareArray(lhs->result, rhs->result); - return order_; - } - // The order flag, smaller, -1, bigger: +1, equal: 0 - int order_{0}; - // Whether tie intermediate definitions. - // This allows use to tie definitions of two variables together. - // This enables us to assert equal between (let x in x + 1), (let y in y + 1) - // However, the comparison is no longer in total order. - // Only equality/non-equality information is valid. - bool tie_def_{false}; - // varaible remap if any - std::unordered_map vmap_; -}; - - -bool Equal(const Stmt& lhs, const Stmt& rhs) { - return IRDeepCompare().Equal(lhs, rhs); -} - -bool Equal(const PrimExpr& lhs, const PrimExpr& rhs) { - // quick pass for constant expressions. - if (const int64_t *a = as_const_int(lhs)) { - if (const int64_t *b = as_const_int(rhs)) { - return a[0] == b[0]; - } - } - if (!lhs.defined()) { - if (rhs.defined()) return false; - if (!rhs.defined()) return true; - } else { - if (!rhs.defined()) return false; - } - // deep comparison. - return IRDeepCompare().Equal(lhs, rhs); -} - -int Compare(const PrimExpr& lhs, const PrimExpr& rhs) { - return IRDeepCompare().Compare(lhs, rhs); -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/pass/make_api.cc b/src/tir/pass/make_api.cc index 70ea2a21a869..f8eae645a044 100644 --- a/src/tir/pass/make_api.cc +++ b/src/tir/pass/make_api.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -255,7 +256,7 @@ class DeviceTypeBinder: public StmtExprMutator { // eager check NE for device check PrimExpr res = StmtExprMutator::VisitExpr_(op); op = res.as(); - if (tir::Equal(op->a, op->b)) { + if (tir::ExprDeepEqual()(op->a, op->b)) { return make_const(op->dtype, false); } return res; diff --git a/src/tir/pass/storage_rewrite.cc b/src/tir/pass/storage_rewrite.cc index 39f71dd629b4..b4e6061a35d0 100644 --- a/src/tir/pass/storage_rewrite.cc +++ b/src/tir/pass/storage_rewrite.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -311,7 +312,7 @@ class InplaceOpVerifier : public StmtExprVisitor { if (src_ == buf) { if (store_ == nullptr || store_->value.dtype() != op->dtype || - !tir::Equal(store_->index, op->index)) { + !tir::ExprDeepEqual()(store_->index, op->index)) { result_ = false; return; } } diff --git a/src/tir/pass/storage_sync.cc b/src/tir/pass/storage_sync.cc index 0f9af3ca48db..7e81ba613cda 100644 --- a/src/tir/pass/storage_sync.cc +++ b/src/tir/pass/storage_sync.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -179,8 +180,8 @@ class ThreadSyncPlanner : public StorageAccessVisitor { // TODO(tqchen) more standard set based testing. if (e.touched.is_single_point() && x.touched.is_single_point()) { - if (Equal(e.touched.point_value(), - x.touched.point_value())) continue; + if (ExprDeepEqual()(e.touched.point_value(), + x.touched.point_value())) continue; } if (x.double_buffer_write && e.type == kRead && diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 069de571dc38..324c1704aa63 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -26,11 +26,13 @@ #include #include #include +#include +#include #include #include -#include +#include namespace tvm { namespace tir { @@ -39,12 +41,6 @@ namespace tir { // These information are needed during codegen. class ContextCallCombiner final : public StmtExprMutator { public: - struct CompareExpr { - bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { - return Compare(lhs, rhs) < 0; - } - }; - PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_thread_context)) { CHECK_EQ(op->args.size(), 1U); @@ -73,7 +69,7 @@ class ContextCallCombiner final : public StmtExprMutator { if (op->attr_key == attr::thread_extent || op->attr_key == attr::coproc_uop_scope) { // Map of comparison expression to variable - std::map temp; + std::unordered_map temp; std::swap(temp, ctx_map_); Stmt stmt = StmtExprMutator::VisitStmt_(op); std::swap(temp, ctx_map_); @@ -86,7 +82,7 @@ class ContextCallCombiner final : public StmtExprMutator { Stmt VisitStmt_(const ForNode* op) final { if (op->for_type == ForType::Parallel) { // Map of comparison expression to variable - std::map temp; + std::unordered_map temp; std::swap(temp, ctx_map_); Stmt stmt = StmtExprMutator::VisitStmt_(op); std::swap(temp, ctx_map_); @@ -101,15 +97,16 @@ class ContextCallCombiner final : public StmtExprMutator { } private: - static Stmt BuildContext(const std::map& cmap, - Stmt body) { + static Stmt BuildContext( + const std::unordered_map& cmap, + Stmt body) { for (const auto& kv : cmap) { body = LetStmtNode::make(kv.second, kv.first, body); } return body; } // Map of comparison expression to variable - std::map ctx_map_; + std::unordered_map ctx_map_; }; LoweredFunc CombineContextCall(LoweredFunc f) { diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 5176a5d6f6f6..5cb79101a05e 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -18,6 +18,7 @@ */ #include +#include #include "../src/arith/pattern_match.h" TEST(Pattern, Basic) { @@ -39,12 +40,13 @@ TEST(Pattern, Basic) { { CHECK((px + (py + px)).Match(r)); auto rr = (px + py).Eval(); - CHECK(tir::Equal(rr, 1 + y)); - CHECK(tir::Equal(px.Eval() + py.Eval(), 1 + y)); + + CHECK(tir::ExprDeepEqual()(rr, 1 + y)); + CHECK(tir::ExprDeepEqual()(px.Eval() + py.Eval(), 1 + y)); } { CHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1)))); - CHECK(tir::Equal(px.Eval(), x + 1)); + CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); CHECK((px + min(py, px)).Match(z + min(y, z))); @@ -64,7 +66,7 @@ TEST(Pattern, Basic) { { CHECK(select(px >= pz, py, py + pz).Match( tir::SelectNode::make((x + 1) >= 1, y, y + 1))); - CHECK(tir::Equal(px.Eval(), x + 1)); + CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } // bit intrinsics { @@ -90,7 +92,7 @@ TEST(Pattern, Basic) { { CHECK(select(px, py, pz).Match( tir::SelectNode::make(x > 2, y, y + 1))); - CHECK(tir::Equal(pz.Eval(), y + 1)); + CHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1)); } // if_then_else { diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index b4649a4ba75e..0dcf1fb5344c 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -23,7 +23,9 @@ def __init__(self): def verify(self, data, expected): res = self.analyzer.canonical_simplify(data) - assert tvm.tir.ir_pass.Equal(res, expected), "\ndata={}\nres={}\nexpected={}".format(data, res, expected) + expected = tvm.runtime.convert(expected) + assert tvm.ir.structural_equal( + res, expected), "\ndata={}\nres={}\nexpected={}".format(data, res, expected) def test_mul_sum_simplify(): @@ -197,7 +199,7 @@ def test_reduce_combiner_simplify(): # Check that the remaining components are the expected ones. for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]): - assert tvm.tir.ir_pass.Equal(lhs, rhs) + assert tvm.ir.structural_equal(lhs, rhs) # Test that components with side effects are not removed side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic, None, 0) diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index c6e6b753a692..278581d0cacd 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -45,7 +45,7 @@ def test_multivariate(): v = [te.var("v%d" % i) for i in range(4)] b = te.var("b") m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v) - assert(tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.Simplify(m[0]), b + 5)) + assert(tvm.tir.analysis.expr_deep_equal(tvm.tir.ir_pass.Simplify(m[0]), b + 5)) assert(m[1].value == 8) m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v) diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 8352d9cf22dd..e57dcef75994 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -28,7 +28,7 @@ def err_msg(): return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected) def equal(x, y): res = self.analyzer.canonical_simplify(x - y) - return tvm.tir.ir_pass.Equal(res, 0) + return tvm.tir.analysis.expr_deep_equal(res, 0) assert equal(res.min_value, expected[0]), err_msg() assert equal(res.max_value, expected[1]), err_msg() diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index c8c3b0bd9a3b..dbfdde3ac883 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -23,7 +23,7 @@ def __init__(self): def verify(self, data, expected): res = self.analyzer.rewrite_simplify(data) - assert tvm.tir.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected) + assert tvm.ir.structural_equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected) def test_vector_simplify(): diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 3e90442d6ee8..5a56cc332ad2 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -182,7 +182,7 @@ def fanout(n, a): assert isinstance(ir, tvm.tir.For) assert ir.loop_var.name == 'i' assert ir.min.value == 0 - assert tvm.tir.ir_pass.Equal(ir.extent, n - 3) + assert tvm.ir.structural_equal(ir.extent, n - 3) #Check loopbody ibody = ir.body assert isinstance(ibody, tvm.tir.AttrStmt) @@ -215,7 +215,7 @@ def fanout(n, a): assert value.a.args[0].value == 0 assert value.b.name == 'a' assert len(value.b.args) == 1 - assert tvm.tir.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var) + assert tvm.ir.structural_equal(value.b.args[0], ir.loop_var + jloop.loop_var) divide= rbody[2] assert isinstance(divide, tvm.tir.Provide) assert len(divide.args) == 1 diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index 28a3ae875fc7..7dceaefd9761 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -100,13 +100,14 @@ def check(factor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[z], dom_map) - assert tvm.tir.ir_pass.Equal(out_dom[z.op.axis[0]].extent, factor) - assert tvm.tir.ir_pass.Equal(out_dom[z.op.axis[0]].min, xo * factor) - assert tvm.tir.ir_pass.Equal(in_dom.items()[0][1][0].extent, factor) + assert tvm.ir.structural_equal(out_dom[z.op.axis[0]].extent, factor) + assert tvm.ir.structural_equal(out_dom[z.op.axis[0]].min, xo * factor) + assert tvm.ir.structural_equal(in_dom.items()[0][1][0].extent, factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[z], out_dom, in_dom, vadd) - assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(vadd.op.body[0])) + assert tvm.ir.structural_equal( + tvm.tir.ir_pass.CanonicalSimplify(body[0]), + tvm.tir.ir_pass.CanonicalSimplify(vadd.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [x, y, z]) @@ -133,13 +134,14 @@ def check(factor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.tir.ir_pass.Equal(out_dom[x].extent, 1) - assert tvm.tir.ir_pass.Equal(out_dom[y].extent, factor) - assert tvm.tir.ir_pass.Equal(out_dom[y].min, yo * factor) + assert tvm.ir.structural_equal(out_dom[x].extent, 1) + assert tvm.ir.structural_equal(out_dom[y].extent, factor) + assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) - assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + assert tvm.ir.structural_equal( + tvm.tir.ir_pass.CanonicalSimplify(body[0]), + tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -157,13 +159,14 @@ def check_rfactor(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.tir.ir_pass.Equal(out_dom[x].extent, 1) - assert tvm.tir.ir_pass.Equal(out_dom[y].extent, factor) - assert tvm.tir.ir_pass.Equal(out_dom[y].min, yo * factor) + assert tvm.ir.structural_equal(out_dom[x].extent, 1) + assert tvm.ir.structural_equal(out_dom[y].extent, factor) + assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) - assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + assert tvm.ir.structural_equal( + tvm.tir.ir_pass.CanonicalSimplify(body[0]), + tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -180,13 +183,14 @@ def check_rfactor_no_reset(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.tir.ir_pass.Equal(out_dom[x].extent, 1) - assert tvm.tir.ir_pass.Equal(out_dom[y].extent, factor) - assert tvm.tir.ir_pass.Equal(out_dom[y].min, yo * factor) + assert tvm.ir.structural_equal(out_dom[x].extent, 1) + assert tvm.ir.structural_equal(out_dom[y].extent, factor) + assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) - assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + assert tvm.ir.structural_equal( + tvm.tir.ir_pass.CanonicalSimplify(body[0]), + tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -204,13 +208,14 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.tir.ir_pass.Equal(out_dom[x].extent, 1) - assert tvm.tir.ir_pass.Equal(out_dom[y].extent, factor) - assert tvm.tir.ir_pass.Equal(out_dom[y].min, yo * factor) + assert tvm.ir.structural_equal(out_dom[x].extent, 1) + assert tvm.ir.structural_equal(out_dom[y].extent, factor) + assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) - assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + assert tvm.ir.structural_equal( + tvm.tir.ir_pass.CanonicalSimplify(body[0]), + tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) diff --git a/tests/python/unittest/test_tir_pass_equal.py b/tests/python/unittest/test_tir_analysis_expr_deep_equal.py similarity index 51% rename from tests/python/unittest/test_tir_pass_equal.py rename to tests/python/unittest/test_tir_analysis_expr_deep_equal.py index 873cb7be447c..86a1ed727805 100644 --- a/tests/python/unittest/test_tir_pass_equal.py +++ b/tests/python/unittest/test_tir_analysis_expr_deep_equal.py @@ -27,40 +27,10 @@ def func1(): def func2(): return te.exp(tvm.tir.truncdiv((x + y + 1) * y, 4)) - assert tvm.tir.ir_pass.Equal(func1(), func1()) - assert tvm.tir.ir_pass.Equal(func2(), func2()) - assert not tvm.tir.ir_pass.Equal(func2(), func1()) - - -def test_equal_compute(): - x = te.var('x') - y = te.var('y') - n = 128 - A = te.placeholder((n, n), name='A') - B = te.placeholder((n, n), name='B') - ii = te.var('i') - jj = te.var('j') - - def func1(): - k = te.reduce_axis((0, n), name='k') - return te.sum(A[ii, k] * B[jj, k], axis=k) - - Ab = tvm.tir.decl_buffer((n,), name='A') - n = te.var("n") - def func2(): - ib = tvm.tir.ir_builder.create() - A = ib.buffer_ptr(Ab) - with ib.for_range(0, n, name="i") as i: - A[i] = A[i] + 1 - with ib.for_range(0, 10, name="j") as j: - A[j] = A[j] + 2 - A[j] = A[j] + 2 - return ib.get() - - assert tvm.tir.ir_pass.Equal(func1(), func1()) - assert tvm.tir.ir_pass.Equal(func2(), func2()) + assert tvm.tir.analysis.expr_deep_equal(func1(), func1()) + assert tvm.tir.analysis.expr_deep_equal(func2(), func2()) + assert not tvm.tir.analysis.expr_deep_equal(func2(), func1()) if __name__ == "__main__": test_equal_expr() - test_equal_compute() diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 9203fb1c7b34..fe23955017a0 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -36,7 +36,7 @@ def test_buffer_access_ptr(): n = te.size_var('n') Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1 , 1]) aptr = Ab.access_ptr("rw") - assert tvm.tir.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m) + assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m) assert aptr.args[0].dtype == Ab.dtype assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("w") @@ -49,16 +49,16 @@ def test_buffer_access_ptr_offset(): Ab = tvm.tir.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw", offset=100) offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.tir.ir_pass.Equal(offset, 100) + assert tvm.ir.structural_equal(offset, 100) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE v = te.size_var('int32') aptr = Ab.access_ptr("rw", offset=100 + 100 + v) offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.tir.ir_pass.Equal(offset, 200 + v) + assert tvm.ir.structural_equal(offset, 200 + v) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern('int32', "test_call", 100 + 100 + v)) offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.tir.ir_pass.Equal(offset, tvm.tir.call_extern('int32', "test_call", 200 + v)) + assert tvm.ir.structural_equal(offset, tvm.tir.call_extern('int32', "test_call", 200 + v)) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE @@ -67,12 +67,12 @@ def test_buffer_access_ptr_extent(): n = te.size_var('n') Ab = tvm.tir.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw") - assert tvm.tir.ir_pass.Equal(aptr.args[3], m * n) + assert tvm.ir.structural_equal(aptr.args[3], m * n) aptr = Ab.access_ptr("rw", offset=100) - assert tvm.tir.ir_pass.Equal(aptr.args[3], m * n - 100) + assert tvm.ir.structural_equal(aptr.args[3], m * n - 100) Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1 , 1]) aptr = Ab.access_ptr("rw", offset=100) - assert tvm.tir.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100) + assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100) def test_buffer_vload(): @@ -81,7 +81,7 @@ def test_buffer_vload(): Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) offset = tvm.tir.ir_pass.Simplify(load.index) - assert tvm.tir.ir_pass.Equal(offset, n * 2 + 103) + assert tvm.ir.structural_equal(offset, n * 2 + 103) def test_buffer_index_merge_mult_mod(): @@ -93,7 +93,7 @@ def test_buffer_index_merge_mult_mod(): A = tvm.tir.decl_buffer((m, n), "float32") A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1)) def assert_simplified_equal(index_simplified, index_direct): - assert tvm.tir.ir_pass.Equal(index_simplified, index_direct),\ + assert tvm.ir.structural_equal(index_simplified, index_direct),\ "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct) idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod diff --git a/tests/python/unittest/test_tir_ops.py b/tests/python/unittest/test_tir_ops.py index 23c594022faf..65d87be96181 100644 --- a/tests/python/unittest/test_tir_ops.py +++ b/tests/python/unittest/test_tir_ops.py @@ -71,7 +71,7 @@ def test_const_fold3(): for tvm_func, py_func in [(tvm.tir.all, lambda a, b: a and b), (tvm.tir.any, lambda a, b: a or b)]: for v1 in [0, 1]: for v2 in [0, 1]: - assert tvm.tir.ir_pass.Equal(tvm_func(tvm.tir.const(v1, 'uint1'), tvm.tir.const(v2, 'uint1')), + assert tvm.ir.structural_equal(tvm_func(tvm.tir.const(v1, 'uint1'), tvm.tir.const(v2, 'uint1')), tvm.tir.const(py_func(v1, v2), 'uint1')) x = te.var("x", 'uint1') @@ -170,13 +170,13 @@ def test_if_then_else(): out = tvm.tir.if_then_else(cond, lhs, rhs) out2 = tvm.tir.if_then_else(not cond, rhs, lhs) out3 = tvm.tir.if_then_else(not cond, lhs, rhs) - assert tvm.tir.ir_pass.Equal(out, out2) == 1 + assert tvm.ir.structural_equal(out, out2) == 1 if cond: - assert tvm.tir.ir_pass.Equal(out, lhs.astype(out_dtype)) == 1 - assert tvm.tir.ir_pass.Equal(out3, rhs.astype(out_dtype)) == 1 + assert tvm.ir.structural_equal(out, lhs.astype(out_dtype)) == 1 + assert tvm.ir.structural_equal(out3, rhs.astype(out_dtype)) == 1 else: - assert tvm.tir.ir_pass.Equal(out, rhs.astype(out_dtype)) == 1 - assert tvm.tir.ir_pass.Equal(out3, lhs.astype(out_dtype)) == 1 + assert tvm.ir.structural_equal(out, rhs.astype(out_dtype)) == 1 + assert tvm.ir.structural_equal(out3, lhs.astype(out_dtype)) == 1 elif cond.dtype == 'bool': out = tvm.tir.if_then_else(cond, lhs, rhs) assert out.dtype == out_dtype diff --git a/tests/python/unittest/test_tir_pass_basic.py b/tests/python/unittest/test_tir_pass_basic.py index f7eaa217683b..228e0c52c435 100644 --- a/tests/python/unittest/test_tir_pass_basic.py +++ b/tests/python/unittest/test_tir_pass_basic.py @@ -22,11 +22,11 @@ def test_simplify(): tmod = tvm.tir.truncmod x = te.var('x') e1 = tvm.tir.ir_pass.Simplify(x + 2 + 1) - assert(tvm.tir.ir_pass.Equal(e1, x + 3)) + assert(tvm.ir.structural_equal(e1, x + 3)) e2 = tvm.tir.ir_pass.Simplify(x * 3 + 5 * x) - assert(tvm.tir.ir_pass.Equal(e2, x * 8)) + assert(tvm.ir.structural_equal(e2, x * 8)) e3 = tvm.tir.ir_pass.Simplify(x - tdiv(x, 3) * 3) - assert(tvm.tir.ir_pass.Equal(e3, tmod(x, 3))) + assert(tvm.ir.structural_equal(e3, tmod(x, 3))) def test_verify_ssa(): diff --git a/tests/python/unittest/test_tir_pass_loop_partition.py b/tests/python/unittest/test_tir_pass_loop_partition.py index 7ec35e618aa3..7e383ddf7810 100644 --- a/tests/python/unittest/test_tir_pass_loop_partition.py +++ b/tests/python/unittest/test_tir_pass_loop_partition.py @@ -444,7 +444,7 @@ def test_simple_rfactor(): stmt2 = tvm.tir.ir_pass.Simplify(stmt2) #make sure loop partition actually did something - assert not tvm.tir.ir_pass.Equal(stmt1.body, stmt2.body) + assert not tvm.ir.structural_equal(stmt1.body, stmt2.body) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index 39a3a199e127..3fcdc65c30ce 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -142,9 +142,34 @@ def test_attrs(): assert not consistent_equal(y, z) +def test_stmt(): + x = te.var('x') + y = te.var('y') + n = 128 + A = te.placeholder((n, n), name='A') + B = te.placeholder((n, n), name='B') + ii = te.var('i') + jj = te.var('j') + + Ab = tvm.tir.decl_buffer((n,), name='A') + n = te.var("n") + def func2(): + ib = tvm.tir.ir_builder.create() + A = ib.buffer_ptr(Ab) + with ib.for_range(0, n, name="i") as i: + A[i] = A[i] + 1 + with ib.for_range(0, 10, name="j") as j: + A[j] = A[j] + 2 + A[j] = A[j] + 2 + return ib.get() + + assert consistent_equal(func2(), func2()) + + if __name__ == "__main__": test_exprs() test_prim_func() test_attrs() test_array() test_env_func() + test_stmt() diff --git a/tests/python/unittest/test_tir_transform_prim_func_pass.py b/tests/python/unittest/test_tir_transform_prim_func_pass.py index 87aecd178909..1695cbc39dec 100644 --- a/tests/python/unittest/test_tir_transform_prim_func_pass.py +++ b/tests/python/unittest/test_tir_transform_prim_func_pass.py @@ -43,7 +43,7 @@ def transform_function(self, func, mod, ctx): mod = tvm.IRModule({"main": func}) mod = TestReplaceFunc(new_func)(mod) - assert tvm.tir.ir_pass.Equal(mod["main"].body, new_func.body) + assert tvm.ir.structural_equal(mod["main"].body, new_func.body) if __name__ == "__main__": diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 4da11d80a483..74be9453ae61 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -114,10 +115,11 @@ inline std::vector GetConstInt64Values( * \return result True if both expressions are equal, else false */ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { - bool result = tvm::tir::Equal(lhs, rhs); + tvm::tir::ExprDeepEqual expr_equal; + bool result = expr_equal(lhs, rhs); if (!result) { PrimExpr zero(0); - result = tvm::tir::Equal(tvm::tir::CanonicalSimplify(lhs-rhs), zero); + result = expr_equal(tvm::tir::CanonicalSimplify(lhs-rhs), zero); } return result; } diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index 4f8deff285a6..5924cdd6ca82 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -83,7 +83,7 @@ def _post_order(op): fail[0] = True return op if gemm_offsets[i] is not None: - if not tvm.tir.ir_pass.Equal(m[0], gemm_offsets[i]): + if not tvm.ir.structural_equal(m[0], gemm_offsets[i]): fail[0] = True return op args.append(m[1]) @@ -775,7 +775,7 @@ def inject_alu_intrin(stmt_in): def _do_fold(stmt): def _equal(x, y): - return tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.Simplify(x - y), 0) + return tvm.ir.structural_equal(tvm.tir.ir_pass.Simplify(x - y), 0) def _flatten_loop(src_coeff, dst_coeff, extents): src_coeff = list(src_coeff) @@ -895,9 +895,9 @@ def _flatten_loop(src_coeff, dst_coeff, extents): lhs_equal = True rhs_equal = True for i, coef in enumerate(dst_coeff): - if not tvm.tir.ir_pass.Equal(coef, src_lhs_coeff[i]): + if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]): lhs_equal = False - if not tvm.tir.ir_pass.Equal(coef, src_rhs_coeff[i]): + if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]): rhs_equal = False # Make sure at least one of the source is identical to the # destination (in-place computation) @@ -916,20 +916,20 @@ def _flatten_loop(src_coeff, dst_coeff, extents): assert len(src_coeff) > 1 assert len(dst_coeff) > 1 assert len(extents) != 0 - assert tvm.tir.ir_pass.Equal( + assert tvm.ir.structural_equal( tvm.tir.ir_pass.Simplify( idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) - assert tvm.tir.ir_pass.Equal( + assert tvm.ir.structural_equal( tvm.tir.ir_pass.Simplify( idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) - assert tvm.tir.ir_pass.Equal(src_coeff[-2], 1) - assert tvm.tir.ir_pass.Equal(dst_coeff[-2], 1) + assert tvm.ir.structural_equal(src_coeff[-2], 1) + assert tvm.ir.structural_equal(dst_coeff[-2], 1) if env.BATCH > 1: assert len(src_coeff) > 2 assert len(dst_coeff) > 2 assert len(extents) > 1 - assert tvm.tir.ir_pass.Equal(src_coeff[-3], env.BLOCK_OUT) - assert tvm.tir.ir_pass.Equal(dst_coeff[-3], env.BLOCK_OUT) + assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT) + assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT) # Apply tensorization of the loop coefficients src_offset = src_coeff[-1]