From fecdc8616f3ad391e7c0b3bc60a9612b39d7e3df Mon Sep 17 00:00:00 2001 From: Piotr Paturej Date: Thu, 15 Sep 2022 17:48:08 +0200 Subject: [PATCH 1/3] Convert slice+grad oneDNN fluid kernels to PHI --- .../fluid/operators/mkldnn/stack_mkldnn_op.cc | 146 ------------------ .../optimizers/mkldnn/sgd_mkldnn_op.cc | 90 ----------- paddle/phi/backends/CMakeLists.txt | 1 + paddle/phi/backends/onednn/axpy_handler.cc | 133 ++++++++++++++++ paddle/phi/backends/onednn/axpy_handler.h | 60 +++++++ paddle/phi/backends/onednn/onednn_helper.h | 12 +- paddle/phi/backends/onednn/onednn_reuse.h | 8 +- paddle/phi/kernels/onednn/sgd_kernel.cc | 93 +++++++++++ paddle/phi/kernels/onednn/stack_kernel.cc | 128 +++++++++++++++ 9 files changed, 425 insertions(+), 246 deletions(-) delete mode 100644 paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc delete mode 100644 paddle/fluid/operators/optimizers/mkldnn/sgd_mkldnn_op.cc create mode 100644 paddle/phi/backends/onednn/axpy_handler.cc create mode 100644 paddle/phi/backends/onednn/axpy_handler.h create mode 100644 paddle/phi/kernels/onednn/sgd_kernel.cc create mode 100644 paddle/phi/kernels/onednn/stack_kernel.cc diff --git a/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc deleted file mode 100644 index 1e546e44fa241..0000000000000 --- a/paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" -namespace paddle { -namespace operators { - -using dnnl::concat; -using dnnl::memory; -using dnnl::primitive; -using dnnl::stream; -using framework::DataLayout; -using framework::LoDTensor; -using framework::Tensor; -using platform::to_void_cast; - -template -class StackMKLDNNHandler - : public platform::MKLDNNHandlerNoCachingT { - public: - StackMKLDNNHandler(const framework::ExecutionContext& ctx, - const dnnl::engine mkldnn_engine, - const std::vector& inputs, - Tensor* output) - : platform::MKLDNNHandlerNoCachingT(mkldnn_engine, - ctx.GetPlace()) { - int stack_axis = ctx.Attr("axis"); - - int ndims = inputs[0]->dims().size(); - - if (stack_axis < 0) { - stack_axis = ndims + 1 + stack_axis; // +1 to match output's ndims - } - - // in stack op all inputs must have same dims - auto input_dims = phi::vectorize(inputs[0]->dims()); - - memory::data_type dt = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(inputs[0]->dtype())); - std::vector srcs_md; - memory::desc dst_md; - MKLDNNMemoryFormat dst_fmt; - - srcs_md.reserve(inputs.size()); - - // if stack is not done on last(non existing) axis, then we can optimize - // concat primitive by not adding additional dimension, since it causes - // wrong output format deduction and suboptimal performance as a result - if (stack_axis != ndims) { - for (size_t i = 0; i < inputs.size(); ++i) { - srcs_md.push_back(inputs[i]->mem_desc()); - } - - input_dims[stack_axis] *= inputs.size(); - dst_md = memory::desc(input_dims, dt, MKLDNNMemoryFormat::any); - } else { - auto extended_input_dims = phi::vectorize(output->dims()); - extended_input_dims[stack_axis] = 1; - - for (size_t i = 0; i < inputs.size(); ++i) { - srcs_md.push_back(inputs[i]->mem_desc().reshape(extended_input_dims)); - } - - // concat primitive choses suboptimal format tag because it cannot - // distinguish between f.e. abcd and abdc if last dim is equal to 1 so - // enforcing is needed for better performance - dst_fmt = platform::GetPlainMKLDNNFormat(extended_input_dims.size()); - dst_md = memory::desc(phi::vectorize(output->dims()), dt, dst_fmt); - } - - this->AcquireForwardPrimitiveDescriptor(dst_md, stack_axis, srcs_md); - } - - // concat oneDNN prim is not having .desc attribute so we cannot use default - // AcquireForwardPrimitiveDescriptor - void AcquireForwardPrimitiveDescriptor( - const memory::desc& dst_md, - const int stack_axis, - const std::vector& srcs_md) { - this->fwd_pd_.reset(new dnnl::concat::primitive_desc( - dst_md, stack_axis, srcs_md, this->engine_)); - } - - std::shared_ptr AcquireSrcMemory(const Tensor& input, int i) { - const T* input_data = input.data(); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), - to_void_cast(input_data)); - } -}; - -template -class StackMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - auto multi_input = ctx.MultiInput("X"); - - Tensor* output = ctx.Output("Y"); - - StackMKLDNNHandler handler(ctx, mkldnn_engine, multi_input, output); - - std::vector> srcs; - srcs.reserve(multi_input.size()); - - auto dst_mem = handler.AcquireDstMemory(output); - auto concat_p = handler.AcquireForwardPrimitive(); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - std::unordered_map args; - for (size_t i = 0; i < multi_input.size(); ++i) { - srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i)); - args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs.at(i))}); - } - args.insert({DNNL_ARG_DST, *dst_mem}); - - concat_p->execute(astream, args); - astream.wait(); - - output->set_mem_desc( - dst_mem->get_desc().reshape(phi::vectorize(output->dims()))); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL(stack, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::StackMKLDNNOpKernel); diff --git a/paddle/fluid/operators/optimizers/mkldnn/sgd_mkldnn_op.cc b/paddle/fluid/operators/optimizers/mkldnn/sgd_mkldnn_op.cc deleted file mode 100644 index e332972f7576a..0000000000000 --- a/paddle/fluid/operators/optimizers/mkldnn/sgd_mkldnn_op.cc +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/operators/mkldnn/axpy_handler.h" -#include "paddle/fluid/operators/optimizers/sgd_op.h" - -namespace pplat = paddle::platform; - -namespace paddle { -namespace operators { - -template -class SGDOneDNNKernel : public SGDOpKernel { - protected: - void dense_param_and_grad_kernel( - const framework::ExecutionContext &ctx) const override { - VLOG(4) << "[ONEDNN]: sgd_dense_param_kernel"; - const auto *learning_rate = ctx.Input("LearningRate"); - const auto *param = ctx.Input("Param"); - auto *param_out = ctx.Output("ParamOut"); - const auto *grad = ctx.Input("Grad"); - - auto *out_data = param_out->mutable_data(ctx.GetPlace()); - const T *param_data = param->data(); - const auto *grad_data = grad->data(); - const auto *lr = learning_rate->data(); - // Since denese SGD is not in place operation, first copy params to output - // tensor and then update it. - std::memcpy(out_data, param_data, param->memory_size()); - OneDNNAXPYHandler(param_out->numel(), -lr[0])(grad_data, out_data); - } - - void dense_param_sparse_grad_kernel( - const framework::ExecutionContext &ctx) const override { - VLOG(4) << "[ONEDNN]: sgd_dense_param_kernel"; - const auto *learning_rate = ctx.Input("LearningRate"); - auto *param_out = ctx.Output("ParamOut"); - const auto *grad = ctx.Input("Grad"); - - const auto &grad_value = grad->value(); - const auto &grad_rows = grad->rows(); - const auto grad_height = grad->height(); - const int64_t grad_val_height = static_cast(grad_rows.size()); - const auto grad_width = grad_value.numel() / grad_val_height; - - const auto *grad_data = grad_value.data(); - auto *out_data = param_out->data(); - const auto *lr = learning_rate->data(); - - OneDNNAXPYHandler axpy_handler(grad_width, -lr[0]); - - for (size_t i = 0; i < grad_rows.size(); ++i) { - PADDLE_ENFORCE_LT( - grad_rows[i], - grad_height, - pplat::errors::OutOfRange( - "Grad rows index value should be less than grad height." - "Got [%s], but expected less than [%s]", - grad_rows[i], - grad_height)); - const int64_t row = grad_rows[i]; - const auto *src = grad_data + i * grad_width; - auto *dst = out_data + row * grad_width; - axpy_handler(src, dst); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(sgd, - MKLDNN, - pplat::CPUPlace, - ops::SGDOneDNNKernel, - ops::SGDOneDNNKernel); diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index 9a26aed5f341b..9bc9573529241 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -21,6 +21,7 @@ endif() if(WITH_MKLDNN) list(APPEND BACKENDS_SRCS onednn/onednn_context.cc) + list(APPEND BACKENDS_SRCS onednn/axpy_handler.cc) list(APPEND BACKENDS_DEPS mkldnn) endif() diff --git a/paddle/phi/backends/onednn/axpy_handler.cc b/paddle/phi/backends/onednn/axpy_handler.cc new file mode 100644 index 0000000000000..7304815909d15 --- /dev/null +++ b/paddle/phi/backends/onednn/axpy_handler.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/onednn/axpy_handler.h" + +#include +#include +#include +#include + +#include "paddle/phi/backends/onednn/onednn_helper.h" + +namespace phi { +namespace funcs { + +template +class AXPYHandler { + public: + AXPYHandler(const dnnl::engine onednn_engine, int n, float alpha) { + OneDNNContext::tls().log_lib_version(); + auto md = dnnl::memory::desc( + {n}, OneDNNGetDataType(), dnnl::memory::format_tag::x); + src_mem_ = dnnl::memory(md, onednn_engine, DNNL_MEMORY_NONE); + dst_mem_ = dnnl::memory(md, onednn_engine, DNNL_MEMORY_NONE); + dnnl::primitive_attr reorder_attr; + dnnl::post_ops post_operations; + if (alpha != 1.f) { + std::vector scales(1, alpha); + reorder_attr.set_output_scales(0, scales); + } + post_operations.append_sum(1.0f); + + reorder_attr.set_post_ops(post_operations); + reorder_p_ = dnnl::reorder(src_mem_, dst_mem_, reorder_attr); + } + + dnnl::memory &AcquireSrcMemory(const T *x) { + src_mem_.set_data_handle(to_void_cast(x)); + return src_mem_; + } + + dnnl::memory &AcquireDstMemory(T *y) { + dst_mem_.set_data_handle(y); + return dst_mem_; + } + + const dnnl::reorder &AcquireReorder() { return reorder_p_; } + + private: + dnnl::memory src_mem_; + dnnl::memory dst_mem_; + dnnl::reorder reorder_p_; +}; + +template class AXPYHandler; +template class AXPYHandler; + +template +static void naive_axpy(int n, T alpha, const T *x, T *y) { + while (n-- > 0) { + *y += alpha * *x; + ++y; + ++x; + } +} + +template +class OneDNNAXPYHandler::Impl { + public: + Impl(int64_t n, T alpha, const dnnl::engine onednn_engine); + void operator()(const T *x, T *y); + + private: + std::unique_ptr> handler_; + int64_t n_; + T alpha_; +}; + +template +OneDNNAXPYHandler::Impl::Impl(int64_t n, + T alpha, + const dnnl::engine onednn_engine) + : n_{n}, alpha_{alpha} { + handler_ = std::make_unique>( + onednn_engine, n, static_cast(alpha)); +} + +template +void OneDNNAXPYHandler::Impl::operator()(const T *x, T *y) { + if (this->n_ < 100) { + naive_axpy(this->n_, this->alpha_, x, y); + return; + } + + auto &reorder_src_mem_p = handler_->AcquireSrcMemory(x); + auto &reorder_dst_mem_p = handler_->AcquireDstMemory(y); + auto reorder_p = handler_->AcquireReorder(); + auto &astream = OneDNNContext::tls().get_stream(); + reorder_p.execute(astream, reorder_src_mem_p, reorder_dst_mem_p); + astream.wait(); +} + +template +OneDNNAXPYHandler::OneDNNAXPYHandler(int64_t n, + T alpha, + const dnnl::engine onednn_engine) + : pimpl_{new Impl{n, alpha, onednn_engine}, + [](Impl *impl) { delete impl; }} { + VLOG(4) << "[OneDNN] OneDNNAXPYHandler<" << typeid(T).name() << ">, " + << "n: " << n << ", alpha: " << alpha; +} + +template +void OneDNNAXPYHandler::operator()(const T *x, T *y) { + pimpl_->operator()(x, y); +} + +template class OneDNNAXPYHandler; +template class OneDNNAXPYHandler; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/backends/onednn/axpy_handler.h b/paddle/phi/backends/onednn/axpy_handler.h new file mode 100644 index 0000000000000..81c47689de92f --- /dev/null +++ b/paddle/phi/backends/onednn/axpy_handler.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include "dnnl.hpp" // NOLINT + +namespace phi { +namespace funcs { +/// +/// @brief Helper class for AXPY execution using oneDNN library. +/// +/// @tparam T Data type. +/// +template +class OneDNNAXPYHandler { + public: + OneDNNAXPYHandler(OneDNNAXPYHandler&) = delete; + OneDNNAXPYHandler(OneDNNAXPYHandler&&) = delete; + OneDNNAXPYHandler& operator=(OneDNNAXPYHandler&) = delete; + OneDNNAXPYHandler& operator=(OneDNNAXPYHandler&&) = delete; + /// + /// @brief Constructor. + /// + /// @param[in] n The number of elements in tensor (assumed 1D + /// tensor) + /// @param[in] alpha The alpha coefficient. + /// @param[in] onednn_engine The oneDNN engine. + /// + OneDNNAXPYHandler(int64_t n, T alpha, dnnl::engine onednn_engine); + /// + /// @brief Executes AXPY. + /// + /// @param[in] x The pointer to input X tensor data. + /// @param[out] y The pointer to output Y tensor data. + /// + void operator()(const T* x, T* y); + + private: + OneDNNAXPYHandler() = delete; + // (arogowie-intel) Private implementation idiom to hide dependency + // on OneDNN headers. + class Impl; + // We need custom deleter, since the compiler is unable to parameterize + // an allocator's default deleter due to incomple type. + std::unique_ptr pimpl_; +}; +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/backends/onednn/onednn_helper.h b/paddle/phi/backends/onednn/onednn_helper.h index aeaecf7491e61..e91e02282ccc0 100644 --- a/paddle/phi/backends/onednn/onednn_helper.h +++ b/paddle/phi/backends/onednn/onednn_helper.h @@ -96,29 +96,29 @@ inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) { } template -dnnl::memory::data_type oneDNNGetDataType() { +dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::undef; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::f32; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::s32; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::s8; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::u8; } template <> -inline dnnl::memory::data_type oneDNNGetDataType() { +inline dnnl::memory::data_type OneDNNGetDataType() { return dnnl::memory::data_type::bf16; } diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 66376dd883543..6b806748d0ef8 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -834,7 +834,7 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT { src0_md = src0_md.reshape(dims0_ex); } const auto dst_md = - memory::desc(dst_tz, oneDNNGetDataType(), OneDNNMemoryFormat::any); + memory::desc(dst_tz, OneDNNGetDataType(), OneDNNMemoryFormat::any); auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops); @@ -905,7 +905,7 @@ class BroadcastDataOneDNNHandler : OneDNNHandlerNoCachingT(engine, cpu_place) { const auto src0_tz = vectorize(out->dims()); const auto src0_md = dnnl::memory::desc( - src0_tz, oneDNNGetDataType(), GetPlainOneDNNFormat(src0_tz.size())); + src0_tz, OneDNNGetDataType(), GetPlainOneDNNFormat(src0_tz.size())); const auto src1_md = x->mem_desc().reshape(extended_x_dims); dnnl::primitive_attr attributes; @@ -940,7 +940,7 @@ class ReductionOneDNNHandler const dnnl::primitive_attr& attrs = NULL) : OneDNNHandlerNoCachingT(engine, cpu_place) { const auto out_md = memory::desc( - out_tz, oneDNNGetDataType(), dnnl::memory::format_tag::any); + out_tz, OneDNNGetDataType(), dnnl::memory::format_tag::any); if (attrs) this->AcquireForwardPrimitiveDescriptor( @@ -1144,7 +1144,7 @@ class PoolingOneDNNHandler const auto dt = ToOneDNNDataType(in_x->dtype()); auto dst_md = dnnl::memory::desc(diff_dst_tz, dt, OneDNNMemoryFormat::any); auto diff_src_md = dnnl::memory::desc( - diff_src_tz, oneDNNGetDataType(), OneDNNMemoryFormat::any); + diff_src_tz, OneDNNGetDataType(), OneDNNMemoryFormat::any); auto onednn_paddings = ToOneDNNPadding(copied_paddings); diff --git a/paddle/phi/kernels/onednn/sgd_kernel.cc b/paddle/phi/kernels/onednn/sgd_kernel.cc new file mode 100644 index 0000000000000..4750da425f40f --- /dev/null +++ b/paddle/phi/kernels/onednn/sgd_kernel.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sgd_kernel.h" + +#include "paddle/phi/backends/onednn/axpy_handler.h" +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SGDDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + const paddle::optional& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + auto* out_data = param_out->mutable_data(dev_ctx.GetPlace()); + const T* param_data = param.data(); + const auto* grad_data = grad.data(); + const auto* lr = learning_rate.data(); + // Since denese SGD is not in place operation, first copy params to output + // tensor and then update it. + std::memcpy(out_data, param_data, param.memory_size()); + funcs::OneDNNAXPYHandler(param_out->numel(), -lr[0], dev_ctx.GetEngine())( + grad_data, out_data); +} + +template +void SGDDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + const paddle::optional& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + const auto& grad_value = grad.value(); + const auto& grad_rows = grad.rows(); + const auto grad_height = grad.height(); + const int64_t grad_val_height = static_cast(grad_rows.size()); + const auto grad_width = grad_value.numel() / grad_val_height; + + const auto* grad_data = grad_value.data(); + auto* out_data = param_out->data(); + const auto* lr = learning_rate.data(); + + funcs::OneDNNAXPYHandler axpy_handler( + grad_width, -lr[0], dev_ctx.GetEngine()); + + for (size_t i = 0; i < grad_rows.size(); ++i) { + PADDLE_ENFORCE_LT( + grad_rows[i], + grad_height, + errors::OutOfRange( + "Grad rows index value should be less than grad height." + "Got [%s], but expected less than [%s]", + grad_rows[i], + grad_height)); + const int64_t row = grad_rows[i]; + const auto* src = grad_data + i * grad_width; + auto* dst = out_data + row * grad_width; + axpy_handler(src, dst); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + sgd, OneDNN, ALL_LAYOUT, phi::SGDDenseKernel, float, phi::dtype::bfloat16) { +} + +PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, + OneDNN, + ALL_LAYOUT, + phi::SGDDenseParamSparseGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/stack_kernel.cc b/paddle/phi/kernels/onednn/stack_kernel.cc new file mode 100644 index 0000000000000..a39f48bdc3cb2 --- /dev/null +++ b/paddle/phi/kernels/onednn/stack_kernel.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/stack_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +namespace funcs { +template +class StackOneDNNHandler : public OneDNNHandlerNoCachingT { + public: + StackOneDNNHandler(const Place& cpu_place, + int stack_axis, + const dnnl::engine onednn_engine, + const std::vector& inputs, + DenseTensor* output) + : funcs::OneDNNHandlerNoCachingT(onednn_engine, + cpu_place) { + int ndims = inputs[0]->dims().size(); + + if (stack_axis < 0) { + stack_axis = ndims + 1 + stack_axis; // +1 to match output's ndims + } + + // in stack op all inputs must have same dims + auto input_dims = vectorize(inputs[0]->dims()); + + dnnl::memory::data_type dt = ToOneDNNDataType(inputs[0]->dtype()); + std::vector srcs_md; + dnnl::memory::desc dst_md; + OneDNNMemoryFormat dst_fmt; + + srcs_md.reserve(inputs.size()); + + // if stack is not done on last(non existing) axis, then we can optimize + // concat primitive by not adding additional dimension, since it causes + // wrong output format deduction and suboptimal performance as a result + if (stack_axis != ndims) { + for (size_t i = 0; i < inputs.size(); ++i) { + srcs_md.push_back(inputs[i]->mem_desc()); + } + + input_dims[stack_axis] *= inputs.size(); + dst_md = dnnl::memory::desc(input_dims, dt, OneDNNMemoryFormat::any); + } else { + auto extended_input_dims = vectorize(output->dims()); + extended_input_dims[stack_axis] = 1; + + for (size_t i = 0; i < inputs.size(); ++i) { + srcs_md.push_back(inputs[i]->mem_desc().reshape(extended_input_dims)); + } + + // concat primitive choses suboptimal format tag because it cannot + // distinguish between f.e. abcd and abdc if last dim is equal to 1 so + // enforcing is needed for better performance + dst_fmt = GetPlainOneDNNFormat(extended_input_dims.size()); + dst_md = dnnl::memory::desc(vectorize(output->dims()), dt, dst_fmt); + } + + this->AcquireForwardPrimitiveDescriptor(dst_md, stack_axis, srcs_md); + } + + // concat oneDNN prim is not having .desc attribute so we cannot use default + // AcquireForwardPrimitiveDescriptor + void AcquireForwardPrimitiveDescriptor( + const memory::desc& dst_md, + const int stack_axis, + const std::vector& srcs_md) { + this->fwd_pd_.reset(new dnnl::concat::primitive_desc( + dst_md, stack_axis, srcs_md, this->engine_)); + } + + std::shared_ptr AcquireSrcMemory(const DenseTensor& input, + int i) { + const T* input_data = input.data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), + to_void_cast(input_data)); + } +}; +} // namespace funcs + +template +void StackKernel(const Context& dev_ctx, + const std::vector& multi_input, + int axis, + DenseTensor* output) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + funcs::StackOneDNNHandler handler( + dev_ctx.GetPlace(), axis, onednn_engine, multi_input, output); + + std::vector> srcs; + srcs.reserve(multi_input.size()); + + auto dst_mem = handler.AcquireDstMemory(output); + auto concat_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + std::unordered_map args; + for (size_t i = 0; i < multi_input.size(); ++i) { + srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i)); + args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs.at(i))}); + } + args.insert({DNNL_ARG_DST, *dst_mem}); + + concat_p->execute(astream, args); + astream.wait(); + + output->set_mem_desc(dst_mem->get_desc().reshape(vectorize(output->dims()))); +} + +} // namespace phi + +PD_REGISTER_KERNEL(stack, OneDNN, ALL_LAYOUT, phi::StackKernel, float) {} From 63c3c965455fa8112175291b5a96467feef5ccc3 Mon Sep 17 00:00:00 2001 From: Piotr Paturej Date: Tue, 20 Sep 2022 09:17:31 +0200 Subject: [PATCH 2/3] Change mutable_data to Alloc --- paddle/phi/kernels/onednn/sgd_kernel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/onednn/sgd_kernel.cc b/paddle/phi/kernels/onednn/sgd_kernel.cc index 4750da425f40f..bbb02204105d1 100644 --- a/paddle/phi/kernels/onednn/sgd_kernel.cc +++ b/paddle/phi/kernels/onednn/sgd_kernel.cc @@ -29,7 +29,7 @@ void SGDDenseKernel(const Context& dev_ctx, bool multi_precision, DenseTensor* param_out, DenseTensor* master_param_out) { - auto* out_data = param_out->mutable_data(dev_ctx.GetPlace()); + auto* out_data = dev_ctx.template Alloc(param_out); const T* param_data = param.data(); const auto* grad_data = grad.data(); const auto* lr = learning_rate.data(); From b34bf6de068a252c0530a509d3b8585be17b3057 Mon Sep 17 00:00:00 2001 From: Piotr Paturej Date: Wed, 21 Sep 2022 15:32:24 +0200 Subject: [PATCH 3/3] Refactor licences --- paddle/phi/backends/onednn/axpy_handler.cc | 28 +++++++++++----------- paddle/phi/backends/onednn/axpy_handler.h | 25 +++++++++---------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/paddle/phi/backends/onednn/axpy_handler.cc b/paddle/phi/backends/onednn/axpy_handler.cc index 7304815909d15..df61948d62215 100644 --- a/paddle/phi/backends/onednn/axpy_handler.cc +++ b/paddle/phi/backends/onednn/axpy_handler.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/phi/backends/onednn/axpy_handler.h" @@ -127,7 +127,7 @@ void OneDNNAXPYHandler::operator()(const T *x, T *y) { } template class OneDNNAXPYHandler; -template class OneDNNAXPYHandler; +template class OneDNNAXPYHandler; } // namespace funcs } // namespace phi diff --git a/paddle/phi/backends/onednn/axpy_handler.h b/paddle/phi/backends/onednn/axpy_handler.h index 81c47689de92f..dd9a8108f59b0 100644 --- a/paddle/phi/backends/onednn/axpy_handler.h +++ b/paddle/phi/backends/onednn/axpy_handler.h @@ -1,16 +1,17 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ #pragma once #include