diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index b4a143727a9f..fc7d6f0a5229 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -208,14 +208,14 @@ class ExprMutator : public ExprFunctor { * visitor for types which transform them appropriately. */ virtual Type VisitType(const Type& t); - virtual void VisitBinding(const Binding& binding); - virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder); - virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder); + virtual void VisitBinding(const Binding& binding, IRBuilder& builder); + virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder); + virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& builder); virtual BindingBlock VisitBindingBlock(const BindingBlock& block); virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); protected: - LazyIRBuilder irbuilder_; + IRBuilder builder_; }; /*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes @@ -223,7 +223,7 @@ class ExprMutator : public ExprFunctor { class DataflowMutator : public ExprMutator { public: virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); - virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder); + virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder); protected: /*! \brief Look up the value binded to a var. */ diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 2a581865b3db..c205a5f41214 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck +from tvm import IRModule from . import _ffi_api def fma_rewrite(expr): @@ -37,3 +38,13 @@ def explicit_memory_rewrite(expr): The input expression. """ return _ffi_api.explicit_memory_rewrite(expr) + +def shape_lower(mod: IRModule) -> IRModule: + """Lower the shape expression in relax to shape heap and TIR functions. + + Parameters + ---------- + expr : tvm.IRModule + The input module. + """ + return _ffi_api.shape_lower(mod) diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index c48c85e0edcf..a7359613c28e 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -550,4 +550,4 @@ String AsRelaxScript(const ObjectRef& mod) { TVM_REGISTER_GLOBAL("script.AsRelaxScript").set_body_typed(AsRelaxScript); } // namespace relax -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index c277a991fd37..a630f7945c32 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -140,6 +140,9 @@ void ExprVisitor::VisitVarBinding(const VarBinding& binding) { void ExprVisitor::VisitMatchShape(const MatchShape& binding) { this->VisitExpr(binding->value); + // TODO(ziheng): should we change pattern from + // Array to ShapeExpr? + this->VisitExpr(ShapeExpr(binding->pattern)); } void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { @@ -321,50 +324,53 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { Type ExprMutator::VisitType(const Type& t) { return t; } -void ExprMutator::VisitBinding(const Binding& binding) { +void ExprMutator::VisitBinding(const Binding& binding, IRBuilder& builder) { Binding new_binding; if (binding.as()) { - this->VisitVarBinding(Downcast(binding), this->irbuilder_); + this->VisitVarBinding(Downcast(binding), builder); } else if (binding.as()) { - this->VisitMatchShape(Downcast(binding), this->irbuilder_); + this->VisitMatchShape(Downcast(binding), builder); } else { LOG(FATAL) << "Wrong type."; } } -Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) { +Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& builder) { Expr new_value = this->Mutate(binding->value); if (!binding->var.as()) { - return ir_builder->EmitOutput(new_value); + return builder->EmitOutput(new_value); } else { - return ir_builder->Emit(Downcast(new_value)); + return builder->Emit(Downcast(new_value)); } } -void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder) { +void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& builder) { this->Mutate(binding->value); + this->Mutate(ShapeExpr(binding->pattern)); } BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { if (block.as()) { return this->VisitDataflowBlock(Downcast(block)); } else{ - // TODO - return block; + this->builder_ = IRBuilderNode::Create(); + for (auto binding : block->bindings) { + this->VisitBinding(binding, this->builder_); + } + auto blocks = this->builder_->GetBlocks(); + return blocks.back(); } } BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) { - this->irbuilder_ = LazyIRBuilderNode::Create(block); + this->builder_ = LazyIRBuilderNode::Create(block); { - With scope(this->irbuilder_); + With scope(this->builder_); for (auto binding : block->bindings) { - if (binding.as()) { - this->VisitVarBinding(Downcast(binding), this->irbuilder_); - } + this->VisitBinding(binding, this->builder_); } } - return this->irbuilder_->GetBlocks().back(); + return this->builder_->GetBlocks().back(); } Expr ExprMutator::VisitExpr(const Expr& expr) { @@ -377,27 +383,27 @@ Expr ExprMutator::VisitExpr(const Expr& expr) { // DataflowMutator BindingBlock DataflowMutator::VisitDataflowBlock(const DataflowBlock& block) { - this->irbuilder_ = LazyIRBuilderNode::Create(block); + this->builder_ = LazyIRBuilderNode::Create(block); { - With scope(this->irbuilder_); + With scope(this->builder_); for (auto binding : block->bindings) { if (auto* var_binding = binding.as()) { - Var var = this->VisitVarBinding(Downcast(binding), this->irbuilder_); + Var var = this->VisitVarBinding(Downcast(binding), this->builder_); this->pre_post_var_map_[var_binding->var] = var; } } } - return this->irbuilder_->GetBlocks().back(); + return this->builder_->GetBlocks().back(); } -Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) { +Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& builder) { Expr new_value = this->Mutate(binding->value); Var new_var; if (new_value.as()) { - new_var = ir_builder->Emit(Downcast(new_value)); + new_var = builder->Emit(Downcast(new_value)); } if (!binding->var.as()) { - new_var = ir_builder->EmitOutput(new_value); + new_var = builder->EmitOutput(new_value); } pre_post_var_map_[binding->var] = new_var; return new_var; @@ -406,9 +412,9 @@ Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_bu Expr DataflowMutator::LookupVar(Var var) { auto it = pre_post_var_map_.find(var); if (it != pre_post_var_map_.end()) { - return irbuilder_->LookupVar(it->first); + return builder_->LookupVar(it->first); } else { - return irbuilder_->LookupVar(var); + return builder_->LookupVar(var); } } } // namespace relax diff --git a/src/relax/ir/ir_builder.cc b/src/relax/ir/ir_builder.cc index 46e7d590a60c..864afdf5421f 100644 --- a/src/relax/ir/ir_builder.cc +++ b/src/relax/ir/ir_builder.cc @@ -142,6 +142,7 @@ Var IRBuilderNode::EmitMatchShape(const Expr& value, const Array& patt } Var IRBuilderNode::Emit(const VarBinding& binding) { + // FIXME(yuchen or ziheng): consider binding in normal block) if (!binding->var.as()) { return EmitOutput(binding->value); } else { @@ -192,9 +193,14 @@ Expr IRBuilderNode::LookupVar(const Var& var) { return it->second; } -Function IRBuilderNode::Get() { return this->func_.func; } +Function IRBuilderNode::Get() { + return this->func_.func; +} -std::vector IRBuilderNode::GetBlocks() { return this->func_.binding_blocks; } +std::vector IRBuilderNode::GetBlocks() { + this->BuildBlock(); + return this->func_.binding_blocks; +} bool IRBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) { if (lhs == rhs) { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 27c8201752da..b40d41b4c75f 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -57,7 +57,8 @@ Expr MakeCallDPS(Expr shape, Expr func, Tuple args) { return Call(op, {shape, func, args}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.call_dps").set_body_typed(MakeCallDPS); +TVM_REGISTER_GLOBAL("relax.op.call_dps") +.set_body_typed(MakeCallDPS); // shape_of diff --git a/src/relax/transform/shape_lower.cc b/src/relax/transform/shape_lower.cc new file mode 100644 index 000000000000..f60d9bdb4f8a --- /dev/null +++ b/src/relax/transform/shape_lower.cc @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/shape_lower.cc + * \brief + */ +#include +#include +#include +#include +#include +#include "../../printer/text_printer.h" + +namespace tvm { +namespace relax { + +// Replace ShapeExpr with corresponding Var +class ShapeReplacer : public ExprMutator { + public: + explicit ShapeReplacer(Map mapping) { + mapping_ = mapping; + } + Expr VisitExpr_(const ShapeExprNode* op) override { + return mapping_.at(GetRef(op)); + } + + private: + Map mapping_; +}; + + +class ShapeLowerMutator : public ExprMutator { + public: + static DataType ShapeDType() { + return DataType::Int(32); + }; + + explicit ShapeLowerMutator(IRModule mod) { + mod_ = mod; + } + + IRModule Lower() { + ret_mod_ = IRModule(); + for (auto& p : mod_->functions) { + if (!p.second->IsInstance()) { + continue; + } + // prepare mapping and heap var + expr2slot_ = PrepareExpr2Slot(Downcast(p.second)); + // LOG(INFO) << "mapping: " << expr2slot_; + heap_size_ = IntImm(ShapeDType(), expr2slot_.size()); + DynTensorType heap_type(1, ShapeDType()); + shape_heap_ = Var("shape_heap", ShapeExpr({heap_size_}), heap_type); + + // mutate + Expr new_func = this->Mutate(p.second); + ret_mod_->Add(p.first, Downcast(new_func)); + } + return ret_mod_; + } + + void VisitMatchShape(const MatchShape& binding, + IRBuilder& builder) override { + Expr value = binding->value; + Array pattern = binding->pattern; + Array indexes; + for (size_t i = 0; i < pattern.size(); ++i) { + IntImm idx = expr2slot_.at(pattern[i]); + indexes.push_back(idx); + } + ShapeExpr indexes_(indexes); + Call call(ExternFunc("decode_shape"), {value, shape_heap_, indexes_}); + builder->Emit(call); + } + + Expr VisitExpr_(const FunctionNode* node) override { + Expr visited_func = ExprMutator::VisitExpr_(node); + const auto* visited = visited_func.as(); + ICHECK(visited); + const auto* seq = visited->body.as(); + ICHECK(seq); + + // prologue block: allocate shape heap + ShapeExpr heap_size({heap_size_}); + Call alloc_heap_call(ExternFunc("relax.alloc_shape_heap"), {heap_size}); + VarBinding binding(shape_heap_, alloc_heap_call); + BindingBlock prologue({binding}); + + // process body + IRBuilder ib = IRBuilderNode::Create(); + Array shapes = CollectShapeExpr(seq->body); + Map mapping; + for (ShapeExpr shape : shapes) { + // generate tir shape function + tir::PrimFunc func = CalculateShape(shape); + GlobalVar shape_func_var("shape_func" + std::to_string(shape_func_counter_++)); + ib->Emit(Call(shape_func_var, {shape_heap_})); + ret_mod_->Add(shape_func_var, func); + + // construct shape + Array indexes; + for (PrimExpr e : shape->values) { + indexes.push_back(expr2slot_.at(e)); + } + ShapeExpr indexes_(indexes); + Call call(ExternFunc("construct_shape"), {shape_heap_, indexes_}); + Var shape_var = ib->Emit(call); + mapping.Set(shape, shape_var); + } + Expr new_body = ShapeReplacer(mapping).Mutate(seq->body); + + // epilogue block: kill the shape heap + Call free_heap_call(ExternFunc("relax.free_shape_heap"), {shape_heap_}); + ib->Emit(free_heap_call); + + // process blocks + Array blocks; + blocks.push_back(prologue); + blocks.insert(blocks.end(), seq->blocks.begin(), seq->blocks.end()); + blocks.push_back(ib->GetBlocks().back()); + + + SeqExpr new_seq(blocks, new_body); + return Function(visited->name, visited->params, new_seq, visited->ret_type); + } + + tir::PrimFunc CalculateShape(ShapeExpr s) { + // TODO(ziheng): avoid generating shape func for known value + tir::Var heap("heap", DataType::Handle()); + Array buffer_shape{heap_size_}; + tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H"); + Map buffer_map; + buffer_map.Set(heap, buffer); + + Array seq; + for (PrimExpr e : s->values) { + Map var_mapping = BuildVarMapping(e, buffer); + PrimExpr value = tir::Substitute(e, var_mapping); + IntImm idx = expr2slot_.at(e); + seq.push_back(tir::Store(buffer->data, value, idx, tir::const_true())); + } + tir::Stmt body = tir::SeqStmt(seq); + Array params{heap}; + Type ret_type = VoidType(); + return tir::PrimFunc(params, body, ret_type, buffer_map); + } + + Map BuildVarMapping(PrimExpr expr, tir::Buffer buffer) { + Map ret; + auto func = [&](const ObjectRef& e) { + if (e->IsInstance()) { + PrimExpr prim_e = Downcast(e); + tir::Load load(ShapeDType(), buffer->data, expr2slot_.at(prim_e), tir::const_true()); + ret.Set(Downcast(e), load); + } + }; + tir::PostOrderVisit(expr, func); + return ret; + } + + Array CollectShapeExpr(Expr expr) const { + Array ret; + auto func = [&ret](const Expr& e) { + if (e->IsInstance()) { + ret.push_back(Downcast(e)); + } + }; + PostOrderVisit(expr, func); + return ret; + } + + + Map PrepareExpr2Slot(Function expr) const { + int cnt = 0; + Map ret; + auto func = [&](const Expr& e) { + if (e->IsInstance()) { + ShapeExpr shape = Downcast(e); + for (auto prim_e: shape->values) { + if (ret.count(prim_e) == 0) { + IntImm idx(ShapeDType(), cnt++); + ret.Set(prim_e, idx); + } + } + } + }; + PostOrderVisit(expr, func); + return ret; + } + + private: + IRModule mod_; + IRModule ret_mod_; + int shape_func_counter_{0}; + + // function-wise members + IntImm heap_size_; + Var shape_heap_; + Map expr2slot_; +}; + + +TVM_REGISTER_GLOBAL("relax.transform.shape_lower") +.set_body_typed([](IRModule mod) { + return ShapeLowerMutator(mod).Lower(); +}); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_op.py b/tests/python/relax/test_op.py index 9f708b89a42b..c9ac6fbe09ae 100644 --- a/tests/python/relax/test_op.py +++ b/tests/python/relax/test_op.py @@ -21,7 +21,7 @@ @tvm.register_func("test.op.identity") def identity_packed(a): - return tvm.nd.array(a.asnumpy) + return tvm.nd.array(a.asnumpy()) @tvm.script.tir def identity_tir(a: ty.handle, b: ty.handle) -> None: diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 304ced55d2de..3b3248a43ca8 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations # must import to defer parsing of annotations import tvm from tvm import tir from tvm import relax as rx @@ -99,6 +100,26 @@ def test_explicit_memory_rewrite(): s2 = block.bindings[1].value assert s2.op.global_symbol == "test.op.identity" + +@rx.script +class Mod: + def foo(x: Tensor[_, "float32"]) -> Shape: + relax.match_shape(x.shape, (n, m)) + return (n*2, m*3) + +def test_shape_lowering(): + mod = Mod() + new_mod = rx.transform.shape_lower(mod) + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(new_mod["shape_func0"], tvm.tir.function.PrimFunc) + assert isinstance(new_mod["foo"], tvm.relax.expr.Function) + code = rx.parser.astext(new_mod) + assert "alloc_shape_heap" in code + assert "decode_shape" in code + assert "construct_shape" in code + + if __name__ == "__main__": test_fma_rewrite() test_explicit_memory_rewrite() + test_shape_lowering()