From 919ae889638555b82de2d124d5f3e08d76bf789b Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Sun, 29 Mar 2020 09:58:58 -0700 Subject: [PATCH] [REFACTOR][IR] alpha_equal to structural_equal (#5161) --- include/tvm/ir/type.h | 4 +- include/tvm/relay/analysis.h | 55 -- python/tvm/ir/type.py | 3 +- python/tvm/relay/__init__.py | 1 - python/tvm/relay/analysis/analysis.py | 72 -- src/ir/module.cc | 19 +- src/relay/analysis/alpha_equal.cc | 628 ------------------ src/relay/analysis/type_solver.cc | 7 +- src/relay/backend/compile_engine.h | 3 +- src/relay/backend/vm/lambda_lift.cc | 4 +- src/relay/op/tensor/transform.cc | 7 +- src/relay/transforms/lazy_gradient_init.cc | 22 +- src/relay/transforms/pattern_util.h | 3 +- src/relay/transforms/to_cps.cc | 2 +- tests/cpp/relay_pass_alpha_equal.cc | 67 -- tests/cpp/relay_pass_type_infer_test.cc | 3 +- tests/cpp/relay_transform_sequential.cc | 3 +- tests/python/frontend/caffe2/test_graph.py | 3 +- tests/python/frontend/mxnet/test_graph.py | 2 +- .../test_analysis_extract_fused_functions.py | 2 +- tests/python/relay/test_annotate_target.py | 2 +- tests/python/relay/test_call_graph.py | 2 +- tests/python/relay/test_ir_bind.py | 4 +- tests/python/relay/test_ir_nodes.py | 3 +- .../relay/test_ir_structural_equal_hash.py | 2 +- tests/python/relay/test_ir_text_printer.py | 6 +- tests/python/relay/test_op_level10.py | 4 +- .../python/relay/test_pass_alter_op_layout.py | 46 +- tests/python/relay/test_pass_annotation.py | 12 +- .../relay/test_pass_canonicalize_cast.py | 2 +- .../test_pass_combine_parallel_conv2d.py | 8 +- .../relay/test_pass_combine_parallel_dense.py | 6 +- .../relay/test_pass_convert_op_layout.py | 22 +- .../relay/test_pass_dead_code_elimination.py | 6 +- .../test_pass_eliminate_common_subexpr.py | 4 +- tests/python/relay/test_pass_eta_expand.py | 6 +- tests/python/relay/test_pass_fold_constant.py | 14 +- .../python/relay/test_pass_fold_scale_axis.py | 22 +- tests/python/relay/test_pass_fuse_ops.py | 30 +- tests/python/relay/test_pass_gradient.py | 4 +- tests/python/relay/test_pass_inline.py | 32 +- tests/python/relay/test_pass_legalize.py | 8 +- tests/python/relay/test_pass_manager.py | 6 +- .../python/relay/test_pass_merge_composite.py | 23 +- tests/python/relay/test_pass_partial_eval.py | 9 +- .../python/relay/test_pass_partition_graph.py | 10 +- tests/python/relay/test_pass_qnn_legalize.py | 6 +- .../test_pass_remove_unused_functions.py | 2 +- .../relay/test_pass_simplify_inference.py | 4 +- .../relay/test_pass_to_a_normal_form.py | 2 +- tests/python/relay/test_type_functor.py | 4 +- tests/python/unittest/test_ir_type.py | 3 +- 52 files changed, 208 insertions(+), 1016 deletions(-) delete mode 100644 src/relay/analysis/alpha_equal.cc delete mode 100644 tests/cpp/relay_pass_alpha_equal.cc diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 6f6c66ad2107..0e65758a2e1c 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -498,7 +498,9 @@ class IncompleteTypeNode : public TypeNode { } bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const { - return equal(kind, other->kind); + return + equal(kind, other->kind) && + equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index fe8fae5ef788..51eae5a9ab7d 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -64,61 +64,6 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod); */ TVM_DLL bool ConstantCheck(const Expr& e); -/*! - * \brief Compare two expressions for structural equivalence. - * - * This comparison operator respects scoping and compares - * expressions without regard to variable choice. - * - * For example: `let x = 1 in x` is equal to `let y = 1 in y`. - * - * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence - * for more details. - * - * \param e1 The left hand expression. - * \param e2 The right hand expression. - * - * \return true if equal, otherwise false - */ -TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); - -/*! - * \brief Compare two types for structural equivalence. - * - * This comparison operator respects scoping and compares - * expressions without regard to variable choice. - * - * For example: `forall s, Tensor[f32, s]` is equal to - * `forall w, Tensor[f32, w]`. - * - * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence - * for more details. - * - * \param t1 The left hand type. - * \param t2 The right hand type. - * - * \return true if equal, otherwise false - */ -TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); - -/*! - * \brief Compare two patterns for structural equivalence. - * - * This comparison operator respects scoping and compares - * patterns without regard to variable choice. - * - * For example: `A(x, _, y)` is equal to `A(z, _, a)`. - * - * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence - * for more details. - * - * \param t1 The left hand pattern. - * \param t2 The right hand pattern. - * - * \return true if equal, otherwise false - */ -TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2); - /*! * \brief Check that each Var is only bound once. * diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index a61c6e4d58d1..e980011319c2 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -16,6 +16,7 @@ # under the License. """Unified type system in the project.""" from enum import IntEnum +import tvm import tvm._ffi from .base import Node @@ -26,7 +27,7 @@ class Type(Node): """The base class of all types.""" def __eq__(self, other): """Compare two types for structural equivalence.""" - return bool(_ffi_api.type_alpha_equal(self, other)) + return bool(tvm.ir.structural_equal(self, other)) def __ne__(self, other): return not self.__eq__(other) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 95545c8cd559..1517cf9484cf 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -33,7 +33,6 @@ from . import transform from . import analysis -from .analysis import alpha_equal from .build_module import build, create_executor, optimize from .transform import build_config from . import debug diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 722f3b0630de..b09a40bb9957 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -220,78 +220,6 @@ def all_type_vars(expr, mod=None): return _ffi_api.all_type_vars(expr, use_mod) -def alpha_equal(lhs, rhs): - """Compare two Relay expr for structural equivalence (alpha equivalence). - - Parameters - ---------- - lhs : tvm.relay.Expr - One of the input Expression. - - rhs : tvm.relay.Expr - One of the input Expression. - - Returns - ------- - result : bool - True iff lhs is alpha equal to rhs. - """ - return bool(_ffi_api._alpha_equal(lhs, rhs)) - - -def assert_alpha_equal(lhs, rhs): - """Assert that two Relay expr is structurally equivalent. (alpha equivalence). - - Parameters - ---------- - lhs : tvm.relay.Expr - One of the input Expression. - - rhs : tvm.relay.Expr - One of the input Expression. - """ - _ffi_api._assert_alpha_equal(lhs, rhs) - - -def graph_equal(lhs, rhs): - """Compare two Relay expr for data-flow equivalence. - The difference between this and alpha-equality is that - variables are not expected to match between lhs and rhs; - they are treated as sources and are mapped between each other. - - Parameters - ---------- - lhs : tvm.relay.Expr - One of the input Expression. - - rhs : tvm.relay.Expr - One of the input Expression. - - Returns - ------- - result : bool - True iff lhs is data-flow equivalent to rhs. - """ - return bool(_ffi_api._graph_equal(lhs, rhs)) - - -def assert_graph_equal(lhs, rhs): - """Compare two Relay expr for data-flow equivalence. - The difference between this and alpha-equality is that - variables are not expected to match between lhs and rhs; - they are treated as sources and are mapped between each other. - - Parameters - ---------- - lhs : tvm.relay.Expr - One of the input Expression. - - rhs : tvm.relay.Expr - One of the input Expression. - """ - _ffi_api._assert_graph_equal(lhs, rhs) - - def collect_device_info(expr): """Collect the device allocation map for the given expression. The device ids are propagated from the `device_copy` operators. diff --git a/src/ir/module.cc b/src/ir/module.cc index de093144b38a..c7474dee95f7 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -23,6 +23,7 @@ */ #include #include +#include // NOTE: reverse dependency on relay. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -194,12 +195,11 @@ relay::Function RunTypeCheck(const IRModule& mod, << AsText(func, false) << std::endl; } - func = - relay::Function(concat(func->params, fv), - func->body, - func->ret_type, - concat(func->type_params, ftv), - func->attrs); + func = relay::Function(concat(func->params, fv), + func->body, + func->ret_type, + concat(func->type_params, ftv), + func->attrs); // Type check the item before we add it to the module. relay::Function checked_func = InferType(func, mod, var); return checked_func; @@ -222,7 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var, CHECK(update) << "Already have definition for " << var->name_hint; auto old_type = functions[var]->checked_type(); - CHECK(relay::AlphaEqual(type, old_type)) + CHECK(tvm::StructuralEqual()(type, old_type)) << "Module#update changes type, not possible in this mode."; } var->checked_type_ = type; @@ -353,9 +353,8 @@ IRModule IRModule::FromExpr( if (auto* func_node = expr.as()) { func = GetRef(func_node); } else { - func = relay::Function( - relay::FreeVars(expr), expr, Type(), - relay::FreeTypeVars(expr, mod), {}); + func = relay::Function(relay::FreeVars(expr), expr, Type(), + relay::FreeTypeVars(expr, mod), {}); } auto main_gv = GlobalVar("main"); mod->Add(main_gv, func); diff --git a/src/relay/analysis/alpha_equal.cc b/src/relay/analysis/alpha_equal.cc deleted file mode 100644 index 28c768138be3..000000000000 --- a/src/relay/analysis/alpha_equal.cc +++ /dev/null @@ -1,628 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/analysis/alpha_equal.cc - * \brief Alpha equality check by deep comparing two nodes. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include "../../ir/attr_functor.h" - - -namespace tvm { -namespace relay { - -// Alpha Equal handler for Relay. -class AlphaEqualHandler: - public AttrsEqualHandler, - public TypeFunctor, - public ExprFunctor, - public PatternFunctor { - public: - explicit AlphaEqualHandler(bool map_free_var, bool assert_mode) - : map_free_var_(map_free_var), assert_mode_(assert_mode) { } - - /*! - * Check equality of two nodes. - * \param lhs The left hand operand. - * \param rhs The right hand operand. - * \return The comparison result. - */ - bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) { - return VisitAttr(lhs, rhs); - } - /*! - * Check equality of two attributes. - * \param lhs The left hand operand. - * \param rhs The right hand operand. - * \return The comparison result. - */ - bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) { - auto compute = [&]() { - return VisitAttr(lhs, rhs); - }; - return Compare(compute(), lhs, rhs); - } - /*! - * Check equality of two types. - * \param lhs The left hand operand. - * \param rhs The right hand operand. - * \return the comparison result. - */ - bool TypeEqual(const Type& lhs, const Type& rhs) { - auto compute = [&]() { - if (lhs.same_as(rhs)) return true; - if (!lhs.defined() || !rhs.defined()) return false; - return this->VisitType(lhs, rhs); - }; - return Compare(compute(), lhs, rhs); - } - - bool Compare(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { - if (assert_mode_) { - CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true); - } - return result; - } - /*! - * Check equality of two expressions. - * - * \note We run graph structural equality checking when comparing two Exprs. - * This means that AlphaEqualHandler can only be used once for each pair. - * The equality checker checks data-flow equvalence of the Expr DAG. - * This function also runs faster as it memomizes equal_map. - * - * \param lhs The left hand operand. - * \param rhs The right hand operand. - * \return The comparison result. - */ - bool ExprEqual(const Expr& lhs, const Expr& rhs) { - auto compute = [&]() { - if (lhs.same_as(rhs)) return true; - if (!lhs.defined() || !rhs.defined()) return false; - auto it = equal_map_.find(lhs); - if (it != equal_map_.end()) { - return it->second.same_as(rhs); - } - if (this->VisitExpr(lhs, rhs)) { - equal_map_[lhs] = rhs; - return true; - } else { - return false; - } - }; - return Compare(compute(), lhs, rhs); - } - - protected: - // So that the new definition of equality in relay can be handled directly. - // Specifically, if a DictAttr contains a value defined by a relay AST. - // We want to able to recursively check the equality in the attr defined by the relay AST. - bool VisitAttr(const ObjectRef& lhs, const ObjectRef& rhs) final { - if (lhs.same_as(rhs)) return true; - if (!lhs.defined() && rhs.defined()) return false; - if (!rhs.defined() && lhs.defined()) return false; - if (lhs->IsInstance() || rhs->IsInstance()) { - if (!rhs->IsInstance() || !lhs->IsInstance()) return false; - return TypeEqual(Downcast(lhs), Downcast(rhs)); - } - if (lhs->IsInstance() || rhs->IsInstance()) { - if (!rhs->IsInstance() || !lhs->IsInstance()) return false; - return ExprEqual(Downcast(lhs), Downcast(rhs)); - } - if (const auto lhsm = lhs.as()) { - auto rhsm = rhs.as(); - if (!rhsm) return false; - if (lhsm->functions.size() != rhsm->functions.size()) return false; - for (const auto& p : lhsm->functions) { - if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false; - } - if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; - for (const auto& p : lhsm->type_definitions) { - if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) || - !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) { - return false; - } - } - return true; - } - // Fall back to the object equal case. - return AttrsEqualHandler::VisitAttr(lhs, rhs); - } - /*! - * \brief Check if data type equals each other. - * \param lhs The left hand operand. - * \param rhs The right hand operand. - * \return The compare result. - */ - bool DataTypeEqual(const DataType& lhs, const DataType& rhs) { - return lhs == rhs; - } - /*! - * \brief Check Equality of leaf node of the graph. - * if map_free_var_ is set to true, try to map via equal node. - * \param lhs The left hand operand. - * \param rhs The right hand operand. - * \return The compare result. - */ - bool LeafObjectEqual(const ObjectRef& lhs, const ObjectRef& rhs) { - if (lhs.same_as(rhs)) return true; - auto it = equal_map_.find(lhs); - if (it != equal_map_.end()) { - return it->second.same_as(rhs); - } else { - if (map_free_var_) { - if (lhs->type_index() != rhs->type_index()) return false; - equal_map_[lhs] = rhs; - return true; - } else { - return false; - } - } - } - using AttrsEqualHandler::VisitAttr_; - bool VisitAttr_(const tvm::tir::VarNode* lhs, const ObjectRef& other) final { - return LeafObjectEqual(GetRef(lhs), other); - } - - // Type equality - bool VisitType_(const TensorTypeNode* lhs, const Type& other) final { - if (const TensorTypeNode* rhs = other.as()) { - return (lhs->dtype == rhs->dtype && - AttrEqual(lhs->shape, rhs->shape)); - } else { - return false; - } - } - - bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final { - return LeafObjectEqual(GetRef(lhs), other); - } - - bool VisitType_(const PrimTypeNode* lhs, const Type& other) final { - if (const PrimTypeNode* rhs = other.as()) { - return lhs->dtype == rhs->dtype; - } else { - return false; - } - } - - bool VisitType_(const PointerTypeNode* lhs, const Type& other) final { - if (const PointerTypeNode* rhs = other.as()) { - return TypeEqual(lhs->element_type, rhs->element_type); - } else { - return false; - } - } - - bool VisitType_(const TypeVarNode* lhs, const Type& other) final { - if (const TypeVarNode* rhs = other.as()) { - if (lhs->kind != rhs->kind) return false; - return LeafObjectEqual(GetRef(lhs), other); - } else { - return false; - } - } - - bool VisitType_(const FuncTypeNode* lhs, const Type& other) final { - if (const FuncTypeNode* rhs = other.as()) { - if (lhs->arg_types.size() != rhs->arg_types.size()) return false; - if (lhs->type_params.size() != rhs->type_params.size()) return false; - if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false; - for (size_t i = 0; i < lhs->type_params.size(); ++i) { - if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) { - return false; - } - equal_map_[lhs->type_params[i]] = rhs->type_params[i]; - } - for (size_t i = 0; i < lhs->arg_types.size(); i++) { - if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false; - } - if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false; - for (size_t i = 0; i < lhs->type_constraints.size(); i++) { - if (!TypeEqual(lhs->type_constraints[i], - rhs->type_constraints[i])) { - return false; - } - } - return true; - } else { - return false; - } - } - - bool VisitType_(const TypeRelationNode* lhs, const Type& other) final { - if (const TypeRelationNode* rhs = other.as()) { - if (lhs->func->name != rhs->func->name) return false; - if (lhs->num_inputs != rhs->num_inputs) return false; - if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false; - if (lhs->args.size() != rhs->args.size()) return false; - for (size_t i = 0; i < lhs->args.size(); ++i) { - if (!TypeEqual(lhs->args[i], rhs->args[i])) return false; - } - return true; - } else { - return false; - } - } - - bool VisitType_(const TupleTypeNode* lhs, const Type& other) final { - if (const TupleTypeNode* rhs = other.as()) { - if (lhs->fields.size() != rhs->fields.size()) return false; - for (size_t i = 0; i < lhs->fields.size(); ++i) { - if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false; - } - return true; - } else { - return false; - } - } - - bool VisitType_(const RelayRefTypeNode* lhs, const Type& other) final { - if (const RelayRefTypeNode* rhs = other.as()) { - return TypeEqual(lhs->value, rhs->value); - } - return false; - } - - bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final { - return LeafObjectEqual(GetRef(lhs), other); - } - - bool VisitType_(const TypeCallNode* lhs, const Type& other) final { - const TypeCallNode* rhs = other.as(); - if (rhs == nullptr - || lhs->args.size() != rhs->args.size() - || !TypeEqual(lhs->func, rhs->func)) { - return false; - } - - for (size_t i = 0; i < lhs->args.size(); ++i) { - if (!TypeEqual(lhs->args[i], rhs->args[i])) { - return false; - } - } - return true; - } - - bool VisitType_(const TypeDataNode* lhs, const Type& other) final { - const TypeDataNode* rhs = other.as(); - if (rhs == nullptr - || lhs->type_vars.size() != rhs->type_vars.size() - || !TypeEqual(lhs->header, rhs->header)) { - return false; - } - for (size_t i = 0; i < lhs->type_vars.size(); ++i) { - if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) { - return false; - } - } - for (size_t i = 0; i < lhs->constructors.size(); ++i) { - if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) { - return false; - } - } - return true; - } - - // Expr equal checking. - bool NDArrayEqual(const runtime::NDArray& lhs, - const runtime::NDArray& rhs) { - if (lhs.defined() != rhs.defined()) { - return false; - } else if (lhs.same_as(rhs)) { - return true; - } else { - auto ldt = lhs->dtype; - auto rdt = rhs->dtype; - CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - size_t data_size = runtime::GetDataSize(*lhs.operator->()); - return std::memcmp(lhs->data, rhs->data, data_size) == 0; - } else { - return false; - } - } - } - // merge declaration of two variables together. - bool MergeVarDecl(const Var& lhs, const Var& rhs) { - if (lhs.same_as(rhs)) return true; - if (!lhs.defined() || !rhs.defined()) return false; - if (!TypeEqual(lhs->type_annotation, - rhs->type_annotation)) return false; - CHECK(!equal_map_.count(lhs)) - << "Duplicated declaration of variable " << lhs; - equal_map_[lhs] = rhs; - return true; - } - - bool VisitExpr_(const VarNode* lhs, const Expr& other) final { - // This function will only be triggered if we are matching free variables. - if (const VarNode* rhs = other.as()) { - if (lhs->name_hint() != rhs->name_hint()) return false; - if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false; - return LeafObjectEqual(GetRef(lhs), other); - } else { - return false; - } - } - - bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final { - if (const GlobalVarNode* rhs = other.as()) { - // use name equality for global var for now. - return lhs->name_hint == rhs->name_hint; - } - return false; - } - - bool VisitExpr_(const TupleNode* lhs, const Expr& other) final { - if (const TupleNode* rhs = other.as()) { - if (lhs->fields.size() != rhs->fields.size()) return false; - for (size_t i = 0; i < lhs->fields.size(); ++i) { - if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false; - } - return true; - } else { - return false; - } - } - - bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final { - if (const FunctionNode* rhs = other.as()) { - if (lhs->params.size() != rhs->params.size()) return false; - if (lhs->type_params.size() != rhs->type_params.size()) return false; - // map type parameter to be the same - for (size_t i = 0; i < lhs->type_params.size(); ++i) { - if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false; - equal_map_[lhs->type_params[i]] = rhs->type_params[i]; - } - // check parameter type annotations - for (size_t i = 0; i < lhs->params.size(); ++i) { - if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false; - } - // check return types. - if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false; - if (!AttrEqual(lhs->attrs, rhs->attrs)) return false; - return ExprEqual(lhs->body, rhs->body); - } else { - return false; - } - } - - bool VisitExpr_(const CallNode* lhs, const Expr& other) final { - if (const CallNode* rhs = other.as()) { - if (!ExprEqual(lhs->op, rhs->op)) return false; - if (lhs->args.size() != rhs->args.size()) return false; - // skip type_args check for primitive ops. - bool is_primitive = IsPrimitiveOp(lhs->op); - if (!is_primitive) { - if (lhs->type_args.size() != rhs->type_args.size()) { - return false; - } - } - for (size_t i = 0; i < lhs->args.size(); ++i) { - if (!ExprEqual(lhs->args[i], rhs->args[i])) { - return false; - } - } - - if (!is_primitive) { - for (size_t i = 0; i < lhs->type_args.size(); ++i) { - if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false; - } - } - return AttrEqual(lhs->attrs, rhs->attrs); - } else { - return false; - } - } - - bool VisitExpr_(const LetNode* lhs, const Expr& other) final { - if (const LetNode* rhs = other.as()) { - if (!MergeVarDecl(lhs->var, rhs->var)) return false; - if (!ExprEqual(lhs->value, rhs->value)) return false; - return ExprEqual(lhs->body, rhs->body); - } else { - return false; - } - } - - bool VisitExpr_(const IfNode* lhs, const Expr& other) final { - if (const IfNode* rhs = other.as()) { - return ExprEqual(lhs->cond, rhs->cond) && - ExprEqual(lhs->true_branch, rhs->true_branch) && - ExprEqual(lhs->false_branch, rhs->false_branch); - } else { - return false; - } - } - - bool VisitExpr_(const OpNode* lhs, const Expr& other) final { - return lhs == other.get(); - } - - bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final { - if (const ConstantNode* rhs = other.as()) { - return NDArrayEqual(lhs->data, rhs->data); - } else { - return false; - } - } - - bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final { - if (const TupleGetItemNode* rhs = other.as()) { - return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index; - } else { - return false; - } - } - - bool VisitExpr_(const RefCreateNode* lhs, const Expr& other) final { - if (const RefCreateNode* rhs = other.as()) { - return ExprEqual(lhs->value, rhs->value); - } else { - return false; - } - } - - bool VisitExpr_(const RefReadNode* lhs, const Expr& other) final { - if (const RefReadNode* rhs = other.as()) { - return ExprEqual(lhs->ref, rhs->ref); - } else { - return false; - } - } - - bool VisitExpr_(const RefWriteNode* lhs, const Expr& other) final { - if (const RefWriteNode* rhs = other.as()) { - return ExprEqual(lhs->ref, rhs->ref) && ExprEqual(lhs->value, rhs->value); - } else { - return false; - } - } - - bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final { - if (const ConstructorNode* rhs = other.as()) { - return lhs->name_hint == rhs->name_hint; - } - return false; - } - - bool ClauseEqual(const Clause& lhs, const Clause& rhs) { - return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs); - } - - bool PatternEqual(const Pattern& lhs, const Pattern& rhs) { - return Compare(VisitPattern(lhs, rhs), lhs, rhs); - } - - bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final { - return other.as(); - } - - bool VisitPattern_(const PatternVarNode* lhs, const Pattern& other) final { - if (const auto* rhs = other.as()) { - return MergeVarDecl(lhs->var, rhs->var); - } - return false; - } - - bool VisitPattern_(const PatternConstructorNode* lhs, const Pattern& other) final { - const auto* rhs = other.as(); - if (rhs == nullptr - || !ExprEqual(lhs->constructor, rhs->constructor) - || lhs->patterns.size() != rhs->patterns.size()) { - return false; - } - - for (size_t i = 0; i < lhs->patterns.size(); i++) { - if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) { - return false; - } - } - return true; - } - - bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final { - const auto* rhs = other.as(); - if (rhs == nullptr - || lhs->patterns.size() != rhs->patterns.size()) { - return false; - } - - for (size_t i = 0; i < lhs->patterns.size(); i++) { - if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) { - return false; - } - } - return true; - } - - bool VisitExpr_(const MatchNode* lhs, const Expr& other) final { - const MatchNode* rhs = other.as(); - - if (rhs == nullptr - || !ExprEqual(lhs->data, rhs->data) - || lhs->clauses.size() != rhs->clauses.size() - || lhs->complete != rhs->complete) { - return false; - } - - for (size_t i = 0; i < lhs->clauses.size(); ++i) { - if (!ClauseEqual(lhs->clauses[i], rhs->clauses[i])) { - return false; - } - } - return true; - } - - private: - // whether to map open terms. - bool map_free_var_; - // if in assert mode, must return true, and will throw error otherwise. - bool assert_mode_; - // renaming of NodeRef to indicate two nodes equals to each other - std::unordered_map equal_map_; -}; - -bool AlphaEqual(const Type& lhs, const Type& rhs) { - return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs); -} - -bool AlphaEqual(const Expr& lhs, const Expr& rhs) { - return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs); -} - -TVM_REGISTER_GLOBAL("relay.analysis._alpha_equal") -.set_body_typed([](ObjectRef a, ObjectRef b) { - return AlphaEqualHandler(false, false).Equal(a, b); -}); - -TVM_REGISTER_GLOBAL("ir.type_alpha_equal") -.set_body_typed([](Type a, Type b) { - return AlphaEqual(a, b); -}); - -TVM_REGISTER_GLOBAL("relay.analysis._assert_alpha_equal") -.set_body_typed([](ObjectRef a, ObjectRef b) { - bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); - CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; -}); - -TVM_REGISTER_GLOBAL("relay.analysis._graph_equal") -.set_body_typed([](ObjectRef a, ObjectRef b) { - return AlphaEqualHandler(true, false).Equal(a, b); -}); - -TVM_REGISTER_GLOBAL("relay.analysis._assert_graph_equal") -.set_body_typed([](ObjectRef a, ObjectRef b) { - bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); - CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; -}); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index c39df9d50f58..650403ca5267 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -21,6 +21,7 @@ * \file type_solver.cc * \brief Type solver implementations. */ +#include #include #include #include @@ -151,11 +152,11 @@ class TypeSolver::Unifier : public TypeFunctor { return rc.Check(t); } - // default: unify only if alpha-equal + // default: unify only if structural-equal Type VisitTypeDefault_(const Object* op, const Type& tn) final { ObjectRef nr = GetRef(op); Type t1 = GetRef(nr.as()); - if (!AlphaEqual(t1, tn)) { + if (!tvm::StructuralEqual()(t1, tn)) { return Type(nullptr); } return t1; @@ -216,7 +217,7 @@ class TypeSolver::Unifier : public TypeFunctor { auto tt1 = GetRef(op); auto tt2 = GetRef(tt_node); - if (AlphaEqual(tt1, tt2)) { + if (tvm::StructuralEqual()(tt1, tt2)) { return std::move(tt1); } diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 098211e7ea86..eec2bd344f15 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -25,6 +25,7 @@ #ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ +#include #include #include #include @@ -268,7 +269,7 @@ inline bool CCacheKeyNode::Equal( const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && - AlphaEqual(this->source_func, other->source_func); + tvm::StructuralEqual()(this->source_func, other->source_func); } } // namespace relay diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 7e7622ca96cb..398760ffa789 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -22,6 +22,7 @@ * \brief Lift all local functions into global functions. */ +#include #include #include #include @@ -161,7 +162,8 @@ class LambdaLifter : public ExprMutator { if (module_->ContainGlobalVar(name)) { const auto existing_func = module_->Lookup(name); - CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision"; + CHECK(tvm::StructuralEqual()(lifted_func, existing_func)) + << "lifted function hash collision"; // If an identical function already exists, use its global var. global = module_->GetGlobalVar(name); } else { diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3d03b4af8720..87b4602095a2 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2142,7 +2142,12 @@ Expr MakeSplit(Expr data, TVM_REGISTER_GLOBAL("relay.op._make.split") .set_body([](const TVMArgs& args, TVMRetValue* rv) { if (args.type_codes[1] == kDLInt) { - *rv = MakeSplit(args[0], tir::make_const(DataType::Int(64), int64_t(args[1])), args[2]); + // Note: we change it from Int(64) to Int(32) for now as + // combine_parallel_dense will transform the graph with Int(32). + // More invetigation is needs to check which one we should use. + *rv = MakeSplit(args[0], + tir::make_const(DataType::Int(32), static_cast(args[1])), + args[2]); } else { *rv = MakeSplit(args[0], args[1], args[2]); } diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index ba6ca05663bb..e6248f11a00e 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -59,6 +59,7 @@ * Thus, it is necessary to wrap this outer function so that the input/output types remain the same */ +#include #include #include #include @@ -93,7 +94,7 @@ class InputVisitor: public ExprFunctor { Expr WrapExpr(const Expr expr, const Type& type) { if (type.as()) { return Call(module_->GetConstructor("GradCell", "Raw"), - {expr}, Attrs(), {type}); + {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { @@ -185,7 +186,7 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { Expr VisitExpr_(const ConstantNode* op) final { return Call(module_->GetConstructor("GradCell", "Raw"), - {GetRef(op)}, Attrs(), {op->checked_type()}); + {GetRef(op)}, Attrs(), {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { @@ -207,26 +208,25 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { // call appropriate GradCell constructor std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero"; return Call(module_->GetConstructor("GradCell", constructor_name), - {func}, Attrs(), {call_node->checked_type()}); + {func}, Attrs(), {call_node->checked_type()}); } if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) { // ones_like and zeros_like need TensorType input Expr result = CallPrimitiveOp(call_node); // fn() -> T, function returns result of operation - Expr func = Function({}, result, - {call_node->checked_type()}, Array()); + Expr func = Function({}, result, {call_node->checked_type()}, Array()); // call appropriate GradCell constructor std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero"; return Call(module_->GetConstructor("GradCell", "One"), - {func}, Attrs(), {call_node->checked_type()}); + {func}, Attrs(), {call_node->checked_type()}); } // handle all other ops Expr result = CallPrimitiveOp(call_node); // wrap result with Raw constructor return Call(module_->GetConstructor("GradCell", "Raw"), {result}, - Attrs(), {call_node->checked_type()}); + Attrs(), {call_node->checked_type()}); } // not an op return ExprMutator::VisitExpr_(call_node); @@ -253,10 +253,11 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) { // can only use overloaded functions if 2 arguments of same type if (call_node->args.size() != 2 || - !AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { + !tvm::StructuralEqual()(call_node->args[0]->checked_type(), + call_node->args[1]->checked_type())) { Expr result = CallPrimitiveOp(call_node); return Call(module_->GetConstructor("GradCell", "Raw"), {result}, - Attrs(), {call_node->checked_type()}); + Attrs(), {call_node->checked_type()}); } tvm::Array args; @@ -266,8 +267,7 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { Var("rhs", paramType)}; // use primitive op in this case Expr callOp = Call(call_node->op, {params[0], params[1]}); - Expr func = Function(params, callOp, paramType, - Array()); + Expr func = Function(params, callOp, paramType, Array()); // pass "fallback" function and tensors as arguments args.push_back(func); diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index e86fcdcc23aa..8ce42a2023d8 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -27,6 +27,7 @@ #define TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_ #include +#include #include #include #include @@ -300,7 +301,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) { return false; } - return AlphaEqual(a, b); + return tvm::StructuralEqual()(a, b); } inline Expr GetField(Expr t, size_t i) { diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 1039a1b6272d..e6c83928b098 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -353,7 +353,7 @@ Function UnCPS(const Function& f) { auto answer_type = new_type_params.back(); new_type_params.pop_back(); // TODO(@M.K.): make alphaequal work on free term - // CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type))); + // CHECK(tvm::StructuralEqual()(cont_type, Arrow(new_ret_type, answer_type))); auto x = Var("x", new_ret_type); auto cont = Function({x}, x, new_ret_type, {}, {}); tvm::Array args; diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc deleted file mode 100644 index 0207fca00cf7..000000000000 --- a/tests/cpp/relay_pass_alpha_equal.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include - -using namespace tvm; - -class TestAlphaEquals { - runtime::PackedFunc *_packed_func; - public: - TestAlphaEquals(const char* func_name) { - _packed_func = new runtime::PackedFunc(); - TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); - } - - void UpdatePackedFunc(const char* func_name) { - TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); - } - - bool operator()(ObjectRef input_1, ObjectRef input_2) { - TVMRetValue rv; - std::vector values(2); - std::vector codes(2); - runtime::TVMArgsSetter setter(values.data(), codes.data()); - setter(0, input_1); - setter(1, input_2); - _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); - return bool(rv); - }; - -}; - -TEST(Relay, AlphaTestEmptyTypeNodes) { - auto x = TypeVar("x", kTypeData); - auto y = TypeVar(); - EXPECT_FALSE(relay::AlphaEqual(x, y)); - - TestAlphaEquals test_equals("relay._make._alpha_equal"); - EXPECT_FALSE(test_equals(x, y)); -} - -int main(int argc, char ** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index f951a8f386a6..3c416918e441 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -38,7 +39,7 @@ TEST(Relay, SelfReference) { auto type_fx = mod->Lookup("main"); auto expected = relay::FuncType(tvm::Array{ tensor_type }, tensor_type, {}, {}); - CHECK(relay::AlphaEqual(type_fx->checked_type(), expected)); + CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected)); } int main(int argc, char ** argv) { diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index 756468c9b110..d974f023d74b 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -102,7 +103,7 @@ TEST(Relay, Sequential) { auto mod1 = IRModule::FromExpr(expected_func); mod1 = relay::transform::InferType()(mod1); auto expected = mod1->Lookup("main"); - CHECK(relay::AlphaEqual(f, expected)); + CHECK(tvm::StructuralEqual()(f, expected)); } int main(int argc, char** argv) { diff --git a/tests/python/frontend/caffe2/test_graph.py b/tests/python/frontend/caffe2/test_graph.py index 35914ec1f9bf..d64b133bfd5e 100644 --- a/tests/python/frontend/caffe2/test_graph.py +++ b/tests/python/frontend/caffe2/test_graph.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Test graph equality of caffe2 models.""" +import tvm from tvm import relay from tvm.relay import transform from model_zoo import c2_squeezenet, relay_squeezenet @@ -23,7 +24,7 @@ def compare_graph(lhs_mod, rhs_mod): lhs_mod = transform.InferType()(lhs_mod) rhs_mod = transform.InferType()(rhs_mod) - assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"]) + assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"]) def test_squeeze_net(): diff --git a/tests/python/frontend/mxnet/test_graph.py b/tests/python/frontend/mxnet/test_graph.py index 0008799caebb..b7c01a5722e7 100644 --- a/tests/python/frontend/mxnet/test_graph.py +++ b/tests/python/frontend/mxnet/test_graph.py @@ -25,7 +25,7 @@ def compare_graph(lhs_mod, rhs_mod): lhs_mod = transform.InferType()(lhs_mod) rhs_mod = transform.InferType()(rhs_mod) - assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"]) + assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"]) def test_mlp(): shape = {"data": (1, 1, 28, 28)} diff --git a/tests/python/relay/test_analysis_extract_fused_functions.py b/tests/python/relay/test_analysis_extract_fused_functions.py index 1a70ef174233..dab481ccd290 100644 --- a/tests/python/relay/test_analysis_extract_fused_functions.py +++ b/tests/python/relay/test_analysis_extract_fused_functions.py @@ -77,7 +77,7 @@ def test_extract_identity(): mod["main"] = mod["main"].with_attr( "Primitive", tvm.tir.IntImm("int32", 1)) - relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"]) + tvm.ir.structural_equal(list(items.values())[0], mod["main"]) def test_extract_conv_net(): diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index f4e602a3973b..12a15dcb2c3a 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -136,7 +136,7 @@ def test_annotate(): mod = annotated(dtype, ishape, w1shape) mod = transform.AnnotateTarget("dnnl")(mod) ref_mod = expected(dtype, ishape, w1shape) - assert relay.analysis.alpha_equal(mod, ref_mod) + assert tvm.ir.structural_equal(mod, ref_mod) def test_run(): if not tvm.get_global_func("relay.ext.dnnl", True): diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py index 849f01546788..0af55d282b8f 100644 --- a/tests/python/relay/test_call_graph.py +++ b/tests/python/relay/test_call_graph.py @@ -27,7 +27,7 @@ def test_callgraph_construct(): mod["g1"] = relay.Function([x, y], x + y) call_graph = relay.analysis.CallGraph(mod) assert "g1" in str(call_graph) - assert relay.alpha_equal(mod, call_graph.module) + assert tvm.ir.structural_equal(mod, call_graph.module) def test_print_element(): diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py index 45474b6cc426..8ba4644d0436 100644 --- a/tests/python/relay/test_ir_bind.py +++ b/tests/python/relay/test_ir_bind.py @@ -29,11 +29,11 @@ def test_bind_params(): fexpected =relay.Function( [y], relay.add(relay.const(1, "float32"), y)) - assert relay.analysis.alpha_equal(fbinded, fexpected) + assert tvm.ir.structural_equal(fbinded, fexpected) zbinded = relay.bind(z, {y: x}) zexpected = relay.add(x, x) - assert relay.analysis.alpha_equal(zbinded, zexpected) + assert tvm.ir.structural_equal(zbinded, zexpected) if __name__ == "__main__": diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 968a3bbd779f..6d4a685ecebb 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -21,13 +21,12 @@ from tvm import relay from tvm.tir.expr import * from tvm.relay import op -from tvm.relay.analysis import graph_equal import numpy as np def check_json_roundtrip(node): json_str = tvm.ir.save_json(node) back = tvm.ir.load_json(json_str) - assert graph_equal(back, node) + assert tvm.ir.structural_equal(back, node, map_free_vars=True) # Span diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py index cf626d702f4c..5295e17d2e8b 100644 --- a/tests/python/relay/test_ir_structural_equal_hash.py +++ b/tests/python/relay/test_ir_structural_equal_hash.py @@ -107,7 +107,7 @@ def test_func_type_sequal(): ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([tr1])) - translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, + translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2, tvm.runtime.convert([tp2, tp4]), tvm.runtime.convert([tr2])) assert ft == translate_vars diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 49518a8d198c..61dbca33ca7a 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -20,7 +20,7 @@ import tvm.relay.testing import numpy as np from tvm.relay import Expr -from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equal, free_vars +from tvm.relay.analysis import free_vars do_print = [False] @@ -32,9 +32,9 @@ def astext(p, unify_free_vars=False): return txt x = relay.fromtext(txt) if unify_free_vars: - assert_graph_equal(x, p) + tvm.ir.assert_structural_equal(x, p, map_free_vars=True) else: - assert_alpha_equal(x, p) + tvm.ir.assert_structural_equal(x, p) return txt def show(text): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 953760c66b0a..30e25067fb01 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -99,7 +99,7 @@ def test_checkpoint_alpha_equal(): """ ) - relay.analysis.assert_alpha_equal(df, df_parsed) + tvm.ir.assert_structural_equal(df, df_parsed) def test_checkpoint_alpha_equal_tuple(): xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)] @@ -146,7 +146,7 @@ def test_checkpoint_alpha_equal_tuple(): """ ) - relay.analysis.assert_alpha_equal(df, df_parsed) + tvm.ir.assert_structural_equal(df, df_parsed) def test_collapse_sum_like(): shape = (3, 4, 5, 6) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index eabe7584f013..a30492f11634 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -66,7 +66,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_return_none(): @@ -88,7 +88,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(before(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) assert(called[0]) @@ -151,7 +151,7 @@ def expected(): transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_dual_path(): @@ -214,7 +214,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_resnet(): """Test alternating the layout of a residual block @@ -271,7 +271,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_broadcast_op(): @@ -318,7 +318,7 @@ def expected(): transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_broadcast_scalar_op(): @@ -381,7 +381,7 @@ def expected(): transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_scalar(): @@ -424,7 +424,7 @@ def expected(): transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_concatenate(): @@ -478,7 +478,7 @@ def expected_nchw(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nchw(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) # NHWC layout transformation. def before_nhwc(): @@ -524,7 +524,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_nchw_upsamping_op(): @@ -561,7 +561,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_strided_slice(): @@ -597,7 +597,7 @@ def expected(): transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_depthwise_conv2d(): """Test depthwise_conv2d operator""" @@ -632,7 +632,7 @@ def expected(): transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert(analysis.alpha_equal(a, b)) + assert(tvm.ir.structural_equal(a, b)) def test_alter_layout_prelu(): """Test PRelu operator""" @@ -672,7 +672,7 @@ def expected(): a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert(analysis.alpha_equal(a, b)) + assert(tvm.ir.structural_equal(a, b)) def test_alter_layout_pad(): @@ -715,7 +715,7 @@ def expected_nchw(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nchw(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) # Check NHWC conversion. def before_nhwc(): @@ -749,7 +749,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) # Check that conversion does not happen when padding along split axis. def before(): @@ -782,7 +782,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_pool(): @@ -825,7 +825,7 @@ def expected_nchw(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nchw(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) # Check NHWC conversion. def before_nhwc(): @@ -859,7 +859,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_sum(): @@ -902,7 +902,7 @@ def expected_nchw(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nchw(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) # Check NHWC conversion. def before_nhwc(): @@ -937,7 +937,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) # TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the @@ -999,7 +999,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_op_with_global_var(): """Test directly replacing an operator with a new one""" @@ -1041,7 +1041,7 @@ def expected(): a = transform.AlterOpLayout()(a) b = transform.InferType()(expected()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a) if __name__ == "__main__": test_alter_op() diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 3e7d916c96fa..ea92546fa1d2 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -64,7 +64,7 @@ def expected(): annotated_func = annotated() expected_func = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.alpha_equal(annotated_func, expected_func) + assert tvm.ir.structural_equal(annotated_func, expected_func) def test_annotate_expr(): @@ -91,7 +91,7 @@ def expected(): annotated_expr = annotated() expected_expr = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.graph_equal(annotated_expr, expected_expr) + assert tvm.ir.structural_equal(annotated_expr, expected_expr) def test_annotate_all(): @@ -120,7 +120,7 @@ def expected(): annotated_func = annotated() expected_func = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.graph_equal(annotated_func, expected_func) + assert tvm.ir.structural_equal(annotated_func, expected_func) def test_annotate_none(): @@ -146,13 +146,13 @@ def expected(): annotated_func = annotated() expected_func = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.graph_equal(annotated_func, expected_func) + assert tvm.ir.structural_equal(annotated_func, expected_func) def check_annotated_graph(annotated_func, expected_func): annotated_func = run_opt_pass(annotated_func, transform.InferType()) expected_func = run_opt_pass(expected_func, transform.InferType()) - assert relay.analysis.alpha_equal(annotated_func, expected_func) + assert tvm.ir.structural_equal(annotated_func, expected_func) def test_conv_network(): @@ -596,7 +596,7 @@ def annotated(): annotated_func = annotated() expected_func = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.graph_equal(annotated_func, expected_func) + assert tvm.ir.structural_equal(annotated_func, expected_func) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_canonicalize_cast.py b/tests/python/relay/test_pass_canonicalize_cast.py index e9ab67ff5166..7b6617a3c7f4 100644 --- a/tests/python/relay/test_pass_canonicalize_cast.py +++ b/tests/python/relay/test_pass_canonicalize_cast.py @@ -64,7 +64,7 @@ def check(shape): mod[gv] = y_expected mod = _transform.InferType()(mod) y_expected = mod["expected"] - assert relay.analysis.alpha_equal(y, y_expected) + assert tvm.ir.structural_equal(y, y_expected) check((1, 16, 7, 7)) diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index ec9bcd9f2bc4..345f068e39d2 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -72,7 +72,7 @@ def check(x_shape, channels1, channels2, channels3, channels4): transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y, y_expected) + assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True) check((1, 4, 16, 16), 4, 4, 4, 4) check((1, 4, 16, 16), 4, 8, 4, 7) @@ -118,7 +118,7 @@ def check(x_shape, channels1, channels2): transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y, y_expected) + assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True) check((1, 4, 16, 16), 4, 8) @@ -157,7 +157,7 @@ def check(x_shape, channels1, channels2): transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y, y_expected) + assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True) check((1, 4, 16, 16), 4, 8) @@ -193,7 +193,7 @@ def check(x_shape, repeat): transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w, out_c, repeat) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y, y_expected) + assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True) check((1, 4, 16, 16), 4) diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index 84d8211666d8..f0f2e1858fb1 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -75,7 +75,7 @@ def check(i, j, k): transform.CombineParallelDense(min_num_branches=2)) y_expected = expected(x, w1, w2, w3, w4) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y, y_expected) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check(3, 5, 4) check(100, 200, 300) @@ -127,7 +127,7 @@ def check(i, j, k, is_2d_bias): transform.CombineParallelDense(min_num_branches=2)) y_expected = expected(x, w1, w2, b1, b2, is_2d_bias) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y, y_expected) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check(3, 5, 4, False) check(100, 200, 300, False) @@ -184,7 +184,7 @@ def check(i, j, k, scale1, scale2, newshape): transform.CombineParallelDense(min_num_branches=2)) y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y, y_expected) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check(3, 5, 4, 0.5, 0.25, (1, 1, 15)) check(100, 200, 300, 0.5, 0.25, (1, 1, 200)) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 9e8f66249bd3..c783971c0568 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -52,7 +52,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_conv_convert_layout(): @@ -87,7 +87,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_conv_bias_pool_convert_layout(): @@ -132,7 +132,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_conv_concat_convert_layout(): @@ -180,7 +180,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_dual_path_convert_layout(): @@ -235,7 +235,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_bn_convert_layout(): @@ -315,7 +315,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_scalar_convert_layout(): @@ -347,7 +347,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_conv_bn_convert_layout(): @@ -395,7 +395,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_qnn_conv_requantize_convert_layout(): @@ -451,7 +451,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_qnn_conv_concat_convert_layout(): @@ -529,7 +529,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_qnn_conv_add_convert_layout(): @@ -609,7 +609,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout('NCHW')) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 3a0bf1feccc2..60dfa622ba8b 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -18,7 +18,7 @@ from tvm import te from tvm import relay from tvm.relay import Function, transform -from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal +from tvm.relay.analysis import free_vars from tvm.relay.op import log, add, equal, subtract from tvm.relay.testing import inception_v3 @@ -69,7 +69,7 @@ def test_used_let(): def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) orig = run_opt_pass(orig, transform.DeadCodeElimination(True)) - assert_alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d)) + tvm.ir.assert_structural_equal(Function(free_vars(orig), orig), Function([e.d], e.d)) def test_chain_unused_let(): @@ -105,7 +105,7 @@ def test_recursion(): orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)])) dced = run_opt_pass(orig, transform.DeadCodeElimination()) orig = run_opt_pass(orig, transform.InferType()) - assert_alpha_equal(dced, orig) + tvm.ir.assert_structural_equal(dced, orig) def test_recursion_dead(): x = relay.Let(e.a, e.one, e.three) diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py index dddbef73e564..89e3b6784a70 100644 --- a/tests/python/relay/test_pass_eliminate_common_subexpr.py +++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py @@ -52,7 +52,7 @@ def expected(): z = before() z = run_opt_pass(z, transform.EliminateCommonSubexpr()) - assert analysis.alpha_equal(z, expected()) + assert tvm.ir.structural_equal(z, expected()) def test_callback(): @@ -82,7 +82,7 @@ def fskip(expr): z = before() z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip)) - assert analysis.alpha_equal(z, expected()) + assert tvm.ir.structural_equal(z, expected()) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py index ad04e413b21b..84ff54a3b21e 100644 --- a/tests/python/relay/test_pass_eta_expand.py +++ b/tests/python/relay/test_pass_eta_expand.py @@ -47,7 +47,8 @@ def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { } } """) - relay.analysis.assert_graph_equal(mod['main'], expected['main']) + tvm.ir.assert_structural_equal(mod['main'], expected['main'], + map_free_vars=True) def test_eta_expand_constructor(): @@ -76,7 +77,8 @@ def @main[A]() -> (fn(A, List[A]) -> List[A]) { } } """) - relay.analysis.assert_graph_equal(mod['main'], expected['main']) + tvm.ir.assert_structural_equal(mod['main'], expected['main'], + map_free_vars=True) if __name__ == '__main__': diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index cc362a266aa5..3ddafd73b114 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -59,7 +59,7 @@ def fail(x): with tvm.target.create("cuda"): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.alpha_equal(zz, zexpected) + assert tvm.ir.structural_equal(zz, zexpected) def test_fold_let(): @@ -84,7 +84,7 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.graph_equal(zz, zexpected) + assert tvm.ir.structural_equal(zz, zexpected) def test_fold_tuple(): @@ -106,7 +106,7 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.graph_equal(zz, zexpected) + assert tvm.ir.structural_equal(zz, zexpected) def test_fold_concat(): @@ -125,7 +125,7 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.graph_equal(zz, zexpected) + assert tvm.ir.structural_equal(zz, zexpected) def test_fold_shape_of(): @@ -146,7 +146,7 @@ def expected(dtype): for dtype in ["int32", "float32"]: zz = run_opt_pass(before(dtype), transform.FoldConstant()) zexpected = run_opt_pass(expected(dtype), transform.InferType()) - assert relay.analysis.graph_equal(zz, zexpected) + assert tvm.ir.structural_equal(zz, zexpected) def test_fold_full(): @@ -161,7 +161,7 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.graph_equal(zz, zexpected) + assert tvm.ir.structural_equal(zz, zexpected) def test_fold_batch_norm(): @@ -202,7 +202,7 @@ def initializer(_, param): mod = remove_bn_pass(mod) expect = run_infer_type(expected()) - assert relay.analysis.graph_equal(mod["main"], expect) + assert tvm.ir.structural_equal(mod["main"], expect) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 4c094fb3e6e7..bf2a708ceea9 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -79,7 +79,7 @@ def check(shape, channels): y1_folded = run_opt_pass(y1_folded, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y1_folded, y1_expected) + assert tvm.ir.structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 2) @@ -148,7 +148,7 @@ def check(dshape, channels): weight = relay.var("weight", type_dict["weight"]) y1_expected = expected(x, weight, in_bias, in_scale, channels) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y1_folded, y1_expected) + assert tvm.ir.structural_equal(y1_folded, y1_expected) check((2, 4, 10, 3), 3) @@ -177,7 +177,7 @@ def check(shape, channels): y1 = before(x, weight, in_bias, in_scale, channels) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) - assert relay.analysis.alpha_equal(y1, y1_folded) + assert tvm.ir.structural_equal(y1, y1_folded) check((2, 11, 10, 4), 4) @@ -205,7 +205,7 @@ def check(shape, channels, in_scale): y1 = before(x, weight, in_bias, in_scale, channels) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) - assert relay.analysis.alpha_equal(y1, y1_folded) + assert tvm.ir.structural_equal(y1, y1_folded) in_scale = relay.var("in_scale", shape=(4,)) check((2, 11, 10, 4), 4, in_scale) @@ -249,7 +249,7 @@ def check(shape, channels): y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) y1_expected = expected(x, weight, in_scale, channels) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y1_folded, y1_expected) + assert tvm.ir.structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 4) @@ -300,7 +300,7 @@ def check(shape, channels): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y1_folded, y1_expected) + assert tvm.ir.structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 8) @@ -359,7 +359,7 @@ def check(shape, channels): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y1_folded, y1_expected) + assert tvm.ir.structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 8) @@ -431,7 +431,7 @@ def check(shape, channels): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y1_folded, y1_expected) + assert tvm.ir.structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 4) @@ -480,7 +480,7 @@ def check(shape, channels, fbefore): y1 = fbefore(x, weight, out_bias, out_scale, channels) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - assert relay.analysis.alpha_equal(y1_folded, y1) + assert tvm.ir.structural_equal(y1_folded, y1) check((4, 4, 10, 10), 4, fail1) check((4, 4, 10, 10), 4, fail2) @@ -505,7 +505,7 @@ def check(shape, channels, out_scale): y1 = before(x, weight, out_scale, channels) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - assert relay.analysis.alpha_equal(y1, y1_folded) + assert tvm.ir.structural_equal(y1, y1_folded) out_scale = relay.var("in_scale", shape=(4, 1, 1)) check((4, 4, 10, 10), 4, out_scale) @@ -547,7 +547,7 @@ def check(shape, channels): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_scale, channels) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert relay.analysis.alpha_equal(y1_folded, y1_expected) + assert tvm.ir.structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 8) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 108c91bd2bcb..6b7d297541c7 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -45,7 +45,7 @@ def expected(): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) def test_conv2d_fuse(): @@ -127,7 +127,7 @@ def expected(dshape): z = before(dshape) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) after = run_opt_pass(expected(dshape), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) def test_concatenate(): @@ -167,7 +167,7 @@ def expected(dshape): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dshape), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) def test_tuple_root(): @@ -204,7 +204,7 @@ def expected(dshape): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dshape), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) def test_stop_fusion(): @@ -235,7 +235,7 @@ def expected(dshape): z = before(dshape) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(dshape), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) def test_fuse_myia_regression(): @@ -271,7 +271,7 @@ def expected(dshape, dtype): f = before(dshape, dtype) zz = run_opt_pass(f, transform.FuseOps()) after = run_opt_pass(expected(dshape, dtype), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) def test_fuse_tuple_get_elemwise(): @@ -309,7 +309,7 @@ def expected(dim): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dim), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) def test_tuple_get_root(): @@ -346,7 +346,7 @@ def expected(dim): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dim), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) fuse0 = relay.transform.FuseOps(fuse_opt_level=0) @@ -379,7 +379,7 @@ def expected(p0): m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, 'llvm') after = run_opt_pass(expected(x), transform.InferType()) - assert relay.analysis.alpha_equal(m["main"], after) + assert tvm.ir.structural_equal(m["main"], after) def test_tuple_consecutive(): @@ -437,7 +437,7 @@ def expected(dshape): m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, 'llvm') after = run_opt_pass(expected(dshape), transform.InferType()) - assert relay.analysis.alpha_equal(m["main"], after) + assert tvm.ir.structural_equal(m["main"], after) def test_inception_like(): @@ -510,7 +510,7 @@ def expected(dshape): m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, 'llvm') after = run_opt_pass(expected(dshape), transform.InferType()) - assert relay.analysis.alpha_equal(m["main"], after) + assert tvm.ir.structural_equal(m["main"], after) def test_fuse_parallel_injective(): @@ -541,7 +541,7 @@ def expected(): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) def test_immutable(): @@ -570,8 +570,8 @@ def expected(): mod = before() new_mod = transform.FuseOps(fuse_opt_level=2)(mod) - assert relay.analysis.alpha_equal(mod, before()) - assert relay.analysis.alpha_equal(new_mod, expected()) + assert tvm.ir.structural_equal(mod, before()) + assert tvm.ir.structural_equal(new_mod, expected()) def test_split(): @@ -619,7 +619,7 @@ def expected(): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(), transform.InferType()) - assert relay.analysis.alpha_equal(zz, after) + assert tvm.ir.structural_equal(zz, after) if __name__ == "__main__": test_fuse_simple() diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 6f2a12589fb5..efd01cbe1a6b 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -19,7 +19,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal +from tvm.relay.analysis import free_vars, free_type_vars from tvm.relay import create_executor, transform from tvm.relay.transform import gradient from tvm.relay.prelude import Prelude @@ -292,7 +292,7 @@ def test_concat(): func = relay.Function([x], y) func = run_infer_type(func) back_func = run_infer_type(gradient(func)) - assert_alpha_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])]))) + tvm.ir.assert_structural_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])]))) # no value validation as concatenate has dummy gradient right now. diff --git a/tests/python/relay/test_pass_inline.py b/tests/python/relay/test_pass_inline.py index f4943ab6851b..0f6d539768fe 100644 --- a/tests/python/relay/test_pass_inline.py +++ b/tests/python/relay/test_pass_inline.py @@ -115,7 +115,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_call_chain_inline_multiple_levels(): @@ -188,7 +188,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_call_chain_inline_multiple_levels_extern_compiler(): @@ -266,7 +266,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_recursive_call_with_global(): @@ -321,7 +321,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_recursive_called(): @@ -330,7 +330,7 @@ def test_recursive_called(): mod["main"] = relay.Function([iarg], sum_up(iarg)) ref_mod = mod mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, ref_mod) + assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True) def test_recursive_not_called(): @@ -356,7 +356,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) ref_mod = expected() - assert relay.analysis.alpha_equal(mod, ref_mod) + assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True) def test_recursive_not_called_extern_compiler(): @@ -387,7 +387,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) ref_mod = expected() - assert relay.analysis.alpha_equal(mod, ref_mod) + assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True) def test_globalvar_as_call_arg(): @@ -434,7 +434,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_globalvar_as_call_arg_extern_compiler(): @@ -500,7 +500,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_inline_globalvar_without_args(): @@ -531,7 +531,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_inline_globalvar_without_args_extern_compiler(): @@ -566,7 +566,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_globalvar_called_by_multiple_functions(): @@ -644,7 +644,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_entry_with_inline(): @@ -674,7 +674,7 @@ def get_mod(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, get_mod()) + assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True) def test_callee_not_inline(): @@ -707,7 +707,7 @@ def get_mod(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, get_mod()) + assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True) def test_callee_not_inline_leaf_inline(): @@ -765,7 +765,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) def test_callee_not_inline_leaf_inline_extern_compiler(): @@ -830,7 +830,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert relay.analysis.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) if __name__ == '__main__': diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 9976eca28b29..1456700c4627 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -68,7 +68,7 @@ def expected(): a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_legalize_none(): """Test doing nothing by returning 'None' """ @@ -89,7 +89,7 @@ def legalize_conv2d(attrs, inputs, types): a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(before(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) assert(called[0]) def test_legalize_multiple_ops(): @@ -134,7 +134,7 @@ def expected(): a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_legalize_multi_input(): @@ -170,7 +170,7 @@ def expected(): a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index aed026996a21..0a6555b5c5be 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -111,7 +111,7 @@ def get_rand(shape, dtype='float32'): def check_func(func, ref_func): func = run_infer_type(func) ref_func = run_infer_type(ref_func) - assert analysis.graph_equal(func, ref_func) + assert tvm.ir.structural_equal(func, ref_func) def test_module_pass(): @@ -211,7 +211,7 @@ def transform_function(self, func, mod, ctx): mod = fpass(mod) # wrap in expr mod2 = tvm.IRModule.from_expr(f1) - assert relay.alpha_equal(mod["main"], mod2["main"]) + assert tvm.ir.structural_equal(mod["main"], mod2["main"]) def test_function_pass(): @@ -496,7 +496,7 @@ def expected(): zz = mod["main"] zexpected = run_infer_type(expected()) - assert analysis.alpha_equal(zz, zexpected) + assert tvm.ir.structural_equal(zz, zexpected) def test_print_ir(capfd): diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 72ed3fcef6ff..3c70cf237c94 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Unit tests for merge composite.""" +import tvm from tvm import relay from tvm import tir from tvm.relay.testing import run_opt_pass @@ -192,7 +193,7 @@ def expected(): result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(expected(), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) def test_branch_merge(): @@ -270,7 +271,7 @@ def expected(): result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(expected(), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) def test_reuse_call_merge(): @@ -329,7 +330,7 @@ def expected(): result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(expected(), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) def test_multiple_patterns(): @@ -422,7 +423,7 @@ def expected(): result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(expected(), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) def test_merge_order(): @@ -494,7 +495,7 @@ def after_A_priority(composite_name): result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) # check B highest priority pattern_table = [ @@ -505,7 +506,7 @@ def after_A_priority(composite_name): result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) # check C highest priority pattern_table = [ @@ -516,7 +517,7 @@ def after_A_priority(composite_name): result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) def test_parallel_merge(): @@ -563,7 +564,7 @@ def after(): result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(after(), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) def test_multiple_input_subgraphs(): @@ -676,13 +677,13 @@ def after_B(): result = run_opt_pass(before()['A'], relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(after_A(), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) # check case 'B' result = run_opt_pass(before()['B'], relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(after_B(), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) def test_tuple_get_item_merge(): @@ -728,7 +729,7 @@ def expected(): result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(expected(), relay.transform.InferType()) - assert relay.analysis.alpha_equal(result, expected) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 1299084ef740..0f3eea663f69 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -19,7 +19,6 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.analysis import alpha_equal, assert_alpha_equal from tvm.relay.prelude import Prelude from tvm.relay import op, create_executor, transform from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate @@ -124,7 +123,7 @@ def test_ad(): body = relay.Let(x1, o, body) expected = Function([d], relay.Let(x, m, body)) expected = run_opt_pass(expected, transform.InferType()) - assert_alpha_equal(g, expected) + tvm.ir.assert_structural_equal(g, expected) def test_if_ref(): @@ -312,7 +311,7 @@ def test_concat(): x = Var("x", t) y = Var("x", t) orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0))) - assert_alpha_equal(dcpe(orig), orig) + tvm.ir.assert_structural_equal(dcpe(orig), orig) def test_triangle_number(): @@ -321,7 +320,7 @@ def test_triangle_number(): f_var = Var("f") f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1)))) orig = run_infer_type(Let(f_var, f, f_var(const(10)))) - assert_alpha_equal(dcpe(orig), const(55)) + tvm.ir.assert_structural_equal(dcpe(orig), const(55)) def test_nat_update(): @@ -337,7 +336,7 @@ def test_tuple_match(): b = relay.Var("b") clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) - assert_alpha_equal(dcpe(x), const(2)) + tvm.ir.assert_structural_equal(dcpe(x), const(2)) if __name__ == '__main__': diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 1f37ab84d4a5..fc8dfb619124 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -339,7 +339,7 @@ def expected(): fused_mod = transform.FuseOps(2)(mod) expected_mod = expected() - assert relay.alpha_equal(fused_mod, expected_mod) + assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True) x_data = np.random.rand(8, 8).astype('float32') y_data = np.random.rand(8, 8).astype('float32') @@ -427,7 +427,7 @@ def get_func(): mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func()) mod = transform.PartitionGraph()(mod) - assert relay.alpha_equal(mod, expected()) + assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) ref_mod = tvm.IRModule() ref_mod["main"] = get_func() @@ -561,7 +561,7 @@ def expected(): partitioned = partition() ref_mod = expected() - assert relay.analysis.alpha_equal(partitioned, ref_mod) + assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) def test_function_lifting_inline(): @@ -631,7 +631,7 @@ def expected(): partitioned = partition() ref_mod = expected() - assert relay.analysis.alpha_equal(partitioned, ref_mod) + assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) def test_constant_propagation(): @@ -671,7 +671,7 @@ def expected(): mod = transform.PartitionGraph()(mod) expected_mod = expected() - assert relay.alpha_equal(mod, expected_mod) + assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True) y_data = np.random.rand(8, 8).astype('float32') np_add = ones + y_data diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index b1648211002c..e7980e712035 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -31,7 +31,7 @@ def alpha_equal(x, y): """ x = x['main'] y = y['main'] - return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) + return tvm.ir.structural_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] @@ -85,12 +85,12 @@ def expected(): # Check that Relay Legalize does not change the graph. a = run_opt_pass(a, relay.transform.Legalize()) b = run_opt_pass(before(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) # Check that QNN Legalize modifies the graph. a = run_opt_pass(a, relay.qnn.transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_qnn_legalize_qnn_conv2d(): diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index 33816344f562..43b54e9e6efe 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -110,7 +110,7 @@ def get_mod(): mod = get_mod() ref_mod = get_mod() mod = relay.transform.RemoveUnusedFunctions()(mod) - assert relay.alpha_equal(mod, ref_mod) + assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True) if __name__ == '__main__': diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py index bb398939156e..3a8c90b331d1 100644 --- a/tests/python/relay/test_pass_simplify_inference.py +++ b/tests/python/relay/test_pass_simplify_inference.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from tvm.ir import IRModule +from tvm.ir import IRModule, structural_equal from tvm import relay as rly from tvm.relay.transform import SimplifyInference @@ -56,7 +56,7 @@ def check(dim, axis, nstep): mod = simplify(mod) y1 = mod["main"].body - assert rly.analysis.graph_equal(y1, y2) + assert structural_equal(y1, y2, map_free_vars=True) check(2, 1, 1) check(4, 1, 1) diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 29818f8e136a..d7babf37ed2a 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -18,7 +18,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.analysis import alpha_equal, detect_feature +from tvm.relay.analysis import detect_feature from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count diff --git a/tests/python/relay/test_type_functor.py b/tests/python/relay/test_type_functor.py index 9e023bc6b1e4..b90a6887c18d 100644 --- a/tests/python/relay/test_type_functor.py +++ b/tests/python/relay/test_type_functor.py @@ -18,7 +18,6 @@ from tvm import te from tvm import relay from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor -from tvm.relay.analysis import assert_graph_equal from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType, TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall) from tvm.relay.adt import TypeData @@ -34,7 +33,8 @@ def check_visit(typ): ev = TypeVisitor() ev.visit(typ) - assert_graph_equal(TypeMutator().visit(typ), typ) + tvm.ir.assert_structural_equal(TypeMutator().visit(typ), typ, + map_free_vars=True) def test_type_var(): diff --git a/tests/python/unittest/test_ir_type.py b/tests/python/unittest/test_ir_type.py index f919f92aa305..a0e7d2b46ad6 100644 --- a/tests/python/unittest/test_ir_type.py +++ b/tests/python/unittest/test_ir_type.py @@ -18,10 +18,9 @@ import tvm def check_json_roundtrip(node): - from tvm.relay.analysis import graph_equal json_str = tvm.ir.save_json(node) back = tvm.ir.load_json(json_str) - assert graph_equal(back, node) + assert tvm.ir.structural_equal(back, node, map_free_vars=True) def test_prim_type():