Skip to content

Commit

Permalink
Relax IRVisitor/IRMuator (apache#10)
Browse files Browse the repository at this point in the history
* ExprVisitor/ExprMutator for relax nodes.

* Update Visitor & Mutator.

* Update Mutator.

* DataflowMutator interface.

* EwiseFMARewriter.

* Update fma rewrite and add test.

* Update test.

* Fix dataflow block dispatching.

* Construct new dataflow block with IRBuilder.

* VisitBinding return void and mutate internal IRBuilder.

* Simplify.

* Update emit dataflow output.

* Explicit memeory allocation rewrite.

* LazyIRBuilder.

* Update ExplicitMemMutator.

* Overload IRBuilder::Emit to have 3 styles.

* Update IRBuilder/IRMutator interfaces and passes.

* Add MatchShape binding to IRBuilder.

* Improve IRMutator interface; add Normalize and CanProveShapeEqual to IRBuilder

* Update EmitMatchShape.

Co-authored-by: ZihengJiang <ziheng@apache.org>
  • Loading branch information
2 people authored and junrushao committed Oct 14, 2022
1 parent f8519c1 commit fd8f28d
Show file tree
Hide file tree
Showing 20 changed files with 1,759 additions and 52 deletions.
6 changes: 6 additions & 0 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,15 @@ namespace relax {
using Expr = RelayExpr;
using ExprNode = RelayExprNode;
using relay::Call;
using relay::CallNode;
using relay::ConstantNode;
using relay::Id;
using relay::If;
using relay::IfNode;
using relay::Tuple;
using relay::TupleGetItem;
using relay::TupleGetItemNode;
using relay::TupleNode;

/*! \brief A shape expression which allows users to construct a shape containing PrimExpr.
*/
Expand Down
237 changes: 237 additions & 0 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/expr_functor.h
* \brief A more powerful visitor which enables defining arbitrary function
* signatures with type based dispatch on first argument.
*/
#ifndef TVM_RELAX_EXPR_FUNCTOR_H_
#define TVM_RELAX_EXPR_FUNCTOR_H_

#include <tvm/ir/error.h>
#include <tvm/node/functor.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/ir_builder.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>

#include <deque>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace tvm {
namespace relax {

/*!
* \brief A dynamical functor that dispatches on in the first Expr argument.
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
* \sa tvm/ir_functor.h
*
* \tparam FType function signiture
* This type is only defined for FType with function signature R(const Expr&,
* Args...)
*/
template <typename FType>
class ExprFunctor;

// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT \
{ return VisitExprDefault_(op, std::forward<Args>(args)...); }

#define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
});

template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;

public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~ExprFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); }
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
"have generated invalid data.";
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw;
}

private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAX_EXPR_FUNCTOR_DISPATCH(ConstantNode);
RELAX_EXPR_FUNCTOR_DISPATCH(TupleNode);
RELAX_EXPR_FUNCTOR_DISPATCH(VarNode);
RELAX_EXPR_FUNCTOR_DISPATCH(DataflowVarNode);
RELAX_EXPR_FUNCTOR_DISPATCH(ShapeExprNode);
RELAX_EXPR_FUNCTOR_DISPATCH(ExternFuncNode);
RELAX_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode);
RELAX_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAX_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAX_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
return vtable;
}
};

/*!
* \brief A simple visitor wrapper around ExprFunctor.
* Recursively visit the content.
*
* ExprVisitor treats Expr as dataflow graph,
* and only visit each Expr node once.
*/
class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
public:
void VisitExpr(const Expr& expr) override;
void VisitExpr_(const ConstantNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const DataflowVarNode* op) override;
void VisitExpr_(const ShapeExprNode* op) override;
void VisitExpr_(const ExternFuncNode* op) override;
void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const FunctionNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const SeqExprNode* op) override;
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;

virtual void VisitType(const Type& t);
virtual void VisitSpan(const Span& span);
virtual void VisitBinding(const Binding& binding);
virtual void VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);
virtual void VisitBindingBlock(const BindingBlock& block);
virtual void VisitDataflowBlock(const DataflowBlock& block);
};

void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);

/*!
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once.
* The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
*/
class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
public:
/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); }
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const DataflowVarNode* op) override;
Expr VisitExpr_(const ShapeExprNode* op) override;
Expr VisitExpr_(const ExternFuncNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expr VisitExpr_(const CallNode* op) override;
Expr VisitExpr_(const SeqExprNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const OpNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;

/*!
* \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
virtual void VisitBinding(const Binding& binding);
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder);
virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder);
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);

protected:
LazyIRBuilder irbuilder_;
};

/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes
*/
class DataflowMutator : public ExprMutator {
public:
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder);

protected:
/*! \brief Look up the value binded to a var. */
Expr LookupVar(Var var);
// A remapping table: pre var -> post var
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> pre_post_var_map_;
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_EXPR_FUNCTOR_H_
Loading

0 comments on commit fd8f28d

Please sign in to comment.