Skip to content

Commit

Permalink
[PASS] Shape lowering (#16)
Browse files Browse the repository at this point in the history
* [PASS] Shape lowering.

* Update to IRModule based.

* TIR function generation.

* Improve.

* Improve.

* Improve test.

* Improve.

* Address comment.
  • Loading branch information
ZihengJiang authored and junrushao committed Feb 5, 2023
1 parent 30b5482 commit ff61039
Show file tree
Hide file tree
Showing 9 changed files with 304 additions and 34 deletions.
10 changes: 5 additions & 5 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,22 +208,22 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
virtual void VisitBinding(const Binding& binding);
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder);
virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder);
virtual void VisitBinding(const Binding& binding, IRBuilder& builder);
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder);
virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& builder);
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);

protected:
LazyIRBuilder irbuilder_;
IRBuilder builder_;
};

/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes
*/
class DataflowMutator : public ExprMutator {
public:
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder);
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder);

protected:
/*! \brief Look up the value binded to a var. */
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
from tvm import IRModule
from . import _ffi_api

def fma_rewrite(expr):
Expand All @@ -37,3 +38,13 @@ def explicit_memory_rewrite(expr):
The input expression.
"""
return _ffi_api.explicit_memory_rewrite(expr)

def shape_lower(mod: IRModule) -> IRModule:
"""Lower the shape expression in relax to shape heap and TIR functions.
Parameters
----------
expr : tvm.IRModule
The input module.
"""
return _ffi_api.shape_lower(mod)
54 changes: 30 additions & 24 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ void ExprVisitor::VisitVarBinding(const VarBinding& binding) {

void ExprVisitor::VisitMatchShape(const MatchShape& binding) {
this->VisitExpr(binding->value);
// TODO(ziheng): should we change pattern from
// Array<PrimExpr> to ShapeExpr?
this->VisitExpr(ShapeExpr(binding->pattern));
}

void ExprVisitor::VisitBindingBlock(const BindingBlock& block) {
Expand Down Expand Up @@ -321,50 +324,53 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) {

Type ExprMutator::VisitType(const Type& t) { return t; }

void ExprMutator::VisitBinding(const Binding& binding) {
void ExprMutator::VisitBinding(const Binding& binding, IRBuilder& builder) {
Binding new_binding;
if (binding.as<VarBindingNode>()) {
this->VisitVarBinding(Downcast<VarBinding>(binding), this->irbuilder_);
this->VisitVarBinding(Downcast<VarBinding>(binding), builder);
} else if (binding.as<MatchShapeNode>()) {
this->VisitMatchShape(Downcast<MatchShape>(binding), this->irbuilder_);
this->VisitMatchShape(Downcast<MatchShape>(binding), builder);
} else {
LOG(FATAL) << "Wrong type.";
}
}

Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) {
Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& builder) {
Expr new_value = this->Mutate(binding->value);
if (!binding->var.as<DataflowVarNode>()) {
return ir_builder->EmitOutput(new_value);
return builder->EmitOutput(new_value);
} else {
return ir_builder->Emit(Downcast<Call>(new_value));
return builder->Emit(Downcast<Call>(new_value));
}
}

void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder) {
void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& builder) {
this->Mutate(binding->value);
this->Mutate(ShapeExpr(binding->pattern));
}

BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) {
if (block.as<DataflowBlockNode>()) {
return this->VisitDataflowBlock(Downcast<DataflowBlock>(block));
} else{
// TODO
return block;
this->builder_ = IRBuilderNode::Create();
for (auto binding : block->bindings) {
this->VisitBinding(binding, this->builder_);
}
auto blocks = this->builder_->GetBlocks();
return blocks.back();
}
}

BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) {
this->irbuilder_ = LazyIRBuilderNode::Create(block);
this->builder_ = LazyIRBuilderNode::Create(block);
{
With<DataflowScope> scope(this->irbuilder_);
With<DataflowScope> scope(this->builder_);
for (auto binding : block->bindings) {
if (binding.as<VarBindingNode>()) {
this->VisitVarBinding(Downcast<VarBinding>(binding), this->irbuilder_);
}
this->VisitBinding(binding, this->builder_);
}
}
return this->irbuilder_->GetBlocks().back();
return this->builder_->GetBlocks().back();
}

Expr ExprMutator::VisitExpr(const Expr& expr) {
Expand All @@ -377,27 +383,27 @@ Expr ExprMutator::VisitExpr(const Expr& expr) {
// DataflowMutator

BindingBlock DataflowMutator::VisitDataflowBlock(const DataflowBlock& block) {
this->irbuilder_ = LazyIRBuilderNode::Create(block);
this->builder_ = LazyIRBuilderNode::Create(block);
{
With<DataflowScope> scope(this->irbuilder_);
With<DataflowScope> scope(this->builder_);
for (auto binding : block->bindings) {
if (auto* var_binding = binding.as<VarBindingNode>()) {
Var var = this->VisitVarBinding(Downcast<VarBinding>(binding), this->irbuilder_);
Var var = this->VisitVarBinding(Downcast<VarBinding>(binding), this->builder_);
this->pre_post_var_map_[var_binding->var] = var;
}
}
}
return this->irbuilder_->GetBlocks().back();
return this->builder_->GetBlocks().back();
}

Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) {
Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& builder) {
Expr new_value = this->Mutate(binding->value);
Var new_var;
if (new_value.as<CallNode>()) {
new_var = ir_builder->Emit(Downcast<Call>(new_value));
new_var = builder->Emit(Downcast<Call>(new_value));
}
if (!binding->var.as<DataflowVarNode>()) {
new_var = ir_builder->EmitOutput(new_value);
new_var = builder->EmitOutput(new_value);
}
pre_post_var_map_[binding->var] = new_var;
return new_var;
Expand All @@ -406,9 +412,9 @@ Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_bu
Expr DataflowMutator::LookupVar(Var var) {
auto it = pre_post_var_map_.find(var);
if (it != pre_post_var_map_.end()) {
return irbuilder_->LookupVar(it->first);
return builder_->LookupVar(it->first);
} else {
return irbuilder_->LookupVar(var);
return builder_->LookupVar(var);
}
}
} // namespace relax
Expand Down
10 changes: 8 additions & 2 deletions src/relax/ir/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ Var IRBuilderNode::EmitMatchShape(const Expr& value, const Array<PrimExpr>& patt
}

Var IRBuilderNode::Emit(const VarBinding& binding) {
// FIXME(yuchen or ziheng): consider binding in normal block)
if (!binding->var.as<DataflowVarNode>()) {
return EmitOutput(binding->value);
} else {
Expand Down Expand Up @@ -192,9 +193,14 @@ Expr IRBuilderNode::LookupVar(const Var& var) {
return it->second;
}

Function IRBuilderNode::Get() { return this->func_.func; }
Function IRBuilderNode::Get() {
return this->func_.func;
}

std::vector<BindingBlock> IRBuilderNode::GetBlocks() { return this->func_.binding_blocks; }
std::vector<BindingBlock> IRBuilderNode::GetBlocks() {
this->BuildBlock();
return this->func_.binding_blocks;
}

bool IRBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) {
if (lhs == rhs) {
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ Expr MakeCallDPS(Expr shape, Expr func, Tuple args) {
return Call(op, {shape, func, args}, {}, {});
}

TVM_REGISTER_GLOBAL("relax.op.call_dps").set_body_typed(MakeCallDPS);
TVM_REGISTER_GLOBAL("relax.op.call_dps")
.set_body_typed(MakeCallDPS);

// shape_of

Expand Down
Loading

0 comments on commit ff61039

Please sign in to comment.