Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Improve IRMutator interface; add Normalize and CanProveShapeEqual to …
Browse files Browse the repository at this point in the history
…IRBuilder
  • Loading branch information
YuchenJin committed Sep 27, 2021
1 parent c762389 commit ec1ab88
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 259 deletions.
19 changes: 5 additions & 14 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
*/
virtual Type VisitType(const Type& t);
virtual void VisitBinding(const Binding& binding);
virtual void VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder);
virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder);
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);

Expand All @@ -223,20 +223,11 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
class DataflowMutator : public ExprMutator {
public:
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
virtual void VisitVarBinding(const VarBinding& binding);
/*! \brief Insert a call node, a new binding will be created. */
Var Insert(Call value);
/*! \brief Emit a call node, the var could be reused depending on if the shape/type of \p value is
* changed or not. */
Var Emit(Var var, Call value);
/*! \brief Emit a binding. */
Var Emit(VarBinding binding);
/*! \brief Look up the value binded to a var. */
Expr LookupVar(Var var);
/*! \brief Switch from building a DataflowBlock to building a BindingBlock. */
void EmitBindingBlock();
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder);

protected:
/*! \brief Look up the value binded to a var. */
Expr LookupVar(Var var);
// A remapping table: pre var -> post var
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> pre_post_var_map_;
};
Expand Down
56 changes: 38 additions & 18 deletions include/tvm/relax/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,29 +75,30 @@ class IRBuilderNode : public Object {
*/
virtual void BuildBlock();
/*!
* \brief Emit a call node.
* \param call The CallNode to be emitted.
* \brief Emit a Call, and return a newly created Var binded to the Call.
* \param call The Call to be emitted.
* \return The variable being created and binded to \p call.
*/
virtual Var Emit(const Call& call);
/*!
* \brief Emit a MatchShape.
* \param match_shape The MatchShape to be emitted.
*/
void Emit(const MatchShape& match_shape);
/*!
* \brief Emit a var binding.
* \param binding The VarBinding to be emitted.
* \return The VarNode of the VarBinding \p binding.
*/
virtual Var Emit(const VarBinding& binding);
/*!
* \brief Emit a call node, and bind it to a Var.
* \param var The VarNode to be binded with. \p var is reused implicitly.
* \param call The CallNode to be emitted.
* \return The VarNode to be binded with \p var.
* \brief Emit a Call, and bind it to a Var.
* \param var The Var to be binded with. \p var is reused implicitly if the shape
* and type of \p call matches \p var. Otherwise a new Var is created.
* \param call The Call to be emitted.
* \return The Var to be binded with \p var.
*/
virtual Var Emit(const Var& var, const Call& call);
/*!
* \brief Emit a MatchShape.
* \param match_shape The MatchShape to be emitted.
*/
void Emit(const MatchShape& match_shape);
/*!
* \brief Generate an output for the current dataflow block or function.
* \param output The output variable of the block/function.
Expand All @@ -116,6 +117,19 @@ class IRBuilderNode : public Object {
* \brief Get binding blocks being built.
*/
std::vector<BindingBlock> GetBlocks();
/*!
* \brief Check if two shape expressions can be proven equal at compile time.
* \param lhs The input lhs shape.
* \param rhs The input rhs shape.
* \return Whether we can prove lhs shape == rhs shape.
*/
bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs);
/*!
* \brief Normalize an Expr to complete its shape and type.
* \param expr The input expr.
* \return The expr with normalized shape and type.
*/
Expr Normalize(const Expr& expr);
/*!
* \brief Create a IRBuilder.
* \return The created IRBuilder.
Expand Down Expand Up @@ -225,27 +239,33 @@ class DataflowScope : public ObjectRef {
class LazyIRBuilderNode : public IRBuilderNode {
public:
/*!
* \brief Emit a call node in a copy-on-write way.
* \brief Emit a Call in a copy-on-write way.
* If no bindings in a dataflow block need to be rewritten, reuse the original variable instead of
* emiting one. If any binding in the block needs to be rewritten, reconstruct the whole block
* from scratch by emiting all previous bindings.
* \param call The CallNode to be emitted.
* \param call The Call to be emitted.
* \return The variable being created and binded to \p call.
*/
virtual Var Emit(const Call& call);
/*!
* \brief Emit a var binding in a copy-on-write way.
* \param binding The VarBinding to be emitted.
* \return The VarNode of the VarBinding \p binding.
* \return The Var of the \p binding.
*/
virtual Var Emit(const VarBinding& binding);
/*!
* \brief Emit a call node, and bind it to a Var in a copy-on-write way.
* \param var The VarNode to be binded with.
* \param call The CallNode to be emitted.
* \return The VarNode to be binded with \p var.
* \brief Emit a Call, and bind it to a Var in a copy-on-write way.
* \param var The Var to be binded with.
* \param call The Call to be emitted.
* \return The Var to be binded with \p var.
*/
virtual Var Emit(const Var& var, const Call& call);
/*!
* \brief Emit an output for the current dataflow block or function in a copy-on-write way.
* \param binding The VarBinding to be emitted.
* \return The variable being binded to \p output.
*/
virtual Var EmitOutput(const VarBinding& binding);
/*!
* \brief Build a binding block.
*/
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def checked_type(self):
"""
ret = self.checked_type_
if ret is None:
raise ValueError("The type checker has not populated" " the checked_type for this node")
raise ValueError("The type checker has not populated the checked_type for this node")
return ret

@property
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relax/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ def emit_output(self,
output = Tuple(output)
return _ffi_api.IRBuilderEmitOutput(self, output)

def normalize(self,
expr: Expr) -> Expr:
"""Normalize an Expr to complete its shape and type.
Parameters
----------
expr : Expr
The input expr.
Returns
-------
ret : Expr
The expr with normalized shape and type.
"""
return _ffi_api.IRBuilderNormalize(self, expr)

def get(self) -> Function:
"""Return the function being built.
Expand Down
98 changes: 34 additions & 64 deletions src/relax/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,26 +328,24 @@ Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprMutator::VisitBinding(const Binding& binding) {
Binding new_binding;
if (binding.as<VarBindingNode>()) {
this->VisitVarBinding(Downcast<VarBinding>(binding));
this->VisitVarBinding(Downcast<VarBinding>(binding), this->irbuilder_);
} else if (binding.as<MatchShapeNode>()) {
this->VisitMatchShape(Downcast<MatchShape>(binding));
this->VisitMatchShape(Downcast<MatchShape>(binding), this->irbuilder_);
} else {
LOG(FATAL) << "Wrong type.";
}
}

