Skip to content

Commit

Permalink
[Relay] Mix mode type inference (#6704)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored Nov 6, 2020
1 parent a6c29b2 commit 3cf997a
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 77 deletions.
68 changes: 68 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>

#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -408,6 +409,73 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
*/
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);

/*!
* \brief A function to iteratively traverse dataflow regions of a graph
*
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
* order of nodes in an input graph.
*
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
* leaf via fvisit_leaf.
*
* This function should be used internally to other classes to implement mixed-mode traversals. The
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
* hits a non-dataflow node.
*
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
*/
template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
std::stack<std::pair<Expr, bool>> stack;
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
// The second state of the stack indicate whether the child has been
// expanded in the pre-order.
// NOTE: function will be inlined.
if (!fcheck_visited(expr)) {
stack.push({expr, false});
}
};
fpush_to_stack(expr);
while (stack.size() > 0) {
auto node = stack.top().first;
if (fcheck_visited(node)) {
// if this node was visited through another path
// after being added to the stack ignore it.
stack.pop();
} else if (stack.top().second) {
// all the children have already been expanded.
// we can just run post order visit on it.
fvisit_leaf(node);
stack.pop();
} else if (const CallNode* op = node.as<CallNode>()) {
// mark expanded = true
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
}
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
stack.top().second = true;
fpush_to_stack(op->tuple);
} else {
// No need to expand the children directly run visit.
fvisit_leaf(node);
stack.pop();
}
}
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_FUNCTOR_H_
68 changes: 0 additions & 68 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,74 +33,6 @@

namespace tvm {
namespace relay {
/*!
* \brief A function to iteratively traverse dataflow regions of a graph
*
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
* order of nodes in an input graph.
*
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
* leaf via fvisit_leaf.
*
* This function should be used internally to other classes to implement mixed-mode traversals. The
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
* hits a non-dataflow node.
*
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
*/
template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
std::stack<std::pair<Expr, bool>> stack;
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
// The second state of the stack indicate whether the child has been
// expanded in the pre-order.
// NOTE: function will be inlined.
if (!fcheck_visited(expr)) {
stack.push({expr, false});
}
};
fpush_to_stack(expr);
while (stack.size() > 0) {
auto node = stack.top().first;
if (fcheck_visited(node)) {
// if this node was visited through another path
// after being added to the stack ignore it.
stack.pop();
} else if (stack.top().second) {
// all the children have already been expanded.
// we can just run post order visit on it.
fvisit_leaf(node);
stack.pop();
} else if (const CallNode* op = node.as<CallNode>()) {
// mark expanded = true
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
}
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
stack.top().second = true;
fpush_to_stack(op->tuple);
} else {
// No need to expand the children directly run visit.
fvisit_leaf(node);
stack.pop();
}
}
}

MixedModeVisitor::MixedModeVisitor(int visit_limit) {
ICHECK(visit_limit > 0) << "Dataflow visit limit must be greater than 0";
ICHECK(visit_limit < 10) << "Dataflow visit limit must be less than 10";
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/algorithm/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TopKAttrs* param = attrs.as<TopKAttrs>();
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
ICHECK(data);
if (data == nullptr) return false;
int ndim = data->shape.size();
int axis = param->axis;
if (axis < 0) {
Expand Down
51 changes: 43 additions & 8 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,37 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
TypeRelationFn tuple_getitem_rel_;
TypeRelationFn make_tuple_rel_;

/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Type, ObjectPtrHash, ObjectPtrEqual> memo_;

void VisitLeaf(const Expr& expr) {
if (!memo_.count(expr)) {
Type ret = this->DispatchVisitExpr(expr);
memo_[expr] = ret;
}
}

bool CheckVisited(const Expr& expr) {
if (memo_.count(expr)) {
return true;
} else {
return false;
}
}

Type DispatchVisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); }

Type VisitExpr(const Expr& expr) final {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
if (memo_.count(expr)) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
return memo_[expr];
}
}

// Perform unification on two types and report the error at the expression
// or the span of the expression.
Type Unify(const Type& t1, const Type& t2, const Span& span) {
Expand Down Expand Up @@ -546,12 +577,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
};

class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator {
public:
Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual>& tmap,
TypeSolver* solver)
: tmap_(tmap), solver_(solver) {}

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef<Var>(op)); }

Expr VisitExpr_(const ConstantNode* op) final { return AttachCheckedType(op); }
Expand All @@ -560,13 +593,15 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {

Expr VisitExpr_(const OpNode* op) final { return ExprMutator::VisitExpr_(op); }

Expr VisitExpr_(const TupleNode* op) final { return AttachCheckedType(op); }
Expr Rewrite_(const TupleNode* op, const Expr& post) final { return AttachCheckedType(op, post); }

Expr VisitExpr_(const TupleGetItemNode* op) final { return AttachCheckedType(op); }
Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
return AttachCheckedType(op, post);
}

Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); }

Expr VisitExpr_(const CallNode* op) final { return AttachCheckedType(op); }
Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); }

Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); }

Expand All @@ -593,7 +628,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {

// attach checked type to the mutated node.
template <typename T>
Expr AttachCheckedType(const T* op) {
Expr AttachCheckedType(const T* op, const Expr& post = Expr()) {
auto it = tmap_.find(GetRef<Expr>(op));
ICHECK(it != tmap_.end());
Type checked_type = solver_->Resolve(it->second.checked_type);
Expand All @@ -606,7 +641,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
<< " check other reported errors for hints of what may of happened.");
}

Expr new_e = ExprMutator::VisitExpr_(op);
Expr new_e = post.defined() ? post : ExprMutator::VisitExpr_(op);
// new_call and new_var's code is only going to be valid for VarNode/CallNode.
// Compiler optimization will likely fold these away for other nodes.
CallNode* new_call = (std::is_base_of<CallNode, T>::value
Expand Down Expand Up @@ -702,8 +737,8 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) {
return resolved_expr;
}

struct AllCheckTypePopulated : ExprVisitor {
void VisitExpr(const Expr& e) {
struct AllCheckTypePopulated : MixedModeVisitor {
void DispatchExprVisit(const Expr& e) {
if (e.as<OpNode>()) {
return;
}
Expand Down

0 comments on commit 3cf997a

Please sign in to comment.