From fa561c816d40e3fe7e16028cbf93619bcf604eb3 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Sun, 5 Feb 2023 00:51:08 -0500 Subject: [PATCH] [Unity] Basic StructInfo Analysis and Expr construction (#13916) [Unity] Basic StructInfo Analysis and Expr construction. This PR adds struct info analysis and expr support. These are logics to construct the IR node and perform struct info related analysis. Testcases are added to cover the IR node construction and related struct info analysis checks. Co-authored-by: Tianqi Chen Co-authored-by: Altan Haan Co-authored-by: Andrew Liu Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Jiawei Liu Co-authored-by: Junru Shao Co-authored-by: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Co-authored-by: masahi Co-authored-by: Prakalp Srivastava Co-authored-by: Ruihang Lai Co-authored-by: Siyuan Feng Co-authored-by: Steven S. Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Co-authored-by: Yixin Dong Co-authored-by: Yong Wu Co-authored-by: Ziheng Jiang --- CMakeLists.txt | 1 + include/tvm/ir/type.h | 3 +- include/tvm/relax/analysis.h | 252 ++++++ include/tvm/relax/expr.h | 43 +- include/tvm/relax/expr_functor.h | 415 ++++++++++ include/tvm/relax/struct_info.h | 7 +- include/tvm/relax/struct_info_functor.h | 151 ++++ python/tvm/ir/expr.py | 11 + python/tvm/relax/__init__.py | 48 ++ python/tvm/relax/analysis/__init__.py | 20 + python/tvm/relax/analysis/_ffi_api.py | 19 + python/tvm/relax/analysis/analysis.py | 135 ++++ python/tvm/relax/expr.py | 729 ++++++++++++++++++ python/tvm/relax/struct_info.py | 197 +++++ python/tvm/relax/ty.py | 75 ++ python/tvm/script/__init__.py | 1 + python/tvm/script/parser/relax/__init__.py | 21 + src/ir/function.cc | 14 +- src/ir/type.cc | 3 +- src/relax/analysis/shape_analysis.cc | 55 ++ src/relax/analysis/struct_info_analysis.cc | 716 +++++++++++++++++ src/relax/ir/expr.cc | 601 +++++++++++++++ src/relax/ir/expr_functor.cc | 546 +++++++++++++ src/relax/ir/struct_info.cc | 14 +- src/relax/ir/struct_info_functor.cc | 130 ++++ src/relax/ir/type.cc | 88 +++ .../test_analysis_struct_info_analysis.py | 418 ++++++++++ tests/python/relax/test_expr.py | 258 +++++++ tests/python/relax/test_struct_info.py | 241 ++++++ 29 files changed, 5198 insertions(+), 14 deletions(-) create mode 100644 include/tvm/relax/analysis.h create mode 100644 include/tvm/relax/expr_functor.h create mode 100644 include/tvm/relax/struct_info_functor.h create mode 100644 python/tvm/relax/analysis/__init__.py create mode 100644 python/tvm/relax/analysis/_ffi_api.py create mode 100644 python/tvm/relax/analysis/analysis.py create mode 100644 python/tvm/relax/expr.py create mode 100644 python/tvm/relax/struct_info.py create mode 100644 python/tvm/relax/ty.py create mode 100644 python/tvm/script/parser/relax/__init__.py create mode 100644 src/relax/analysis/shape_analysis.cc create mode 100644 src/relax/analysis/struct_info_analysis.cc create mode 100644 src/relax/ir/expr.cc create mode 100644 src/relax/ir/expr_functor.cc create mode 100644 src/relax/ir/struct_info_functor.cc create mode 100644 src/relax/ir/type.cc create mode 100644 tests/python/relax/test_analysis_struct_info_analysis.py create mode 100644 tests/python/relax/test_expr.py create mode 100644 tests/python/relax/test_struct_info.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 19f37d06f315..fa38ba6c6c8a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -290,6 +290,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/support/*.cc src/script/*.cc src/relax/ir/*.cc + src/relax/analysis/*.cc src/relax/backend/vm/*.cc ) diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index c6baf5e08be3..ec13635a2643 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -131,8 +131,9 @@ class PrimType : public Type { /*! * \brief Constructor * \param dtype The corresponding dtype. + * \param span The span */ - TVM_DLL explicit PrimType(runtime::DataType dtype); + TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); }; diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h new file mode 100644 index 000000000000..82145032f458 --- /dev/null +++ b/include/tvm/relax/analysis.h @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis.h + * \brief The set of Relax specific analysis on IR. + */ +#ifndef TVM_RELAX_ANALYSIS_H_ +#define TVM_RELAX_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { +//----------------------------------- +// Shape expression analysis +//---------------------------------- +/*! + * \brief Can prove the two symbolic shape arrays equals to each other. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana The analyzer used for integer analysis. + * \return The prove result. + * + * \note This function does best effort prove, which means + * if result is false, there is still possibility that + * two shapes equals to each other during runtime. + */ +TVM_DLL bool CanProveShapeEqual(const Array& lhs, const Array& rhs, + arith::Analyzer* ana); + +/*! + * \brief Can prove the two symbolic shape expressions equals to each other. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana The analyzer used for integer analysis. + * + * \note This function does best effort prove, which means + * if result is false, there is still possibility that + * two shapes equals to each other during runtime. + */ +TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana); + +//----------------------------------- +// Foundational StructInfo analysis +//----------------------------------- +/*! + * \brief Get the corresponding static type from a given struct info. + * \param info The struct info. + * \return the corresponding static type. + */ +TVM_DLL Type GetStaticType(const StructInfo& info); + +/*! + * \brief Get the corresponding struct info from static type. + * \param type The input type + * \return the corresponding struct info. + */ +TVM_DLL StructInfo StructInfoFromType(const Type& type); + +/*! + * \brief Erase the info to a corresponding more coarse grained + * struct info that is still well-defined(with all the vars in scope). + * + * When we are returning a StructInfo to another scope, + * it is important to remember that StructInfo may carry + * dependencies on var that is not defined the other scope. + * + * In such cases, it is important to call EraseToWellDefined to get + * another StructInfo that **only** contains the vars that are defined + * in the target scope. + * + * For example, consider the following function + * + * \code + * + * @R.function + * def f(x: R.Tensor[(n, m)]): + * k = tir.Var("k", "int64") + * v0 = opaque_fn(x) + * v1 = match_cast(v0, R.Tensor[(n, k)]) + * v2 : R.Tensor[(n + 1, k + 2)] = pad(v1) + * return v2 + * + * \endcode + * + * In the above code, the return value y have shape `(n + 1, k + 2)`, + * However, at the level of function signature, only n, m are defined, + * k is undefined here. + * + * When we call EraseToWellDefined(R.Tensor[(n + 1, k + 2)], fshape_var_map={n: n, m: m}), + * we will obtain R.Tensor(ndim=2), which is an erased info that does not depend + * on k(which is undefined from parameter signature). + * + * However, if we call EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: n, m: m}), + * Then the return value will be R.Tensor[(n + 1, m)], because both n and m are defined. + * + * We can also make these var map to return a different expression. + * For example, EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: 2, m: m}) + * will give us R.Tensor[(3, m)], where n get replaced by 2. + * + * Use this function in the following scenarios: + * - Decide the struct_info of expr with sub-scopes, such as If, SeqExpr + * - Decide the deduced return struct_info of a function that can be fully decided by params. + * + * \param info The struct info. + * \param f_shape_var_map callback function to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param f_var_map callback function to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Optional context analyzer to prove symbolic expression equality. + * + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo +EraseToWellDefined(const StructInfo& info, + std::function(const tir::Var& var)> f_shape_var_map = nullptr, + std::function(const Var& var)> f_var_map = nullptr, + arith::Analyzer* ana = nullptr); + +/*! + * \brief EraseToWellDefined variant with map. + * \param info The struct info. + * \param shape_var_map map to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param var_map map to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Optional context analyzer to prove symbolic expression equality. + * + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, + Map var_map, arith::Analyzer* ana = nullptr); + +/*! + * \brief Fine grained result of base check. + * + * This analysis comes with different levels of checking failures + * that can help to customize the compilation decisions. + * + * For a given pair of lhs_struct_info, rhs_struct_info. We adopt + * the following terminology: + * - LSet = {value | value matches lhs_struct_info} + * - RSet = {value | value matches rhs_struct_info} + * + * See the definition of each level below. + */ +enum class BaseCheckResult { + /*! + * \brief The two value sets have no intersection at all: Interset(LSet, RSet) = empty + */ + kFailL0 = 0, + /*! + * \brief LSet is not superset of RSet by only looking at static information. + * + * \note This level will trigger static type checking error when lhs is param and rhs is arg. + */ + kFailL1 = 1, + /*! + * \brief WLSet is not superset of RSet because of mismatch in value information. + * + * L1-level mismatches in params of FuncStructInfo is categorized as + * If lhs is FuncStructInfo, then L1-level mismatch in its params + * is categorized as L2-level mismatch for lhs. + * + * Design considerations for functions: + * - (a) We want to be able to erase type/value in function signature + * when we unify function struct info and preserve simpler representations. + * - (b) We automatically insert match_cast at function boundary, so + * we can erase (int)->int argument as (object)->int. + * The input shape/type mismatch will be detected by runtime checks at function boundary. + * This behavior is also consistent with the PackedFunc behavior. + * + * \note This level means there is no problem about static known information. + * It is OK for the checker to do best effort and return this value. + */ + kFailL2 = 2, + /*! \brief LSet is superset of RSet. */ + kPass = 3 +}; + +/*! + * \brief Run a base check to see if base subsumes derived. + * + * This function returns fine-grained base-check result on reasons of failure. + * + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + * + * \sa BaseCheckResult + */ +TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana = nullptr); + +/*! + * \brief Check the relation of two struct info to see if one subsumes another one. + * + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + */ +TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana = nullptr); + +/*! + * \brief Unify the two struct info to their least common ancestor. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return The unified information. + */ +TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, + arith::Analyzer* ana = nullptr); +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ANALYSIS_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 8154b1dd86de..9e563c7061dc 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -35,7 +34,47 @@ namespace relax { using Expr = RelayExpr; using ExprNode = RelayExprNode; -using relay::Id; +/*! + * \brief The unique identifier of variables. + * + * Id is like name to the variables, + * except that id is unique for each Var. + * + * \note Do not create Id directly, they are created in Var. + */ +class IdNode : public Object { + public: + /*! + * \brief The name of the variable, + * this only acts as a hint to the user, + * and is not used for equality. + */ + String name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); } + + bool SEqualReduce(const IdNode* other, SEqualReducer equal) const { + return equal.FreeVarEqualImpl(this, other); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.FreeVarHashImpl(this); } + + static constexpr const char* _type_key = "relax.Id"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); +}; + +class Id : public ObjectRef { + public: + /*! + * \brief The constructor + * \param name_hint The name of the variable. + */ + TVM_DLL explicit Id(String name_hint); + + TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); +}; /*! * \brief Base type of all structure information. diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h new file mode 100644 index 000000000000..5735e8661f6f --- /dev/null +++ b/include/tvm/relax/expr_functor.h @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/expr_functor.h + * \brief A more powerful visitor which enables defining arbitrary function + * signatures with type based dispatch on first argument. + */ +#ifndef TVM_RELAX_EXPR_FUNCTOR_H_ +#define TVM_RELAX_EXPR_FUNCTOR_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Expr&, + * Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } + +#define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); + +#define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \ + { \ + if (PY_FUNC != nullptr) \ + PY_FUNC(N); \ + else \ + DEFAULT_FUNC; \ + } + +#define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \ + { \ + if (PY_FUNC != nullptr) { \ + RET_TYPE ret = PY_FUNC(N); \ + return ret; \ + } else { \ + return DEFAULT_FUNC; \ + } \ + } + +#define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) \ + self->PY_FUNC(n); \ + else \ + self->VisitExpr_(static_cast(n.get())); \ + }); + +#define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) { \ + Expr expr = self->PY_FUNC(n); \ + return expr; \ + } else { \ + return self->VisitExpr_(static_cast(n.get())); \ + } \ + }); + +#define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \ + post_order_vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + return self->VisitExprPostOrder_(static_cast(n.get())); \ + }); + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward(args)...); } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + // NOTE: cross dialect calls are invoked through global var + // We do not expect inline PrimFunc to appear in relax IR. + virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const PrimValueNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DataTypeImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAX_EXPR_FUNCTOR_DISPATCH(ConstantNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleNode); + RELAX_EXPR_FUNCTOR_DISPATCH(VarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(DataflowVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ShapeExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ExternFuncNode); + RELAX_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode); + RELAX_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(IfNode); + RELAX_EXPR_FUNCTOR_DISPATCH(OpNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); + RELAX_EXPR_FUNCTOR_DISPATCH(PrimValueNode); + RELAX_EXPR_FUNCTOR_DISPATCH(StringImmNode); + RELAX_EXPR_FUNCTOR_DISPATCH(DataTypeImmNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around ExprFunctor. + * Recursively visit the content. + */ +class ExprVisitor : public ExprFunctor { + public: + /*! + * \brief Generic dispatcher for Expr. + * \param expr The expr to be visited. + */ + void VisitExpr(const Expr& expr) override; + // specific leaf level visitor functions + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const DataflowVarNode* op) override; + void VisitExpr_(const ShapeExprNode* op) override; + void VisitExpr_(const ExternFuncNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const SeqExprNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const PrimValueNode* op) override; + void VisitExpr_(const StringImmNode* op) override; + void VisitExpr_(const DataTypeImmNode* op) override; + + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ + virtual void VisitBinding(const Binding& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchCastNode* binding); + // second level dispatching based on binding value type. + // these dispatching functions get called from first-level dispatch on VarBinding + virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val); + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + */ + virtual void VisitBindingBlock(const BindingBlock& block); + // specific leaf level visitor functions + virtual void VisitBindingBlock_(const BindingBlockNode* block); + virtual void VisitBindingBlock_(const DataflowBlockNode* block); + + /*! + * \brief Generic dispatcher for visiting the var definition site. + * \param var The var to be visited. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual void VisitVarDef(const Var& var); + + /*! + * \brief Visit struct_info may recursively contain Expr/PrimExpr. + * + * By default, this function recurse into struct info such as + * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionStructInfo as it does + * not contain Expr defined in the current scope. + * + * Pass writers can overload this function to change to other behaviors. + * For example, if we are not interested in Expr in StructInfo, we can + * override this function by a no-op. + * + * \param struct_info Input struct info field. + */ + virtual void VisitExprDepStructInfoField(const StructInfo& struct_info); + + // specific leaf level visitor functions + virtual void VisitVarDef_(const VarNode* var); + virtual void VisitVarDef_(const DataflowVarNode* var); + + virtual void VisitSpan(const Span& span); + virtual void VisitPrimExpr(const PrimExpr& expr); + + private: + using TSelf = ExprVisitor; + using VisitBindingVTable = + tvm::NodeFunctor; + // initialize the vtable. + static VisitBindingVTable InitVisitBindingVTable(); + /*! + * \brief Private internal struct info field visitor. + * + * Support default visiting of struct info field and recursive into + * their Expr fields. + * + * We use component instead of sub-classing so there can be other + * joint inheritance between ExprVisitor and StructInfoVisitor. + */ + class DefaultStructInfoFieldVisitor : public StructInfoVisitor { + public: + explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent); + + // Override defaults in struct info visitor. + void VisitStructInfoExprField(const Expr& expr) final; + void VisitStructInfoExprField(const PrimExpr& expr) final; + void VisitStructInfo_(const FuncStructInfoNode* op) final; + + private: + ExprVisitor* parent_; + }; + // This visitor is not visible to child classes and only + // used to supportd default visiting behavior. + DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this}; +}; + +void PostOrderVisit(const Expr& node, std::function fvisit); + +/*! + * \brief A mutator works in unnormalized form. + * + * ExprMutatorBase expects input AST to be in the unnormalized form, i.e., checked_type_ and shape_ + * of expressions can be nullptr, and the expressions may nest(and as a result the AST is not in + * ANF). + */ + +class ExprMutatorBase : public ExprFunctor { + public: + Expr VisitExpr(const Expr& expr) override; + Expr VisitExpr_(const ConstantNode* op) override; + Expr VisitExpr_(const TupleNode* op) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const DataflowVarNode* op) override; + Expr VisitExpr_(const ShapeExprNode* op) override; + Expr VisitExpr_(const ExternFuncNode* op) override; + Expr VisitExpr_(const GlobalVarNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const CallNode* op) override; + Expr VisitExpr_(const SeqExprNode* op) override; + Expr VisitExpr_(const IfNode* op) override; + Expr VisitExpr_(const OpNode* op) override; + Expr VisitExpr_(const TupleGetItemNode* op) override; + Expr VisitExpr_(const PrimValueNode* op) override; + Expr VisitExpr_(const StringImmNode* op) override; + Expr VisitExpr_(const DataTypeImmNode* op) override; + + /*! + * \brief Mutate BindingBlock. + * \param block The binding block to be visited. + * \return The binding block after transformation. + */ + virtual BindingBlock VisitBindingBlock(const BindingBlock& block); + + /*! + * \brief Used to visit the PrimExpr inside of expressions. + * + * Can be overloaded to transform the shape expressions. + */ + virtual PrimExpr VisitPrimExpr(const PrimExpr& expr); + + /*! + * \brief Visit struct_info that may recursively contain Expr/PrimExpr. + * + * By default, this function recurse into struct info such as + * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionStructInfo as it does + * not contain Expr defined in the current scope. + * + * Pass writers can overload this function to change to other behaviors. + * For example, if in Expr in StructInfo won't change, we can + * override this function by an identity function. + * + * \param struct_info Input struct info field. + * \return The updated struct info. + */ + virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info); + + protected: + /*! + * \brief Check whether VisitExprDepStructInfoField change struct_info. + * \return Whether struct info changed. + * \note This function is used by mutator implementations to check if + * previous Expr update will trigger a change in struct_info. + * If change is detected, the implementation can generate a fresh + * node without struct_info, and trigger normalizer to re-derive. + */ + bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) { + if (const StructInfoNode* sinfo = struct_info.as()) { + return this->VisitExprDepStructInfoField(GetRef(sinfo)).same_as(struct_info); + } else { + return true; + } + } + + private: + /*! + * \brief Private internal struct info field visitor to support + * Default visiting of struct info field and recursive into their Expr fields. + * + * We use component instead of sub-classing so there can be other + * joint inheritance between ExprMutator and StructInfoMutator. + */ + class DefaultStructInfoFieldMutator : public StructInfoMutator { + public: + explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent); + + // Override defaults in struct info visitor. + Expr VisitStructInfoExprField(const Expr& expr) final; + PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final; + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final; + + private: + ExprMutatorBase* parent_; + }; + // This visitor is not visible to child classes and only + // used to supportd default visiting behavior. + DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this}; +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index d21c8db86b3f..f38a32f6bb83 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -22,13 +22,16 @@ #include #include #include -// #include #include #include namespace tvm { namespace relax { +// TODO(relax-team) replace with real BlockBuilder +// once it is ready. +using BlockBuilder = ObjectRef; + /*! * \brief Opaque object. */ @@ -257,8 +260,6 @@ class TupleStructInfo : public StructInfo { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode); }; -class BlockBuilder; - /*! * \brief custom-defined StructInfo derivation function. * \param call The call expression to be derived. diff --git a/include/tvm/relax/struct_info_functor.h b/include/tvm/relax/struct_info_functor.h new file mode 100644 index 000000000000..382b4ab2c936 --- /dev/null +++ b/include/tvm/relax/struct_info_functor.h @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/struct_info_functor.h + * \brief Functors and visitors for struct info. + */ +#ifndef TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ +#define TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +template +class StructInfoFunctor; + +// functions to be overriden. +#define STRUCT_INFO_FUNCTOR_DEFAULT \ + { return VisitStructInfoDefault_(op, std::forward(args)...); } + +#define TVM_STRUCT_INFO_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStructInfo_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class StructInfoFunctor { + private: + using TSelf = StructInfoFunctor; + using FStructInfo = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~StructInfoFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const StructInfo& n, Args... args) { + return VisitStructInfo(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitStructInfo(const StructInfo& n, Args... args) { + ICHECK(n.defined()); + static FStructInfo vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitStructInfo_(const ObjectStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const PrimStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const ShapeStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const TensorStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const TupleStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const FuncStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfoDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; // unreachable, written to stop compiler warning + } + + private: + // initialize the vtable. + static FStructInfo InitVTable() { + FStructInfo vtable; + // Set dispatch + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ObjectStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(PrimStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ShapeStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TensorStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode); + return vtable; + } +}; + +#undef TVM_STRUCT_INFO_FUNCTOR_DISPATCH + +/*! + * \brief A struct info visitor. + */ +class TVM_DLL StructInfoVisitor : public StructInfoFunctor { + public: + void VisitStructInfo_(const ObjectStructInfoNode* op) override; + void VisitStructInfo_(const PrimStructInfoNode* op) override; + void VisitStructInfo_(const ShapeStructInfoNode* op) override; + void VisitStructInfo_(const TensorStructInfoNode* op) override; + void VisitStructInfo_(const TupleStructInfoNode* op) override; + void VisitStructInfo_(const FuncStructInfoNode* op) override; + + protected: + // two functions to override when visit expr fields in struct info. + virtual void VisitStructInfoExprField(const Expr& expr) {} + virtual void VisitStructInfoExprField(const PrimExpr& expr) {} +}; + +/*! + * \brief StructInfoMutator that mutates struct info. + */ +class TVM_DLL StructInfoMutator : public StructInfoFunctor { + public: + StructInfo VisitStructInfo_(const ObjectStructInfoNode* op) override; + StructInfo VisitStructInfo_(const PrimStructInfoNode* op) override; + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) override; + StructInfo VisitStructInfo_(const TensorStructInfoNode* op) override; + StructInfo VisitStructInfo_(const TupleStructInfoNode* op) override; + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) override; + + protected: + // two functions to override when visit expr fields in struct info. + virtual Expr VisitStructInfoExprField(const Expr& expr) { return expr; } + virtual PrimExpr VisitStructInfoExprField(const PrimExpr& expr) { return expr; } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 3c3fefb6d6c6..f90468de66c6 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -51,6 +51,17 @@ def checked_type(self): raise ValueError("The type checker has not populated" " the checked_type for this node") return ret + @property + def struct_info(self) -> "tvm.relax.StructInfo": + """Get the struct info field + + Returns + ------- + struct_info : tvm.relax.StructInfo + The struct info if available. + """ + return _ffi_api.ExprStructInfo(self) + @tvm._ffi.register_object("GlobalVar") class GlobalVar(RelayExpr): diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index c070fa479188..01310f6455dd 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -17,8 +17,56 @@ # pylint: disable=invalid-name, wrong-import-position """The Relax IR namespace containing the IR, type, operator, builder, vm, etc.""" from . import exec_builder +from . import expr +from . import ty +from . import analysis from . import vm +from . import struct_info + +# Expr +from .expr import ( + Expr, + Span, + SourceName, + Id, + GlobalVar, + Var, + DataflowVar, + Binding, + MatchCast, + VarBinding, + BindingBlock, + DataflowBlock, + SeqExpr, + ShapeExpr, + Tuple, + TupleGetItem, + Function, + ExternFunc, + Call, + If, + Constant, + PrimValue, + DataTypeImm, + StringImm, +) + +from .expr import const, extern, get_shape_of + +# Type +from .ty import Type, ObjectType, ShapeType, DynTensorType, TupleType, FuncType, PackedFuncType # VM from .exec_builder import ExecBuilder from .vm import VirtualMachine + +# StructInfo +from .struct_info import ( + StructInfo, + ObjectStructInfo, + PrimStructInfo, + ShapeStructInfo, + TensorStructInfo, + TupleStructInfo, + FuncStructInfo, +) diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py new file mode 100644 index 000000000000..cc0089ff3134 --- /dev/null +++ b/python/tvm/relax/analysis/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax IR analysis. """ + +from .analysis import * diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py new file mode 100644 index 000000000000..40ee05c3960d --- /dev/null +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs""" +import tvm._ffi + +tvm._ffi._init_api("relax.analysis", __name__) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py new file mode 100644 index 000000000000..301f3ecc7265 --- /dev/null +++ b/python/tvm/relax/analysis/analysis.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains the set of passes for Relax, which exposes an interface for +configuring the passes and scripting them in Python. +""" + +from typing import Dict +from enum import IntEnum + +from tvm import tir +from tvm.relax.ty import Type +from tvm.relax.struct_info import StructInfo +from tvm.relax.expr import Var, Expr +from . import _ffi_api + + +def get_static_type(sinfo: StructInfo) -> Type: + """Get the corresponding static type from a StructInfo. + + Parameters + ---------- + sinfo : StructInfo + The input struct info. + + Returns + ------- + ret : Type + The corresponding static type. + """ + return _ffi_api.GetStaticType(sinfo) # type: ignore + + +def erase_to_well_defined( + sinfo: StructInfo, + shape_var_map: Dict[tir.Var, tir.PrimExpr] = None, + var_map: Dict[Var, Expr] = None, +) -> StructInfo: + """Erase sinfo into a well defined form. + + This function removes the StructInfo's dependencies on shape and vars that + are not defined in given maps. + + Parameters + ---------- + sinfo : StructInfo + The input struct info. + + shape_var_map : Dict[tir.Var, tir.PrimExpr] + Specifies the defined shape vars and the values they should map to. + + var_map : Dict[Var, Expr] + Specifies the defined vars and the values they should map to. + + Returns + ------- + ret : StructInfo + The corresponding erased struct info. + """ + shape_var_map = {} if shape_var_map is None else shape_var_map + var_map = {} if var_map is None else var_map + + return _ffi_api.EraseToWellDefined(sinfo, shape_var_map, var_map) # type: ignore + + +class BaseCheckResult(IntEnum): + """Return result of fine-grained base check. + + Note + ---- + Base check comes with fine-grained fail levels. + + - FAIL_L0: The lhs and rhs have no intersection at all. + - FAIL_L1: We get the failure by looking at static information. + - FAIL_L2: We get the failure due to unknown symbolic variable relations. + """ + + FAIL_L0 = 0 + FAIL_L1 = 1 + FAIL_L2 = 2 + PASS = 3 + + +def struct_info_base_check(base: StructInfo, derived: StructInfo) -> BaseCheckResult: + """Run a base check to see if base subsumes derived. + + Parameters + ---------- + base: StructInfo + The base struct info. + + derived: StructInfo + The derived struct info. + + Returns + ------- + ret : StructInfo + The derived return value struct info. + """ + return _ffi_api.StructInfoBaseCheck(base, derived) # type: ignore + + +def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: + """Unify the two struct info to their least common ancestor. + + Parameters + ---------- + lhs: StructInfo + The left operand. + + rhs: StructInfo + The right operand. + + Returns + ------- + ret : StructInfo + The corresponding lca result. + """ + return _ffi_api.StructInfoLCA(lhs, rhs) # type: ignore diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py new file mode 100644 index 000000000000..138724ed0693 --- /dev/null +++ b/python/tvm/relax/expr.py @@ -0,0 +1,729 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import, super-init-not-called +# pylint: disable=redefined-builtin +"""The expression nodes of Relax.""" +import typing +from numbers import Number +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as _np # type: ignore +import tvm +import tvm._ffi +import tvm.relax +import tvm.ir +from tvm import DataType +from tvm._ffi import base as _base +from tvm.runtime import ndarray as _nd, Object + +from ..ir import BaseFunc, Node, SourceName, Span +from ..runtime import String +from ..tir import PrimExpr +from . import _ffi_api + +# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 +# This feature is not supported until python 3.10: +# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias +Expr = Union[tvm.ir.RelayExpr] +Type = Union[tvm.ir.Type] +GlobalVar = Union[tvm.ir.GlobalVar] + + +@tvm._ffi.register_object("relax.Id") +class Id(Object): + """Unique identifier(name) used in Var. + Guaranteed to be stable across all passes. + """ + + def __init__(self): + raise RuntimeError("Cannot directly construct Id") + + +# NOTE: place base struct info in expr to avoid cyclic dep +# from expr to struct info. +class StructInfo(Node): + """The base class of all StructInfo. + + StructInfo contains both the static type + and runtime structural information. + """ + + def __eq__(self, other): + """Compare two struct info for structural equivalence.""" + return tvm.ir.structural_equal(self, other) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Overload with structural equality.""" + return super().__eq__(other) + + def is_base_of(self, derived: "StructInfo") -> bool: + """Check if self is base of another derived struct info. + + Parameters + ---------- + derived : StructInfo + The derived struct info to be checked. + + Returns + ------- + result : bool + The check result. + """ + return _ffi_api.StructInfoIsBaseOf(self, derived) # type: ignore + + +# will be registered afterwards in python/tvm/relax/op/init.py +_op_ffi_api = None + + +def _binary_op_helper(lhs: "ExprWithOp", rhs: "ExprWithOp", op: Callable) -> "ExprWithOp": + if not isinstance(lhs, Expr): # type: ignore + raise ValueError("lhs must be Expr") + if isinstance(rhs, Expr): # type: ignore + return op(lhs, rhs) + elif isinstance(rhs, Number): + raise TypeError(f"Please convert {rhs} with `const` first") + else: + raise TypeError(f"type {type(rhs)} not supported") + + +def _binary_rhs_helper(rhs: "ExprWithOp") -> "ExprWithOp": + if isinstance(rhs, Number): + raise TypeError(f"Please convert {rhs} with `const` first") + raise TypeError(f"type {type(rhs)} not supported") + + +class ExprWithOp(Expr): + """Basetype of all relax expressions that defines op overloading.""" + + def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp": + """Cast the content type of the current data to dtype. + + Parameters + ---------- + dtype : str + The target data type. + + Note + ---- + This function only works for TensorType Exprs. + + Returns + ------- + result : ExprWithOp + The result expression. + """ + return _op_ffi_api.astype(self, dtype) # type: ignore + + def __neg__(self) -> "ExprWithOp": + raise ValueError("relax.negative is not supported yet.") + + def __lt__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.less) # type: ignore + + def __gt__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.greater) # type: ignore + + def __ge__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.greater_equal) # type: ignore + + def __le__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.less_equal) # type: ignore + + # NOTE: Cannot override __eq__ and __ne__, which will influence object equal + + def __add__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.add) # type: ignore + + def __radd__(self, other: Expr) -> "ExprWithOp": + return self.__add__(other) + + def __sub__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.subtract) # type: ignore + + def __rsub__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __mul__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.multiply) # type: ignore + + def __rmul__(self, other: Expr) -> "ExprWithOp": + return self.__mul__(other) + + def __truediv__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.divide) # type: ignore + + def __rtruediv__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __floordiv__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.floor_divide) # type: ignore + + def __rfloordiv__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __mod__(self, other: Expr) -> "ExprWithOp": + # TODO(siyuan): Support it after mod operator is supported in relax + raise ValueError("relax.mod is not supported yet.") + + def __rmod__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __call__(self, *args: List[Expr], attrs: Optional[Dict[str, Any]] = None) -> "ExprWithOp": + """Call the variable (if it represents a function). + + Parameters + ---------- + args: List[Expr] + The arguments to the call. + + attr: Optional[Dict[str, object]] + The additional attributes to the call. + + Returns + ------- + call: ExprWithOp + A call taking the variable as a function. + """ + return Call(self, args, attrs=attrs) + + def __getitem__(self, index: int) -> "ExprWithOp": + """Get the i-th element of the tuple or Expr with TupleType. + + Parameters + ---------- + index: int + The index of the element to be retrieved. + + Note + ---- + This function will be overridden by Tuple and ShapeExpr + + Returns + ------- + result: ExprWithOp + The result expression. + """ + return TupleGetItem(self, index) + + +@tvm._ffi.register_object("relax.expr.Call") +class Call(ExprWithOp): + """Function call node in Relax. + + Call node corresponds the operator application node + in computational graph terminology. + + Parameters + ---------- + op: tvm.ir.Op or any tvm.relax.Expr with function type. + The operation to be called. + + args: Union[List[Expr], typing.Tuple[Expr, ...]] + The arguments to the call. + + attrs: Optional[tvm.ir.Attrs] + Attributes to the call, can be None + + sinfo_args: Optional[Union[List[StructInfo], typing.Tuple[StructInfo, ...]]] + The structure info arguments of a CallNode. + sinfo_args is designed to be non-empty only for intrinsic op (e.g., + call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main + usage of structure info inference. + + span: Optional[Span] + Span that points to original source code + """ + + def __init__( + self, + op: Union[Expr, tvm.ir.Op], + args: Union[List[Expr], typing.Tuple[Expr, ...]], + attrs: Optional[tvm.ir.Attrs] = None, + sinfo_args: Optional[Union[List[StructInfo], typing.Tuple[StructInfo, ...]]] = None, + span: Optional[Span] = None, + ): + if not sinfo_args: + sinfo_args = [] + self.__init_handle_by_constructor__( + _ffi_api.Call, op, args, attrs, sinfo_args, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.If") +class If(ExprWithOp): + """A conditional expression in Relax. + + Parameters + ---------- + cond: Expr + The condition. + + true_branch: Expr + The expression evaluated when condition is true. + + false_branch: Expr + The expression evaluated when condition is false. + """ + + def __init__(self, cond: Expr, true_branch: Expr, false_branch: Expr, span: Span = None): + self.__init_handle_by_constructor__( + _ffi_api.If, cond, true_branch, false_branch, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.Tuple") +class Tuple(ExprWithOp): + """Tuple expression that groups several fields together. + + Parameters + ---------- + fields : Union[List[Expr], typing.Tuple[Expr, ...]] + The fields in the tuple. + + span: Optional[Span] + Span that points to original source code + """ + + def __init__(self, fields: Union[List[Expr], typing.Tuple[Expr, ...]], span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.Tuple, fields, span) # type: ignore + + def __getitem__(self, index: int) -> Expr: + if index >= len(self) or index < -len(self): + raise IndexError("Tuple index out of range") + return self.fields[index] + + def __len__(self) -> int: + return len(self.fields) + + +@tvm._ffi.register_object("relax.expr.TupleGetItem") +class TupleGetItem(ExprWithOp): + """Get index-th item from a tuple. + + Parameters + ---------- + tuple_value: Expr + The input tuple expression. + + index: int + The index. + """ + + def __init__(self, tuple_value: Expr, index: int): + self.__init_handle_by_constructor__( + _ffi_api.TupleGetItem, tuple_value, index # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.ShapeExpr") +class ShapeExpr(ExprWithOp): + """A shape expression which allows users to construct a shape containing PrimExpr.""" + + values: List[PrimExpr] + + def __init__( + self, + values: Union[List[PrimExpr], typing.Tuple[PrimExpr, ...], tvm.ir.Array], + span: Span = None, + ) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span) # type: ignore + + def __getitem__(self, index): + if index >= len(self) or index < -len(self): + raise IndexError("ShapeExpr index out of range") + return self.values[index] + + def __len__(self): + return len(self.values) + + +def make_shape(shape: Union[List[Any], typing.Tuple[Any, ...]]) -> ShapeExpr: + if isinstance(shape, (list, tuple)): + return ShapeExpr(shape) + raise ValueError("Wrong type") + + +@tvm._ffi.register_object("relax.expr.Constant") +class Constant(ExprWithOp): + def __init__(self, data: tvm.nd.NDArray, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.Constant, data, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.Var") +class Var(ExprWithOp): + """The variable class for all Relax bindings.""" + + vid: Id + struct_info: Optional[StructInfo] + + def __init__( + self, + name_hint: Union[str, Id], + struct_info: Optional[StructInfo] = None, + span: Span = None, + ) -> None: + if struct_info is not None: + struct_info = tvm.runtime.convert_to_object(struct_info) + if not isinstance(struct_info, StructInfo): + raise TypeError( + "struct_info needs to be an instance of StructInfo. " + "If you attempt to pass in shape, " + "use relax.TensorStructInfo(shape, dtype)." + ) + self.__init_handle_by_constructor__( + _ffi_api.Var if isinstance(name_hint, str) else _ffi_api.VarFromId, # type: ignore + name_hint, + struct_info, + span, + ) + + @property + def name_hint(self): + """Get name hint of the current var.""" + name = str(self.vid.name_hint) + return name + + +@tvm._ffi.register_object("relax.expr.DataflowVar") +class DataflowVar(Var): + """A sub-type of the variable node used to mark dataflow variables from + normal visible "function local" bindings.""" + + vid: Id + struct_info: Optional[StructInfo] + + def __init__( + self, + name_hint: Union[str, Id], + struct_info: Optional[StructInfo] = None, + span: Span = None, + ) -> None: + if struct_info is not None: + struct_info = tvm.runtime.convert_to_object(struct_info) + if not isinstance(struct_info, StructInfo): + raise TypeError( + "struct_info needs to be an instance of StructInfo. " + "If you attempt to pass in shape, " + "use relax.TensorStructInfo(shape, dtype)." + ) + + self.__init_handle_by_constructor__( + _ffi_api.DataflowVar # type: ignore + if isinstance(name_hint, str) + else _ffi_api.DataflowVarFromId, # type: ignore + name_hint, + struct_info, + span, + ) + + +@tvm._ffi.register_object("relax.expr.PrimValue") +class PrimValue(Expr): + """The prim expr representing the value.""" + + value: PrimExpr + + def __init__(self, value: Union[PrimExpr, int], span: Span = None) -> None: + if isinstance(value, int): + value = tvm.tir.IntImm("int64", value) + self.__init_handle_by_constructor__(_ffi_api.PrimValue, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.StringImm") +class StringImm(Expr): + """Represent a string literal constant.""" + + value: str + + def __init__(self, value: str, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.DataTypeImm") +class DataTypeImm(Expr): + """Represent a data type constant.""" + + value: DataType + + def __init__(self, value: Union[DataType, str], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DataTypeImm, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.Binding") +class Binding(Node): + """The base class of a binding in Relax.""" + + ... + + +@tvm._ffi.register_object("relax.expr.MatchCast") +class MatchCast(Binding): + """Runtime-match the value to the struct info. + + This operation does runtime check, populates the un-defined symbolic shape vars + and vars in struct_info in the first occurrence, and insert equality assertions in + other cases. + + Parameters + ---------- + var: Var + The return variable that the match cast bind to. + + value: Expr + The input value expression. + + struct_info: tvm.relax.StructInfo + The struct info to match cast to. + """ + + var: Var + struct_info: "tvm.relax.StructInfo" + value: Expr + + def __init__( + self, var: Var, value: Expr, struct_info: "tvm.relax.StructInfo", span: Span = None + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MatchCast, var, value, struct_info, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.VarBinding") +class VarBinding(Binding): + """Variable binding, bind he variable of the lhs with the rhs.""" + + var: Var + value: Expr + + def __init__(self, var: Var, value: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.BindingBlock") +class BindingBlock(Node): + """base class of binding block, bindings inside can be impure + (with side effect or control flow)""" + + bindings: List[Binding] + + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.DataflowBlock") +class DataflowBlock(BindingBlock): + """dataflow block, bindings inside are pure (no side effect and no control flow)""" + + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DataflowBlock, bindings, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.SeqExpr") +class SeqExpr(ExprWithOp): + """A sequence of binding blocks followed by an expression.""" + + blocks: List[BindingBlock] + body: Expr + + def __init__(self, blocks: List[BindingBlock], body: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.Function") +class Function(BaseFunc): + """A Relax function.""" + + params: List[Var] + body: Expr + ret_struct_info: StructInfo + attrs: Optional[tvm.ir.DictAttrs] + + def __init__( + self, + params: List[Var], + body: Expr, + ret_struct_info: Optional[StructInfo] = None, + attrs: Optional[tvm.ir.DictAttrs] = None, + span: Optional[Span] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.Function, params, body, ret_struct_info, attrs, span # type: ignore + ) + + @staticmethod + def create_empty( + params: List[Var], + ret_struct_info: StructInfo, + attrs: Optional[tvm.ir.DictAttrs] = None, + span: Optional[Span] = None, + ): + """Construct a relax.Function but without body""" + return _ffi_api.FunctionCreateEmpty(params, ret_struct_info, attrs, span) # type: ignore + + def __call__(self, *args): + """Invoke the global function. + + Parameters + ---------- + args: List[relax.Expr] + Arguments. + """ + return Call(self, args, None, None) + + def script(self, show_meta: bool = False) -> str: + """Print relax.Function into TVMScript + + Parameters + ---------- + show_meta : bool + Whether to show meta information + + Returns + ------- + script : str + The TVM Script of the relax.Function + """ + return tvm._ffi.get_global_func("script.AsRelaxScript")(self, show_meta) # type: ignore + + def show(self, style: str = "light") -> None: + """ + A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygments styles extended by "light" (default) and "dark", by default "light" + """ + from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel + + # Use deferred import to avoid circular import while keeping cprint under tvm/script + cprint(self, style=style) + + +@tvm._ffi.register_object("relax.expr.ExternFunc") +class ExternFunc(BaseFunc): + """extern function, which can represent a TIR PrimFunc or a PackedFunc.""" + + global_symbol: String + + def __init__(self, global_symbol: String, span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ExternFunc, global_symbol, span # type: ignore + ) + + +def extern(name: str, span: Span = None): + """Create extern function.""" + return ExternFunc(name, span) + + +def const( + value: Union[bool, int, float, _np.ndarray, tvm.nd.NDArray], dtype: Optional[str] = None +) -> Constant: + """Create a constant value. + + Parameters + ---------- + value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray] + The constant value. + + dtype: Optional[str] + The data type of the resulting constant. + + Note + ---- + When dtype is None, we use the following rule: + + - int maps to "int32" + - float maps to "float32" + - bool maps to "bool" + - other using the same default rule as numpy. + """ + if isinstance(value, (_base.numeric_types, (bool, list))): + value = _np.array(value, dtype=dtype) + + if not dtype: + # when dtype is None: int maps to "int32", float maps to "float32" + dtype = { # type: ignore + _np.dtype("int64"): _np.int32, # type: ignore + _np.dtype("float64"): _np.float32, # type: ignore + }.get( + value.dtype, None # type: ignore + ) + + if isinstance(value, (_np.ndarray, _np.generic)): + if dtype is not None: + value = value.astype(dtype) + value = _nd.array(value) + + if not isinstance(value, _nd.NDArray): + raise ValueError("value has to be scalar or NDArray") + + return Constant(value) + + +def te_tensor( + value: Expr, tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr], name: str = "rxplaceholder" +): + """Create a TE tensor from relax expression, with TIR variables in the + tensor shape substituted by the given mapping + + Parameters + ---------- + value : Expr + The relax expression, which is required to have TensorStructInfo. + + tir_var_map : Dict[tvm.tir.Var, tvm.tir.PrimExpr] + The mapping to substitute the TIR variables appeared in the + shape of the input Expr. + + name : str + The name of the created tensor. + """ + return _ffi_api.TETensor(value, tir_var_map, name) # type: ignore + + +def get_shape_of(expr: Expr) -> Expr: + """Get shape of expr. + + Parameters + ---------- + expr: Expr + The input expr. + + Returns + ------- + shape: Expr + The shape expression + + Note + ---- + This function requires expr to be normalized. + The function will report an error if expr's StructInfo is not TensorStructInfo. + It will try to return symbolic function when possible. If the tensor do not + have a compile-time symbolic shape, the function will then choose to return + `Call(relax.op.shape_of, [expr])`. + """ + return _ffi_api.GetShapeOf(expr) # type: ignore + + +def _update_struct_info(expr: Expr, struct_info: Optional[StructInfo]) -> None: + _ffi_api.UpdateStructInfo(expr, struct_info) # type: ignore diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py new file mode 100644 index 000000000000..2ff027b22924 --- /dev/null +++ b/python/tvm/relax/struct_info.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import +"""The struct info nodes of the Relax language.""" +from typing import List, Optional, Tuple, Union + +import tvm._ffi +import tvm + +from tvm.ir import Span, Node, EnvFunc, Array, Type +from tvm.tir import PrimExpr +from .expr import StructInfo, Var, Expr, ShapeExpr + +from . import _ffi_api, ty, expr + + +@tvm._ffi.register_object("relax.ObjectStructInfo") +class ObjectStructInfo(StructInfo): + """StructInfo of an Object.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore + + +@tvm._ffi.register_object("relax.PrimStructInfo") +class PrimStructInfo(StructInfo): + """StructInfo of a primitive POD value. + + Parameters + ---------- + dtype : str + The data type of the prim value. + """ + + dtype: str + + def __init__(self, dtype: str, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.PrimStructInfo, dtype, span) # type: ignore + + +@tvm._ffi.register_object("relax.ShapeStructInfo") +class ShapeStructInfo(StructInfo): + """StructInfo of a shape value. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + + Note + ---- + Do not specify values and ndim at the same time. + """ + + values: Optional[List[PrimExpr]] + ndim: int + span: Span + + def __init__( + self, values: Optional[List[PrimExpr]] = None, ndim: int = -1, span: Span = None + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ShapeStructInfo, values, ndim, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.TensorStructInfo") +class TensorStructInfo(StructInfo): + """StructInfo of a Tensor value. + + Parameters + ---------- + shape : Optional[Expr] + The shape expression. + + dtype : Optional[str] + The content data type. + + ndim : Optional[int] + The number of dimensions of the tensor. + + Note + ---- + Do not specify shape and ndim at the same time. + """ + + shape: Optional[Expr] + dtype: str + ndim: int + span: Span + + def __init__( + self, + shape: Union[Optional[Expr], List[PrimExpr]] = None, + dtype: str = "float32", + ndim: int = -1, + span: Span = None, + ) -> None: + if isinstance(shape, (list, tuple, Array)): + shape = ShapeExpr(shape) + + self.__init_handle_by_constructor__( + _ffi_api.TensorStructInfo, shape, dtype, ndim, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.TupleStructInfo") +class TupleStructInfo(StructInfo): + """StructInfo of a Tuple value. + + Parameters + ---------- + fields: List[StructInfo] + The struct info of the fields. + """ + + fields: List[StructInfo] + span: Span + + def __init__(self, fields: List[StructInfo], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.TupleStructInfo, fields, span) # type: ignore + + +@tvm._ffi.register_object("relax.FuncStructInfo") +class FuncStructInfo(StructInfo): + """StructInfo of a function value. + + Parameters + ---------- + params: List[StructInfo] + The struct info of the fields. + + ret: StructInfo + The struct info of return value + """ + + params: Optional[List[StructInfo]] + ret: StructInfo + derive_func: Optional[EnvFunc] + span: Span + + def __init__(self, params: List[StructInfo], ret: StructInfo, span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.FuncStructInfo, params, ret, span # type: ignore + ) + + @staticmethod + def opaque_func( + *, + ret: Optional[StructInfo] = None, + derive_func: Optional[EnvFunc] = None, + span: Span = None, + ) -> "FuncStructInfo": + """ + Create an opaque FuncStructInfo. + + The opaque function takes either a ret + that specificies the struct info of the return value + or a derive_func that provides a customized derivation rule. + + Parameters + ---------- + ret: Optional[StructInfo] + The struct info of the the function return value. + + derive_func: Optional[EnvFunc] + The environment function used for derivation + + span: Optional[Span] + Optional span information of the ast. + + Returns + ------- + info: FuncStructInfo + + Note + ---- + We cannot specify ret and derive_func simultaneously. + """ + return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, span) # type: ignore diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py new file mode 100644 index 000000000000..05492d6a9c34 --- /dev/null +++ b/python/tvm/relax/ty.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import +"""The type nodes of the Relax language.""" +import tvm._ffi +from tvm.ir import Type, TensorType, TupleType, FuncType, Span + +from . import _ffi_api + + +@tvm._ffi.register_object("relax.ShapeType") +class ShapeType(Type): + """The type of shape in Relax. + + Parameters + ---------- + ndim : Optional[int] + The size of the shape. + """ + + # TODO(relax-team): consider make ndim mandatory + def __init__(self, ndim: int = -1, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore + + +@tvm._ffi.register_object("relax.ObjectType") +class ObjectType(Type): + """A type that corresponds to tvm::runtime::Object, is base of all possible object + values in TVM.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) # type: ignore + + +@tvm._ffi.register_object("relax.DynTensorType") +class DynTensorType(Type): + """A dynamic tensor type in Relax. + + This is the type assigned to tensors with a known dtype and unknown shape. + + Parameters + ---------- + ndim : Optional[int] + The ndim of the Tensor + + dtype : Optional[str] + The content data type. + """ + + def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.DynTensorType, ndim, dtype, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.PackedFuncType") +class PackedFuncType(Type): + """The type of ExternFunc in Relax.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.PackedFuncType, span) # type: ignore diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 9283727ad41a..6d92c68367b3 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -18,3 +18,4 @@ from .parser import ir, ir_module from .parser import parse as from_source from .parser import tir +from .parser import relax diff --git a/python/tvm/script/parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py new file mode 100644 index 000000000000..feb8e683401c --- /dev/null +++ b/python/tvm/script/parser/relax/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Initial impl of relax parser for sugars""" +from tvm.relax import TensorStructInfo, ShapeStructInfo + +Tensor = TensorStructInfo +Shape = ShapeStructInfo diff --git a/src/ir/function.cc b/src/ir/function.cc index ce294708b2a9..69752f529a3c 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -22,6 +22,8 @@ * \brief The function data structure. */ #include +#include +#include #include #include @@ -35,13 +37,13 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); } - if (const auto* f = runtime::Registry::Get("relay.ir.FuncWithAttr")) { - if (Optional ret = (*f)(func, key, value)) { - return ret.value(); - } - } - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); }); } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index d965406e8bb0..b61a3df09107 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -25,9 +25,10 @@ #include namespace tvm { -PrimType::PrimType(runtime::DataType dtype) { +PrimType::PrimType(runtime::DataType dtype, Span span) { ObjectPtr n = make_object(); n->dtype = dtype; + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relax/analysis/shape_analysis.cc b/src/relax/analysis/shape_analysis.cc new file mode 100644 index 000000000000..70ce5ac06e90 --- /dev/null +++ b/src/relax/analysis/shape_analysis.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file shape_analysis.cc + * + * \brief Utilities for shape analysis. + */ + +#include +#include + +namespace tvm { +namespace relax { + +bool CanProveShapeEqual(const Array& lhs, const Array& rhs, + arith::Analyzer* ana) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!ana->CanProveEqual(lhs[i], rhs[i])) return false; + } + return true; +} + +bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana) { + if (lhs.same_as(rhs)) return true; + auto* lhs_shape = lhs.as(); + auto* rhs_shape = rhs.as(); + + if (lhs_shape && rhs_shape) { + return CanProveShapeEqual(lhs_shape->values, rhs_shape->values, ana); + } else { + return false; + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc new file mode 100644 index 000000000000..d9b139753455 --- /dev/null +++ b/src/relax/analysis/struct_info_analysis.cc @@ -0,0 +1,716 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file struct_info_analysis.cc + * \brief Implementations of foundation struct info analysis + * + * \note Update this file when you added a new StructInfo. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +//-------------------------- +// GetStaticType +//-------------------------- +class StaticTypeDeriver : public StructInfoFunctor { + public: + Type VisitStructInfo_(const ObjectStructInfoNode* op) final { return ObjectType(op->span); } + + Type VisitStructInfo_(const PrimStructInfoNode* op) final { + return PrimType(op->dtype, op->span); + } + + Type VisitStructInfo_(const ShapeStructInfoNode* op) final { + return ShapeType(op->ndim, op->span); + } + + Type VisitStructInfo_(const TensorStructInfoNode* op) final { + return DynTensorType(op->ndim, op->dtype); + } + + Type VisitStructInfo_(const TupleStructInfoNode* op) final { + Array fields = + op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + return TupleType(fields, op->span); + } + + Type VisitStructInfo_(const FuncStructInfoNode* op) final { + if (op->IsOpaque()) return PackedFuncType(op->span); + Array params = op->params.value().Map( + [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + Type ret = this->VisitStructInfo(op->ret); + return FuncType(params, ret, {}, {}, op->span); + } +}; + +Type GetStaticType(const StructInfo& info) { return StaticTypeDeriver()(info); } + +TVM_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const StructInfo& info) { + return GetStaticType(info); +}); + +//-------------------------- +// StructInfoFromType +//-------------------------- + +StructInfo StructInfoFromType(const Type& type) { + if (type.as()) { + return ObjectStructInfo(type->span); + } else if (const PrimTypeNode* prim_type = type.as()) { + return PrimStructInfo(prim_type->dtype, prim_type->span); + } else if (const ShapeTypeNode* shape_type = type.as()) { + return ShapeStructInfo(shape_type->ndim, type->span); + } else if (const DynTensorTypeNode* tensor_type = type.as()) { + return TensorStructInfo(tensor_type->dtype, tensor_type->ndim); + } else if (const TupleTypeNode* tuple_type = type.as()) { + Array fields; + for (const Type& field : tuple_type->fields) { + fields.push_back(StructInfoFromType(field)); + } + return TupleStructInfo(fields, type->span); + } else if (const FuncTypeNode* func_type = type.as()) { + Array params = + func_type->arg_types.Map([](const Type& param) { return StructInfoFromType(param); }); + StructInfo ret = StructInfoFromType(func_type->ret_type); + return FuncStructInfo(params, ret, func_type->span); + } else { + LOG(FATAL) << "Unsupported type: " << type; + return StructInfo(); + } +} + +//-------------------------- +// EraseToWellDefined +//-------------------------- +class WellDefinedEraser : public StructInfoMutator, + public ExprMutatorBase, + public tir::ExprMutator { + public: + WellDefinedEraser(std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) + : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} + + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) final { + bool has_undefined = false; + Optional> values; + + if (op->values.defined()) { + std::swap(has_undefined_, has_undefined); + values = op->values.value().Map([&](PrimExpr val) { return this->VisitPrimExpr(val); }); + std::swap(has_undefined_, has_undefined); + } + // erase symbolic shape if we have undefined. + if (!has_undefined) { + if (values.same_as(op->values)) { + return GetRef(op); + } else { + return ShapeStructInfo(values.value(), op->span); + } + } else { + return ShapeStructInfo(op->ndim, op->span); + } + } + + StructInfo VisitStructInfo_(const TensorStructInfoNode* op) final { + bool has_undefined = false; + Optional shape; + + if (op->shape.defined()) { + std::swap(has_undefined_, has_undefined); + shape = relax::ExprMutatorBase::VisitExpr(op->shape.value()); + std::swap(has_undefined_, has_undefined); + } + + // erase symbolic shape if we have undefined. + if (!has_undefined) { + if (shape.same_as(op->shape)) { + return GetRef(op); + } else { + if (shape.defined()) { + return TensorStructInfo(shape.value(), op->dtype, op->span); + } else { + return TensorStructInfo(op->dtype, op->ndim, op->span); + } + } + } else { + return TensorStructInfo(op->dtype, op->ndim, op->span); + } + } + + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final { + // NOTE: we always require func struct info to be well-defined. + // + // All the occuring symbolic variables are defined in parameters' + // struct info annotations. So there is no needed to erase. + return GetRef(op); + } + + using relax::ExprMutatorBase::VisitExpr_; + using tir::ExprMutator::VisitExpr_; + + // connect things up + PrimExpr VisitPrimExpr(const PrimExpr& expr) { + // apply eager simplification + PrimExpr val = tir::ExprMutator::VisitExpr(expr); + if (!val.same_as(expr)) { + return ana_->Simplify(val); + } else { + return val; + } + } + + Expr VisitExpr_(const DataflowVarNode* var) final { + return VisitExpr_(static_cast(var)); + } + + Expr VisitExpr_(const VarNode* var) final { + Optional ret; + if (f_var_map_ != nullptr) { + ret = f_var_map_(GetRef(var)); + } + has_undefined_ = has_undefined_ || !ret.defined(); + if (ret.defined()) { + ICHECK(ret.as() || ret.as()) + << "Only allow Expr in StructInfo to be ShapeExpr or Var"; + } + return ret.value_or(GetRef(var)); + } + + PrimExpr VisitExpr_(const tir::VarNode* var) final { + Optional ret; + if (f_shape_var_map_ != nullptr) { + ret = f_shape_var_map_(GetRef(var)); + } + has_undefined_ = has_undefined_ || !ret.defined(); + + if (ret.defined()) { + PrimExpr value = ret.value(); + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) << "Can only provide i64 expressions in shape"; + return value; + } else { + return GetRef(var); + } + } + + private: + bool has_undefined_ = false; + std::function(const tir::Var& var)> f_shape_var_map_; + std::function(const Var& var)> f_var_map_; + arith::Analyzer* ana_; +}; + +StructInfo EraseToWellDefined( + const StructInfo& info, std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return WellDefinedEraser(f_shape_var_map, f_var_map, &inst).VisitStructInfo(info); + } else { + return WellDefinedEraser(f_shape_var_map, f_var_map, ana).VisitStructInfo(info); + } +} + +StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, + Map var_map, arith::Analyzer* ana) { + std::function(const tir::Var& var)> f_shape_var_map = nullptr; + std::function(const Var& var)> f_var_map = nullptr; + + if (!shape_var_map.empty()) { + f_shape_var_map = [&](const tir::Var& var) -> Optional { + auto it = shape_var_map.find(var); + if (it != shape_var_map.end()) return (*it).second; + return NullOpt; + }; + } + + if (!var_map.empty()) { + f_var_map = [&](const Var& var) -> Optional { + auto it = var_map.find(var); + if (it != var_map.end()) return (*it).second; + return NullOpt; + }; + } + + return EraseToWellDefined(info, f_shape_var_map, f_var_map, ana); +} + +TVM_REGISTER_GLOBAL("relax.analysis.EraseToWellDefined") + .set_body_typed([](const StructInfo& info, Map shape_var_map, + Map var_map) { + return EraseToWellDefined(info, shape_var_map, var_map); + }); + +//-------------------------- +// IsBaseOf +//-------------------------- +class StructInfoBaseChecker + : public StructInfoFunctor { + public: + explicit StructInfoBaseChecker(arith::Analyzer* ana) : analyzer_(ana) {} + + BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { + // quick path + // Note: subclass may disable this quick path if we need to go over all struct info. + if (lhs.same_as(other)) return BaseCheckResult::kPass; + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + // Object is base of everything + BaseCheckResult VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + return BaseCheckResult::kPass; + } + + BaseCheckResult VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + return lhs->dtype == rhs->dtype ? BaseCheckResult::kPass : BaseCheckResult::kFailL0; + } + + BaseCheckResult VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // lhs have unknown ndim + if (lhs->IsUnknownNdim()) return BaseCheckResult::kPass; + + // ndim must match + if (lhs->ndim != rhs->ndim) { + if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // lhs does not have symbolic value + if (!lhs->values.defined()) return BaseCheckResult::kPass; + // rhs does not have symbolic value but lhs do. + if (!rhs->values.defined()) return BaseCheckResult::kFailL2; + + // shape match check + return ShapeMatchCheck(lhs->values.value(), rhs->values.value()); + } + + BaseCheckResult VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // dtype mismatch + if (!lhs->IsUnknownDtype() && lhs->dtype != rhs->dtype) { + if (rhs->IsUnknownDtype()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // ndim msiamtch + if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) { + if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // lhs does not have defined shape and everything else matches + if (!lhs->shape.defined()) return BaseCheckResult::kPass; + // rhs does not have symbolic value but lhs don't + if (!rhs->shape.defined()) return BaseCheckResult::kFailL2; + + // shape match check + return ShapeMatchCheck(lhs->shape.value(), rhs->shape.value()); + } + + BaseCheckResult VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + return ArrayCheck(lhs->fields, rhs->fields); + } + + BaseCheckResult VisitStructInfo_(const FuncStructInfoNode* lhs, + const StructInfo& other) override { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // lhs opaque handling + if (lhs->IsOpaque()) { + if (lhs->derive_func.defined()) { + // function proving is best effort. + return lhs->derive_func.same_as(rhs->derive_func) ? BaseCheckResult::kPass + : BaseCheckResult::kFailL2; + } + // no derivation function, only depends on ret + return this->VisitStructInfo(lhs->ret, rhs->ret); + } + + // Function check is best effort. + // rhs is opaque but lhs is not + if (rhs->IsOpaque()) return BaseCheckResult::kFailL2; + + // NOTE: lhs->params, rhs->params may contain different symbolic + // vars that needs to be re-mapped to each other. + // This can only be done through structural equality check and not ArrayCheck. + // + // So we check structural equality here and if two are structurally + // equal return true. + // + // otherwise we do best effort BaseArrayCheck. + // + // This still does not handle cases where some arguments are sub of another + // while other parameters needs to get remapped. + // + // Given we only do best effort checking in these cases, and such cases + // are likely not a primary concern atm, we take this approach here. + if (struct_equal_(GetRef(lhs), other)) return BaseCheckResult::kPass; + + auto param_check = FuncParamsCheck(lhs->params.value(), rhs->params.value()); + auto ret_check = this->VisitStructInfo(lhs->ret, rhs->ret); + return CombineCheck(param_check, ret_check); + } + + protected: + // analyzer + arith::Analyzer* analyzer_; + // struct equal checker + StructuralEqual struct_equal_; + + // customizable functions. + /*! + * \brief Check symbolic shape value equivalence. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return CheckResult. + */ + virtual BaseCheckResult PrimValueMatchCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + // get static shape checking right. + auto* int_lhs = lhs.as(); + auto* int_rhs = rhs.as(); + if (int_lhs && int_rhs) { + if (int_lhs->value == int_rhs->value) { + return BaseCheckResult::kPass; + } else { + return BaseCheckResult::kFailL0; + } + } + return analyzer_->CanProveEqual(lhs, rhs) ? BaseCheckResult::kPass : BaseCheckResult::kFailL2; + } + /*! + * \brief CheckShape value. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return CheckResult. + */ + virtual BaseCheckResult ShapeMatchCheck(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; + + BaseCheckResult ret = BaseCheckResult::kPass; + for (size_t i = 0; i < lhs.size(); ++i) { + auto cmp_ret = PrimValueMatchCheck(lhs[i], rhs[i]); + if (ret == BaseCheckResult::kFailL0) return ret; + ret = CombineCheck(cmp_ret, ret); + } + return ret; + } + + /*! + * \brief CheckShape value. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return Check result. + */ + virtual BaseCheckResult ShapeMatchCheck(const Expr& lhs, const Expr& rhs) { + if (lhs.same_as(rhs)) return BaseCheckResult::kPass; + auto* lhs_shape = lhs.as(); + auto* rhs_shape = rhs.as(); + if (lhs_shape && rhs_shape) { + return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); + } else { + return BaseCheckResult::kFailL2; + } + } + + /*! + * \brief CheckShape function parameters. + * \param lhs The left hand params. + * \param rhs The right hand params. + * \return Check result. + */ + virtual BaseCheckResult FuncParamsCheck(const Array& lhs, + const Array& rhs) { + auto res = ArrayCheck(lhs, rhs); + // treat L1 failures in params checking as L2. + if (res == BaseCheckResult::kFailL1) res = BaseCheckResult::kFailL2; + return res; + } + // helper functions + /*! + * \brief Combine check results. + * \param lhs The left operand. + * \param rhs The righr operand. + * \return The check result. + */ + static BaseCheckResult CombineCheck(BaseCheckResult lhs, BaseCheckResult rhs) { + if (lhs == BaseCheckResult::kFailL0 || rhs == BaseCheckResult::kFailL0) { + return BaseCheckResult::kFailL0; + } + if (lhs == BaseCheckResult::kFailL1 || rhs == BaseCheckResult::kFailL1) { + return BaseCheckResult::kFailL1; + } + if (lhs == BaseCheckResult::kFailL2 || rhs == BaseCheckResult::kFailL2) { + return BaseCheckResult::kFailL2; + } + return BaseCheckResult::kPass; + } + + /*! + * \brief Generic helper function to check arrays. + * \param lhs The left operand. + * \param rhs The right operand. + */ + BaseCheckResult ArrayCheck(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; + BaseCheckResult ret = BaseCheckResult::kPass; + + for (size_t i = 0; i < lhs.size(); ++i) { + auto cmp_ret = this->VisitStructInfo(lhs[i], rhs[i]); + if (ret == BaseCheckResult::kFailL0) return ret; + ret = CombineCheck(cmp_ret, ret); + } + return ret; + } +}; + +BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return StructInfoBaseChecker(&inst)(base, derived); + } else { + return StructInfoBaseChecker(ana)(base, derived); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.StructInfoBaseCheck") + .set_body_typed([](const StructInfo& base, const StructInfo& derived) -> int { + return static_cast(StructInfoBaseCheck(base, derived)); + }); + +bool IsBaseOf(const StructInfo& base, const StructInfo& derived, arith::Analyzer* ana) { + return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; +} + +TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") + .set_body_typed([](const StructInfo& base, const StructInfo& derived) { + return IsBaseOf(base, derived); + }); + +//-------------------------- +// UnifyToLCA +//-------------------------- +class StructInfoLCAFinder + : public StructInfoFunctor { + public: + explicit StructInfoLCAFinder(arith::Analyzer* ana) : analyzer_(ana) {} + + StructInfo VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { + // quick path + if (lhs.same_as(other)) return lhs; + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + // Object is based of everything, unify to object. + StructInfo VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + return GetRef(lhs); + } + + StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + if (lhs->dtype == rhs->dtype) return GetRef(lhs); + // PrimType will be treated as their boxed(object) values + // as a result we can unify to object. + return ObjectStructInfo(lhs->span); + } + + StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; + if (lhs->ndim != rhs->ndim || !lhs->values.defined() || !rhs->values.defined() || + !CanProveShapeEqual(lhs->values.value(), rhs->values.value(), analyzer_)) { + // prefers return same when possible + if (!lhs->values.defined() && lhs->ndim == ndim) { + return GetRef(lhs); + } else { + return ShapeStructInfo(ndim, lhs->span); + } + } + // equals to each other + return GetRef(lhs); + } + + StructInfo VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + // find the target dtype and ndim. + DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void(); + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; + // if ndim mismatch or one side of shape is missing + // then we cannot keep in symbolic shape + if (lhs->ndim != rhs->ndim || !lhs->shape.defined() || !rhs->shape.defined() || + !CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), analyzer_)) { + // reuse lhs when possible + if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim) { + return GetRef(lhs); + } else { + return TensorStructInfo(dtype, ndim, lhs->span); + } + } + // symbolic shape match but dtype mismatch + if (lhs->dtype != dtype) { + return TensorStructInfo(lhs->shape.value(), dtype, lhs->span); + } else { + return GetRef(lhs); + } + } + + StructInfo VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + Optional> fields = UnifyArray(lhs->fields, rhs->fields); + // tuple length not the same. + if (!fields.defined()) return ObjectStructInfo(lhs->span); + + // same length tuple. + if (!fields.same_as(lhs->fields)) { + return TupleStructInfo(fields.value(), lhs->span); + } else { + return GetRef(lhs); + } + } + + StructInfo VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + // lhs opaque handling + if (lhs->IsOpaque()) { + if (lhs->derive_func.defined()) { + if (lhs->derive_func.same_as(rhs->derive_func)) { + return GetRef(lhs); + } else { + // Create a new opaque with object return + return FuncStructInfo::OpaqueFunc(ObjectStructInfo(), lhs->span); + } + } else { + // no derivation function, only depends on ret + StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); + if (ret.same_as(lhs->ret)) return GetRef(lhs); + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } + } + // rhs is opaque, lhs is not + if (rhs->IsOpaque()) { + // unify ret value, note that rhs's ret is context free(because it is opaque) + // so result of the unify is also context-free. + StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } + + // Both lhs and rhs are not opaque + // NOTE: lhs->params, rhs->params may contain different symbolic + // vars that needs to be re-mapped to each other. + // This can only be done through structural equality check. + // + // So we check structural equality here and if two are structurally + // equal return true. + // + // otherwise we do best effort of unify types without considering var remap. + // + // This still does not handle cases where some arguments are sub of another + // while other parameters needs to get remapped. + // + // Given we only do best effort checking in these cases, and such cases + // are likely not a primary concern atm, we take this approach here. + if (struct_equal_(GetRef(lhs), GetRef(rhs))) { + return GetRef(lhs); + } + + auto params = UnifyArray(lhs->params.value(), rhs->params.value()); + auto ret = this->VisitStructInfo(lhs->ret, rhs->ret); + + if (params.same_as(lhs->params) && ret.same_as(lhs->ret)) { + return GetRef(lhs); + } else { + // fail to unify the params + if (!params.defined()) { + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } else { + return FuncStructInfo(params.value(), ret, lhs->span); + } + } + } + + private: + // analyzer + arith::Analyzer* analyzer_; + // struct equal checker + StructuralEqual struct_equal_; + + // check arrays + Optional> UnifyArray(const Array& lhs, + const Array& rhs) { + if (lhs.same_as(rhs)) return lhs; + if (lhs.size() != rhs.size()) return NullOpt; + size_t index = 0; + return lhs.Map([&](const StructInfo& a) { return this->VisitStructInfo(a, rhs[index++]); }); + } +}; + +StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return StructInfoLCAFinder(&inst)(lhs, rhs); + } else { + return StructInfoLCAFinder(ana)(lhs, rhs); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.StructInfoLCA") + .set_body_typed([](const StructInfo& lhs, const StructInfo& rhs) { + return StructInfoLCA(lhs, rhs); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc new file mode 100644 index 000000000000..45868a488a36 --- /dev/null +++ b/src/relax/ir/expr.cc @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using tvm::ReprPrinter; +using tvm::runtime::Optional; + +TVM_REGISTER_NODE_TYPE(IdNode); + +Id::Id(String name_hint) { + ObjectPtr n = make_object(); + n->name_hint = std::move(name_hint); + data_ = std::move(n); +} + +Call::Call(Expr op, Array args, Attrs attrs, Array sinfo_args, Span span) { + ObjectPtr n = make_object(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->sinfo_args = std::move(sinfo_args); + n->span = std::move(span); + data_ = std::move(n); +} + +Call WithFields(Call call, Optional opt_op, Optional> opt_args, + Optional opt_attrs, Optional> opt_sinfo_args, + Optional opt_span) { + // Collect new values for fields. + Expr op = opt_op.value_or(call->op); + Array args = opt_args.value_or(call->args); + Attrs attrs = opt_attrs.value_or(call->attrs); + Array sinfo_args = opt_sinfo_args.value_or(call->sinfo_args); + Span span = opt_span.value_or(call->span); + + // Check if anything changed. + bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); + if (unchanged) { + if (args.size() == call->args.size()) { + for (size_t i = 0; i < args.size(); i++) { + unchanged &= args[i].same_as(call->args[i]); + } + } else { + unchanged = false; + } + } + if (unchanged) { + if (sinfo_args.size() == call->sinfo_args.size()) { + for (size_t i = 0; i < sinfo_args.size(); i++) { + unchanged &= sinfo_args[i].same_as(call->sinfo_args[i]); + } + } else { + unchanged = false; + } + } + + if (!unchanged) { + // If call is only references, update it in place. Otherwise copy and update. + CallNode* cow_call_node = call.CopyOnWrite(); + cow_call_node->op = op; + cow_call_node->args = args; + cow_call_node->attrs = attrs; + cow_call_node->sinfo_args = sinfo_args; + cow_call_node->span = span; + } + return call; +} + +TVM_REGISTER_NODE_TYPE(CallNode); + +TVM_REGISTER_GLOBAL("relax.Call") + .set_body_typed([](Expr op, Array args, Attrs attrs, Array sinfo_args, + Span span) { return Call(op, args, attrs, sinfo_args, span); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " + << node->sinfo_args << ")"; + }); + +If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { + ObjectPtr n = make_object(); + n->cond = std::move(cond); + n->true_branch = std::move(true_branch); + n->false_branch = std::move(false_branch); + n->span = std::move(span); + data_ = std::move(n); +} + +If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branch, + Optional opt_false_branch, Optional opt_span) { + Expr cond = opt_cond.value_or(if_expr->cond); + Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); + Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); + Span span = opt_span.value_or(if_expr->span); + + bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && + false_branch.same_as(if_expr->false_branch) && span.same_as(if_expr->span); + + if (!unchanged) { + IfNode* cow_if_node = if_expr.CopyOnWrite(); + cow_if_node->cond = cond; + cow_if_node->true_branch = true_branch; + cow_if_node->false_branch = false_branch; + cow_if_node->span = span; + } + return if_expr; +} + +TVM_REGISTER_NODE_TYPE(IfNode); + +TVM_REGISTER_GLOBAL("relax.If") + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { + return If(cond, true_branch, false_branch, span); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " + << node->false_branch << ")"; + }); + +Tuple::Tuple(tvm::Array fields, Span span) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleNode); + +TVM_REGISTER_GLOBAL("relax.Tuple").set_body_typed([](tvm::Array fields, Span span) { + return Tuple(fields, span); +}); + +Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_span) { + Array fields = opt_fields.value_or(tuple->fields); + Span span = opt_span.value_or(tuple->span); + + bool all_fields_unchanged = true; + if (fields.size() == tuple->fields.size()) { + for (size_t i = 0; i < fields.size(); i++) { + all_fields_unchanged &= fields[i].same_as(tuple->fields[i]); + } + } else { + all_fields_unchanged = false; + } + + all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span); + if (!all_fields_unchanged) { + TupleNode* cow_tuple_node = tuple.CopyOnWrite(); + cow_tuple_node->fields = fields; + cow_tuple_node->span = span; + } + return tuple; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Tuple(" << node->fields << ")"; + }); + +TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { + ObjectPtr n = make_object(); + n->tuple = std::move(tuple); + n->index = index; + n->span = std::move(span); + data_ = std::move(n); +} + +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, + Optional opt_index, Optional opt_span) { + Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); + Integer index = opt_index.value_or(tuple_get_item->index); + Span span = opt_span.value_or(tuple_get_item->span); + + bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && + span.same_as(tuple_get_item->span); + if (!unchanged) { + TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); + cow_tuple_get_item_node->tuple = tuple; + cow_tuple_get_item_node->index = index.IntValue(); + cow_tuple_get_item_node->span = span; + } + return tuple_get_item; +} + +TVM_REGISTER_NODE_TYPE(TupleGetItemNode); + +TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index) { + return TupleGetItem(tuple, index); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; + }); + +TVM_REGISTER_NODE_TYPE(ShapeExprNode); + +ShapeExpr::ShapeExpr(Array values, Span span) { + ObjectPtr n = make_object(); + + n->values = values.Map([](PrimExpr value) { + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeStructInfo can only have dtype of int64"; + return value; + }); + n->span = span; + n->checked_type_ = ShapeType(values.size()); + n->struct_info_ = ShapeStructInfo(values, span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { + return ShapeExpr(values, span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const ShapeExprNode* node = static_cast(ref.get()); + p->stream << "ShapeExpr("; + for (auto it = node->values.begin(); it != node->values.end(); it++) { + if (it != node->values.begin()) { + p->stream << ", "; + } + p->stream << *it; + } + p->stream << ")"; + }); + +TVM_REGISTER_NODE_TYPE(VarNode); + +Var::Var(Id vid, Optional struct_info_annotation, Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + if (struct_info_annotation) { + n->checked_type_ = GetStaticType(struct_info_annotation.value()); + } + n->struct_info_ = std::move(struct_info_annotation); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.Var") + .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { + return Var(name_hint, struct_info_annotation, span); + }); + +TVM_REGISTER_GLOBAL("relax.VarFromId") + .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { + return Var(vid, struct_info_annotation, span); + }); + +TVM_REGISTER_NODE_TYPE(DataflowVarNode); + +DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + if (struct_info_annotation) { + n->checked_type_ = GetStaticType(struct_info_annotation.value()); + } + n->struct_info_ = std::move(struct_info_annotation); + n->span = std::move(span); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowVar") + .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { + return DataflowVar(name_hint, struct_info_annotation, span); + }); + +TVM_REGISTER_GLOBAL("relax.DataflowVarFromId") + .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { + return DataflowVar(vid, struct_info_annotation, span); + }); + +Constant::Constant(runtime::NDArray data, Span span) { + ObjectPtr n = make_object(); + n->data = std::move(data); + n->span = std::move(span); + + // set struct info. + Array values; + auto shape_tuple = n->data.Shape(); + for (size_t dim = 0; dim < shape_tuple.size(); ++dim) { + values.push_back(IntImm(DataType::Int(64), shape_tuple[dim])); + } + TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), span); + + n->struct_info_ = tinfo; + n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ConstantNode); + +TVM_REGISTER_GLOBAL("relax.Constant").set_body_typed([](runtime::NDArray data, Span span = Span()) { + return Constant(data, span); +}); + +PrimValue::PrimValue(PrimExpr value, Span span) { + ObjectPtr n = make_object(); + n->checked_type_ = PrimType(value.dtype()); + n->struct_info_ = PrimStructInfo(value.dtype()); + n->value = std::move(value); + n->span = std::move(span); + data_ = std::move(n); +} + +PrimValue PrimValue::Int64(int64_t value, Span span) { + return PrimValue(IntImm(DataType::Int(64), value), span); +} + +TVM_REGISTER_NODE_TYPE(PrimValueNode); + +TVM_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span span) { + return PrimValue(value, span); +}); + +StringImm::StringImm(String value, Span span) { + ObjectPtr n = make_object(); + n->value = std::move(value); + n->span = std::move(span); + // use the base structinfo for now + // we can choose to introduce more fine-grained struct info later if necessary. + n->checked_type_ = ObjectType(); + n->struct_info_ = ObjectStructInfo(); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(StringImmNode); + +TVM_REGISTER_GLOBAL("relax.StringImm").set_body_typed([](String value, Span span) { + return StringImm(value, span); +}); + +DataTypeImm::DataTypeImm(DataType value, Span span) { + ObjectPtr n = make_object(); + n->value = std::move(value); + n->span = std::move(span); + // use the base structinfo for now + // we can choose to introduce more fine-grained struct info later if necessary. + n->checked_type_ = ObjectType(); + n->struct_info_ = ObjectStructInfo(); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DataTypeImmNode); + +TVM_REGISTER_GLOBAL("relax.DataTypeImm").set_body_typed([](DataType value, Span span) { + return DataTypeImm(value, span); +}); + +TVM_REGISTER_NODE_TYPE(MatchCastNode); + +MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { + ObjectPtr n = make_object(); + ICHECK(var.defined()) << "MatchCast requires var to be defined"; + n->var = std::move(var); + n->value = std::move(value); + n->struct_info = std::move(struct_info); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.MatchCast") + .set_body_typed([](Var var, Expr value, StructInfo struct_info, Span span) { + return MatchCast(var, value, struct_info, span); + }); + +TVM_REGISTER_NODE_TYPE(VarBindingNode); + +VarBinding::VarBinding(Var var, Expr value, Span span) { + ObjectPtr n = make_object(); + n->var = std::move(var); + n->value = std::move(value); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { + return VarBinding(var, value, span); +}); + +TVM_REGISTER_NODE_TYPE(BindingBlockNode); + +BindingBlock::BindingBlock(Array bindings, Span span) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { + return BindingBlock(bindings, span); +}); + +TVM_REGISTER_NODE_TYPE(DataflowBlockNode); + +DataflowBlock::DataflowBlock(Array bindings, Span span) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array bindings, Span span) { + return DataflowBlock(bindings, span); +}); + +TVM_REGISTER_NODE_TYPE(SeqExprNode); + +SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { + ObjectPtr n = make_object(); + n->blocks = std::move(blocks); + n->body = std::move(body); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.SeqExpr") + .set_body_typed([](Array blocks, Expr body, Span span) { + return SeqExpr(blocks, body, span); + }); + +TVM_REGISTER_NODE_TYPE(FunctionNode); + +Function::Function(Array params, Expr body, Optional ret_struct_info, + DictAttrs attrs, Span span) { + // Set the function type. + // For function, we take a conservative approach and require the function type + // to be known at construction time. + Array param_sinfo; + + for (const Var& param : params) { + CHECK(param->struct_info_.defined()) + << "relax.Function requires params to contain struct_info_"; + param_sinfo.push_back(GetStructInfo(param)); + } + + Optional body_sinfo; + + if (body->struct_info_.defined()) { + body_sinfo = GetStructInfo(body); + } + + if (ret_struct_info.defined()) { + // allow body to override ret if body is more fine-grained. + if (body_sinfo.defined()) { + if (IsBaseOf(ret_struct_info.value(), body_sinfo.value())) { + ret_struct_info = body_sinfo; + } + } + } else { + CHECK(body_sinfo.defined()) + << "Function do not have a return signature and body is not normalized"; + ret_struct_info = body_sinfo; + } + + FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value()); + + // set the fields + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + n->ret_struct_info = std::move(ret_struct_info.value()); + n->checked_type_ = GetStaticType(func_sinfo); + n->struct_info_ = std::move(func_sinfo); + n->attrs = std::move(attrs); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.Function") + .set_body_typed([](Array params, Expr body, Optional ret_struct_info, + DictAttrs attrs, + Span span) { return Function(params, body, ret_struct_info, attrs, span); }); + +Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, DictAttrs attrs, + Span span) { + Array param_sinfo; + for (const Var& param : params) { + ICHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_."; + param_sinfo.push_back(GetStructInfo(param)); + } + FuncStructInfo finfo(param_sinfo, ret_struct_info); + + // set the fields + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = Expr(); + n->checked_type_ = GetStaticType(finfo); + n->struct_info_ = std::move(finfo); + n->ret_struct_info = std::move(ret_struct_info); + n->attrs = std::move(attrs); + n->span = std::move(span); + return Function(std::move(n)); +} + +TVM_REGISTER_GLOBAL("relax.FunctionCreateEmpty") + .set_body_typed([](Array params, StructInfo ret_struct_info, DictAttrs attrs, Span span) { + return Function::CreateEmpty(params, ret_struct_info, attrs, span); + }); + +// Special opaque derivation function for ExternFunc +// Take look at sinfo_args to figure out the return StructInfo. +TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_sinfo_args") + .set_body_typed([](const Call& call, const BlockBuilder& ctx) -> StructInfo { + ICHECK(call->sinfo_args.defined()) << "sinfo_args field of CallNode should always be defined"; + if (call->sinfo_args.empty()) { + return ObjectStructInfo(); + } else if (call->sinfo_args.size() == 1) { + return call->sinfo_args[0]; + } else { + return TupleStructInfo(call->sinfo_args); + } + }); + +// Get the derive function. +FuncStructInfo GetExternFuncStructInfo() { + EnvFunc fn = EnvFunc::Get("tvm.relax.struct_info.infer_by_sinfo_args"); + StructInfoDeriveFunc derive; + derive = fn; + return FuncStructInfo::OpaqueFunc(derive); +} + +TVM_REGISTER_NODE_TYPE(ExternFuncNode); + +ExternFunc::ExternFunc(String global_symbol, Span span) { + ObjectPtr n = make_object(); + n->global_symbol = std::move(global_symbol); + n->span = span; + static auto sinfo = GetExternFuncStructInfo(); + n->struct_info_ = sinfo; + n->checked_type_ = GetStaticType(sinfo); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, Span span) { + return ExternFunc(global_symbol, span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "ExternFunc(\"" << node->global_symbol << "\")"; + }); + +Expr GetShapeOf(const Expr& expr) { + // default case, to be normalized. + ICHECK(expr->struct_info_.defined()) << "GetShapeOf can only be applied to normalized expr"; + auto* tinfo = GetStructInfoAs(expr); + + ICHECK(tinfo != nullptr) << "ShapeOf can only be applied to expr with TensorStructInfo"; + if (tinfo->shape.defined()) return tinfo->shape.value(); + + static const Op& op = Op::Get("relax.shape_of"); + // default case, call shape of, eagerly normalize the expr. + relax::Call call_shape_of(op, {expr}, {}, {}); + UpdateStructInfo(call_shape_of, ShapeStructInfo(tinfo->ndim)); + return call_shape_of; +} + +TVM_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { + return GetShapeOf(expr); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc new file mode 100644 index 000000000000..048de7950f97 --- /dev/null +++ b/src/relax/ir/expr_functor.cc @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/expr_functor.cc + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ +#include +#include +#include +#include + +// functions to be overriden. +#define RELAX_VISIT_BINDING_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const ObjectRef& n, TSelf* self, const VarBindingNode* binding) { \ + self->VisitBinding_(binding, static_cast(n.get())); \ + }); + +#define RELAX_VAR_BINDING_DISPATCH_IMPL(Type) \ + Type::VisitBindingVTable Type::InitVisitBindingVTable() { \ + VisitBindingVTable vtable; \ + RELAX_VISIT_BINDING_DISPATCH(ConstantNode); \ + RELAX_VISIT_BINDING_DISPATCH(TupleNode); \ + RELAX_VISIT_BINDING_DISPATCH(VarNode); \ + RELAX_VISIT_BINDING_DISPATCH(DataflowVarNode); \ + RELAX_VISIT_BINDING_DISPATCH(ShapeExprNode); \ + RELAX_VISIT_BINDING_DISPATCH(ExternFuncNode); \ + RELAX_VISIT_BINDING_DISPATCH(GlobalVarNode); \ + RELAX_VISIT_BINDING_DISPATCH(FunctionNode); \ + RELAX_VISIT_BINDING_DISPATCH(CallNode); \ + RELAX_VISIT_BINDING_DISPATCH(SeqExprNode); \ + RELAX_VISIT_BINDING_DISPATCH(IfNode); \ + RELAX_VISIT_BINDING_DISPATCH(OpNode); \ + RELAX_VISIT_BINDING_DISPATCH(TupleGetItemNode); \ + RELAX_VISIT_BINDING_DISPATCH(PrimValueNode); \ + RELAX_VISIT_BINDING_DISPATCH(StringImmNode); \ + RELAX_VISIT_BINDING_DISPATCH(DataTypeImmNode); \ + return vtable; \ + } \ + void Type::VisitBinding_(const VarBindingNode* binding) { \ + static VisitBindingVTable vtable = InitVisitBindingVTable(); \ + const Expr& value = binding->value; \ + ICHECK(value.defined()) << "Found null pointer node while traversing AST."; \ + ICHECK(vtable.can_dispatch(value)) \ + << "VisitVarBinding do not allow binding value type" << value->GetTypeKey(); \ + vtable(value, this, binding); \ + } + +// functions to be overriden. +#define RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OP) \ + void ExprVisitor::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ + this->VisitExpr(binding->value); \ + this->VisitVarDef(binding->var); \ + } + +// functions to be overriden. +#define RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(OP) \ + void ExprMutator::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ + Expr new_value = this->VisitExpr(binding->value); \ + this->ReEmitBinding(binding, new_value); \ + } + +namespace tvm { +namespace relax { + +// ================== +// ExprVisitor + +void ExprVisitor::VisitExprDepStructInfoField(const StructInfo& struct_info) { + // recurse into struct info in case they depend on value + // under the current scope. + default_struct_info_field_visitor_.VisitStructInfo(struct_info); +} + +ExprVisitor::DefaultStructInfoFieldVisitor::DefaultStructInfoFieldVisitor(ExprVisitor* parent) + : parent_(parent) {} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const Expr& expr) { + parent_->VisitExpr(expr); +} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const PrimExpr& expr) { + parent_->VisitPrimExpr(expr); +} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { + // Do not recurse into function struct info + // as they won't contain ref to values in current scope. +} + +void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); } + +void ExprVisitor::VisitExpr_(const ConstantNode* op) { + this->VisitSpan(op->span); + // Constant's StructInfo does not depend on Expr. +} + +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { + this->VisitSpan(op->span); + // FuncStructInfo is not value-dep +} + +void ExprVisitor::VisitExpr_(const TupleNode* op) { + this->VisitSpan(op->span); + for (Expr field : op->fields) { + this->VisitExpr(field); + } + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +// Visit the use-site of a defined Var +void ExprVisitor::VisitExpr_(const VarNode* op) { + this->VisitSpan(op->span); + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +// Visit the use-site of a defined DataflowVar +void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { + this->VisitSpan(op->span); + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const FunctionNode* op) { + this->VisitSpan(op->span); + for (Var param : op->params) { + this->VisitVarDef(param); + } + + this->VisitExpr(op->body); + // FuncStructInfo does not depend on Expr. +} + +void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->op); + + for (StructInfo sinfo_arg : op->sinfo_args) { + this->VisitExprDepStructInfoField(sinfo_arg); + } + + for (Expr arg : op->args) { + this->VisitExpr(arg); + } + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->cond); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const OpNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->tuple); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { + for (PrimExpr val : op->values) { + this->VisitPrimExpr(val); + } + this->VisitSpan(op->span); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { + this->VisitSpan(op->span); + // FuncStructInfo does not depend on Expr. +} + +void ExprVisitor::VisitExpr_(const SeqExprNode* op) { + this->VisitSpan(op->span); + for (BindingBlock block : op->blocks) { + this->VisitBindingBlock(block); + } + this->VisitExpr(op->body); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const PrimValueNode* op) { + this->VisitPrimExpr(op->value); + this->VisitSpan(op->span); +} + +void ExprVisitor::VisitExpr_(const StringImmNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const DataTypeImmNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitSpan(const Span& span) {} + +void ExprVisitor::VisitPrimExpr(const PrimExpr& expr) {} + +// implementations of binding visitor dispatch +RELAX_VAR_BINDING_DISPATCH_IMPL(ExprVisitor); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ConstantNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(VarNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataflowVarNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ShapeExprNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ExternFuncNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(GlobalVarNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(FunctionNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(CallNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(SeqExprNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(IfNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OpNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleGetItemNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(PrimValueNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(StringImmNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); + +void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { + this->VisitExpr(binding->value); + this->VisitVarDef(binding->var); +} + +void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } +} + +void ExprVisitor::VisitBindingBlock_(const DataflowBlockNode* block) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } +} + +void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { this->VisitSpan(var->span); } + +void ExprVisitor::VisitVarDef_(const VarNode* var) { this->VisitSpan(var->span); } + +void ExprVisitor::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { + if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } +} + +void ExprVisitor::VisitVarDef(const Var& var) { + if (const auto* node = var.as()) { + VisitVarDef_(node); + } else if (const auto* node = var.as()) { + VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } +} + +class ExprApplyVisit : public ExprVisitor { + public: + explicit ExprApplyVisit(std::function f) : f_(f) {} + + void VisitExpr(const Expr& e) final { + ExprVisitor::VisitExpr(e); + f_(e); + } + + private: + std::function f_; +}; + +void PostOrderVisit(const Expr& e, std::function fvisit) { + ExprApplyVisit(fvisit).VisitExpr(e); +} + +TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); +}); + +// ================== +// ExprMutatorBase + +StructInfo ExprMutatorBase::VisitExprDepStructInfoField(const StructInfo& struct_info) { + // recurse into struct info in case they depend on value + // under the current scope. + return default_struct_info_field_mutator_.VisitStructInfo(struct_info); +} + +ExprMutatorBase::DefaultStructInfoFieldMutator::DefaultStructInfoFieldMutator( + ExprMutatorBase* parent) + : parent_(parent) {} + +Expr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField(const Expr& expr) { + return parent_->VisitExpr(expr); +} + +PrimExpr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField( + const PrimExpr& expr) { + return parent_->VisitPrimExpr(expr); +} + +StructInfo ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfo_( + const FuncStructInfoNode* op) { + // Do not recurse into function struct info + // as they won't contain ref to values in current scope. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } + +Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { + // Constant' struct info won't be affected by Expr/PrimExpr change. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { + // FuncStructInfo won't be affected by Expr/PrimExpr change. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { + bool unchanged = true; + tvm::Array fields; + for (Expr field : op->fields) { + Expr new_field = this->VisitExpr(field); + fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + + if (unchanged) { + // If tuple's struct info change it means that + // one of its fields' struct info will change + // so un-changed already implies that struct info won't change + return GetRef(op); + } else { + // when there is a change return a new tuple node + return Tuple(fields, op->span); + } +} + +// Visit the use-site of a defined Var +Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { + // struct info of var-use should remain stable + // or the var itself will get replaced + return GetRef(op); +} + +// Visit the use-site of a defined DataflowVar +Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) { + // struct info of var-use should remain stable + // or the var itself will get replaced + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { + // struct info of function is not value dependent + // so no need to check struct_info field + Expr body = this->VisitExpr(op->body); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, body, op->ret_struct_info, op->attrs); + } +} + +Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { + Expr new_op = this->VisitExpr(call_node->op); + bool unchanged = call_node->op.same_as(new_op); + + Array sinfo_args; + for (StructInfo sinfo_arg : call_node->sinfo_args) { + StructInfo new_sinfo_arg = this->VisitExprDepStructInfoField(sinfo_arg); + sinfo_args.push_back(new_sinfo_arg); + unchanged &= new_sinfo_arg.same_as(sinfo_arg); + } + + tvm::Array call_args; + for (Expr arg : call_node->args) { + Expr new_arg = this->VisitExpr(arg); + call_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + + if (unchanged && VisitAndCheckStructInfoFieldUnchanged(call_node->struct_info_)) { + return GetRef(call_node); + } else { + return Call(new_op, call_args, call_node->attrs, sinfo_args, call_node->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitExpr(op->true_branch); + Expr false_b = this->VisitExpr(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { + auto t = this->VisitExpr(op->tuple); + if (op->tuple.same_as(t)) { + // struct info can be deterministically derived by tuple and index + // if t does not change, then struct info won't change. + return GetRef(op); + } else { + return TupleGetItem(t, op->index, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const PrimValueNode* op) { + auto value = this->VisitPrimExpr(op->value); + if (op->value.same_as(value)) { + // struct info can be deterministically derived by value + // if value does not change, then struct info won't change. + return GetRef(op); + } + return PrimValue(value, op->span); +} + +Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const DataTypeImmNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { + auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); + + if (values.same_as(op->values)) { + // If values does not change, struct info won't change. + return GetRef(op); + } else { + return ShapeExpr(values, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { + // StructInfo of function remains value independent. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + Expr body = this->VisitExpr(op->body); + + if (all_blocks_unchanged && body.same_as(op->body) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } + return SeqExpr(blocks, body); +} + +BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { + Array bindings; + if (const auto* node = block.as()) { + for (auto binding : node->bindings) { + if (auto var_binding = binding.as()) { + Expr new_value = this->VisitExpr(var_binding->value); + bindings.push_back(VarBinding(var_binding->var, new_value)); + } else if (auto match_cast = binding.as()) { + Expr new_value = this->VisitExpr(match_cast->value); + bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info)); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + + if (block.as()) { + return DataflowBlock(bindings); + } else { + return BindingBlock(bindings); + } +} + +PrimExpr ExprMutatorBase::VisitPrimExpr(const PrimExpr& expr) { return expr; } + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 88046ed81f10..9db7cea6725d 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -21,7 +21,9 @@ * \file src/relax/ir/struct_info.cc * \brief Relax struct info. */ +#include #include +#include #include namespace tvm { @@ -228,7 +230,17 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Helper functions -// TODO(unity-team): add UpdateStructInfo once analysis.cc is upstreamed +void UpdateStructInfo(Expr expr, StructInfo struct_info) { + ICHECK(!expr->struct_info_.defined()) + << "the struct_info_ of the Expr to be updated must be nullptr for idempotency"; + expr->struct_info_ = struct_info; + // also set checked type + expr->checked_type_ = GetStaticType(struct_info); +} + +TVM_REGISTER_GLOBAL("relax.UpdateStructInfo").set_body_typed([](Expr expr, StructInfo struct_info) { + UpdateStructInfo(expr, struct_info); +}); TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { return GetStructInfo(expr); diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc new file mode 100644 index 000000000000..199491e3c63f --- /dev/null +++ b/src/relax/ir/struct_info_functor.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file struct_info_functor.cc + * \brief Implementations of struct info functors. + */ +#include + +namespace tvm { +namespace relax { + +void StructInfoVisitor::VisitStructInfo_(const ObjectStructInfoNode* op) {} + +void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {} + +void StructInfoVisitor::VisitStructInfo_(const ShapeStructInfoNode* op) { + if (op->values.defined()) { + for (PrimExpr value : op->values.value()) { + this->VisitStructInfoExprField(value); + } + } +} + +void StructInfoVisitor::VisitStructInfo_(const TensorStructInfoNode* op) { + if (op->shape.defined()) { + this->VisitStructInfoExprField(op->shape.value()); + } +} + +void StructInfoVisitor::VisitStructInfo_(const TupleStructInfoNode* op) { + for (StructInfo field : op->fields) { + this->VisitStructInfo(field); + } +} + +void StructInfoVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { + if (op->params.defined()) { + for (StructInfo param : op->params.value()) { + this->VisitStructInfo(param); + } + } + this->VisitStructInfo(op->ret); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) { + return GetRef(op); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) { + return GetRef(op); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { + Optional> values; + + if (op->values.defined()) { + // if no changes are made the original array will be returned. + values = op->values.value().Map( + [this](const PrimExpr& expr) { return this->VisitStructInfoExprField(expr); }); + } + + if (values.same_as(op->values)) { + return GetRef(op); + } else { + return ShapeStructInfo(values.value(), op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) { + Optional shape; + + if (op->shape.defined()) { + shape = this->VisitStructInfoExprField(op->shape.value()); + } + + if (shape.same_as(op->shape)) { + return GetRef(op); + } else { + return TensorStructInfo(shape.value(), op->dtype, op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const TupleStructInfoNode* op) { + Array fields = + op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + + if (fields.same_as(op->fields)) { + return GetRef(op); + } else { + return TupleStructInfo(fields, op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { + Optional> params; + + if (op->params.defined()) { + params = op->params.value().Map( + [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + } + + StructInfo ret = this->VisitStructInfo(op->ret); + + if (params.same_as(op->params) && ret.same_as(op->ret)) { + return GetRef(op); + } else { + ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; + return FuncStructInfo(params.value(), ret, op->span); + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc new file mode 100644 index 000000000000..49ef1d7163f1 --- /dev/null +++ b/src/relax/ir/type.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/type.cc + * \brief Relax type system. + */ +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(ShapeTypeNode); + +ShapeType::ShapeType(int ndim, Span span) { + ObjectPtr n = make_object(); + n->ndim = ndim; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](int ndim, Span span) { + return ShapeType(ndim, span); +}); + +ObjectType::ObjectType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ObjectTypeNode); + +TVM_REGISTER_GLOBAL("relax.ObjectType").set_body_typed([](Span span) { return ObjectType(span); }); + +DynTensorType::DynTensorType(int ndim, DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->ndim = std::move(ndim); + n->dtype = std::move(dtype); + n->span = span; + data_ = std::move(n); +} + +DynTensorType DynTensorType::CreateUnknownNDim(DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->ndim = -1; + n->dtype = std::move(dtype); + n->span = std::move(span); + return DynTensorType(std::move(n)); +} + +TVM_REGISTER_NODE_TYPE(DynTensorTypeNode); + +TVM_REGISTER_GLOBAL("relax.DynTensorType").set_body_typed([](int ndim, DataType dtype, Span span) { + return DynTensorType(ndim, dtype, span); +}); + +PackedFuncType::PackedFuncType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PackedFuncTypeNode); + +TVM_REGISTER_GLOBAL("relax.PackedFuncType").set_body_typed([](Span span) { + return PackedFuncType(span); +}); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py new file mode 100644 index 000000000000..faf8fedcf4bf --- /dev/null +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests analysis functions of struct info""" + +import pytest +import tvm +import tvm.testing +from tvm import relax as rx, TVMError +from tvm import tir + + +def test_get_static_type_basic(): + # object + s0 = rx.ObjectStructInfo() + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s0), rx.ObjectType()) + + # prim + s1 = rx.PrimStructInfo("float32") + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), tvm.ir.PrimType("float32")) + + +def test_get_static_type_shape(): + # shape + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s3 = rx.ShapeStructInfo(ndim=2) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType(ndim=3)) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType(ndim=2)) + + +def test_get_static_type_tensor(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + + tvm.ir.assert_structural_equal( + rx.analysis.get_static_type(s4), rx.DynTensorType(ndim=3, dtype="int64") + ) + + +def test_get_static_type_tuple(): + # tuple + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s0 = rx.ObjectStructInfo() + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + t0 = rx.TupleStructInfo([s4, s0]) + t1 = rx.TupleStructInfo([t0, s2]) + + tvm.ir.assert_structural_equal( + rx.analysis.get_static_type(t1), + rx.TupleType( + [ + rx.TupleType([rx.DynTensorType(ndim=3, dtype="int64"), rx.ObjectType()]), + rx.ShapeType(ndim=3), + ] + ), + ) + + +def test_get_static_type_func(): + # tuple + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_type(): + x = rx.DynTensorType(ndim=3, dtype="float32") + y = rx.DynTensorType(ndim=3, dtype="float32") + z = rx.DynTensorType(ndim=2, dtype="float32") + return rx.FuncType([x, y], z) + + f0 = fn_info(1) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(fn_info(1)), fn_type()) + + +def test_erase_to_well_defined_basic(): + s0 = rx.ObjectStructInfo() + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s0), s0) + + # prim + s1 = rx.PrimStructInfo("float32") + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1), s1) + + +def test_erase_to_well_defined_shape(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s3 = rx.ShapeStructInfo(ndim=2) + # have undefined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2), rx.ShapeStructInfo(ndim=3) + ) + # all defined + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2, {n: n, m: m}), s2) + + # replacement + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2, {n: 2, m: m + 1}), rx.ShapeStructInfo([1, 3, m + 1]) + ) + + # partial defined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2, {n: n}), rx.ShapeStructInfo(ndim=3) + ) + + +def test_erase_to_well_defined_tensor(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) + s0 = rx.TensorStructInfo(rshape, dtype="int32") + + # undefined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, None), + rx.TensorStructInfo(ndim=2, dtype="int32"), + ) + + # defined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, {rshape: rshape}), s0 + ) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, {rshape: rx.ShapeExpr([1, 2])}), + rx.TensorStructInfo([1, 2], dtype="int32"), + ) + + s1 = rx.TensorStructInfo([m + 1, n], dtype="float32") + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1, {n: n, m: m}), s1) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s1, {n: 2, m: 3}), + rx.TensorStructInfo([4, 2], dtype="float32"), + ) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s1, {m: m}, {rshape: rshape}), + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + s2 = rx.TensorStructInfo([1, 2], dtype="float32") + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2), s2) + + +def test_erase_to_well_defined_tuple(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s0 = rx.ObjectStructInfo() + s2 = rx.ShapeStructInfo([1, m]) + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + t0 = rx.TupleStructInfo([s4, s0]) + t1 = rx.TupleStructInfo([t0, s2]) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(t1, {m: m + 1}), + rx.TupleStructInfo( + [ + rx.TupleStructInfo( + [rx.TensorStructInfo(ndim=3, dtype="int64"), rx.ObjectStructInfo()] + ), + rx.ShapeStructInfo([1, m + 1]), + ] + ), + ) + + +def test_erase_to_well_defined_func(): + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + f0 = fn_info(1) + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(f0), f0) + + +def test_base_check(): + BR = rx.analysis.BaseCheckResult + bcheck = rx.analysis.struct_info_base_check + + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("int32") + prim1 = rx.PrimStructInfo("float32") + + shape0 = rx.ShapeStructInfo(ndim=-1) + shape1 = rx.ShapeStructInfo(ndim=2) + shape2 = rx.ShapeStructInfo(ndim=3) + shape3 = rx.ShapeStructInfo([1, 2, 3]) + shape4 = rx.ShapeStructInfo([1, n, 3]) + + tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") + tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") + tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") + tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") + tensor4 = rx.TensorStructInfo([n, m], "int32") + tensor5 = rx.TensorStructInfo([n, m, 1], "int32") + tensor6 = rx.TensorStructInfo([n, m, 2], "int32") + + # obj + assert bcheck(obj0, prim0) == BR.PASS + assert bcheck(obj0, shape1) == BR.PASS + assert bcheck(obj0, tensor2) == BR.PASS + assert obj0.is_base_of(tensor2) + + # prim + assert prim0.is_base_of(prim0) + assert not prim0.is_base_of(prim1) + assert bcheck(prim0, obj0) == BR.FAIL_L1 + assert bcheck(prim0, prim0) == BR.PASS + assert bcheck(prim0, prim1) == BR.FAIL_L0 + + # shape + assert bcheck(shape0, obj0) == BR.FAIL_L1 + assert bcheck(shape0, prim0) == BR.FAIL_L0 + + # unknown dim + assert bcheck(shape0, shape1) == BR.PASS + assert bcheck(shape1, shape0) == BR.FAIL_L1 + + # ndim mismatch + assert bcheck(shape1, shape2) == BR.FAIL_L0 + + # lhs do not have symbolic value but ndim match + assert bcheck(shape2, shape3) == BR.PASS + + # rhs do not symbolic but lhs do + assert bcheck(shape3, shape2) == BR.FAIL_L2 + + # shape mismatch + assert bcheck(shape3, shape4) == BR.FAIL_L2 + assert shape4.is_base_of(rx.ShapeStructInfo([1, n, 3])) + + # tensor + assert bcheck(tensor0, obj0) == BR.FAIL_L1 + assert bcheck(tensor0, prim0) == BR.FAIL_L0 + assert bcheck(tensor0, shape0) == BR.FAIL_L0 + + # dtype mismatch + assert bcheck(tensor0, tensor1) == BR.FAIL_L0 + assert bcheck(tensor0, tensor3) == BR.FAIL_L0 + assert bcheck(tensor3, tensor4) == BR.FAIL_L0 + assert bcheck(tensor1, tensor2) == BR.FAIL_L0 + + # ndim mismatch + assert bcheck(tensor2, tensor5) == BR.FAIL_L0 + + # static shape mismatch + assert bcheck(tensor5, tensor6) == BR.FAIL_L0 + + # match + assert tensor0.is_base_of(rx.TensorStructInfo(ndim=-1, dtype="int32")) + assert tensor0.is_base_of(tensor2) + assert tensor0.is_base_of(tensor4) + assert tensor0.is_base_of(tensor5) + assert tensor0.is_base_of(tensor6) + assert tensor2.is_base_of(tensor4) + assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32")) + + # tuple + t0 = rx.TupleStructInfo([obj0, tensor0]) + t1 = rx.TupleStructInfo([prim0, tensor4]) + t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) + t3 = rx.TupleStructInfo([tensor0, obj0]) + + assert t0.is_base_of(t1) + + assert bcheck(t0, t2) == BR.FAIL_L0 + assert bcheck(t0, t3) == BR.FAIL_L1 + + assert rx.TupleStructInfo([t0, t1]).is_base_of(rx.TupleStructInfo([t1, t1])) + assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1 + + def fn_info_shape(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_info_erased(): + x = rx.TensorStructInfo(ndim=3, dtype="float32") + y = rx.TensorStructInfo(ndim=3, dtype="float32") + z = rx.TensorStructInfo(ndim=2, dtype="float32") + return rx.FuncStructInfo([x, y], z) + + assert fn_info_shape(1).is_base_of(fn_info_shape(1)) + assert fn_info_erased().is_base_of(fn_info_shape(1)) + assert bcheck(fn_info_shape(1), fn_info_erased()) == BR.FAIL_L2 + + fopaque = rx.FuncStructInfo.opaque_func() + assert fopaque.is_base_of(fn_info_shape(1)) + + +def _check_lca(lhs, rhs, target): + tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(lhs, rhs), target) + tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(rhs, lhs), target) + + +def test_struct_info_lca(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("int32") + prim1 = rx.PrimStructInfo("float32") + + shape0 = rx.ShapeStructInfo(ndim=-1) + shape1 = rx.ShapeStructInfo(ndim=2) + shape2 = rx.ShapeStructInfo(ndim=3) + shape3 = rx.ShapeStructInfo([1, 2, 3]) + shape4 = rx.ShapeStructInfo([1, n, 3]) + + tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") + tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") + tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") + tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") + tensor4 = rx.TensorStructInfo([n, m], "int32") + tensor5 = rx.TensorStructInfo([n, m, 1], "int32") + tensor6 = rx.TensorStructInfo([n, m, 2], "int32") + + # obj + _check_lca(obj0, prim0, obj0) + _check_lca(obj0, prim1, obj0) + + # shape + _check_lca(shape0, tensor0, obj0) + _check_lca(shape0, shape1, shape0) + _check_lca(shape1, shape2, shape0) + _check_lca(shape1, shape3, shape0) + + _check_lca(shape2, shape3, shape2) + _check_lca(shape3, shape4, shape2) + _check_lca(shape4, rx.ShapeStructInfo([1, n, 3]), shape4) + + # tensor + _check_lca(tensor0, prim0, obj0) + _check_lca(tensor0, tensor1, rx.TensorStructInfo(ndim=-1, dtype=None)) + _check_lca(tensor0, tensor2, tensor0) + _check_lca(tensor0, tensor4, tensor0) + + _check_lca(tensor2, tensor4, tensor2) + _check_lca(tensor5, tensor6, rx.TensorStructInfo(ndim=3, dtype="int32")) + _check_lca(tensor4, tensor5, rx.TensorStructInfo(ndim=-1, dtype="int32")) + _check_lca(tensor4, rx.TensorStructInfo([n, m], dtype="int32"), tensor4) + + # tuple + t0 = rx.TupleStructInfo([obj0, tensor0]) + t1 = rx.TupleStructInfo([prim0, tensor4]) + t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) + t3 = rx.TupleStructInfo([tensor0, obj0]) + + _check_lca(t0, t1, t0) + _check_lca(t0, t2, obj0) + _check_lca(t0, t3, rx.TupleStructInfo([obj0, obj0])) + + t5 = rx.TupleStructInfo([t0, t1]) + t6 = rx.TupleStructInfo([t1, t2]) + + _check_lca(t5, t6, rx.TupleStructInfo([t0, obj0])) + + t7 = rx.TupleStructInfo([]) + _check_lca(t7, rx.TupleStructInfo([]), t7) + + def fn_info_shape(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_info_erased(): + x = rx.TensorStructInfo(ndim=3, dtype="float32") + y = rx.TensorStructInfo(ndim=3, dtype="float32") + z = rx.TensorStructInfo(ndim=2, dtype="float32") + return rx.FuncStructInfo([x, y], z) + + fopaque0 = lambda: rx.FuncStructInfo.opaque_func() + fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) + fopaque2 = lambda: rx.FuncStructInfo.opaque_func( + ret=rx.TensorStructInfo(ndim=2, dtype="float32") + ) + + _check_lca(fn_info_shape(1), fn_info_shape(2), fn_info_erased()) + _check_lca(fn_info_shape(2), fn_info_shape(2), fn_info_shape(2)) + + _check_lca(fopaque0(), fopaque1(), fopaque0()) + _check_lca(fopaque0(), fn_info_shape(1), fopaque0()) + _check_lca(fopaque2(), fn_info_shape(1), fopaque2()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py new file mode 100644 index 000000000000..4eeaed1e0b50 --- /dev/null +++ b/tests/python/relax/test_expr.py @@ -0,0 +1,258 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import tvm +from tvm import relax as rx +from tvm import tir +from tvm.script import relax as R + + +def _check_equal(x, y, map_free_vars=False): + tvm.ir.assert_structural_equal(x, y, map_free_vars) + tvm.ir.assert_structural_equal(y, x, map_free_vars) + + xhash = tvm.ir.structural_hash(x, map_free_vars) + yhash = tvm.ir.structural_hash(y, map_free_vars) + + assert xhash == yhash + + +def _check_json_roundtrip(x): + xret = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, xret, map_free_vars=True) + return xret + + +def test_var() -> None: + v0 = rx.Var("v0") + assert v0.name_hint == "v0" + assert v0._checked_type_ is None + assert v0.struct_info_ is None + shape = [54, 96] + v1 = rx.Var("v1", R.Tensor(shape, "float32")) + assert v1.name_hint == "v1" + for s0, s1 in zip(v1.struct_info.shape, shape): + assert s0 == s1 + assert v1.checked_type == rx.DynTensorType(2, "float32") + tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float32")) + + +def test_dataflow_var() -> None: + v0 = rx.DataflowVar("v0") + assert v0.name_hint == "v0" + assert v0._checked_type_ is None + assert v0.struct_info_ is None + + shape = [54, 96] + v1 = rx.DataflowVar("v1", R.Tensor(shape, "float16")) + assert v1.name_hint == "v1" + + assert v1._checked_type_ == rx.DynTensorType(2, "float16") + assert isinstance(v1, rx.DataflowVar) + tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float16")) + + +def test_tuple() -> None: + v0 = rx.Var("v0") + v1 = rx.Var("v1") + t = rx.Tuple((v0, v1)) + + assert t.fields[0] == v0 + assert t.fields[1] == v1 + assert t[0] == v0 + assert t[1] == v1 + assert t[-1] == v1 + assert t[-2] == v0 + + with pytest.raises(IndexError, match="Tuple index out of range"): + t[2] + + with pytest.raises(IndexError, match="Tuple index out of range"): + t[-3] + + +def test_match_cast() -> None: + # match_cast([16, 8], [m, n]) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + var = rx.Var("v0", R.Shape()) + b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32")) + assert b0.value == shape + assert b0.pattern[0] == m + assert b0.pattern[1] == n + assert b0.var is not None + assert b0.var.checked_type == rx.ShapeType() + + # var1: R.Tensor((m, n), "float32") = + # match_cast(var0: R.Tensor("float32", ndim=-1), R.Tensor((m, n), "float32")) + value = rx.Var("value", R.Tensor("float32", ndim=-1)) + + var = rx.Var("v1", R.Tensor([m, n], "float32")) + b1 = rx.MatchCast(var, value, R.Tensor([m, n], "float32")) + assert b1.value == value + assert b1.pattern[0] == m + assert b1.pattern[1] == n + assert b1.var is not None + assert b1.var.checked_type == rx.DynTensorType(2, "float32") + + +def test_match_cast() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + ivalue = rx.Var("input_value") + sinfo = rx.TensorStructInfo([n, m], "float32") + b0 = rx.MatchCast(rx.Var("v"), ivalue, sinfo) + assert b0.value.same_as(ivalue) + assert b0.struct_info == sinfo + _check_json_roundtrip(b0) + + +def test_var_binding() -> None: + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b0 = rx.VarBinding(v0, val) + assert b0.var.name_hint == "v0" + assert b0.value == val + + +def test_binding_block() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.BindingBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + + +def test_dataflow_block() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.DataflowBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + assert isinstance(block0, rx.DataflowBlock) + + +def test_seq_expr() -> None: + x = rx.Var("foo") + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + assert seqe.blocks[0] == blocks[0] + assert seqe.body == x + + +def test_func(): + x = rx.Var("foo", R.Tensor(dtype="float32", ndim=2)) + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + + seqe = rx.SeqExpr(blocks, x) + ret_struct_info = R.Tensor(dtype="float32", ndim=-1) + func = rx.Function([x], seqe, ret_struct_info) + func = func.with_attr("global_symbol", "func") + assert func.params[0] == x + assert func.body == seqe + assert func.ret_struct_info == ret_struct_info + assert func.attrs["global_symbol"] == "func" + + +def test_shape_of(): + shape = [96, 54] + v1 = rx.Var("v1", R.Tensor(shape)) + s1 = rx.get_shape_of(v1) + for x, y in zip(shape, s1): + assert x == y + + +def test_shape_expr(): + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + s = rx.ShapeExpr([m, n]) + assert s.values[0] == m + assert s.values[1] == n + assert s[0] == m + assert s[1] == n + assert s[-1] == n + assert s[-2] == m + assert isinstance(s.struct_info, rx.ShapeStructInfo) + + with pytest.raises(IndexError, match="ShapeExpr index out of range"): + s[2] + + with pytest.raises(IndexError, match="ShapeExpr index out of range"): + s[-3] + + shape_expr = rx.ShapeExpr([10, 20]) + assert shape_expr.values[0] == 10 + assert shape_expr.values[1] == 20 + assert shape_expr.checked_type == rx.ShapeType(ndim=2) + tvm.ir.assert_structural_equal(shape_expr.struct_info, R.Shape((10, 20))) + + x = rx.Var("v0", R.Tensor((10, 20), "float32")) + assert x.struct_info.shape[0] == 10 + assert x.struct_info.shape[1] == 20 + assert x.struct_info.shape.checked_type == rx.ShapeType(ndim=2) + tvm.ir.assert_structural_equal(x.struct_info.shape.struct_info, R.Shape((10, 20))) + + m = tir.Var("m", "int32") + with pytest.raises( + tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" + ): + rx.ShapeExpr([m, 3]) + + +def test_prim_value(): + pv = rx.PrimValue(tir.IntImm("int64", 1)) + assert pv.value.value == 1 + _check_equal(pv, rx.PrimValue(tir.IntImm("int64", 1))) + _check_json_roundtrip(pv) + + +def test_string_imm(): + s0 = rx.StringImm("hello") + s1 = rx.StringImm("hello") + assert s0.value == "hello" + _check_equal(s0, s1) + _check_json_roundtrip(s0) + + +def test_datatype_imm(): + d0 = rx.DataTypeImm("int32") + d1 = rx.DataTypeImm("int32") + assert d0.value == "int32" + _check_equal(d0, d1) + _check_json_roundtrip(d0) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_struct_info.py new file mode 100644 index 000000000000..80ebc3cb182a --- /dev/null +++ b/tests/python/relax/test_struct_info.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +import pytest + +from tvm import relax as rx, TVMError, tir + + +def _check_equal(x, y, map_free_vars=False): + tvm.ir.assert_structural_equal(x, y, map_free_vars) + tvm.ir.assert_structural_equal(y, x, map_free_vars) + + xhash = tvm.ir.structural_hash(x, map_free_vars) + yhash = tvm.ir.structural_hash(y, map_free_vars) + + assert xhash == yhash + + +def _check_json_roundtrip(x): + xret = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, xret, map_free_vars=True) + return xret + + +def test_object_struct_info(): + s0 = rx.ObjectStructInfo() + s1 = rx.ObjectStructInfo() + + # can turn into str + str(s0) + _check_equal(s0, s1) + + assert isinstance(s0, rx.ObjectStructInfo) + _check_json_roundtrip(s0) + + +def test_shape_type(): + t0 = rx.ShapeType() + t1 = rx.ShapeType() + assert t0 == t1 + + +def test_dyn_tensor_type(): + t0 = rx.DynTensorType() + assert t0.ndim == -1 + t1 = rx.DynTensorType(3, "int32") + assert t1.ndim == 3 + assert t1.dtype == "int32" + + +def test_prim_struct_info(): + s0 = rx.PrimStructInfo("float32") + s1 = rx.PrimStructInfo("float32") + s2 = rx.PrimStructInfo("int32") + + _check_equal(s0, s1) + + # can turn into str + str(s0) + + assert s0 == s1 + assert s0 != s2 + + assert isinstance(s0, rx.PrimStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + assert s1.dtype == "float32" + assert s2.dtype == "int32" + + # wrong API constructors + with pytest.raises(TVMError): + rx.PrimStructInfo(1) + + +def test_shape_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.ShapeStructInfo([1, n + 1, m]) + s1 = rx.ShapeStructInfo([1, n + 1, m]) + + _check_equal(s0, s1) + + assert s0 == s1 + assert s0.ndim == 3 + assert s1.ndim == 3 + + assert s0.values[2] == m + + assert isinstance(s0, rx.ShapeStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + s2 = rx.ShapeStructInfo(ndim=2) + + assert s2.ndim == 2 + assert s2.values is None + _check_json_roundtrip(s2) + assert s0 != s2 + + # can turn into str + str(s0) + + # wrong argument type + with pytest.raises(TVMError): + rx.ShapeStructInfo(1) + + # cannot pass both ndim and values + with pytest.raises(ValueError): + rx.ShapeStructInfo([1, 2], ndim=3) + + # cannot pass both ndim and values even if they are consistent + with pytest.raises(ValueError): + rx.ShapeStructInfo([1, 2], ndim=2) + + +def test_tensor_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.TensorStructInfo([1, n + 1, m], "float32") + s1 = rx.TensorStructInfo(rx.ShapeExpr([1, n + 1, m]), "float32") + + _check_equal(s0, s1) + + assert s0 == s1 + assert s0.ndim == 3 + assert s1.ndim == 3 + + assert isinstance(s0, rx.TensorStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + s2 = rx.TensorStructInfo(ndim=2, dtype="int32") + + assert s2.ndim == 2 + assert s2.dtype == "int32" + assert s2.shape is None + _check_json_roundtrip(s2) + assert s0 != s2 + + # take in opaque var + rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) + + s3 = rx.TensorStructInfo(rshape, dtype="int32") + assert s3.dtype == "int32" + assert s3.shape == rshape + assert s3.ndim == 2 + _check_json_roundtrip(s3) + + # can turn into str + str(s0) + + # cannot pass both ndim and values + with pytest.raises(ValueError): + rx.TensorStructInfo([1, 2], ndim=3) + + # cannot pass both ndim and values even if they are consistent + with pytest.raises(ValueError): + rx.TensorStructInfo([1, 2], ndim=2) + + +def test_tuple_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.TensorStructInfo([1, 2, m + n], "float32") + s1 = rx.ObjectStructInfo() + + t0 = rx.TupleStructInfo([s0, s1]) + t1 = rx.TupleStructInfo([s0, rx.ObjectStructInfo()]) + t2 = rx.TupleStructInfo([s0, s0]) + + _check_equal(t0, t1) + + assert t0 == t1 + + assert isinstance(t0, rx.TupleStructInfo) + t0 = _check_json_roundtrip(t0) + t1 = _check_json_roundtrip(t1) + t2 = _check_json_roundtrip(t2) + + # can turn into str + str(t0) + + # wrong argument type + with pytest.raises(TVMError): + rx.TupleStructInfo(1) + + +def test_func_struct_info(): + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n, m], "float32") + return rx.FuncStructInfo([x, y], z) + + f0 = fn_info(1) + f1 = fn_info(1) + f2 = fn_info(2) + f3 = rx.FuncStructInfo.opaque_func() + + _check_equal(f0, f1) + + assert f0 == f1 + assert f0 != f2 + + assert len(f0.params) == 2 + assert isinstance(f0.ret, rx.TensorStructInfo) + assert f2.derive_func is None + assert f3.params is None + assert f3.derive_func is None + _check_equal(f3.ret, rx.ObjectStructInfo()) + + assert isinstance(f0, rx.FuncStructInfo) + f0 = _check_json_roundtrip(f0) + f1 = _check_json_roundtrip(f1) + f2 = _check_json_roundtrip(f2) + f3 = _check_json_roundtrip(f3) + + # can turn into str + str(f3) + + +if __name__ == "__main__": + tvm.testing.main()