Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] change device annotation from post DFS to recursive #6124

Merged
merged 2 commits into from
Aug 19, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 34 additions & 45 deletions src/relay/transforms/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,22 +386,35 @@ class DeviceInfo {
}

void VisitExpr_(const ConstantNode* cn) final {
post_dfs_order_.push_back(std::make_pair(cn, has_copy_));
device_tag_[cn] = dev_type_;
}

void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
if (!IsOnDeviceNode(call)) {
if (GetDeviceCopyNode(call)) {
if (const auto* node = GetDeviceCopyNode(call)) {
CHECK(node->IsInstance<CallNode>());
const auto* call_node = static_cast<const CallNode*>(node);
auto attrs = call_node->attrs.as<DeviceCopyAttrs>();

num_device_copy_ops_++;
bool has_copy_prev = has_copy_;
has_copy_ = true;
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
has_copy_ = has_copy_prev;
dev_type_ = attrs->src_dev_type;
for (auto& arg : call->args) {
Visit(arg);
// restore the type for remaining arguments
dev_type_ = attrs->src_dev_type;
}
device_tag_[call] = attrs->dst_dev_type;
// update the out_dev_type_, which should be the dst_dev_type of last copy
out_dev_type_ = attrs->dst_dev_type;
} else {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
for (auto& arg : call->args) {
int cur_dev_type = dev_type_;
Visit(arg);
// restore the type for remaining arguments
dev_type_ = cur_dev_type;
}
device_tag_[call] = dev_type_;
}
}
}
Expand All @@ -414,22 +427,23 @@ class DeviceInfo {
void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); }

void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
device_tag_[vn] = dev_type_;
}

void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
device_tag_[ln] = dev_type_;
}

void VisitExpr_(const IfNode* in) final {
ExprVisitor::VisitExpr_(in);
post_dfs_order_.push_back(std::make_pair(in, has_copy_));
device_tag_[in] = dev_type_;
}

int num_device_copy_ops_{0};
bool has_copy_ = false;
std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
int dev_type_ = -1;
int out_dev_type_ = -1;
Comment on lines +440 to +441
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love making this class state, precisely because you have to do crazy things with maintaining state in your recursive calls. You could do it as a set of recursive arguments, but that kind of requires re-implementing with ExprFunctor...so maybe this is the cleanest solution.

std::unordered_map<const ExprNode*, int> device_tag_;
friend DeviceInfo;
};

Expand All @@ -455,39 +469,14 @@ class DeviceInfo {
}

void PropagateDeviceId() {
// Bottom-up propagation.
int out_dev_type = BottomUpPropagation();
// propagation for remained nodes.
FillPropagation(out_dev_type);
}

int BottomUpPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
int out_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (const auto* node = GetDeviceCopyNode(it->first)) {
CHECK(node->IsInstance<CallNode>());
last_copy_node = static_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
if (it->second) device_map_.Set(GetRef<Expr>(it->first), attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it->first);
CHECK_EQ(device_map_.count(expr), 0U);
if (it->second) device_map_.Set(expr, cur_dev_type);
int out_dev_type = post_visitor_.out_dev_type_;
for (auto& it : post_visitor_.device_tag_) {
if (it.second != -1) {
device_map_.Set(GetRef<Expr>(it.first), it.second);
} else {
device_map_.Set(GetRef<Expr>(it.first), out_dev_type);
}
}
return out_dev_type;
}

void FillPropagation(int out_dev_type) {
for (const auto& it : post_visitor_.post_dfs_order_) {
Expr expr = GetRef<Expr>(it.first);
if (!it.second) device_map_.Set(expr, out_dev_type);
}
}

PostDfsOrderVisitor post_visitor_;
Expand Down