-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Amendments for gradients #5941
Amendments for gradients #5941
Conversation
- We fix the dtype handling of consts in generated gradients. - We add a collapse_sum_to instruction mirroring the collapse_sum_like. While for general definitions (potentially dynamic shapes), collapse_sum_like is the first choice, when moving to static, using collapse_sum_to will greatly simplify the graph. (This simplification is not part of the PR.)
@MarisaKirisame @tqchen If I can interest you in this. I have more gradient work coming up. |
@junrushao1994 too. :) |
src/relay/op/tensor/transform.cc
Outdated
@@ -1713,6 +1713,54 @@ RELAY_REGISTER_OP("collapse_sum_like") | |||
.set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute) | |||
.set_attr<TOpPattern>("TOpPattern", kCommReduce); | |||
|
|||
// CollapseSumTo: <A, B> -> B where CollapseSumTo(A, B) = B |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused by this line. BroadCast(A, B) = B?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, right. But Broadcast(A, B) = A right? Thanks for spotting this. I must admit that I'm not 100% sure I understand the notation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BroadCast is a symmetric function that take two tensor type A, B, and return the broadcast type (I am confusing the type and term level a bit.). I am just listing constraint between the two elements with the equation after where.
what constraint does it exist between two argument of CollapseSumTo? If I get it I can 'collapse' A into B, which mean the type B can be broadcast to A right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment above, but I think it is good.
Sounds cool. I assume you have plan for a PR that implement a pass to turn |
* Amendments for gradients - We fix the dtype handling of consts in generated gradients. - We add a collapse_sum_to instruction mirroring the collapse_sum_like. While for general definitions (potentially dynamic shapes), collapse_sum_like is the first choice, when moving to static, using collapse_sum_to will greatly simplify the graph. (This simplification is not part of the PR.) * Fix Broadcast rel description in comment Thank you, @MarisaKirisame
* Amendments for gradients - We fix the dtype handling of consts in generated gradients. - We add a collapse_sum_to instruction mirroring the collapse_sum_like. While for general definitions (potentially dynamic shapes), collapse_sum_like is the first choice, when moving to static, using collapse_sum_to will greatly simplify the graph. (This simplification is not part of the PR.) * Fix Broadcast rel description in comment Thank you, @MarisaKirisame
While for general definitions (potentially dynamic shapes),
collapse_sum_like is the first choice, when moving to static,
using collapse_sum_to will greatly simplify the graph.
(This simplification is not part of the PR.)