-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] change device annotation from post DFS to recursive #6124
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -309,6 +309,76 @@ def test_visitor_annotation(): | |
test_visitor_annotation() | ||
|
||
|
||
def test_propogation(): | ||
R""" The network and device type is as following: | ||
x 1 | ||
| | ||
log 1 | ||
/ \ | ||
log2 log10 2 | ||
\ / | ||
add 2 | ||
| | ||
tan 1 | ||
""" | ||
ctx1 = tvm.context(1) | ||
ctx2 = tvm.context(2) | ||
|
||
expected_dev_type = { | ||
'log': ctx1, | ||
'log2': ctx2, | ||
'log10': ctx2, | ||
'add': ctx2, | ||
'tan': ctx1 | ||
} | ||
|
||
x = relay.var("x", shape=(3,)) | ||
|
||
def annotated(): | ||
log = relay.log(x) | ||
_log = relay.annotation.on_device(log, expected_dev_type['log']) | ||
log2 = relay.log2(_log) | ||
_log2 = relay.annotation.on_device(log2, expected_dev_type['log2']) | ||
log10 = relay.log10(_log) | ||
_log10 = relay.annotation.on_device(log10, expected_dev_type['log10']) | ||
add = relay.add(_log2, _log10) | ||
_add = relay.annotation.on_device(add, expected_dev_type['add']) | ||
tan = relay.tan(_add) | ||
_tan = relay.annotation.on_device(tan, expected_dev_type['tan']) | ||
|
||
func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(ctx1.device_type)) | ||
return func | ||
|
||
def expected(): | ||
log = relay.log(x) | ||
_log_left = relay.device_copy(log, ctx1, ctx2) | ||
_log_right = relay.device_copy(log, ctx1, ctx2) | ||
log2 = relay.log2(_log_left) | ||
log10 = relay.log10(_log_right) | ||
add = relay.add(log2, log10) | ||
_add = relay.device_copy(add, ctx2, ctx1) | ||
tan = relay.tan(_add) | ||
|
||
func = run_opt_pass(tan, transform.InferType()) | ||
return func | ||
|
||
annotated_expr = annotated() | ||
expected_expr = expected() | ||
assert tvm.ir.structural_equal(annotated_expr, expected_expr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm curious that master passes this check, but fails on line 377. Why doesn't structural equal properly resolve the error in the device copy op? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think master fails on line 379, right? the device type of Up to this line, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 You're right, I misread my first test. |
||
|
||
smap = relay.backend._backend.GraphPlanMemory(annotated_expr) | ||
for expr, storage_dev_type in smap.items(): | ||
# x is ctx1 as output is ctx1 | ||
if isinstance(expr, tvm.relay.expr.Var): | ||
assert storage_dev_type[1][0] == ctx1.device_type | ||
else: | ||
# device_copy op should be its dst_dev_type | ||
if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs): | ||
assert storage_dev_type[1][0] == expr.attrs.dst_dev_type | ||
else: | ||
assert storage_dev_type[1][0] == expected_dev_type[expr.op.name].device_type | ||
|
||
|
||
def run_fusible_network(dev, tgt): | ||
R""" The network is as following: | ||
x y | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love making this class state, precisely because you have to do crazy things with maintaining state in your recursive calls. You could do it as a set of recursive arguments, but that kind of requires re-implementing with ExprFunctor...so maybe this is the cleanest solution.