-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[oneDNN] Pool softmax and LRN access to cache optimized #32922
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -14,21 +14,105 @@ limitations under the License. */ | |||||
|
||||||
#include "paddle/fluid/platform/mkldnn_reuse.h" | ||||||
|
||||||
namespace paddle { | ||||||
namespace framework { | ||||||
class Tensor; | ||||||
} // namespace framework | ||||||
namespace platform { | ||||||
class MKLDNNDeviceContext; | ||||||
} // namespace platform | ||||||
} // namespace paddle | ||||||
|
||||||
namespace paddle { | ||||||
namespace operators { | ||||||
|
||||||
using paddle::framework::Tensor; | ||||||
using paddle::platform::MKLDNNDeviceContext; | ||||||
|
||||||
template <typename T> | ||||||
class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, | ||||||
mkldnn::lrn_backward> { | ||||||
public: | ||||||
LRNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, | ||||||
const MKLDNNDeviceContext& dev_ctx, | ||||||
const mkldnn::engine mkldnn_engine, | ||||||
platform::Place cpu_place, const Tensor* input, | ||||||
const std::string& unique_name) | ||||||
|
||||||
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>( | ||||||
dev_ctx, mkldnn_engine, cpu_place, | ||||||
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), | ||||||
unique_name)) { | ||||||
if (!this->isCachedNonBlocking()) { | ||||||
const int n = ctx.Attr<int>("n"); | ||||||
// MKL-DNN implements LRN in a caffe way: | ||||||
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html | ||||||
// Where sum of squares is divided by size of normalization window | ||||||
// this is not the case for PaddlePaddle LRN. | ||||||
// Hence we need to compensate for this diffrence by | ||||||
// multipliing alpha by size of window(n) | ||||||
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n); | ||||||
const float beta = ctx.Attr<float>("beta"); | ||||||
const float k = ctx.Attr<float>("k"); | ||||||
bool is_test = ctx.Attr<bool>("is_test"); | ||||||
|
||||||
auto dims = paddle::framework::vectorize(input->dims()); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||||||
|
||||||
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), | ||||||
input->format()); | ||||||
|
||||||
this->AcquireForwardPrimitiveDescriptorNonBlocking( | ||||||
is_test ? mkldnn::prop_kind::forward_inference | ||||||
: mkldnn::prop_kind::forward_training, | ||||||
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k); | ||||||
} | ||||||
} | ||||||
|
||||||
LRNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||||||
const platform::MKLDNNDeviceContext& dev_ctx, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
because you have "using paddle::platform::MKLDNNDeviceContext;" in line 30 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||||||
platform::Place cpu_place, const Tensor* in_x, | ||||||
const Tensor* out_grad, Tensor* in_x_grad, | ||||||
const std::string& unique_name) | ||||||
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>( | ||||||
dev_ctx, dev_ctx.GetEngine(), cpu_place, | ||||||
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()), | ||||||
unique_name)) { | ||||||
if (!this->isBwdCached()) { | ||||||
PADDLE_ENFORCE_EQ( | ||||||
ctx.Attr<bool>("is_test"), false, | ||||||
platform::errors::PreconditionNotMet( | ||||||
"is_test attribute should be set to False in training phase.")); | ||||||
|
||||||
const int n = ctx.Attr<int>("n"); | ||||||
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n); | ||||||
const float beta = ctx.Attr<float>("beta"); | ||||||
const float k = ctx.Attr<float>("k"); | ||||||
|
||||||
auto dims = paddle::framework::vectorize<int64_t>(in_x->dims()); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||||||
|
||||||
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), | ||||||
in_x->format()); | ||||||
auto diff_md = mkldnn::memory::desc( | ||||||
dims, platform::MKLDNNGetDataType<T>(), out_grad->format()); | ||||||
|
||||||
this->AcquireForwardPrimitiveDescriptorNonBlocking( | ||||||
mkldnn::prop_kind::forward_training, | ||||||
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k); | ||||||
|
||||||
this->AcquireBackwardPrimitiveDescriptorNonBlocking( | ||||||
mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha, | ||||||
beta, k); | ||||||
} | ||||||
} | ||||||
|
||||||
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory( | ||||||
framework::Tensor* workspace) { | ||||||
T* ptr = workspace->mutable_data<T>( | ||||||
this->place_, this->fwd_pd_->workspace_desc().get_size()); | ||||||
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->workspace_desc(), | ||||||
ptr, "@wrk_mem_p"); | ||||||
} | ||||||
|
||||||
std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory( | ||||||
const framework::Tensor* workspace) { | ||||||
const T* workspace_data = workspace->data<T>(); | ||||||
return this->AcquireMemoryFromPrimitive( | ||||||
this->fwd_pd_->workspace_desc(), | ||||||
platform::to_void_cast<T>(workspace_data), "@bwd-wrk_mem_p"); | ||||||
} | ||||||
}; | ||||||
|
||||||
template <typename T> | ||||||
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { | ||||||
public: | ||||||
|
@@ -48,8 +132,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { | |||||
auto out = ctx.Output<Tensor>("Out"); | ||||||
auto mid = ctx.Output<Tensor>("MidOut"); | ||||||
|
||||||
platform::LRNMKLDNNHandler<T> handler( | ||||||
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, ctx.OutputName("Out")); | ||||||
LRNMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, | ||||||
ctx.OutputName("Out")); | ||||||
|
||||||
auto src_memory = handler.AcquireSrcMemory(x); | ||||||
auto dst_memory = handler.AcquireDstMemory(out); | ||||||
|
@@ -87,34 +171,22 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { | |||||
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, | ||||||
paddle::platform::errors::PreconditionNotMet( | ||||||
"Operator DNNL LRNGrad must use CPUPlace")); | ||||||
PADDLE_ENFORCE_EQ( | ||||||
ctx.Attr<bool>("is_test"), false, | ||||||
platform::errors::PreconditionNotMet( | ||||||
"is_test attribute should be set to False in training phase.")); | ||||||
|
||||||
auto x = ctx.Input<Tensor>("X"); | ||||||
auto in_x = ctx.Input<Tensor>("X"); | ||||||
auto mid = ctx.Input<Tensor>("MidOut"); | ||||||
|
||||||
auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||||||
auto x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); | ||||||
|
||||||
const int n = ctx.Attr<int>("n"); | ||||||
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n); | ||||||
const float beta = ctx.Attr<float>("beta"); | ||||||
const float k = ctx.Attr<float>("k"); | ||||||
auto in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); | ||||||
|
||||||
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); | ||||||
|
||||||
auto dims = paddle::framework::vectorize<int64_t>(x->dims()); | ||||||
LRNMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), in_x, out_grad, | ||||||
in_x_grad, ctx.InputName("Out")); | ||||||
|
||||||
platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(), | ||||||
out_grad->format(), dev_ctx, | ||||||
ctx.GetPlace(), ctx.InputName("Out")); | ||||||
|
||||||
auto src_memory = handler.AcquireSrcMemory(x); | ||||||
auto src_memory = handler.AcquireSrcMemory(in_x); | ||||||
auto workspace = handler.AcquireBackwardWorkspaceMemory(mid); | ||||||
auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad); | ||||||
auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad); | ||||||
auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad); | ||||||
|
||||||
auto lrn_bwd = handler.AcquireBackwardPrimitive(); | ||||||
|
||||||
|
@@ -125,8 +197,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { | |||||
{MKLDNN_ARG_WORKSPACE, *workspace}}); | ||||||
astream.wait(); | ||||||
|
||||||
x_grad->set_layout(framework::DataLayout::kMKLDNN); | ||||||
x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory)); | ||||||
in_x_grad->set_layout(framework::DataLayout::kMKLDNN); | ||||||
in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory)); | ||||||
} | ||||||
}; | ||||||
} // namespace operators | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok