diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 19d2992c3006..b3f22e00fda4 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -385,9 +385,7 @@ class DeviceInfo { // TODO(zhiics) Skip annotation of function node for now. } - void VisitExpr_(const ConstantNode* cn) final { - device_tag_[cn] = dev_type_; - } + void VisitExpr_(const ConstantNode* cn) final { device_tag_[cn] = dev_type_; } void VisitExpr_(const CallNode* call) final { // Skip annotation nodes. @@ -426,9 +424,7 @@ class DeviceInfo { void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); } - void VisitExpr_(const VarNode* vn) final { - device_tag_[vn] = dev_type_; - } + void VisitExpr_(const VarNode* vn) final { device_tag_[vn] = dev_type_; } void VisitExpr_(const LetNode* ln) final { ExprVisitor::VisitExpr_(ln); diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 0ecb2b522103..e95708007b95 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -309,6 +309,76 @@ def test_visitor_annotation(): test_visitor_annotation() +def test_propogation(): + R""" The network and device type is as following: + x 1 + | + log 1 + / \ + log2 log10 2 + \ / + add 2 + | + tan 1 + """ + ctx1 = tvm.context(1) + ctx2 = tvm.context(2) + + expected_dev_type = { + 'log': ctx1, + 'log2': ctx2, + 'log10': ctx2, + 'add': ctx2, + 'tan': ctx1 + } + + x = relay.var("x", shape=(3,)) + + def annotated(): + log = relay.log(x) + _log = relay.annotation.on_device(log, expected_dev_type['log']) + log2 = relay.log2(_log) + _log2 = relay.annotation.on_device(log2, expected_dev_type['log2']) + log10 = relay.log10(_log) + _log10 = relay.annotation.on_device(log10, expected_dev_type['log10']) + add = relay.add(_log2, _log10) + _add = relay.annotation.on_device(add, expected_dev_type['add']) + tan = relay.tan(_add) + _tan = relay.annotation.on_device(tan, expected_dev_type['tan']) + + func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(ctx1.device_type)) + return func + + def expected(): + log = relay.log(x) + _log_left = relay.device_copy(log, ctx1, ctx2) + _log_right = relay.device_copy(log, ctx1, ctx2) + log2 = relay.log2(_log_left) + log10 = relay.log10(_log_right) + add = relay.add(log2, log10) + _add = relay.device_copy(add, ctx2, ctx1) + tan = relay.tan(_add) + + func = run_opt_pass(tan, transform.InferType()) + return func + + annotated_expr = annotated() + expected_expr = expected() + assert tvm.ir.structural_equal(annotated_expr, expected_expr) + + smap = relay.backend._backend.GraphPlanMemory(annotated_expr) + for expr, storage_dev_type in smap.items(): + # x is ctx1 as output is ctx1 + if isinstance(expr, tvm.relay.expr.Var): + assert storage_dev_type[1][0] == ctx1.device_type + else: + # device_copy op should be its dst_dev_type + if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs): + assert storage_dev_type[1][0] == expr.attrs.dst_dev_type + else: + assert storage_dev_type[1][0] == expected_dev_type[expr.op.name].device_type + + def run_fusible_network(dev, tgt): R""" The network is as following: x y