diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 39cf563f730a6..b3f22e00fda44 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -385,23 +385,34 @@ class DeviceInfo { // TODO(zhiics) Skip annotation of function node for now. } - void VisitExpr_(const ConstantNode* cn) final { - post_dfs_order_.push_back(std::make_pair(cn, has_copy_)); - } + void VisitExpr_(const ConstantNode* cn) final { device_tag_[cn] = dev_type_; } void VisitExpr_(const CallNode* call) final { // Skip annotation nodes. if (!IsOnDeviceNode(call)) { - if (GetDeviceCopyNode(call)) { + if (const auto* node = GetDeviceCopyNode(call)) { + CHECK(node->IsInstance()); + const auto* call_node = static_cast(node); + auto attrs = call_node->attrs.as(); + num_device_copy_ops_++; - bool has_copy_prev = has_copy_; - has_copy_ = true; - ExprVisitor::VisitExpr_(call); - post_dfs_order_.push_back(std::make_pair(call, has_copy_)); - has_copy_ = has_copy_prev; + dev_type_ = attrs->src_dev_type; + for (auto& arg : call->args) { + Visit(arg); + // restore the type for remaining arguments + dev_type_ = attrs->src_dev_type; + } + device_tag_[call] = attrs->dst_dev_type; + // update the out_dev_type_, which should be the dst_dev_type of last copy + out_dev_type_ = attrs->dst_dev_type; } else { - ExprVisitor::VisitExpr_(call); - post_dfs_order_.push_back(std::make_pair(call, has_copy_)); + for (auto& arg : call->args) { + int cur_dev_type = dev_type_; + Visit(arg); + // restore the type for remaining arguments + dev_type_ = cur_dev_type; + } + device_tag_[call] = dev_type_; } } } @@ -413,23 +424,22 @@ class DeviceInfo { void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); } - void VisitExpr_(const VarNode* vn) final { - post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); - } + void VisitExpr_(const VarNode* vn) final { device_tag_[vn] = dev_type_; } void VisitExpr_(const LetNode* ln) final { ExprVisitor::VisitExpr_(ln); - post_dfs_order_.push_back(std::make_pair(ln, has_copy_)); + device_tag_[ln] = dev_type_; } void VisitExpr_(const IfNode* in) final { ExprVisitor::VisitExpr_(in); - post_dfs_order_.push_back(std::make_pair(in, has_copy_)); + device_tag_[in] = dev_type_; } int num_device_copy_ops_{0}; - bool has_copy_ = false; - std::vector> post_dfs_order_; + int dev_type_ = -1; + int out_dev_type_ = -1; + std::unordered_map device_tag_; friend DeviceInfo; }; @@ -455,39 +465,14 @@ class DeviceInfo { } void PropagateDeviceId() { - // Bottom-up propagation. - int out_dev_type = BottomUpPropagation(); - // propagation for remained nodes. - FillPropagation(out_dev_type); - } - - int BottomUpPropagation() { - const CallNode* last_copy_node = nullptr; - int cur_dev_type = -1; - int out_dev_type = -1; - for (auto it = post_visitor_.post_dfs_order_.crbegin(); - it != post_visitor_.post_dfs_order_.crend(); ++it) { - if (const auto* node = GetDeviceCopyNode(it->first)) { - CHECK(node->IsInstance()); - last_copy_node = static_cast(node); - const auto* attrs = last_copy_node->attrs.as(); - cur_dev_type = attrs->src_dev_type; - if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type; - if (it->second) device_map_.Set(GetRef(it->first), attrs->dst_dev_type); - } else if (last_copy_node) { - Expr expr = GetRef(it->first); - CHECK_EQ(device_map_.count(expr), 0U); - if (it->second) device_map_.Set(expr, cur_dev_type); + int out_dev_type = post_visitor_.out_dev_type_; + for (auto& it : post_visitor_.device_tag_) { + if (it.second != -1) { + device_map_.Set(GetRef(it.first), it.second); + } else { + device_map_.Set(GetRef(it.first), out_dev_type); } } - return out_dev_type; - } - - void FillPropagation(int out_dev_type) { - for (const auto& it : post_visitor_.post_dfs_order_) { - Expr expr = GetRef(it.first); - if (!it.second) device_map_.Set(expr, out_dev_type); - } } PostDfsOrderVisitor post_visitor_; diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 0ecb2b5221039..e95708007b955 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -309,6 +309,76 @@ def test_visitor_annotation(): test_visitor_annotation() +def test_propogation(): + R""" The network and device type is as following: + x 1 + | + log 1 + / \ + log2 log10 2 + \ / + add 2 + | + tan 1 + """ + ctx1 = tvm.context(1) + ctx2 = tvm.context(2) + + expected_dev_type = { + 'log': ctx1, + 'log2': ctx2, + 'log10': ctx2, + 'add': ctx2, + 'tan': ctx1 + } + + x = relay.var("x", shape=(3,)) + + def annotated(): + log = relay.log(x) + _log = relay.annotation.on_device(log, expected_dev_type['log']) + log2 = relay.log2(_log) + _log2 = relay.annotation.on_device(log2, expected_dev_type['log2']) + log10 = relay.log10(_log) + _log10 = relay.annotation.on_device(log10, expected_dev_type['log10']) + add = relay.add(_log2, _log10) + _add = relay.annotation.on_device(add, expected_dev_type['add']) + tan = relay.tan(_add) + _tan = relay.annotation.on_device(tan, expected_dev_type['tan']) + + func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(ctx1.device_type)) + return func + + def expected(): + log = relay.log(x) + _log_left = relay.device_copy(log, ctx1, ctx2) + _log_right = relay.device_copy(log, ctx1, ctx2) + log2 = relay.log2(_log_left) + log10 = relay.log10(_log_right) + add = relay.add(log2, log10) + _add = relay.device_copy(add, ctx2, ctx1) + tan = relay.tan(_add) + + func = run_opt_pass(tan, transform.InferType()) + return func + + annotated_expr = annotated() + expected_expr = expected() + assert tvm.ir.structural_equal(annotated_expr, expected_expr) + + smap = relay.backend._backend.GraphPlanMemory(annotated_expr) + for expr, storage_dev_type in smap.items(): + # x is ctx1 as output is ctx1 + if isinstance(expr, tvm.relay.expr.Var): + assert storage_dev_type[1][0] == ctx1.device_type + else: + # device_copy op should be its dst_dev_type + if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs): + assert storage_dev_type[1][0] == expr.attrs.dst_dev_type + else: + assert storage_dev_type[1][0] == expected_dev_type[expr.op.name].device_type + + def run_fusible_network(dev, tgt): R""" The network is as following: x y