-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Phi] add the infer shape meta for the graph_send_recv #40320
[Phi] add the infer shape meta for the graph_send_recv #40320
Conversation
Thanks for your contribution! |
paddle/phi/infermeta/multiary.h
Outdated
@@ -75,4 +75,10 @@ void AdadeltaInferMeta(const MetaTensor& param, | |||
MetaTensor* avg_squared_grad_out, | |||
MetaTensor* avg_squared_update_out); | |||
|
|||
void GraphSendRecvInferMeta(const MetaTensor& x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
三个输入的infermeta 函数放到 ternary.h 文件中。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
三个输入的infermeta 函数放到 ternary.h 文件中。
done
paddle/phi/infermeta/multiary.cc
Outdated
"Src_index and Dst_index should have the same shape.")); | ||
|
||
auto dims = x.dims(); | ||
out->set_dims(dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out 也需要设置一下dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out 也需要设置一下dtype
done
paddle/phi/infermeta/multiary.cc
Outdated
out->set_dims(dims); | ||
|
||
if (pool_type == "MEAN") { | ||
dst_count->set_dims({dims[0]}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dst_count 也需要设置一下dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dst_count 也需要设置一下dtype
done
… add_infer_meta_graph_send_recv
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
OPs
Describe
add the infer shape meta for the graph_send_recv