forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added elementwise_sub_mkldnn operator (PaddlePaddle#35662)
* Add elementwise_sub_mkldnn_op without grad * Add test to static_mode_white_list * Refactor code, change license years * Remove invalid grad implementation * Fix element_wise_sub_op test * Fix CI Approval error * Remove unnecessary EltwiseSubMKLDNNGradKernel class * Fix CI Approval 2 * Fix CI Approval 3 * Fix CI Approval Attempt #4 * Fix CI Approve Attempt #5 * Fix CI Approval Attempt #6 * Fix CI Approval Attemt #7 * Change test names containing add to sub * Fix old tests testing add instead of sub * Copy grad implementation from elementwise_add_mkldnn * CI test fix attempt * Revert "CI test fix attempt" This reverts commit c647cacf41e6a87c715385a185de5cbf65fc8900. * Fix CI attempt 2 * Fix elementwise_sub tests, temporary mkldnn broadcast test disable * Add working implementation of elementwise_sub grad * Fix build errors caused by pull * Fix format error * Fix format error 2 * Disable elementwise_sub_mkldnn test on GPU * Apply fix for paddle.fluid import * Revert changes of test_elementwise_sub and Fix mkldnn test * Revert "Apply fix for paddle.fluid import" This reverts commit fc3b122. * fix bug of module 'paddle' has no attribute 'fluid' for python3.6 (PaddlePaddle#35862) * Add changes suggested by reviewers * Change @unittest.skipIf... to @OpTestTool.skip_if_not_cpu_bf16() to satisfy Approval CI * Remove check_dygraph=False to satisify CI Approval Co-authored-by: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com>
- Loading branch information
1 parent
1691dc7
commit 787273e
Showing
4 changed files
with
380 additions
and
5 deletions.
There are no files selected for viewing
132 changes: 132 additions & 0 deletions
132
paddle/fluid/operators/elementwise/mkldnn/elementwise_sub_mkldnn_op.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
|
||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" | ||
namespace paddle { | ||
namespace framework { | ||
class ExecutionContext; | ||
} // namespace framework | ||
namespace platform { | ||
class CPUDeviceContext; | ||
struct CPUPlace; | ||
} // namespace platform | ||
} // namespace paddle | ||
|
||
namespace paddle { | ||
namespace operators { | ||
template <typename T> | ||
class EltwiseSubMKLDNNGradKernel : public ElemwiseGradKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
ElemwiseGradKernel<T>::Compute(ctx); | ||
using Tensor = framework::Tensor; | ||
|
||
auto& dev_ctx = | ||
ctx.template device_context<platform::MKLDNNDeviceContext>(); | ||
const auto& onednn_engine = dev_ctx.GetEngine(); | ||
|
||
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
|
||
auto tz = framework::vectorize<int64_t>(dout->dims()); | ||
memory::data_type dout_type = framework::ToMKLDNNDataType(dout->type()); | ||
platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type, | ||
onednn_engine); | ||
|
||
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); | ||
auto reorder_src_memory_p = handler.AcquireSrcMemory( | ||
dout->format(), platform::to_void_cast(dout->data<T>())); | ||
|
||
if (dx) { | ||
auto reorder_dst_memory_p = | ||
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace()); | ||
auto reorder_p = | ||
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); | ||
platform::RecordEvent record_reorder("int_reorder", | ||
platform::EventRole::kUniqueOp); | ||
|
||
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); | ||
astream.wait(); | ||
|
||
dx->set_layout(DataLayout::kMKLDNN); | ||
dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); | ||
} | ||
|
||
if (dy) { | ||
// Direct copy | ||
if (dout->dims() == dy->dims()) { | ||
auto reorder_dst_memory_p = | ||
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace()); | ||
|
||
dnnl::primitive_attr reorder_attr; | ||
std::vector<float> scales = {-1}; | ||
reorder_attr.set_output_scales(0, scales); | ||
auto reorder_p = std::make_shared<dnnl::reorder>( | ||
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr); | ||
platform::RecordEvent record_reorder("int_reorder", | ||
platform::EventRole::kUniqueOp); | ||
reorder_p->execute(astream, *reorder_src_memory_p, | ||
*reorder_dst_memory_p); | ||
astream.wait(); | ||
|
||
dy->set_layout(DataLayout::kMKLDNN); | ||
dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); | ||
} else { | ||
// Broadcasting | ||
|
||
dnnl::post_ops po; | ||
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0); | ||
dnnl::primitive_attr attr; | ||
attr.set_post_ops(po); | ||
|
||
platform::ReductionMKLDNNHandler<T> handler_sum( | ||
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, | ||
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy), attr); | ||
|
||
auto dy_memory_p = handler_sum.AcquireDstMemory(dy); | ||
auto reduction_p = handler_sum.AcquireForwardPrimitive(); | ||
|
||
reduction_p->execute(astream, { | ||
{DNNL_ARG_SRC, *reorder_src_memory_p}, | ||
{DNNL_ARG_DST, *dy_memory_p}, | ||
}); | ||
astream.wait(); | ||
|
||
dy->set_layout(DataLayout::kMKLDNN); | ||
dy->set_format( | ||
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape( | ||
paddle::framework::vectorize<int64_t>(dy->dims())))); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP_KERNEL( | ||
elementwise_sub, MKLDNN, paddle::platform::CPUPlace, | ||
ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_sub>, | ||
ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16, | ||
dnnl::algorithm::binary_sub>, | ||
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_sub>, | ||
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_sub>) | ||
|
||
REGISTER_OP_KERNEL(elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace, | ||
ops::EltwiseSubMKLDNNGradKernel<paddle::platform::bfloat16>, | ||
ops::EltwiseSubMKLDNNGradKernel<float>) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.