diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 2ad364b96c..c5c86dc471 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -34,9 +34,15 @@ namespace relax { using Expr = RelayExpr; using ExprNode = RelayExprNode; using relay::Call; +using relay::CallNode; +using relay::ConstantNode; using relay::Id; +using relay::If; +using relay::IfNode; using relay::Tuple; using relay::TupleGetItem; +using relay::TupleGetItemNode; +using relay::TupleNode; /*! \brief A shape expression which allows users to construct a shape containing PrimExpr. */ diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h new file mode 100644 index 0000000000..b4a143727a --- /dev/null +++ b/include/tvm/relax/expr_functor.h @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/expr_functor.h + * \brief A more powerful visitor which enables defining arbitrary function + * signatures with type based dispatch on first argument. + */ +#ifndef TVM_RELAX_EXPR_FUNCTOR_H_ +#define TVM_RELAX_EXPR_FUNCTOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +namespace tvm { +namespace relax { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Expr&, + * Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } + +#define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward(args)...); } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAX_EXPR_FUNCTOR_DISPATCH(ConstantNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleNode); + RELAX_EXPR_FUNCTOR_DISPATCH(VarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(DataflowVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ShapeExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ExternFuncNode); + RELAX_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode); + RELAX_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(IfNode); + RELAX_EXPR_FUNCTOR_DISPATCH(OpNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around ExprFunctor. + * Recursively visit the content. + * + * ExprVisitor treats Expr as dataflow graph, + * and only visit each Expr node once. + */ +class ExprVisitor : public ExprFunctor { + public: + void VisitExpr(const Expr& expr) override; + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const DataflowVarNode* op) override; + void VisitExpr_(const ShapeExprNode* op) override; + void VisitExpr_(const ExternFuncNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const SeqExprNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; + + virtual void VisitType(const Type& t); + virtual void VisitSpan(const Span& span); + virtual void VisitBinding(const Binding& binding); + virtual void VisitVarBinding(const VarBinding& binding); + virtual void VisitMatchShape(const MatchShape& binding); + virtual void VisitBindingBlock(const BindingBlock& block); + virtual void VisitDataflowBlock(const DataflowBlock& block); +}; + +void PostOrderVisit(const Expr& node, std::function fvisit); + +/*! + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once. + * The mutated results are memoized in a map and reused so that + * local transformation on the dataflow preserves the graph structure. + */ +class ExprMutator : public ExprFunctor { + public: + /*! + * \brief Mutate is alias for VisitExpr + * \return expr. + */ + Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); } + Expr VisitExpr(const Expr& expr) override; + Expr VisitExpr_(const ConstantNode* op) override; + Expr VisitExpr_(const TupleNode* op) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const DataflowVarNode* op) override; + Expr VisitExpr_(const ShapeExprNode* op) override; + Expr VisitExpr_(const ExternFuncNode* op) override; + Expr VisitExpr_(const GlobalVarNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const CallNode* op) override; + Expr VisitExpr_(const SeqExprNode* op) override; + Expr VisitExpr_(const IfNode* op) override; + Expr VisitExpr_(const OpNode* op) override; + Expr VisitExpr_(const TupleGetItemNode* op) override; + + /*! + * \brief Used to visit the types inside of expressions. + * + * Can be overloaded to transform the types in arbitrary + * ways, one way would be to define a sub-class of type + * visitor for types which transform them appropriately. + */ + virtual Type VisitType(const Type& t); + virtual void VisitBinding(const Binding& binding); + virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder); + virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder); + virtual BindingBlock VisitBindingBlock(const BindingBlock& block); + virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); + + protected: + LazyIRBuilder irbuilder_; +}; + +/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes + */ +class DataflowMutator : public ExprMutator { + public: + virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); + virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder); + + protected: + /*! \brief Look up the value binded to a var. */ + Expr LookupVar(Var var); + // A remapping table: pre var -> post var + std::unordered_map pre_post_var_map_; +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relax/ir_builder.h b/include/tvm/relax/ir_builder.h index f97981a6cf..9d3ec6e689 100644 --- a/include/tvm/relax/ir_builder.h +++ b/include/tvm/relax/ir_builder.h @@ -37,6 +37,7 @@ namespace relax { using relay::Call; class IRBuilder; +class LazyIRBuilder; /*! * \brief The state of Relax function node being built. @@ -72,19 +73,44 @@ class IRBuilderNode : public Object { /*! * \brief Build a binding block. */ - void BuildBlock(); + virtual void BuildBlock(); /*! - * \brief Emit a call node. - * \param call The CallNode to be emitted. + * \brief Emit a Call, and return a newly created Var binded to the Call. + * \param call The Call to be emitted. * \return The variable being created and binded to \p call. */ - Var Emit(const Call& call); + virtual Var Emit(const Call& call); + /*! + * \brief Emit a var binding. + * \param binding The VarBinding to be emitted. + * \return The VarNode of the VarBinding \p binding. + */ + virtual Var Emit(const VarBinding& binding); + /*! + * \brief Emit a Call, and bind it to a Var. + * \param var The Var to be binded with. \p var is reused implicitly if the shape + * and type of \p call matches \p var. Otherwise a new Var is created. + * \param call The Call to be emitted. + * \return The Var to be binded with \p var. + */ + virtual Var Emit(const Var& var, const Call& call); + /*! + * \brief Emit a MatchShape. + * \param value The value of the MatchShape to be emitted. + * \param pattern The pattern of the MatchShape to be emitted. + * \return The variable being binded to the MatchShape. + */ + Var EmitMatchShape(const Expr& value, const Array& pattern); /*! * \brief Generate an output for the current dataflow block or function. * \param output The output variable of the block/function. * \return The variable being binded to \p output. */ Var EmitOutput(const Expr& output); + /*! + * \brief Lookup a var in the binding table \p var_map_. + */ + Expr LookupVar(const Var& var); /*! * \brief Get the function being built. */ @@ -93,29 +119,45 @@ class IRBuilderNode : public Object { * \brief Get binding blocks being built. */ std::vector GetBlocks(); + /*! + * \brief Check if two shape expressions can be proven equal at compile time. + * \param lhs The input lhs shape. + * \param rhs The input rhs shape. + * \return Whether we can prove lhs shape == rhs shape. + */ + bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs); + /*! + * \brief Normalize an Expr to complete its shape and type. + * \param expr The input expr. + * \return The expr with normalized shape and type. + */ + Expr Normalize(const Expr& expr); /*! * \brief Create a IRBuilder. * \return The created IRBuilder. */ TVM_DLL static IRBuilder Create(); + /*! \brief A flag tracking if currently inside a dataflow block or not. */ + bool is_dataflow_ = false; + void VisitAttrs(AttrVisitor* v) {} static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.IRBuilder"; - TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderNode, Object); - private: + protected: /*! \brief The state of the function currently being built. */ RelaxFunction func_; - /*! \brief A flag tracking if currently inside a dataflow block or not. */ - bool is_dataflow_ = false; /*! \brief A global variable counter for naming global variables. */ int global_var_counter_ = 0; /*! \brief A dataflow variable counter for naming dataflow variables. */ int dataflow_var_counter_ = 0; /*! \brief A diagnostic context for reporting errors. */ DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {})); + /*! \brief A binding table that maps var to value. */ + std::unordered_map var_map_; }; class IRBuilder : public ObjectRef { @@ -193,6 +235,70 @@ class DataflowScope : public ObjectRef { TVM_DLL void ExitWithScope(); }; +/*! + * \brief A lazy builder to construct dataflow block in a copy-on-write fashion. + */ +class LazyIRBuilderNode : public IRBuilderNode { + public: + /*! + * \brief Emit a Call in a copy-on-write way. + * If no bindings in a dataflow block need to be rewritten, reuse the original variable instead of + * emiting one. If any binding in the block needs to be rewritten, reconstruct the whole block + * from scratch by emiting all previous bindings. + * \param call The Call to be emitted. + * \return The variable being created and binded to \p call. + */ + virtual Var Emit(const Call& call); + /*! + * \brief Emit a var binding in a copy-on-write way. + * \param binding The VarBinding to be emitted. + * \return The Var of the \p binding. + */ + virtual Var Emit(const VarBinding& binding); + /*! + * \brief Emit a Call, and bind it to a Var in a copy-on-write way. + * \param var The Var to be binded with. + * \param call The Call to be emitted. + * \return The Var to be binded with \p var. + */ + virtual Var Emit(const Var& var, const Call& call); + /*! + * \brief Emit an output for the current dataflow block or function in a copy-on-write way. + * \param binding The VarBinding to be emitted. + * \return The variable being binded to \p output. + */ + virtual Var EmitOutput(const VarBinding& binding); + /*! + * \brief Build a binding block. + */ + virtual void BuildBlock(); + /*! + * \brief Create a LazyIRBuilder. + * \return The created LazyIRBuilder. + */ + TVM_DLL static LazyIRBuilder Create(const DataflowBlock& block); + + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.LazyIRBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(LazyIRBuilderNode, IRBuilderNode); + + private: + /*! \brief Original DataflowBlock before rewriting. */ + DataflowBlock df_block_; + /*! \brief index in the \p bindings. */ + int64_t index_ = 0; + /*! \brief A flag tracking if current dataflow block needs to be rewritten or not. */ + bool is_rewrite_ = false; +}; + +class LazyIRBuilder : public IRBuilder { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LazyIRBuilder, IRBuilder, LazyIRBuilderNode); +}; + + } // namespace relax } // namespace tvm diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 9249b7579c..6f0b78df7d 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -47,7 +47,7 @@ def checked_type(self): """ ret = self.checked_type_ if ret is None: - raise ValueError("The type checker has not populated" " the checked_type for this node") + raise ValueError("The type checker has not populated the checked_type for this node") return ret @property diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index e3b3c61c5e..58f0836329 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -22,6 +22,7 @@ from . import ir_builder from . import op from . import parser +from . import analysis # Expr diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py new file mode 100644 index 0000000000..cc0089ff31 --- /dev/null +++ b/python/tvm/relax/analysis/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax IR analysis. """ + +from .analysis import * diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py new file mode 100644 index 0000000000..c3e9ae4eb7 --- /dev/null +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +import tvm._ffi + +tvm._ffi._init_api("relax.analysis", __name__) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py new file mode 100644 index 0000000000..3e13f0ad13 --- /dev/null +++ b/python/tvm/relax/analysis/analysis.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains the set of passes for Relax, which exposes an interface for +configuring the passes and scripting them in Python. +""" +from . import _ffi_api + + +def post_order_visit(expr, fvisit): + """Recursively visit the ir in post DFS order node, + apply fvisit. Each node is guaranteed to be visited + only once. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + fvisit : function + The visitor function to be applied. + """ + return _ffi_api.post_order_visit(expr, fvisit) + +def fma_rewrite(expr): + """Perform fused multiply add rewriting in dataflow blocks. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + """ + return _ffi_api.fma_rewrite(expr) + +def explicit_memory_rewrite(expr): + """Perform explicit memory allocation for call_dps in dataflow blocks. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + """ + return _ffi_api.explicit_memory_rewrite(expr) \ No newline at end of file diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 1cc067d734..2e570eb8ab 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -51,7 +51,7 @@ def make_shape(shape: List[PrimExpr]) -> ShapeExpr: if isinstance(shape, (list, tuple)): return ShapeExpr(shape) else: - raise ValueError + raise ValueError("Wrong type") @tvm._ffi.register_object("relax.expr.Var") diff --git a/python/tvm/relax/ir_builder.py b/python/tvm/relax/ir_builder.py index 426b1ccabe..34fb139859 100644 --- a/python/tvm/relax/ir_builder.py +++ b/python/tvm/relax/ir_builder.py @@ -131,6 +131,26 @@ def emit(self, A newly created variable that gets binded to the call code. """ return _ffi_api.IRBuilderEmit(self, call) + + def match_shape(self, + value: Expr, + pattern: List[PrimExpr]): + """Emit a MatchShape. + + Parameters + ---------- + value : tvm.relay.Expr + The value of the MatchShape to be emitted. + + pattern : List[PrimExpr] + The pattern of the MatchShape to be emitted. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets binded to the call code. + """ + return _ffi_api.IRBuilderEmitMatchShape(self, value, pattern) def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: @@ -150,6 +170,22 @@ def emit_output(self, output = Tuple(output) return _ffi_api.IRBuilderEmitOutput(self, output) + def normalize(self, + expr: Expr) -> Expr: + """Normalize an Expr to complete its shape and type. + + Parameters + ---------- + expr : Expr + The input expr. + + Returns + ------- + ret : Expr + The expr with normalized shape and type. + """ + return _ffi_api.IRBuilderNormalize(self, expr) + def get(self) -> Function: """Return the function being built. diff --git a/src/relax/expr_functor.cc b/src/relax/expr_functor.cc new file mode 100644 index 0000000000..5068b423b4 --- /dev/null +++ b/src/relax/expr_functor.cc @@ -0,0 +1,528 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/expr_functor.cc + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +void ExprVisitor::VisitExpr_(const ConstantNode* op) { + this->VisitSpan(op->span); +} + +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { + this->VisitSpan(op->span); +} + +void ExprVisitor::VisitExpr_(const TupleNode* op) { + this->VisitSpan(op->span); + for (auto field : op->fields) { + this->VisitExpr(field); + } +} + +void ExprVisitor::VisitExpr_(const VarNode* op) { + this->VisitSpan(op->span); + if (op->type_annotation.defined()) { + this->VisitType(op->type_annotation.value()); + } +} + +void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { + this->VisitSpan(op->span); + if (op->type_annotation.defined()) { + this->VisitType(op->type_annotation.value()); + } +} + +void ExprVisitor::VisitExpr_(const FunctionNode* op) { + this->VisitSpan(op->span); + for (auto param : op->params) { + this->VisitExpr(param); + } + + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->op); + + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (auto arg : op->args) { + this->VisitExpr(arg); + } +} + +void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->cond); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); +} + +void ExprVisitor::VisitExpr_(const OpNode* op) { + return; +} + +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->tuple); +} + +void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { + this->VisitSpan(op->span); +} + +void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { + this->VisitSpan(op->span); +} + +void ExprVisitor::VisitExpr_(const SeqExprNode* op) { + this->VisitSpan(op->span); + for (auto block : op->blocks) { + this->VisitBindingBlock(block); + } + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitType(const Type& t) { + return; +} + +void ExprVisitor::VisitSpan(const Span& span) { + return; +} + +void ExprVisitor::VisitBinding(const Binding& binding) { + if (binding.as()) { + this->VisitVarBinding(Downcast(binding)); + } else if (binding.as()) { + this->VisitMatchShape(Downcast(binding)); + } else { + LOG(FATAL) << "Wrong type."; + } +} + +void ExprVisitor::VisitVarBinding(const VarBinding& binding) { + this->VisitExpr(binding->value); +} + +void ExprVisitor::VisitMatchShape(const MatchShape& binding) { + this->VisitExpr(binding->value); +} + +void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { + if (block.as()) { + this->VisitDataflowBlock(Downcast(block)); + } else { + for (auto binding : block->bindings) { + this->VisitBinding(binding); + } + } +} + +void ExprVisitor::VisitDataflowBlock(const DataflowBlock& block) { + for (auto binding : block->bindings) { + this->VisitBinding(binding); + } +} + +void ExprVisitor::VisitExpr(const Expr& expr) { + using TParent = ExprFunctor; + TParent::VisitExpr(expr); +} + +class ExprApplyVisit : public ExprVisitor { + public: + explicit ExprApplyVisit(std::function f) : f_(f) {} + + void VisitExpr(const Expr& e) final { + ExprVisitor::VisitExpr(e); + f_(e); + } + + private: + std::function f_; +}; + +void PostOrderVisit(const Expr& e, std::function fvisit) { + ExprApplyVisit(fvisit).VisitExpr(e); +} + +TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit") +.set_body_typed([](Expr expr, PackedFunc f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); +}); + + +// ================== +// ExprMutator + +Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } + +Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); } + +Expr ExprMutator::VisitExpr_(const TupleNode* op) { + tvm::Array fields; + bool all_fields_unchanged = true; + for (auto field : op->fields) { + auto new_field = this->Mutate(field); + fields.push_back(new_field); + all_fields_unchanged &= new_field.same_as(field); + } + + if (all_fields_unchanged) { + return GetRef(op); + } else { + return Tuple(fields, op->span); + } +} + +Expr ExprMutator::VisitExpr_(const VarNode* op) { + if (op->type_annotation.defined()) { + auto type = this->VisitType(op->type_annotation.value()); + if (!op->type_annotation.same_as(type)) { + return Var(op->vid, Downcast(op->shape()), type, op->span); + } + } + // default case return self. + return GetRef(op); +} + +Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { + if (op->type_annotation.defined()) { + auto type = this->VisitType(op->type_annotation.value()); + if (!op->type_annotation.same_as(type)) { + return DataflowVar(op->vid, Downcast(op->shape()), type, op->span); + } + } + // default case return self. + return GetRef(op); +} + +Expr ExprMutator::VisitExpr_(const FunctionNode* op) { + tvm::Array params; + bool all_params_unchanged = true; + for (auto param : op->params) { + Var new_param = Downcast(this->Mutate(param)); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + } + + auto ret_type = this->VisitType(op->ret_type); + auto body = this->Mutate(op->body); + + if (all_params_unchanged && ret_type.same_as(op->ret_type) && body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->name, params, body, ret_type); + } +} + +Expr ExprMutator::VisitExpr_(const CallNode* call_node) { + auto new_op = this->Mutate(call_node->op); + bool unchanged = call_node->op.same_as(new_op); + + tvm::Array ty_args; + for (auto ty_arg : call_node->type_args) { + auto new_ty_arg = this->VisitType(ty_arg); + ty_args.push_back(new_ty_arg); + unchanged &= new_ty_arg.same_as(ty_arg); + } + + tvm::Array call_args; + for (auto arg : call_node->args) { + auto new_arg = this->Mutate(arg); + call_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + + if (unchanged) { + return GetRef(call_node); + } else { + return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); + } +} + +Expr ExprMutator::VisitExpr_(const IfNode* op) { + auto guard = this->Mutate(op->cond); + auto true_b = this->Mutate(op->true_branch); + auto false_b = this->Mutate(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } +} + +Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef(op); } + +Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) { + auto t = this->Mutate(get_item->tuple); + if (get_item->tuple == t) { + return GetRef(get_item); + } else { + return TupleGetItem(t, get_item->index, get_item->span); + } +} + +Expr ExprMutator::VisitExpr_(const ShapeExprNode* op) { return GetRef(op); } + +Expr ExprMutator::VisitExpr_(const ExternFuncNode* op) { return GetRef(op); } + +Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + blocks.push_back(new_block); + all_blocks_unchanged &= block.same_as(new_block); + } + + Expr body = this->Mutate(op->body); + if (all_blocks_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return SeqExpr(blocks, body); + } +} + +Type ExprMutator::VisitType(const Type& t) { return t; } + +void ExprMutator::VisitBinding(const Binding& binding) { + Binding new_binding; + if (binding.as()) { + this->VisitVarBinding(Downcast(binding), this->irbuilder_); + } else if (binding.as()) { + this->VisitMatchShape(Downcast(binding), this->irbuilder_); + } else { + LOG(FATAL) << "Wrong type."; + } +} + +Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) { + Expr new_value = this->Mutate(binding->value); + if (!binding->var.as()) { + return ir_builder->EmitOutput(new_value); + } else { + return ir_builder->Emit(Downcast(new_value)); + } +} + +void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder) { + this->Mutate(binding->value); +} + +BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { + if (block.as()) { + return this->VisitDataflowBlock(Downcast(block)); + } else{ + // TODO + return block; + } +} + +BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) { + this->irbuilder_ = LazyIRBuilderNode::Create(block); + { + With scope(this->irbuilder_); + for (auto binding : block->bindings) { + if (binding.as()) { + this->VisitVarBinding(Downcast(binding), this->irbuilder_); + } + } + } + return this->irbuilder_->GetBlocks().back(); +} + +Expr ExprMutator::VisitExpr(const Expr& expr) { + Expr new_expr = ExprFunctor::VisitExpr(expr); + return new_expr; +} + + +// ================== +// DataflowMutator + +BindingBlock DataflowMutator::VisitDataflowBlock(const DataflowBlock& block) { + this->irbuilder_ = LazyIRBuilderNode::Create(block); + { + With scope(this->irbuilder_); + for (auto binding : block->bindings) { + if (auto* var_binding = binding.as()) { + Var var = this->VisitVarBinding(Downcast(binding), this->irbuilder_); + this->pre_post_var_map_[var_binding->var] = var; + } + } + } + return this->irbuilder_->GetBlocks().back(); +} + +Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) { + Expr new_value = this->Mutate(binding->value); + Var new_var; + if (new_value.as()) { + new_var = ir_builder->Emit(Downcast(new_value)); + } + if (!binding->var.as()) { + new_var = ir_builder->EmitOutput(new_value); + } + pre_post_var_map_[binding->var] = new_var; + return new_var; +} + +Expr DataflowMutator::LookupVar(Var var) { + auto it = pre_post_var_map_.find(var); + if (it != pre_post_var_map_.end()) { + return irbuilder_->LookupVar(it->first); + } else { + return irbuilder_->LookupVar(var); + } +} + + +// ================== +// EwiseFMARewriter +// Example: +// x0 = mul(a, b) +// z0 = add(x0, c) +// --> +// z0 = ewise_fma(a, b, c) + +// Example 2: +// Question: do we want to support this? +// x0 = mul(a, add(k, b)) +// z0 = add(x0, c) +// --> +// lv0 = add(k, b) +// z0 = ewise_fma(a, lv0, c) + +class EwiseFMARewriter : public DataflowMutator { + Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { + static const Op& add_op = Op::Get("relax.add"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& ewise_fma_op = Op::Get("relax.ewise_fma"); + + // TODO: shape & dtype check + const CallNode* op1 = binding->value.as(); + if (op1 && (op1->op == add_op)) { + Expr value = LookupVar(Downcast(op1->args[0])); + const CallNode* op2 = value.as(); + if (op2 && op2->op == multiply_op) { + Call fma_call = Call(ewise_fma_op, {op2->args[0], op2->args[1], op1->args[1]}, {}, {}); + return ir_builder->Emit(binding->var, fma_call); + } + } + return ir_builder->Emit(binding); + } +}; + +Expr FMARewrite(const Expr& e) { + return EwiseFMARewriter().Mutate(e); +} + +TVM_REGISTER_GLOBAL("relax.analysis.fma_rewrite") +.set_body_typed([](Expr expr) { + return FMARewrite(expr); +}); + +// ================== +// ExplicitMemMutator +// Example: +// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x)) +// --> +// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m]) +// rx.call_packed(op.identity, x, lv0) + +class ExplicitMemMutator : public DataflowMutator { + Expr ComputeStorageSize(const Expr& shape, const Type& type) const { + DynTensorType tensor_type = Downcast(type); + DataType dtype = DataType(tensor_type->dtype); + // Question: what if the dtype of tensor_type is unknown? + // Symbolic/static shape case + if (auto* shape_expr = shape.as()) { + PrimExpr num = PrimExpr(dtype.bits()) * PrimExpr(dtype.lanes()); + PrimExpr add = num + 7; + PrimExpr ret = 1; + for (PrimExpr dim : shape_expr->values) { + ret = ret * dim; + } + ret = ret * (add / PrimExpr(8)); + return ShapeExpr({ret}); + } + // Fully dynamic shape case + // will need to dedup with ComputeStorageInRelay when we upstream + Expr prod = relay::Prod(shape, Array(nullptr), false, false); + Expr num = relay::MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes()); + Expr add = relay::Add(num, relay::MakeConstantScalar(DataType::Int(64), 7)); + Expr div = relay::MakeConstantScalar(DataType::Int(64), 8); + Expr ret = relay::Multiply(prod, relay::Divide(add, div)); + return ret; + } + + Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { + static const Op& call_dps_op = Op::Get("relax.call_dps"); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + + const CallNode* op = binding->value.as(); + if(op && op->op == call_dps_op) { + // switch current DataflowBlock to an impure BindingBlock + ir_builder->is_dataflow_ = false; + ShapeExpr output_shape = Downcast(op->args[0]); + Type arg_type = Downcast(op->args[2])->fields[0]->checked_type(); + Expr output_size = ComputeStorageSize(output_shape, arg_type); + Var tensor = ir_builder->Emit(Call(alloc_tensor_op, {op->args[0]})); + return ir_builder->Emit(binding->var, Call(op->args[1], {op->args[2], tensor})); + } + return ir_builder->Emit(binding); + } +}; + +Expr ExplicitMemRewrite(const Expr& e) { + return ExplicitMemMutator().Mutate(e); +} + +TVM_REGISTER_GLOBAL("relax.analysis.explicit_memory_rewrite") +.set_body_typed([](Expr expr) { + return ExplicitMemRewrite(expr); +}); + + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir_builder.cc b/src/relax/ir_builder.cc index c2cc4fa550..46e7d590a6 100644 --- a/src/relax/ir_builder.cc +++ b/src/relax/ir_builder.cc @@ -24,11 +24,14 @@ #include #include #include +#include +#include namespace tvm { namespace relax { TVM_REGISTER_NODE_TYPE(IRBuilderNode); +TVM_REGISTER_NODE_TYPE(LazyIRBuilderNode); TVM_REGISTER_NODE_TYPE(FunctionScopeNode); TVM_REGISTER_NODE_TYPE(DataflowScopeNode); @@ -41,7 +44,9 @@ void IRBuilderNode::FillFuncNameParam(const Array& params, const std::strin if (!func_name.empty()) { this->func_.func_name = GlobalVar(func_name); } - + for (Var param : params) { + this->var_map_[param] = param; + } this->func_.params = params; } @@ -66,14 +71,24 @@ void IRBuilderNode::BuildBlock() { Optional InferShape(const Call& call, DiagnosticContext diag_ctx) { auto op_map = Op::GetAttrMap("FInferShape"); - Op op = Downcast(call->op); - return op_map[op](call, diag_ctx); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op](call, diag_ctx); + } + } + return NullOpt; } Type InferType(const Call& call, DiagnosticContext diag_ctx) { auto op_map = Op::GetAttrMap("FInferType"); - Op op = Downcast(call->op); - return op_map[op](call, diag_ctx); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op](call, diag_ctx); + } + } + return VoidType(); } Var IRBuilderNode::Emit(const Call& call) { @@ -98,9 +113,62 @@ Var IRBuilderNode::Emit(const Call& call) { var->checked_type_ = inferred_type; this->func_.bindings.emplace_back(VarBinding(var, call)); + this->var_map_[var] = call; return var; } +Var IRBuilderNode::EmitMatchShape(const Expr& value, const Array& pattern) { + Var var; + if (is_dataflow_) { + var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter_++)), NullOpt, NullOpt); + } else { + var = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt); + } + if (value->checked_type().as()) { + var->checked_type_ = ShapeType(Span()); + } else if (value->checked_type().as()){ + ShapeExpr shape = ShapeExpr(pattern); + var->shape_ = shape; + DataType dtype = (Downcast(value->checked_type()))->dtype; + var->checked_type_ = DynTensorType(pattern.size(), dtype); + } else { + this->diag_ctx_.EmitFatal(Diagnostic::Error(value->span) + << "The value passed to EmitMatchShape must be of DynTensorType or ShapeType."); + } + + MatchShape match_shape = MatchShape(value, pattern, var); + this->func_.bindings.emplace_back(match_shape); + return var; +} + +Var IRBuilderNode::Emit(const VarBinding& binding) { + if (!binding->var.as()) { + return EmitOutput(binding->value); + } else { + this->func_.bindings.emplace_back(binding); + this->var_map_[binding->var] = binding->value; + return binding->var; + } +} + +Var IRBuilderNode::Emit(const Var& var, const Call& call) { + Expr normalized_call = Normalize(call); + // Reuse the input var if the shape and type of the call matches the var + if (CanProveShapeEqual(var->shape(), call->shape()) && StructuralEqual()(var->checked_type(), call->checked_type())) { + this->func_.bindings.emplace_back(VarBinding(var, normalized_call)); + this->var_map_[var] = normalized_call; + return var; + } else { + Var new_var; + if (normalized_call->shape_.defined()) { + new_var->shape_ = normalized_call->shape_; + } + this->func_.bindings.emplace_back(VarBinding(new_var, normalized_call)); + this->var_map_[new_var] = normalized_call; + return new_var; + } +} + Var IRBuilderNode::EmitOutput(const Expr& output) { Var ret; if (is_dataflow_) { @@ -108,16 +176,69 @@ Var IRBuilderNode::EmitOutput(const Expr& output) { ret->shape_ = output->shape_; ret->checked_type_ = output->checked_type_; this->func_.bindings.emplace_back(VarBinding(ret, output)); + this->var_map_[ret] = output; } else { this->func_.ret = output; } return ret; } +Expr IRBuilderNode::LookupVar(const Var& var) { + auto it = this->var_map_.find(var); + if (it == this->var_map_.end()) { + this->diag_ctx_.EmitFatal(Diagnostic::Error(var->span) + << "The var to be looked up is not in the binding table."); + } + return it->second; +} + Function IRBuilderNode::Get() { return this->func_.func; } std::vector IRBuilderNode::GetBlocks() { return this->func_.binding_blocks; } +bool IRBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) { + if (lhs == rhs) { + return true; + } + const auto* lhs_shape = lhs.as(); + const auto* rhs_shape = rhs.as(); + if (lhs_shape && rhs_shape) { + size_t lhs_ndim = lhs_shape->values.size(); + size_t rhs_ndim = rhs_shape->values.size(); + if (lhs_ndim != rhs_ndim) { + return false; + } + arith::Analyzer analyzer; + for (size_t i = 0; i < lhs_ndim; ++i) { + PrimExpr lhs_dim = lhs_shape->values[i]; + PrimExpr rhs_dim = rhs_shape->values[i]; + if (!analyzer.CanProveEqual(lhs_dim, rhs_dim)) { + return false; + } + } + return true; + } + return false; +} + +Expr IRBuilderNode::Normalize(const Expr& expr) { + if (expr.as()) { + Call call = Downcast(expr); + // Shape inference + auto inferred_shape = InferShape(call, this->diag_ctx_); + if (inferred_shape.defined()) { + if (auto* shape_expr = inferred_shape.value().as()) { + call->shape_ = GetRef(shape_expr); + } + } + // Type inference + auto inferred_type = InferType(call, this->diag_ctx_); + call->checked_type_ = inferred_type; + return call; + } + return expr; +} + class FunctionScope::Internal { public: static void ExitScope(FunctionScope scope) { scope.ExitWithScope(); } @@ -151,6 +272,133 @@ void DataflowScope::EnterWithScope() { this->get()->ir_builder->BuildBlock(); } void DataflowScope::ExitWithScope() { this->get()->ir_builder->BuildBlock(); } +LazyIRBuilder LazyIRBuilderNode::Create(const DataflowBlock& block) { + LazyIRBuilder ret(make_object()); + ret->df_block_ = block; + return ret; +} + +Var LazyIRBuilderNode::Emit(const Call& call) { + if (is_rewrite_) { + index_++; + return IRBuilderNode::Emit(call); + } + Expr expr = Downcast(this->df_block_->bindings[index_])->value; + Call old_call = Downcast(expr); + if (call.same_as(old_call)) { + VarBinding binding = Downcast(this->df_block_->bindings[index_++]); + this->var_map_[binding->var] = binding->value; + return binding->var; + } + else { + is_rewrite_ = true; + for (int i = 0; i < index_; i++) { + Expr expr = Downcast(this->df_block_->bindings[i])->value; + IRBuilderNode::Emit(Downcast(expr)); + } + index_++; + return IRBuilderNode::Emit(call); + } +} + +Var LazyIRBuilderNode::Emit(const VarBinding& binding) { + if (!binding->var.as()) { + return IRBuilderNode::EmitOutput(binding->value); + } + if (is_rewrite_) { + index_++; + return IRBuilderNode::Emit(binding); + } + Binding old_binding = this->df_block_->bindings[index_]; + if (binding.same_as(old_binding)) { + index_++; + this->var_map_[binding->var] = binding->value; + return binding->var; + } + else { + is_rewrite_ = true; + for (int i = 0; i < index_; i++) { + if (!binding->var.as()) { + IRBuilderNode::EmitOutput(binding->value); + } else { + Expr expr = Downcast(this->df_block_->bindings[i])->value; + IRBuilderNode::Emit(Downcast(expr)); + } + } + index_++; + Call call = Downcast(binding->value); + return IRBuilderNode::Emit(call); + } +} + +Var LazyIRBuilderNode::Emit(const Var& var, const Call& call) { + if (is_rewrite_) { + index_++; + return IRBuilderNode::Emit(var, call); + } + Expr expr = Downcast(this->df_block_->bindings[index_])->value; + Call old_call = Downcast(expr); + if (call.same_as(old_call)) { + index_++; + this->var_map_[var] = call; + return var; + } + else { + is_rewrite_ = true; + for (int i = 0; i < index_; i++) { + VarBinding old_binding = Downcast(this->df_block_->bindings[i]); + // Reuse the old bindings + IRBuilderNode::Emit(old_binding); + } + index_++; + return IRBuilderNode::Emit(var, call); + } +} + +Var LazyIRBuilderNode::EmitOutput(const VarBinding& binding) { + if (is_rewrite_) { + index_++; + return IRBuilderNode::EmitOutput(binding->value); + } + Binding old_binding = this->df_block_->bindings[index_]; + if (binding.same_as(old_binding)) { + index_++; + this->var_map_[binding->var] = binding->value; + return binding->var; + } + else { + is_rewrite_ = true; + for (int i = 0; i < index_; i++) { + if (!binding->var.as()) { + IRBuilderNode::EmitOutput(binding->value); + } else { + Expr expr = Downcast(this->df_block_->bindings[i])->value; + IRBuilderNode::Emit(Downcast(expr)); + } + } + index_++; + return IRBuilderNode::EmitOutput(binding->value); + } +} + +void LazyIRBuilderNode::BuildBlock() { + if (!this->func_.bindings.empty()) { + if (is_dataflow_) { + if (is_rewrite_) { + this->func_.binding_blocks.emplace_back(DataflowBlock(this->func_.bindings)); + } + else { + this->func_.binding_blocks.emplace_back(this->df_block_); + } + } else { + this->func_.binding_blocks.emplace_back(BindingBlock(this->func_.bindings)); + } + this->func_.bindings.clear(); + } + this->dataflow_var_counter_ = 0; + this->is_dataflow_ = !this->is_dataflow_; +} + TVM_REGISTER_GLOBAL("relax.IRBuilderCreate").set_body_typed(IRBuilderNode::Create); TVM_REGISTER_GLOBAL("relax.IRBuilderFillFuncNameParam") @@ -166,11 +414,20 @@ TVM_REGISTER_GLOBAL("relax.IRBuilderEmit").set_body_typed([](IRBuilder builder, return builder->Emit(call); }); +TVM_REGISTER_GLOBAL("relax.IRBuilderEmitMatchShape").set_body_typed([](IRBuilder builder, const Expr& value, const Array& pattern) { + return builder->EmitMatchShape(value, pattern); +}); + TVM_REGISTER_GLOBAL("relax.IRBuilderEmitOutput") .set_body_typed([](IRBuilder builder, const Expr& output) { return builder->EmitOutput(output); }); +TVM_REGISTER_GLOBAL("relax.IRBuilderNormalize") + .set_body_typed([](IRBuilder builder, const Expr& expr) { + return builder->Normalize(expr); + }); + TVM_REGISTER_GLOBAL("relax.IRBuilderGet").set_body_typed([](IRBuilder builder) { return builder->Get(); }); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 242bb3249e..27c8201752 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -19,18 +19,40 @@ #include #include +#include "op_common.h" + namespace tvm { namespace relax { +bool EqualConstInt(const PrimExpr& lhs, int64_t value) { + if (const int64_t* pvalue = tir::as_const_int(lhs)) { + return pvalue[0] == value; + } + return false; +} + +bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + PrimExpr diff = lhs - rhs; + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + tvm::arith::Analyzer ana; + diff = ana.Simplify(diff); + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + return false; +} + // call_dps RELAY_REGISTER_OP("relax.call_dps") .set_num_inputs(3) -.add_argument("shape", "ShapeExpr", "The output shape.") +.add_argument("shape", "Expr", "The output shape.") .add_argument("func", "Expr", "The destination-passing-style function.") .add_argument("args", "Tuple", "The input arguments."); -Expr MakeCallDPS(ShapeExpr shape, Expr func, Tuple args) { +Expr MakeCallDPS(Expr shape, Expr func, Tuple args) { static const Op& op = Op::Get("relax.call_dps"); return Call(op, {shape, func, args}, {}, {}); } @@ -51,5 +73,19 @@ Expr MakeShapeOf(Expr expr) { TVM_REGISTER_GLOBAL("relax.op.shape_of") .set_body_typed(MakeShapeOf); -} // namespace relax -} // namespace tvm +// alloc_tensor + +RELAY_REGISTER_OP("relax.builtin.alloc_tensor") +.set_num_inputs(1) +.add_argument("shape", "Expr", "The shape of the tensor to allocate."); + +Expr MakeAllocTensor(Expr shape) { + static const Op& op = Op::Get("relax.builtin.alloc_tensor"); + return Call(op, {shape}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor") +.set_body_typed(MakeAllocTensor); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 074544cfc2..b2f7b4a45a 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -26,12 +26,17 @@ #define TVM_RELAX_OP_OP_COMMON_H_ #include +#include #include #include namespace tvm { namespace relax { +bool EqualConstInt(const PrimExpr& lhs, int64_t value); + +bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs); + /*! Quick helper macro * - Expose a positional make function to construct the node. * - Register op to the registry. diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 6818ff12de..7d60167878 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -24,8 +24,6 @@ #include "binary.h" -#include "../op_common.h" - namespace tvm { namespace relax { diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 75957b5305..0684d189f4 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -21,39 +21,14 @@ * \file binary.h * \brief shape and type deduction for binary broadcast operators. */ - -#include #include -#include #include -#include -#include #include "../op_common.h" namespace tvm { namespace relax { -bool EqualConstInt(const PrimExpr& lhs, int64_t value) { - if (const int64_t* pvalue = tir::as_const_int(lhs)) { - return pvalue[0] == value; - } - return false; -} - -bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { - PrimExpr diff = lhs - rhs; - if (const int64_t* pdiff = tir::as_const_int(diff)) { - return pdiff[0] == 0; - } - tvm::arith::Analyzer ana; - diff = ana.Simplify(diff); - if (const int64_t* pdiff = tir::as_const_int(diff)) { - return pdiff[0] == 0; - } - return false; -} - Optional InferShapeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { if (call->args.size() != 2) { diag_ctx.EmitFatal(Diagnostic::Error(call->span) diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc new file mode 100644 index 0000000000..0f2a9c4db2 --- /dev/null +++ b/src/relax/op/tensor/ternary.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.cc + * \brief ternary operators. + */ + +#include "ternary.h" + +namespace tvm { +namespace relax { + +RELAY_REGISTER_OP("relax.ewise_fma") +.set_num_inputs(3) +.add_argument("e1", "Expr", "The input expression") +.add_argument("e2", "Expr", "The input expression") +.add_argument("e3", "Expr", "The input expression") +.set_attr("FInferShape", InferShapeEwiseFMA) +.set_attr("FInferType", InferTypeEwiseFMA); + +Expr MakeEwiseFma(Expr expr1, Expr expr2, Expr expr3) { + static const Op& op = Op::Get("relax.ewise_fma"); + return Call(op, {expr1, expr2, expr3}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.ewise_fma") +.set_body_typed(MakeEwiseFma); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/ternary.h b/src/relax/op/tensor/ternary.h new file mode 100644 index 0000000000..649cdf83ec --- /dev/null +++ b/src/relax/op/tensor/ternary.h @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.h + * \brief shape and type deduction for ternary operators. + */ +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +Optional InferShapeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 3) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "EwiseFMA op should have 3 arguments"); + } + Expr shape0 = call->args[0]->shape(); + Expr shape1 = call->args[1]->shape(); + Expr shape2 = call->args[2]->shape(); + auto* s0 = shape0.as(); + auto* s1 = shape1.as(); + auto* s2 = shape2.as(); + if (s0 && s1 && s2) { + std::vector output_shape; + size_t ndim0 = s0->values.size(); + size_t ndim1 = s1->values.size(); + size_t ndim2 = s2->values.size(); + if (ndim0 != ndim1 || ndim1 != ndim2) { + LOG(INFO) << ndim0; + LOG(INFO) << ndim1; + LOG(INFO) << ndim2; + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 3 arguments of EwiseFMA must have the same number of dimensions"); + } + for (size_t i = 0; i < ndim0; ++i) { + PrimExpr dim0 = s0->values[i]; + PrimExpr dim1 = s1->values[i]; + PrimExpr dim2 = s2->values[i]; + if (EqualCheck(dim0, dim1) && EqualCheck(dim1, dim2)) { + output_shape.push_back(dim0); + } else { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 3 arguments of EwiseFMA must have the same shape"); + } + } + return ShapeExpr(Array(output_shape.begin(), output_shape.end())); + } else { + return NullOpt; + } +} + +Type InferTypeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 3) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "EwiseFMA op should have 3 arguments"); + } + Type type0 = call->args[0]->checked_type(); + Type type1 = call->args[1]->checked_type(); + Type type2 = call->args[2]->checked_type(); + auto* t0 = type0.as(); + auto* t1 = type1.as(); + auto* t2 = type2.as(); + if (!t0 || !t1 || !t2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 3 arguments of EwiseFMA should be DynTensor"); + } + + DataType output_dtype; + if (t0->IsUnknownDtype() || t1->IsUnknownDtype() || t2->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t0->dtype != t1->dtype || t1->dtype != t2->dtype) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Data types " << t0->dtype << ", " << t1->dtype << ", and " << t2->dtype + << " must be equal for EwiseFMA"); + } else { + output_dtype = t0->dtype; + } + + int output_rank; + if (t0->IsUnknownRank() || t1->IsUnknownRank() || t2->IsUnknownRank()) { + output_rank = -1; + } else { + output_rank = t0->rank; + } + return DynTensorType(output_rank, output_dtype); +} + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py new file mode 100644 index 0000000000..b9d31ab3d3 --- /dev/null +++ b/tests/python/relax/test_analysis.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm import relax as rx +from tvm.ir import structural_equal + +def test_dispatch_var(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + v0 = rx.Var("v0", [m, n], dtype0) + v1 = rx.DataflowVar("v1", [n], dtype1) + t = None + def fvisit(e): + nonlocal t + t = type(e) + rx.analysis.post_order_visit(v0, fvisit) + assert t == type(v0) + rx.analysis.post_order_visit(v1, fvisit) + assert t == type(v1) + + +def test_post_order_visit(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [n], dtype1) + ib = rx.IRBuilder() + with ib.function([x, y]): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.add(x, y)) + lv1 = ib.emit(rx.op.multiply(lv0, y)) + gv0 = ib.emit_output(lv1) + ib.emit_output(gv0) + expr = ib.get() + + names = [] + def fvisit(e): + nonlocal names + if isinstance(e, tvm.ir.op.Op): + names.append(e.name) + rx.analysis.post_order_visit(expr.body, fvisit) + assert names == ["relax.add", "relax.multiply"] + + +def test_fma_rewrite(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=2, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [m, n], dtype1) + ib = rx.IRBuilder() + with ib.function([x, y]): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.multiply(x, y)) + lv1 = ib.emit(rx.op.add(lv0, y)) + gv0 = ib.emit_output(lv1) + ib.emit_output(gv0) + expr = ib.get() + + # before rewrite + v0 = expr.body.blocks[0].bindings[1].var + s0 = expr.body.blocks[0].bindings[1].value + assert isinstance(s0, tvm.relay.Call) + assert s0.op.name == "relax.add" + assert structural_equal(v0.shape, rx.ShapeExpr([m, n])) + assert structural_equal(s0.shape, rx.ShapeExpr([m, n])) + assert structural_equal(gv0.shape, rx.ShapeExpr([m, n])) + + # after rewrite + func = rx.analysis.fma_rewrite(expr) + + v1 = func.body.blocks[0].bindings[1].var + s1 = func.body.blocks[0].bindings[1].value + assert isinstance(s1, tvm.relay.Call) + assert s1.op.name == "relax.ewise_fma" + assert structural_equal(v1.shape, rx.ShapeExpr([m, n])) + assert structural_equal(s1.shape, rx.ShapeExpr([m, n])) + + # The var binded to the fma call is reused because the shape + # and type of var are unchanged after rewriting + assert lv1 == v0 + + assert type(func.body.blocks[0].bindings[2].var) == rx.Var + assert type(func.body.blocks[0].bindings[2].value) == rx.DataflowVar + +def test_lazy_irbuilder(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=2, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [m, n], dtype1) + ib = rx.IRBuilder() + + # This program should not be rewritten by the fma_rewriter + with ib.function([x, y]): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.multiply(x, y)) + lv1 = ib.emit(rx.op.multiply(lv0, y)) + gv0 = ib.emit_output(lv1) + ib.emit_output(gv0) + expr = ib.get() + + # before rewrite + block0 = expr.body.blocks[0] + v0 = expr.body.blocks[0].bindings[1].var + s0 = expr.body.blocks[0].bindings[1].value + assert isinstance(s0, tvm.relay.Call) + assert s0.op.name == "relax.multiply" + + # after rewrite (the bindings and the dataflow block are reused) + func = rx.analysis.fma_rewrite(expr) + + block1 = func.body.blocks[0] + v1 = func.body.blocks[0].bindings[1].var + s1 = func.body.blocks[0].bindings[1].value + + # the dataflow block and vars are reused + assert block0 == block1 + assert v1 == v0 + assert s1 == s0 + +def test_explicit_memory_rewrite(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + shape_anno = [m, n] + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", shape_anno, type_anno) + ib = rx.IRBuilder() + with ib.function(x): + with ib.dataflow() as df: + lv0 = rx.call_dps([m, n], rx.extern("test.op.identity"), [x]) + gv0 = ib.emit_output(lv0) + ib.emit_output(gv0) + expr = ib.get() + + # before rewrite + v0 = expr.body.blocks[0].bindings[0].var + s0 = expr.body.blocks[0].bindings[0].value + assert isinstance(s0, tvm.relay.Call) + assert s0.op.name == "relax.call_dps" + + # after rewrite + func = rx.analysis.explicit_memory_rewrite(expr) + + # the dataflow block has changed to binding block due to the rewriting + block = func.body.blocks[0] + assert isinstance(block, rx.BindingBlock) + + s1 = block.bindings[0].value + assert isinstance(s1, tvm.relay.Call) + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], rx.ShapeExpr) + assert structural_equal(s1.args[0], rx.ShapeExpr(shape_anno)) + s2 = block.bindings[1].value + assert s2.op.global_symbol == "test.op.identity" + + +if __name__ == "__main__": + test_dispatch_var() + test_post_order_visit() + test_fma_rewrite() + test_lazy_irbuilder() + test_explicit_memory_rewrite() diff --git a/tests/python/relax/test_irbuilder.py b/tests/python/relax/test_irbuilder.py index 4e00049882..715fec905c 100644 --- a/tests/python/relax/test_irbuilder.py +++ b/tests/python/relax/test_irbuilder.py @@ -17,6 +17,7 @@ import tvm from tvm import tir +from tvm import relay from tvm import relax as rx @@ -40,17 +41,23 @@ def test_dataflow_block(): lv1 = ib.emit(rx.op.multiply(lv0, y)) assert lv1.name_hint == "lv1" - gv0 = ib.emit_output(lv1) + + b0 = ib.match_shape(x, [m, n]) + gv0 = ib.emit_output(lv1) assert gv0.name_hint == "gv0" assert gv0.shape[0] == m assert gv0.shape[1] == n assert gv0.checked_type.rank == 2 assert gv0.checked_type.dtype == "float16" + assert isinstance(gv0, rx.Var) blocks = ib.get_blocks() assert len(blocks) == 1 - assert len(blocks[-1].bindings) == 3 + assert len(blocks[-1].bindings) == 4 + for i in [0, 1, 3]: + assert isinstance(blocks[-1].bindings[i], rx.VarBinding) + assert isinstance(blocks[-1].bindings[2], rx.MatchShape) def test_function_single_block(): @@ -122,7 +129,7 @@ def test_function_multi_blocks(): assert len(func.body.blocks[2].bindings) == 2 -def test_binary_shape_deduction(): +def test_binary_shape_type_deduction(): m = tir.Var("m", "int32") n = tir.Var("n", "int32") k = tir.Var("k", "int32") @@ -130,7 +137,7 @@ def test_binary_shape_deduction(): dtype1 = rx.DynTensorType(rank=1, dtype="float16") x = rx.Var("x", [m, 1], dtype0) y = rx.Var("y", [n], dtype1) - z = rx.Var("z", [5], dtype0) + z = rx.Var("z", [5], dtype1) w = rx.Var("w", [k], dtype1) ib = rx.IRBuilder() @@ -139,23 +146,103 @@ def test_binary_shape_deduction(): lv0 = ib.emit(rx.op.add(x, y)) assert lv0.shape[0] == m assert lv0.shape[1] == n + assert isinstance(lv0.checked_type, rx.DynTensorType) + assert lv0.checked_type.rank == 2 + assert lv0.checked_type.dtype == "float16" lv1 = ib.emit(rx.op.multiply(x, z)) assert lv1.shape[0] == m assert lv1.shape[1] == 5 + assert isinstance(lv1.checked_type, rx.DynTensorType) + assert lv1.checked_type.rank == 2 + assert lv1.checked_type.dtype == "float16" lv2 = ib.emit(rx.op.multiply(z, w)) assert isinstance(lv2.shape, tvm.relay.Call) + assert isinstance(lv2.checked_type, rx.DynTensorType) + assert lv2.checked_type.rank == 1 + assert lv2.checked_type.dtype == "float16" lv3 = ib.emit(rx.op.multiply(y, w)) assert isinstance(lv3.shape, tvm.relay.Call) - gv0 = ib.emit_output(lv3) + assert isinstance(lv3.checked_type, rx.DynTensorType) + assert lv3.checked_type.rank == 1 + assert lv3.checked_type.dtype == "float16" + + gv0 = ib.emit_output(lv3) + ib.emit_output(gv0) assert isinstance(gv0.shape, tvm.relay.Call) + assert isinstance(gv0.checked_type, rx.DynTensorType) + assert gv0.checked_type.rank == 1 + assert gv0.checked_type.dtype == "float16" + + +def test_emit_match_shape(): + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + type_anno0 = rx.DynTensorType(-1, "float32") + x = rx.Var("tensor_value", type_annotation=type_anno0) + shape_anno = [16, 8] + y = rx.Var("shape_value", type_annotation=rx.ShapeType(), shape_annotation=shape_anno) + ib = rx.IRBuilder() + + with ib.function([x, y]): + with ib.dataflow() as df: + # lv0: Tensor[(m, n), "float32"] = + # match_shape(x: Tensor[_, "float32"], [m, n]) + lv0 = ib.match_shape(x, [m, n]) + assert isinstance(lv0, rx.DataflowVar) + assert lv0.shape[0] == m + assert lv0.shape[1] == n + assert lv0.checked_type.rank == 2 + assert lv0.checked_type.dtype == "float32" + + # lv1: Shape = match_shape(shape, [m, n]) + lv1 = ib.match_shape(y, [m, n]) + assert lv1.checked_type == rx.ShapeType() + gv0 = ib.emit_output(lv1) + + ib.emit_output(gv0) + + block = ib.get_blocks()[-1] + b0, b1 = block.bindings[:2] + assert isinstance(b0, rx.MatchShape) + assert isinstance(b1, rx.MatchShape) + + assert b0.value == x + assert b0.pattern[0] == m + assert b0.pattern[1] == n + assert b0.var == lv0 + + assert b1.value == y + assert b1.pattern[0] == m + assert b1.pattern[1] == n + assert b1.var == lv1 + + +def test_normalize(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [n], dtype1) + ib = rx.IRBuilder() + + add_call = rx.op.multiply(x, y) + assert isinstance(add_call.shape, relay.Call) + + ib.normalize(add_call) + assert isinstance(add_call.shape, rx.ShapeExpr) + assert add_call.shape[0] == m + assert add_call.shape[1] == n if __name__ == "__main__": test_dataflow_block() test_function_single_block() test_function_multi_blocks() - test_binary_shape_deduction() + test_binary_shape_type_deduction() + test_emit_match_shape() + test_normalize()