-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[oneDNN] Refactoring of softmax grad onednn kernel to match common API #32851
[oneDNN] Refactoring of softmax grad onednn kernel to match common API #32851
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also lines 18-25 and 30-31 are redundant. Both are getting Tensor and MKLDNNDeviceContext, but in two different ways.
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out")); | ||
auto* dx = | ||
auto* out_grad = ctx.template Input<Tensor>(framework::GradVarName("Out")); | ||
auto* in_x_grad = | ||
ctx.template Output<framework::Tensor>(framework::GradVarName("X")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ctx.template Output<framework::Tensor>(framework::GradVarName("X")); | |
ctx.template Output<Tensor>(framework::GradVarName("X")); |
Please, stay consistent with the usage of namespaces
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
|
||
auto dims = out_grad->dims(); // input and output share the same shape | ||
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size()); | ||
auto softmax_tz = paddle::framework::vectorize<int64_t>(dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto softmax_tz = paddle::framework::vectorize<int64_t>(dims); | |
auto softmax_tz = framework::vectorize<int64_t>(dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims, | ||
const MKLDNNMemoryFormat fmt, | ||
const MKLDNNMemoryFormat diff_fmt, const int& axis, | ||
SoftmaxMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SoftmaxMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, | |
SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims, | ||
const MKLDNNMemoryFormat fmt, | ||
const MKLDNNMemoryFormat diff_fmt, const int& axis, | ||
SoftmaxMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, | ||
const platform::MKLDNNDeviceContext& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const platform::MKLDNNDeviceContext& dev_ctx, | |
const MKLDNNDeviceContext& dev_ctx, |
Since you have using "paddle::platform::MKLDNNDeviceContext;" in line 33, you don't need to declare this namespace here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size()); | ||
auto softmax_tz = paddle::framework::vectorize<int64_t>(dims); | ||
|
||
auto data_softmax_md = platform::MKLDNNMemDesc( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto data_softmax_md = platform::MKLDNNMemDesc( | |
auto data_softmax_md = MKLDNNMemDesc( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
|
||
auto data_softmax_md = platform::MKLDNNMemDesc( | ||
softmax_tz, platform::MKLDNNGetDataType<T>(), out->format()); | ||
auto diff_softmax_md = platform::MKLDNNMemDesc( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto diff_softmax_md = platform::MKLDNNMemDesc( | |
auto diff_softmax_md = MKLDNNMemDesc( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
auto softmax_tz = framework::vectorize<int64_t>(dims); | ||
|
||
auto data_softmax_md = MKLDNNMemDesc( | ||
softmax_tz, platform::MKLDNNGetDataType<T>(), out->format()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have one doubt:
out->format and out_grad->format will be "NHWC" or "NCHW"? which is Paddle format. In this case if next op is also mkldnn op, is reorder needed ?
@luotao1 Could you please start your review? |
PR types
Function optimization
PR changes
OPs
Describe
This PR modifies softmax grad oneDNN kernel so its implementation of other oneDNN grad kernels. This is needed for bigger changes that will come in next PRs