diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index dac2b03c9f..b4a143727a 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -209,8 +209,8 @@ class ExprMutator : public ExprFunctor { */ virtual Type VisitType(const Type& t); virtual void VisitBinding(const Binding& binding); - virtual void VisitVarBinding(const VarBinding& binding); - virtual void VisitMatchShape(const MatchShape& binding); + virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder); + virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder); virtual BindingBlock VisitBindingBlock(const BindingBlock& block); virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); @@ -223,20 +223,11 @@ class ExprMutator : public ExprFunctor { class DataflowMutator : public ExprMutator { public: virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); - virtual void VisitVarBinding(const VarBinding& binding); - /*! \brief Insert a call node, a new binding will be created. */ - Var Insert(Call value); - /*! \brief Emit a call node, the var could be reused depending on if the shape/type of \p value is - * changed or not. */ - Var Emit(Var var, Call value); - /*! \brief Emit a binding. */ - Var Emit(VarBinding binding); - /*! \brief Look up the value binded to a var. */ - Expr LookupVar(Var var); - /*! \brief Switch from building a DataflowBlock to building a BindingBlock. */ - void EmitBindingBlock(); + virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder); protected: + /*! \brief Look up the value binded to a var. */ + Expr LookupVar(Var var); // A remapping table: pre var -> post var std::unordered_map pre_post_var_map_; }; diff --git a/include/tvm/relax/ir_builder.h b/include/tvm/relax/ir_builder.h index 985af62c30..596eae2064 100644 --- a/include/tvm/relax/ir_builder.h +++ b/include/tvm/relax/ir_builder.h @@ -75,16 +75,11 @@ class IRBuilderNode : public Object { */ virtual void BuildBlock(); /*! - * \brief Emit a call node. - * \param call The CallNode to be emitted. + * \brief Emit a Call, and return a newly created Var binded to the Call. + * \param call The Call to be emitted. * \return The variable being created and binded to \p call. */ virtual Var Emit(const Call& call); - /*! - * \brief Emit a MatchShape. - * \param match_shape The MatchShape to be emitted. - */ - void Emit(const MatchShape& match_shape); /*! * \brief Emit a var binding. * \param binding The VarBinding to be emitted. @@ -92,12 +87,18 @@ class IRBuilderNode : public Object { */ virtual Var Emit(const VarBinding& binding); /*! - * \brief Emit a call node, and bind it to a Var. - * \param var The VarNode to be binded with. \p var is reused implicitly. - * \param call The CallNode to be emitted. - * \return The VarNode to be binded with \p var. + * \brief Emit a Call, and bind it to a Var. + * \param var The Var to be binded with. \p var is reused implicitly if the shape + * and type of \p call matches \p var. Otherwise a new Var is created. + * \param call The Call to be emitted. + * \return The Var to be binded with \p var. */ virtual Var Emit(const Var& var, const Call& call); + /*! + * \brief Emit a MatchShape. + * \param match_shape The MatchShape to be emitted. + */ + void Emit(const MatchShape& match_shape); /*! * \brief Generate an output for the current dataflow block or function. * \param output The output variable of the block/function. @@ -116,6 +117,19 @@ class IRBuilderNode : public Object { * \brief Get binding blocks being built. */ std::vector GetBlocks(); + /*! + * \brief Check if two shape expressions can be proven equal at compile time. + * \param lhs The input lhs shape. + * \param rhs The input rhs shape. + * \return Whether we can prove lhs shape == rhs shape. + */ + bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs); + /*! + * \brief Normalize an Expr to complete its shape and type. + * \param expr The input expr. + * \return The expr with normalized shape and type. + */ + Expr Normalize(const Expr& expr); /*! * \brief Create a IRBuilder. * \return The created IRBuilder. @@ -225,27 +239,33 @@ class DataflowScope : public ObjectRef { class LazyIRBuilderNode : public IRBuilderNode { public: /*! - * \brief Emit a call node in a copy-on-write way. + * \brief Emit a Call in a copy-on-write way. * If no bindings in a dataflow block need to be rewritten, reuse the original variable instead of * emiting one. If any binding in the block needs to be rewritten, reconstruct the whole block * from scratch by emiting all previous bindings. - * \param call The CallNode to be emitted. + * \param call The Call to be emitted. * \return The variable being created and binded to \p call. */ virtual Var Emit(const Call& call); /*! * \brief Emit a var binding in a copy-on-write way. * \param binding The VarBinding to be emitted. - * \return The VarNode of the VarBinding \p binding. + * \return The Var of the \p binding. */ virtual Var Emit(const VarBinding& binding); /*! - * \brief Emit a call node, and bind it to a Var in a copy-on-write way. - * \param var The VarNode to be binded with. - * \param call The CallNode to be emitted. - * \return The VarNode to be binded with \p var. + * \brief Emit a Call, and bind it to a Var in a copy-on-write way. + * \param var The Var to be binded with. + * \param call The Call to be emitted. + * \return The Var to be binded with \p var. */ virtual Var Emit(const Var& var, const Call& call); + /*! + * \brief Emit an output for the current dataflow block or function in a copy-on-write way. + * \param binding The VarBinding to be emitted. + * \return The variable being binded to \p output. + */ + virtual Var EmitOutput(const VarBinding& binding); /*! * \brief Build a binding block. */ diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 9249b7579c..6f0b78df7d 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -47,7 +47,7 @@ def checked_type(self): """ ret = self.checked_type_ if ret is None: - raise ValueError("The type checker has not populated" " the checked_type for this node") + raise ValueError("The type checker has not populated the checked_type for this node") return ret @property diff --git a/python/tvm/relax/ir_builder.py b/python/tvm/relax/ir_builder.py index 7f1d11eeff..7ebc0069c8 100644 --- a/python/tvm/relax/ir_builder.py +++ b/python/tvm/relax/ir_builder.py @@ -161,6 +161,22 @@ def emit_output(self, output = Tuple(output) return _ffi_api.IRBuilderEmitOutput(self, output) + def normalize(self, + expr: Expr) -> Expr: + """Normalize an Expr to complete its shape and type. + + Parameters + ---------- + expr : Expr + The input expr. + + Returns + ------- + ret : Expr + The expr with normalized shape and type. + """ + return _ffi_api.IRBuilderNormalize(self, expr) + def get(self) -> Function: """Return the function being built. diff --git a/src/relax/expr_functor.cc b/src/relax/expr_functor.cc index 747cae1513..5068b423b4 100644 --- a/src/relax/expr_functor.cc +++ b/src/relax/expr_functor.cc @@ -328,26 +328,24 @@ Type ExprMutator::VisitType(const Type& t) { return t; } void ExprMutator::VisitBinding(const Binding& binding) { Binding new_binding; if (binding.as()) { - this->VisitVarBinding(Downcast(binding)); + this->VisitVarBinding(Downcast(binding), this->irbuilder_); } else if (binding.as()) { - this->VisitMatchShape(Downcast(binding)); + this->VisitMatchShape(Downcast(binding), this->irbuilder_); } else { LOG(FATAL) << "Wrong type."; } } -void ExprMutator::VisitVarBinding(const VarBinding& binding) { +Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) { Expr new_value = this->Mutate(binding->value); - - if (new_value.as()) { - new_value = this->irbuilder_->Emit(Downcast(new_value)); - } if (!binding->var.as()) { - this->irbuilder_->EmitOutput(new_value); + return ir_builder->EmitOutput(new_value); + } else { + return ir_builder->Emit(Downcast(new_value)); } } -void ExprMutator::VisitMatchShape(const MatchShape& binding) { +void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder) { this->Mutate(binding->value); } @@ -366,7 +364,7 @@ BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) { With scope(this->irbuilder_); for (auto binding : block->bindings) { if (binding.as()) { - this->VisitVarBinding(Downcast(binding)); + this->VisitVarBinding(Downcast(binding), this->irbuilder_); } } } @@ -383,46 +381,27 @@ Expr ExprMutator::VisitExpr(const Expr& expr) { // DataflowMutator BindingBlock DataflowMutator::VisitDataflowBlock(const DataflowBlock& block) { - return ExprMutator::VisitDataflowBlock(block); -} - -void DataflowMutator::VisitVarBinding(const VarBinding& binding) { - Expr new_value = this->Mutate(binding->value); - Var new_var; - if (new_value.as()) { - new_var = this->irbuilder_->Emit(Downcast(new_value)); - } - if (!binding->var.as()) { - new_var = this->irbuilder_->EmitOutput(new_value); - } - pre_post_var_map_[binding->var] = new_var; -} - -Var DataflowMutator::Insert(Call value) { - Var var = this->irbuilder_->Emit(value); - return var; -} - -Var DataflowMutator::Emit(Var var, Call value) { - Var new_var; - // TODO: make shape and type check right - if (var->shape() == value->shape() && var->checked_type() == value->checked_type()) { - new_var = this->irbuilder_->Emit(var, value); - } else { - new_var = this->irbuilder_->Emit(value); + this->irbuilder_ = LazyIRBuilderNode::Create(block); + { + With scope(this->irbuilder_); + for (auto binding : block->bindings) { + if (auto* var_binding = binding.as()) { + Var var = this->VisitVarBinding(Downcast(binding), this->irbuilder_); + this->pre_post_var_map_[var_binding->var] = var; + } + } } - pre_post_var_map_[var] = new_var; - return new_var; + return this->irbuilder_->GetBlocks().back(); } -Var DataflowMutator::Emit(VarBinding binding) { +Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) { Expr new_value = this->Mutate(binding->value); Var new_var; if (new_value.as()) { - new_var = this->irbuilder_->Emit(binding); + new_var = ir_builder->Emit(Downcast(new_value)); } if (!binding->var.as()) { - new_var = this->irbuilder_->EmitOutput(new_value); + new_var = ir_builder->EmitOutput(new_value); } pre_post_var_map_[binding->var] = new_var; return new_var; @@ -437,9 +416,6 @@ Expr DataflowMutator::LookupVar(Var var) { } } -void DataflowMutator::EmitBindingBlock() { - this->irbuilder_->is_dataflow_ = false; -} // ================== // EwiseFMARewriter @@ -458,7 +434,7 @@ void DataflowMutator::EmitBindingBlock() { // z0 = ewise_fma(a, lv0, c) class EwiseFMARewriter : public DataflowMutator { - void VisitVarBinding(const VarBinding& binding) override { + Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { static const Op& add_op = Op::Get("relax.add"); static const Op& multiply_op = Op::Get("relax.multiply"); static const Op& ewise_fma_op = Op::Get("relax.ewise_fma"); @@ -470,12 +446,10 @@ class EwiseFMARewriter : public DataflowMutator { const CallNode* op2 = value.as(); if (op2 && op2->op == multiply_op) { Call fma_call = Call(ewise_fma_op, {op2->args[0], op2->args[1], op1->args[1]}, {}, {}); - // Complete(fma_call); - Emit(binding->var, fma_call); - return; + return ir_builder->Emit(binding->var, fma_call); } } - Emit(binding); + return ir_builder->Emit(binding); } }; @@ -491,11 +465,10 @@ TVM_REGISTER_GLOBAL("relax.analysis.fma_rewrite") // ================== // ExplicitMemMutator // Example: -// lv1: Tensor[(m*n,)] = rx.call_dps((m*n,), extern_packed("flatten"), [lv0]) +// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x)) // --> -// storage0 = rx.call(extern_packed("alloc_storage"), size=[n*m], device=cpu) -// lv1 = rx.call(extern_packed("alloc_tensor"), storage0, 0, [m*n,], f32) -// rx.call(extern_packed("flatten"), lv0, lv1) +// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m]) +// rx.call_packed(op.identity, x, lv0) class ExplicitMemMutator : public DataflowMutator { Expr ComputeStorageSize(const Expr& shape, const Type& type) const { @@ -523,24 +496,21 @@ class ExplicitMemMutator : public DataflowMutator { return ret; } - void VisitVarBinding(const VarBinding& binding) override { + Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { static const Op& call_dps_op = Op::Get("relax.call_dps"); - static const Op& alloc_storage_op = Op::Get("relax.alloc_storage"); - static const Op& alloc_tensor_op = Op::Get("relax.alloc_tensor"); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); const CallNode* op = binding->value.as(); if(op && op->op == call_dps_op) { - // convert current DataflowBlock into an impure BindingBlock - this->EmitBindingBlock(); - + // switch current DataflowBlock to an impure BindingBlock + ir_builder->is_dataflow_ = false; ShapeExpr output_shape = Downcast(op->args[0]); Type arg_type = Downcast(op->args[2])->fields[0]->checked_type(); Expr output_size = ComputeStorageSize(output_shape, arg_type); - Var storage = Insert(Call(alloc_storage_op, {output_size})); - Var tensor = Insert(Call(alloc_tensor_op, {storage, relay::Constant(0), op->args[0]})); - Emit(binding->var, Call(op->args[1], {tensor, op->args[2]})); + Var tensor = ir_builder->Emit(Call(alloc_tensor_op, {op->args[0]})); + return ir_builder->Emit(binding->var, Call(op->args[1], {op->args[2], tensor})); } - Emit(binding); + return ir_builder->Emit(binding); } }; diff --git a/src/relax/ir_builder.cc b/src/relax/ir_builder.cc index 418a7c0425..83c20664d6 100644 --- a/src/relax/ir_builder.cc +++ b/src/relax/ir_builder.cc @@ -24,6 +24,7 @@ #include #include #include +#include namespace tvm { namespace relax { @@ -71,20 +72,22 @@ Optional InferShape(const Call& call, DiagnosticContext diag_ctx) { auto op_map = Op::GetAttrMap("FInferShape"); if (call->op.as()) { Op op = Downcast(call->op); - return op_map[op](call, diag_ctx); - } else { - return NullOpt; - } + if (op_map.count(op)) { + return op_map[op](call, diag_ctx); + } + } + return NullOpt; } Type InferType(const Call& call, DiagnosticContext diag_ctx) { auto op_map = Op::GetAttrMap("FInferType"); if (call->op.as()) { Op op = Downcast(call->op); - return op_map[op](call, diag_ctx); - } else { - return VoidType(); + if (op_map.count(op)) { + return op_map[op](call, diag_ctx); + } } + return VoidType(); } Var IRBuilderNode::Emit(const Call& call) { @@ -118,15 +121,31 @@ void IRBuilderNode::Emit(const MatchShape& match_shape) { } Var IRBuilderNode::Emit(const VarBinding& binding) { - this->func_.bindings.emplace_back(binding); - this->var_map_[binding->var] = binding->value; - return binding->var; + if (!binding->var.as()) { + return EmitOutput(binding->value); + } else { + this->func_.bindings.emplace_back(binding); + this->var_map_[binding->var] = binding->value; + return binding->var; + } } Var IRBuilderNode::Emit(const Var& var, const Call& call) { - this->func_.bindings.emplace_back(VarBinding(var, call)); - this->var_map_[var] = call; - return var; + Expr normalized_call = Normalize(call); + // Reuse the input var if the shape and type of the call matches the var + if (CanProveShapeEqual(var->shape(), call->shape()) && StructuralEqual()(var->checked_type(), call->checked_type())) { + this->func_.bindings.emplace_back(VarBinding(var, normalized_call)); + this->var_map_[var] = normalized_call; + return var; + } else { + Var new_var; + if (normalized_call->shape_.defined()) { + new_var->shape_ = normalized_call->shape_; + } + this->func_.bindings.emplace_back(VarBinding(new_var, normalized_call)); + this->var_map_[new_var] = normalized_call; + return new_var; + } } Var IRBuilderNode::EmitOutput(const Expr& output) { @@ -156,6 +175,49 @@ Function IRBuilderNode::Get() { return this->func_.func; } std::vector IRBuilderNode::GetBlocks() { return this->func_.binding_blocks; } +bool IRBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) { + if (lhs == rhs) { + return true; + } + const auto* lhs_shape = lhs.as(); + const auto* rhs_shape = rhs.as(); + if (lhs_shape && rhs_shape) { + size_t lhs_ndim = lhs_shape->values.size(); + size_t rhs_ndim = rhs_shape->values.size(); + if (lhs_ndim != rhs_ndim) { + return false; + } + arith::Analyzer analyzer; + for (size_t i = 0; i < lhs_ndim; ++i) { + PrimExpr lhs_dim = lhs_shape->values[i]; + PrimExpr rhs_dim = rhs_shape->values[i]; + if (!analyzer.CanProveEqual(lhs_dim, rhs_dim)) { + return false; + } + } + return true; + } + return false; +} + +Expr IRBuilderNode::Normalize(const Expr& expr) { + if (expr.as()) { + Call call = Downcast(expr); + // Shape inference + auto inferred_shape = InferShape(call, this->diag_ctx_); + if (inferred_shape.defined()) { + if (auto* shape_expr = inferred_shape.value().as()) { + call->shape_ = GetRef(shape_expr); + } + } + // Type inference + auto inferred_type = InferType(call, this->diag_ctx_); + call->checked_type_ = inferred_type; + return call; + } + return expr; +} + class FunctionScope::Internal { public: static void ExitScope(FunctionScope scope) { scope.ExitWithScope(); } @@ -219,6 +281,9 @@ Var LazyIRBuilderNode::Emit(const Call& call) { } Var LazyIRBuilderNode::Emit(const VarBinding& binding) { + if (!binding->var.as()) { + return IRBuilderNode::EmitOutput(binding->value); + } if (is_rewrite_) { index_++; return IRBuilderNode::Emit(binding); @@ -232,8 +297,12 @@ Var LazyIRBuilderNode::Emit(const VarBinding& binding) { else { is_rewrite_ = true; for (int i = 0; i < index_; i++) { - Expr expr = Downcast(this->df_block_->bindings[i])->value; - IRBuilderNode::Emit(Downcast(expr)); + if (!binding->var.as()) { + IRBuilderNode::EmitOutput(binding->value); + } else { + Expr expr = Downcast(this->df_block_->bindings[i])->value; + IRBuilderNode::Emit(Downcast(expr)); + } } index_++; Call call = Downcast(binding->value); @@ -256,14 +325,41 @@ Var LazyIRBuilderNode::Emit(const Var& var, const Call& call) { else { is_rewrite_ = true; for (int i = 0; i < index_; i++) { - Expr expr = Downcast(this->df_block_->bindings[i])->value; - IRBuilderNode::Emit(Downcast(expr)); + VarBinding old_binding = Downcast(this->df_block_->bindings[i]); + // Reuse the old bindings + IRBuilderNode::Emit(old_binding); } index_++; return IRBuilderNode::Emit(var, call); } } +Var LazyIRBuilderNode::EmitOutput(const VarBinding& binding) { + if (is_rewrite_) { + index_++; + return IRBuilderNode::EmitOutput(binding->value); + } + Binding old_binding = this->df_block_->bindings[index_]; + if (binding.same_as(old_binding)) { + index_++; + this->var_map_[binding->var] = binding->value; + return binding->var; + } + else { + is_rewrite_ = true; + for (int i = 0; i < index_; i++) { + if (!binding->var.as()) { + IRBuilderNode::EmitOutput(binding->value); + } else { + Expr expr = Downcast(this->df_block_->bindings[i])->value; + IRBuilderNode::Emit(Downcast(expr)); + } + } + index_++; + return IRBuilderNode::EmitOutput(binding->value); + } +} + void LazyIRBuilderNode::BuildBlock() { if (!this->func_.bindings.empty()) { if (is_dataflow_) { @@ -306,6 +402,11 @@ TVM_REGISTER_GLOBAL("relax.IRBuilderEmitOutput") return builder->EmitOutput(output); }); +TVM_REGISTER_GLOBAL("relax.IRBuilderNormalize") + .set_body_typed([](IRBuilder builder, const Expr& expr) { + return builder->Normalize(expr); + }); + TVM_REGISTER_GLOBAL("relax.IRBuilderGet").set_body_typed([](IRBuilder builder) { return builder->Get(); }); diff --git a/src/relax/op/memory/memory.cc b/src/relax/op/memory/memory.cc deleted file mode 100644 index d97c560a51..0000000000 --- a/src/relax/op/memory/memory.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file memory.cc - * \brief memory related operators. - */ - -#include "memory.h" - -namespace tvm { -namespace relax { - -// alloc_storage - -RELAY_REGISTER_OP("relax.alloc_storage") -.set_num_inputs(1) -.add_argument("size", "Expr", "The size of the storage to allocate.") -.set_attr("FInferShape", InferShapeAllocStorage) -.set_attr("FInferType", InferTypeAllocStorage); - -Expr MakeAllocStorage(Expr size) { - static const Op& op = Op::Get("relax.alloc_storage"); - return Call(op, {size}, {}, {}); -} - -TVM_REGISTER_GLOBAL("relax.op.alloc_storage") -.set_body_typed(MakeAllocStorage); - -// alloc_tensor - -RELAY_REGISTER_OP("relax.alloc_tensor") -.set_num_inputs(1) -.add_argument("storage", "Var", "The storage to allocate from.") -.add_argument("offset", "Expr", "The offset into the backing storage.") -.add_argument("shape", "Expr", "The shape of the tensor to allocate.") -.set_attr("FInferShape", InferShapeAllocTensor) -.set_attr("FInferType", InferTypeAllocTensor); - -Expr MakeAllocTensor(Var storage, Expr offset, Expr shape) { - static const Op& op = Op::Get("relax.alloc_tensor"); - return Call(op, {storage, offset, shape}, {}, {}); -} - -TVM_REGISTER_GLOBAL("relax.op.alloc_tensor") -.set_body_typed(MakeAllocTensor); - -} // namespace relax -} // namespace tvm diff --git a/src/relax/op/memory/memory.h b/src/relax/op/memory/memory.h deleted file mode 100644 index 6094914a8d..0000000000 --- a/src/relax/op/memory/memory.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file memory.h - * \brief shape and type deduction for memory related operators. - */ -#include -#include - -#include "../op_common.h" - -namespace tvm { -namespace relax { - -Optional InferShapeAllocStorage(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "AllocStorage op should have 1 argument"); - } - return call->args[0]; -} - -Type InferTypeAllocStorage(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "AllocStorage op should have 1 argument"); - } - DataType output_dtype; - return DynTensorType(1, output_dtype); -} - -Optional InferShapeAllocTensor(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 3) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "AllocTensor op should have 3 argument"); - } - return call->args[2]; -} - -Type InferTypeAllocTensor(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 3) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "AllocTensor op should have 3 argument"); - } - int output_rank; - if (auto* shape = call->args[2].as()) { - output_rank = shape->values.size(); - } else { - output_rank = -1; - } - DataType output_dtype; - return DynTensorType(output_rank, output_dtype); -} - -} // namespace relax -} // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index f0d010ff07..27c8201752 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -73,5 +73,19 @@ Expr MakeShapeOf(Expr expr) { TVM_REGISTER_GLOBAL("relax.op.shape_of") .set_body_typed(MakeShapeOf); +// alloc_tensor + +RELAY_REGISTER_OP("relax.builtin.alloc_tensor") +.set_num_inputs(1) +.add_argument("shape", "Expr", "The shape of the tensor to allocate."); + +Expr MakeAllocTensor(Expr shape) { + static const Op& op = Op::Get("relax.builtin.alloc_tensor"); + return Call(op, {shape}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor") +.set_body_typed(MakeAllocTensor); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index a21ea69e64..b9d31ab3d3 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -17,6 +17,7 @@ import tvm from tvm import tir from tvm import relax as rx +from tvm.ir import structural_equal def test_dispatch_var(): m = tir.Var("m", "int32") @@ -81,6 +82,9 @@ def test_fma_rewrite(): s0 = expr.body.blocks[0].bindings[1].value assert isinstance(s0, tvm.relay.Call) assert s0.op.name == "relax.add" + assert structural_equal(v0.shape, rx.ShapeExpr([m, n])) + assert structural_equal(s0.shape, rx.ShapeExpr([m, n])) + assert structural_equal(gv0.shape, rx.ShapeExpr([m, n])) # after rewrite func = rx.analysis.fma_rewrite(expr) @@ -89,6 +93,12 @@ def test_fma_rewrite(): s1 = func.body.blocks[0].bindings[1].value assert isinstance(s1, tvm.relay.Call) assert s1.op.name == "relax.ewise_fma" + assert structural_equal(v1.shape, rx.ShapeExpr([m, n])) + assert structural_equal(s1.shape, rx.ShapeExpr([m, n])) + + # The var binded to the fma call is reused because the shape + # and type of var are unchanged after rewriting + assert lv1 == v0 assert type(func.body.blocks[0].bindings[2].var) == rx.Var assert type(func.body.blocks[0].bindings[2].value) == rx.DataflowVar @@ -101,6 +111,8 @@ def test_lazy_irbuilder(): x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [m, n], dtype1) ib = rx.IRBuilder() + + # This program should not be rewritten by the fma_rewriter with ib.function([x, y]): with ib.dataflow() as df: lv0 = ib.emit(rx.op.multiply(x, y)) @@ -123,18 +135,21 @@ def test_lazy_irbuilder(): v1 = func.body.blocks[0].bindings[1].var s1 = func.body.blocks[0].bindings[1].value + # the dataflow block and vars are reused assert block0 == block1 assert v1 == v0 assert s1 == s0 def test_explicit_memory_rewrite(): - shape_anno = [54, 96] + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + shape_anno = [m, n] type_anno = rx.DynTensorType(2, "float32") x = rx.Var("x", shape_anno, type_anno) ib = rx.IRBuilder() with ib.function(x): with ib.dataflow() as df: - lv0 = rx.call_dps([54, 96], rx.extern("test.op.identity"), [x]) + lv0 = rx.call_dps([m, n], rx.extern("test.op.identity"), [x]) gv0 = ib.emit_output(lv0) ib.emit_output(gv0) expr = ib.get() @@ -148,16 +163,17 @@ def test_explicit_memory_rewrite(): # after rewrite func = rx.analysis.explicit_memory_rewrite(expr) + # the dataflow block has changed to binding block due to the rewriting block = func.body.blocks[0] assert isinstance(block, rx.BindingBlock) s1 = block.bindings[0].value assert isinstance(s1, tvm.relay.Call) - assert s1.op.name == "relax.alloc_storage" + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], rx.ShapeExpr) + assert structural_equal(s1.args[0], rx.ShapeExpr(shape_anno)) s2 = block.bindings[1].value - assert s2.op.name == "relax.alloc_tensor" - s3 = block.bindings[2].value - assert s3.op.global_symbol == "test.op.identity" + assert s2.op.global_symbol == "test.op.identity" if __name__ == "__main__": diff --git a/tests/python/relax/test_irbuilder.py b/tests/python/relax/test_irbuilder.py index b2bc66bfe2..cd8524948c 100644 --- a/tests/python/relax/test_irbuilder.py +++ b/tests/python/relax/test_irbuilder.py @@ -17,6 +17,7 @@ import tvm from tvm import tir +from tvm import relay from tvm import relax as rx @@ -50,7 +51,7 @@ def test_dataflow_block(): assert gv0.shape[1] == n assert gv0.checked_type.rank == 2 assert gv0.checked_type.dtype == "float16" - isinstance(gv0, rx.Var) + assert isinstance(gv0, rx.Var) blocks = ib.get_blocks() assert len(blocks) == 1 @@ -129,7 +130,7 @@ def test_function_multi_blocks(): assert len(func.body.blocks[2].bindings) == 2 -def test_binary_shape_deduction(): +def test_binary_shape_type_deduction(): m = tir.Var("m", "int32") n = tir.Var("n", "int32") k = tir.Var("k", "int32") @@ -137,7 +138,7 @@ def test_binary_shape_deduction(): dtype1 = rx.DynTensorType(rank=1, dtype="float16") x = rx.Var("x", [m, 1], dtype0) y = rx.Var("y", [n], dtype1) - z = rx.Var("z", [5], dtype0) + z = rx.Var("z", [5], dtype1) w = rx.Var("w", [k], dtype1) ib = rx.IRBuilder() @@ -146,23 +147,57 @@ def test_binary_shape_deduction(): lv0 = ib.emit(rx.op.add(x, y)) assert lv0.shape[0] == m assert lv0.shape[1] == n + assert isinstance(lv0.checked_type, rx.DynTensorType) + assert lv0.checked_type.rank == 2 + assert lv0.checked_type.dtype == "float16" lv1 = ib.emit(rx.op.multiply(x, z)) assert lv1.shape[0] == m assert lv1.shape[1] == 5 + assert isinstance(lv1.checked_type, rx.DynTensorType) + assert lv1.checked_type.rank == 2 + assert lv1.checked_type.dtype == "float16" lv2 = ib.emit(rx.op.multiply(z, w)) assert isinstance(lv2.shape, tvm.relay.Call) + assert isinstance(lv2.checked_type, rx.DynTensorType) + assert lv2.checked_type.rank == 1 + assert lv2.checked_type.dtype == "float16" lv3 = ib.emit(rx.op.multiply(y, w)) assert isinstance(lv3.shape, tvm.relay.Call) - gv0 = ib.emit_output(lv3) + assert isinstance(lv3.checked_type, rx.DynTensorType) + assert lv3.checked_type.rank == 1 + assert lv3.checked_type.dtype == "float16" + + gv0 = ib.emit_output(lv3) + ib.emit_output(gv0) assert isinstance(gv0.shape, tvm.relay.Call) + assert isinstance(gv0.checked_type, rx.DynTensorType) + assert gv0.checked_type.rank == 1 + assert gv0.checked_type.dtype == "float16" + +def test_normalize(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [n], dtype1) + ib = rx.IRBuilder() + + add_call = rx.op.multiply(x, y) + assert isinstance(add_call.shape, relay.Call) + ib.normalize(add_call) + assert isinstance(add_call.shape, rx.ShapeExpr) + assert add_call.shape[0] == m + assert add_call.shape[1] == n if __name__ == "__main__": test_dataflow_block() test_function_single_block() test_function_multi_blocks() - test_binary_shape_deduction() + test_binary_shape_type_deduction() + test_normalize()