-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] change device annotation from post DFS to recursive #6124
Conversation
cc @mbrookhart |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a test showing the problem/desired change in behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @zhanghaohit for this PR, I second @mbrookhart's request to add a test case.
Thanks @mbrookhart and @tmoreau89 for the suggestion. I've added a test, which would fail in the original code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to fix lint. A couple of nitpicks here, but overall, I think it looks good now. Kind of wondering why the unit test doesn't fail earlier.
int dev_type_ = -1; | ||
int out_dev_type_ = -1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love making this class state, precisely because you have to do crazy things with maintaining state in your recursive calls. You could do it as a set of recursive arguments, but that kind of requires re-implementing with ExprFunctor...so maybe this is the cleanest solution.
|
||
annotated_expr = annotated() | ||
expected_expr = expected() | ||
assert tvm.ir.structural_equal(annotated_expr, expected_expr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious that master passes this check, but fails on line 377. Why doesn't structural equal properly resolve the error in the device copy op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think master fails on line 379, right? the device type of log2
is not correctly marked.
Up to this line, annotated_expr
and expected_expr
are exactly the same. The device_copy
op is inserted correctly. We haven't go through the device propagation yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 You're right, I misread my first test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Flaky test?
|
||
annotated_expr = annotated() | ||
expected_expr = expected() | ||
assert tvm.ir.structural_equal(annotated_expr, expected_expr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 You're right, I misread my first test.
please trigger the ci again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM👍
5a47476
to
61a9bc8
Compare
@tmoreau89 Could you help merge this PR? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thank you @zhanghaohit , @mbrookhart, @junrushao1994, @zhiics the PR has been merged. |
This is related to #5840 and split from PR #5842
Originally, device type is propagated based on the post DFS traversed graph, which may not be consistent if the argument order changes. In addition, it may handle some cases wrongly, e.g., the first residual block in Resnet50. The first few layers in Resnet50 are depicted in the following figure (top to bottom is in DFS order). Basically, we want to let all the layers run on FPGA device, except the first and last few layers. In the original device propagation algorithm, based on the post DFS order, the conv2d layers in grey will be propagated with
CPU
device type as we encountercopy2
first, following which the three grey conv2d nodes are marked as the source device type ofcopy2
(i.e.,CPU
), which is not correct.By change the device annotation behaviour, we can support more complex graph structure.