void ExprMutator::VisitVarBinding(const VarBinding& binding) {
Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) {
Expr new_value = this->Mutate(binding->value);

if (new_value.as<CallNode>()) {
new_value = this->irbuilder_->Emit(Downcast<Call>(new_value));
}
if (!binding->var.as<DataflowVarNode>()) {
this->irbuilder_->EmitOutput(new_value);
return ir_builder->EmitOutput(new_value);
} else {
return ir_builder->Emit(Downcast<Call>(new_value));
}
}

void ExprMutator::VisitMatchShape(const MatchShape& binding) {
void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder) {
this->Mutate(binding->value);
}

Expand All @@ -366,7 +364,7 @@ BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) {
With<DataflowScope> scope(this->irbuilder_);
for (auto binding : block->bindings) {
if (binding.as<VarBindingNode>()) {
this->VisitVarBinding(Downcast<VarBinding>(binding));
this->VisitVarBinding(Downcast<VarBinding>(binding), this->irbuilder_);
}
}
}
Expand All @@ -383,46 +381,27 @@ Expr ExprMutator::VisitExpr(const Expr& expr) {
// DataflowMutator

BindingBlock DataflowMutator::VisitDataflowBlock(const DataflowBlock& block) {
return ExprMutator::VisitDataflowBlock(block);
}

void DataflowMutator::VisitVarBinding(const VarBinding& binding) {
Expr new_value = this->Mutate(binding->value);
Var new_var;
if (new_value.as<CallNode>()) {
new_var = this->irbuilder_->Emit(Downcast<Call>(new_value));
}
if (!binding->var.as<DataflowVarNode>()) {
new_var = this->irbuilder_->EmitOutput(new_value);
}
pre_post_var_map_[binding->var] = new_var;
}

Var DataflowMutator::Insert(Call value) {
Var var = this->irbuilder_->Emit(value);
return var;
}

Var DataflowMutator::Emit(Var var, Call value) {
Var new_var;
// TODO: make shape and type check right
if (var->shape() == value->shape() && var->checked_type() == value->checked_type()) {
new_var = this->irbuilder_->Emit(var, value);
} else {
new_var = this->irbuilder_->Emit(value);
this->irbuilder_ = LazyIRBuilderNode::Create(block);
{
With<DataflowScope> scope(this->irbuilder_);
for (auto binding : block->bindings) {
if (auto* var_binding = binding.as<VarBindingNode>()) {
Var var = this->VisitVarBinding(Downcast<VarBinding>(binding), this->irbuilder_);
this->pre_post_var_map_[var_binding->var] = var;
}
}
}
pre_post_var_map_[var] = new_var;
return new_var;
return this->irbuilder_->GetBlocks().back();
}

Var DataflowMutator::Emit(VarBinding binding) {
Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) {
Expr new_value = this->Mutate(binding->value);
Var new_var;
if (new_value.as<CallNode>()) {
new_var = this->irbuilder_->Emit(binding);
new_var = ir_builder->Emit(Downcast<Call>(new_value));
}
if (!binding->var.as<DataflowVarNode>()) {
new_var = this->irbuilder_->EmitOutput(new_value);
new_var = ir_builder->EmitOutput(new_value);
}
pre_post_var_map_[binding->var] = new_var;
return new_var;
Expand All @@ -437,9 +416,6 @@ Expr DataflowMutator::LookupVar(Var var) {
}
}

void DataflowMutator::EmitBindingBlock() {
this->irbuilder_->is_dataflow_ = false;
}

// ==================
// EwiseFMARewriter
Expand All @@ -458,7 +434,7 @@ void DataflowMutator::EmitBindingBlock() {
// z0 = ewise_fma(a, lv0, c)

class EwiseFMARewriter : public DataflowMutator {
void VisitVarBinding(const VarBinding& binding) override {
Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override {
static const Op& add_op = Op::Get("relax.add");
static const Op& multiply_op = Op::Get("relax.multiply");
static const Op& ewise_fma_op = Op::Get("relax.ewise_fma");
Expand All @@ -470,12 +446,10 @@ class EwiseFMARewriter : public DataflowMutator {
const CallNode* op2 = value.as<CallNode>();
if (op2 && op2->op == multiply_op) {
Call fma_call = Call(ewise_fma_op, {op2->args[0], op2->args[1], op1->args[1]}, {}, {});
// Complete(fma_call);
Emit(binding->var, fma_call);
return;
return ir_builder->Emit(binding->var, fma_call);
}
}
Emit(binding);
return ir_builder->Emit(binding);
}
};

Expand All @@ -491,11 +465,10 @@ TVM_REGISTER_GLOBAL("relax.analysis.fma_rewrite")
// ==================
// ExplicitMemMutator
// Example:
// lv1: Tensor[(m*n,)] = rx.call_dps((m*n,), extern_packed("flatten"), [lv0])
// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
// -->
// storage0 = rx.call(extern_packed("alloc_storage"), size=[n*m], device=cpu)
// lv1 = rx.call(extern_packed("alloc_tensor"), storage0, 0, [m*n,], f32)
// rx.call(extern_packed("flatten"), lv0, lv1)
// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
// rx.call_packed(op.identity, x, lv0)

class ExplicitMemMutator : public DataflowMutator {
Expr ComputeStorageSize(const Expr& shape, const Type& type) const {
Expand Down Expand Up @@ -523,24 +496,21 @@ class ExplicitMemMutator : public DataflowMutator {
return ret;
}

void VisitVarBinding(const VarBinding& binding) override {
Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override {
static const Op& call_dps_op = Op::Get("relax.call_dps");
static const Op& alloc_storage_op = Op::Get("relax.alloc_storage");
static const Op& alloc_tensor_op = Op::Get("relax.alloc_tensor");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");

const CallNode* op = binding->value.as<CallNode>();
if(op && op->op == call_dps_op) {
// convert current DataflowBlock into an impure BindingBlock
this->EmitBindingBlock();

// switch current DataflowBlock to an impure BindingBlock
ir_builder->is_dataflow_ = false;
ShapeExpr output_shape = Downcast<ShapeExpr>(op->args[0]);
Type arg_type = Downcast<Tuple>(op->args[2])->fields[0]->checked_type();
Expr output_size = ComputeStorageSize(output_shape, arg_type);
Var storage = Insert(Call(alloc_storage_op, {output_size}));
Var tensor = Insert(Call(alloc_tensor_op, {storage, relay::Constant(0), op->args[0]}));
Emit(binding->var, Call(op->args[1], {tensor, op->args[2]}));
Var tensor = ir_builder->Emit(Call(alloc_tensor_op, {op->args[0]}));
return ir_builder->Emit(binding->var, Call(op->args[1], {op->args[2], tensor}));
}
Emit(binding);
return ir_builder->Emit(binding);
}
};

Expand Down
Loading

0 comments on commit ec1ab88

Please sign in to comment.