Skip to content

Commit

Permalink
TreeNode refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Jul 8, 2019
1 parent 3d88c70 commit 4f1381b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,9 @@ struct TagCompare : ConditionNode {
~TagCompare() {}
};

using TreeNodePtr = relay::TreeNodePtr;
using TreeNodePtr = typename relay::TreeNode<ConditionNodePtr>::pointer;
using TreeLeafNode = relay::TreeLeafNode<ConditionNodePtr>;
using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionNodePtr>;
using TreeBranchNode = relay::TreeBranchNode<ConditionNodePtr>;

void CompileTreeNode(TreeNodePtr tree, VMCompiler* compiler) {
Expand Down
17 changes: 13 additions & 4 deletions src/relay/pass/pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,16 @@ inline bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}

template<typename ConditionNodePtr>
struct TreeNode {
typedef std::shared_ptr<TreeNode<ConditionNodePtr>> pointer;
virtual ~TreeNode() {}
};
using TreeNodePtr = std::shared_ptr<TreeNode>;

struct TreeLeafNode : TreeNode {
template<typename ConditionNodePtr>
struct TreeLeafNode : TreeNode<ConditionNodePtr> {
using TreeNodePtr = typename TreeNode<ConditionNodePtr>::pointer;

Expr body;

explicit TreeLeafNode(Expr body): body(body) {}
Expand All @@ -126,7 +130,10 @@ struct TreeLeafNode : TreeNode {
~TreeLeafNode() {}
};

struct TreeLeafFatalNode : TreeNode {
template<typename ConditionNodePtr>
struct TreeLeafFatalNode : TreeNode<ConditionNodePtr> {
using TreeNodePtr = typename TreeNode<ConditionNodePtr>::pointer;

TreeLeafFatalNode() = default;

static TreeNodePtr Make() {
Expand All @@ -137,7 +144,9 @@ struct TreeLeafFatalNode : TreeNode {
};

template<typename ConditionNodePtr>
struct TreeBranchNode : TreeNode {
struct TreeBranchNode : TreeNode<ConditionNodePtr> {
using TreeNodePtr = typename TreeNode<ConditionNodePtr>::pointer;

ConditionNodePtr cond;
TreeNodePtr then_branch;
TreeNodePtr else_branch;
Expand Down

0 comments on commit 4f1381b

Please sign in to comment.