Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] change device annotation from post DFS to recursive #6124

Merged
merged 2 commits into from
Aug 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 34 additions & 49 deletions src/relay/transforms/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,23 +385,34 @@ class DeviceInfo {
// TODO(zhiics) Skip annotation of function node for now.
}

void VisitExpr_(const ConstantNode* cn) final {
post_dfs_order_.push_back(std::make_pair(cn, has_copy_));
}
void VisitExpr_(const ConstantNode* cn) final { device_tag_[cn] = dev_type_; }

void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
if (!IsOnDeviceNode(call)) {
if (GetDeviceCopyNode(call)) {
if (const auto* node = GetDeviceCopyNode(call)) {
CHECK(node->IsInstance<CallNode>());
const auto* call_node = static_cast<const CallNode*>(node);
auto attrs = call_node->attrs.as<DeviceCopyAttrs>();

num_device_copy_ops_++;
bool has_copy_prev = has_copy_;
has_copy_ = true;
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
has_copy_ = has_copy_prev;
dev_type_ = attrs->src_dev_type;
for (auto& arg : call->args) {
Visit(arg);
// restore the type for remaining arguments
dev_type_ = attrs->src_dev_type;
}
device_tag_[call] = attrs->dst_dev_type;
// update the out_dev_type_, which should be the dst_dev_type of last copy
out_dev_type_ = attrs->dst_dev_type;
} else {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
for (auto& arg : call->args) {
int cur_dev_type = dev_type_;
Visit(arg);
// restore the type for remaining arguments
dev_type_ = cur_dev_type;
}
device_tag_[call] = dev_type_;
}
}
}
Expand All @@ -413,23 +424,22 @@ class DeviceInfo {

void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); }

void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
}
void VisitExpr_(const VarNode* vn) final { device_tag_[vn] = dev_type_; }

void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
device_tag_[ln] = dev_type_;
}

void VisitExpr_(const IfNode* in) final {
ExprVisitor::VisitExpr_(in);
post_dfs_order_.push_back(std::make_pair(in, has_copy_));
device_tag_[in] = dev_type_;
}

int num_device_copy_ops_{0};
bool has_copy_ = false;
std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
int dev_type_ = -1;
int out_dev_type_ = -1;
Comment on lines +440 to +441
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love making this class state, precisely because you have to do crazy things with maintaining state in your recursive calls. You could do it as a set of recursive arguments, but that kind of requires re-implementing with ExprFunctor...so maybe this is the cleanest solution.

std::unordered_map<const ExprNode*, int> device_tag_;
friend DeviceInfo;
};

Expand All @@ -455,39 +465,14 @@ class DeviceInfo {
}

void PropagateDeviceId() {
// Bottom-up propagation.
int out_dev_type = BottomUpPropagation();
// propagation for remained nodes.
FillPropagation(out_dev_type);
}

int BottomUpPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
int out_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (const auto* node = GetDeviceCopyNode(it->first)) {
CHECK(node->IsInstance<CallNode>());
last_copy_node = static_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
if (it->second) device_map_.Set(GetRef<Expr>(it->first), attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it->first);
CHECK_EQ(device_map_.count(expr), 0U);
if (it->second) device_map_.Set(expr, cur_dev_type);
int out_dev_type = post_visitor_.out_dev_type_;
for (auto& it : post_visitor_.device_tag_) {
if (it.second != -1) {
device_map_.Set(GetRef<Expr>(it.first), it.second);
} else {
device_map_.Set(GetRef<Expr>(it.first), out_dev_type);
}
}
return out_dev_type;
}

void FillPropagation(int out_dev_type) {
for (const auto& it : post_visitor_.post_dfs_order_) {
Expr expr = GetRef<Expr>(it.first);
if (!it.second) device_map_.Set(expr, out_dev_type);
}
}

PostDfsOrderVisitor post_visitor_;
Expand Down
70 changes: 70 additions & 0 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,76 @@ def test_visitor_annotation():
test_visitor_annotation()


def test_propogation():
R""" The network and device type is as following:
x 1
|
log 1
/ \
log2 log10 2
\ /
add 2
|
tan 1
"""
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)

expected_dev_type = {
'log': ctx1,
'log2': ctx2,
'log10': ctx2,
'add': ctx2,
'tan': ctx1
}

x = relay.var("x", shape=(3,))

def annotated():
log = relay.log(x)
_log = relay.annotation.on_device(log, expected_dev_type['log'])
log2 = relay.log2(_log)
_log2 = relay.annotation.on_device(log2, expected_dev_type['log2'])
log10 = relay.log10(_log)
_log10 = relay.annotation.on_device(log10, expected_dev_type['log10'])
add = relay.add(_log2, _log10)
_add = relay.annotation.on_device(add, expected_dev_type['add'])
tan = relay.tan(_add)
_tan = relay.annotation.on_device(tan, expected_dev_type['tan'])

func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(ctx1.device_type))
return func

def expected():
log = relay.log(x)
_log_left = relay.device_copy(log, ctx1, ctx2)
_log_right = relay.device_copy(log, ctx1, ctx2)
log2 = relay.log2(_log_left)
log10 = relay.log10(_log_right)
add = relay.add(log2, log10)
_add = relay.device_copy(add, ctx2, ctx1)
tan = relay.tan(_add)

func = run_opt_pass(tan, transform.InferType())
return func

annotated_expr = annotated()
expected_expr = expected()
assert tvm.ir.structural_equal(annotated_expr, expected_expr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious that master passes this check, but fails on line 377. Why doesn't structural equal properly resolve the error in the device copy op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think master fails on line 379, right? the device type of log2 is not correctly marked.

Up to this line, annotated_expr and expected_expr are exactly the same. The device_copy op is inserted correctly. We haven't go through the device propagation yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 You're right, I misread my first test.


smap = relay.backend._backend.GraphPlanMemory(annotated_expr)
for expr, storage_dev_type in smap.items():
# x is ctx1 as output is ctx1
if isinstance(expr, tvm.relay.expr.Var):
assert storage_dev_type[1][0] == ctx1.device_type
else:
# device_copy op should be its dst_dev_type
if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs):
assert storage_dev_type[1][0] == expr.attrs.dst_dev_type
else:
assert storage_dev_type[1][0] == expected_dev_type[expr.op.name].device_type


def run_fusible_network(dev, tgt):
R""" The network is as following:
x y
Expand Down