Skip to content

Commit

Permalink
add testcast for recursive device propogation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghaohit committed Aug 6, 2020
1 parent 90fd7c4 commit 61a9bc8
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/relay/transforms/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,7 @@ class DeviceInfo {
// TODO(zhiics) Skip annotation of function node for now.
}

void VisitExpr_(const ConstantNode* cn) final {
device_tag_[cn] = dev_type_;
}
void VisitExpr_(const ConstantNode* cn) final { device_tag_[cn] = dev_type_; }

void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
Expand Down Expand Up @@ -426,9 +424,7 @@ class DeviceInfo {

void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); }

void VisitExpr_(const VarNode* vn) final {
device_tag_[vn] = dev_type_;
}
void VisitExpr_(const VarNode* vn) final { device_tag_[vn] = dev_type_; }

void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
Expand Down
70 changes: 70 additions & 0 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,76 @@ def test_visitor_annotation():
test_visitor_annotation()


def test_propogation():
R""" The network and device type is as following:
x 1
|
log 1
/ \
log2 log10 2
\ /
add 2
|
tan 1
"""
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)

expected_dev_type = {
'log': ctx1,
'log2': ctx2,
'log10': ctx2,
'add': ctx2,
'tan': ctx1
}

x = relay.var("x", shape=(3,))

def annotated():
log = relay.log(x)
_log = relay.annotation.on_device(log, expected_dev_type['log'])
log2 = relay.log2(_log)
_log2 = relay.annotation.on_device(log2, expected_dev_type['log2'])
log10 = relay.log10(_log)
_log10 = relay.annotation.on_device(log10, expected_dev_type['log10'])
add = relay.add(_log2, _log10)
_add = relay.annotation.on_device(add, expected_dev_type['add'])
tan = relay.tan(_add)
_tan = relay.annotation.on_device(tan, expected_dev_type['tan'])

func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(ctx1.device_type))
return func

def expected():
log = relay.log(x)
_log_left = relay.device_copy(log, ctx1, ctx2)
_log_right = relay.device_copy(log, ctx1, ctx2)
log2 = relay.log2(_log_left)
log10 = relay.log10(_log_right)
add = relay.add(log2, log10)
_add = relay.device_copy(add, ctx2, ctx1)
tan = relay.tan(_add)

func = run_opt_pass(tan, transform.InferType())
return func

annotated_expr = annotated()
expected_expr = expected()
assert tvm.ir.structural_equal(annotated_expr, expected_expr)

smap = relay.backend._backend.GraphPlanMemory(annotated_expr)
for expr, storage_dev_type in smap.items():
# x is ctx1 as output is ctx1
if isinstance(expr, tvm.relay.expr.Var):
assert storage_dev_type[1][0] == ctx1.device_type
else:
# device_copy op should be its dst_dev_type
if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs):
assert storage_dev_type[1][0] == expr.attrs.dst_dev_type
else:
assert storage_dev_type[1][0] == expected_dev_type[expr.op.name].device_type


def run_fusible_network(dev, tgt):
R""" The network is as following:
x y
Expand Down

0 comments on commit 61a9bc8

Please sign in to comment.