From 0da0bdfc15bf23d604cb3a3cea96c2465bb48728 Mon Sep 17 00:00:00 2001 From: co63oc Date: Wed, 5 Jun 2024 16:36:03 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=206th=20Fundable=20Projects?= =?UTF-8?q?=203=20No.170=E3=80=91fusion=5Flstm=20(#64871)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix * Fix * Fix * Fix * Fix * Fix --- paddle/fluid/operators/fused/CMakeLists.txt | 3 - .../fluid/operators/fused/fusion_lstm_op.cc | 575 ------------------ paddle/fluid/operators/fused/fusion_lstm_op.h | 38 -- .../fused/onednn/fusion_lstm_onednn_op.cc | 476 --------------- .../pir/dialect/op_generator/ops_api_gen.py | 1 + .../fluid/pir/dialect/operator/utils/utils.cc | 1 - .../kernels/fusion/cpu/fusion_lstm_kernel.cc | 443 ++++++++++++++ .../fusion/onednn/fusion_lstm_kernel.cc | 573 +++++++++++++++++ .../fusion}/onednn/fusion_rnn_onednn.h | 47 +- paddle/phi/ops/yaml/fused_ops.yaml | 11 + .../ops/yaml/inconsistent/onednn_static.yaml | 11 - 11 files changed, 1051 insertions(+), 1128 deletions(-) delete mode 100644 paddle/fluid/operators/fused/fusion_lstm_op.cc delete mode 100644 paddle/fluid/operators/fused/fusion_lstm_op.h delete mode 100644 paddle/fluid/operators/fused/onednn/fusion_lstm_onednn_op.cc create mode 100644 paddle/phi/kernels/fusion/cpu/fusion_lstm_kernel.cc create mode 100644 paddle/phi/kernels/fusion/onednn/fusion_lstm_kernel.cc rename paddle/{fluid/operators/fused => phi/kernels/fusion}/onednn/fusion_rnn_onednn.h (87%) diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 8489b9b6c0e28..c15c161a60999 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -6,7 +6,6 @@ endif() register_operators( EXCLUDES fused_bn_activation_op - fusion_lstm_op fused_bn_add_activation_op fused_attention_op fused_transformer_op @@ -18,8 +17,6 @@ register_operators( fused_gate_attention_op resnet_basic_block_op) -op_library(fusion_lstm_op) - if(WITH_XPU) op_library(resnet_basic_block_op) op_library(resnet_unit_op) diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc deleted file mode 100644 index e84c0758402b0..0000000000000 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ /dev/null @@ -1,575 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_lstm_op.h" - -#include - -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/fc_functor.h" -#include "paddle/phi/kernels/funcs/jit/kernels.h" -#include "paddle/phi/kernels/funcs/sequence2batch.h" - -namespace paddle::operators { - -void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_lstm"); - OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_lstm"); - OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_lstm"); - OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "fusion_lstm"); - OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_lstm"); - OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_lstm"); - OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "fusion_lstm"); - - auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(x_dims.size(), - 2, - phi::errors::InvalidArgument( - "Input(X)'s rank must be 2, but received x's rank " - "is:%d, x dim is:[%s]", - x_dims.size(), - x_dims)); - - if (ctx->HasInput("H0")) { - OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "fusion_lstm"); - auto h_dims = ctx->GetInputDim("H0"); - auto c_dims = ctx->GetInputDim("C0"); - PADDLE_ENFORCE_EQ(h_dims, - c_dims, - phi::errors::InvalidArgument( - "The dimension of Input(H0) and Input(C0) should be " - "same, but received h0 dims is:[%s], c0 dims is:[%s]", - h_dims, - c_dims)); - } - - auto wx_dims = ctx->GetInputDim("WeightX"); - PADDLE_ENFORCE_EQ(wx_dims.size(), - 2, - phi::errors::InvalidArgument( - "The rank of Input(WeightX) should be 2, but received " - "WeightX's rank is:%d, WeightX dim is:[%s]", - wx_dims.size(), - wx_dims)); - PADDLE_ENFORCE_EQ(wx_dims[0], - x_dims[1], - phi::errors::InvalidArgument( - "The first dimension of Input(WeightX) " - "should equal to second dimension of Input(X), but " - "received WeightX first dim is:%d, X second dim is:%d", - wx_dims[0], - x_dims[1])); - - int frame_size = static_cast(wx_dims[1] / 4); - auto wh_dims = ctx->GetInputDim("WeightH"); - - PADDLE_ENFORCE_EQ(wh_dims.size(), - 2, - phi::errors::InvalidArgument( - "The rank of Input(WeightH) should be 2, but received " - "WeightH rank is:%d, WeightH dim is:[%s]", - wh_dims.size(), - wh_dims)); - PADDLE_ENFORCE_EQ(wh_dims[0], - frame_size, - phi::errors::InvalidArgument( - "The first dimension of Input(WeightH) " - "should equal to frame size, but received WeightH " - "first dim is:%d, frame size is:%d.", - wh_dims[0], - frame_size)); - - PADDLE_ENFORCE_EQ(wh_dims[1], - 4 * frame_size, - phi::errors::InvalidArgument( - "The second dimension of Input(WeightH) " - "should equal to 4 * frame_size, but received WeightH " - "second dimension is:%d, frame size is:%d.", - wh_dims[1], - frame_size)); - - auto b_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(b_dims.size(), - 2, - phi::errors::InvalidArgument( - "The rank of Input(Bias) should be 2, but received " - "Bias rank is:%d, Bias dim is:[%s]", - b_dims.size(), - b_dims)); - PADDLE_ENFORCE_EQ(b_dims[0], - 1, - phi::errors::InvalidArgument( - "The first dimension of Input(Bias) should be 1, but " - "received Bias's dimension is:[%s]", - b_dims)); - - if (ctx->Attrs().Get("use_peepholes")) { - PADDLE_ENFORCE_EQ(b_dims[1], - 7 * frame_size, - phi::errors::InvalidArgument( - "The second dimension of Input(Bias) should be " - "7 * %d if enable peepholes connection, but received " - "Bias dim is:[%s]", - frame_size, - b_dims)); - ctx->SetOutputDim("CheckedCell", {2, frame_size}); - } else { - PADDLE_ENFORCE_EQ( - b_dims[1], - 4 * frame_size, - phi::errors::InvalidArgument( - "The second dimension of Input(Bias) should be " - "4 * %d if disable peepholes, but received Bias dim is:[%s]", - frame_size, - b_dims)); - } - - phi::DDim out_dims({x_dims[0], frame_size}); - ctx->SetOutputDim("Hidden", out_dims); - ctx->SetOutputDim("Cell", out_dims); - ctx->ShareLoD("X", "Hidden"); - ctx->ShareLoD("X", "Cell"); - int xx_width = 0; - if (ctx->Attrs().Get("use_seq")) { - xx_width = static_cast(wx_dims[1]); - } else { - xx_width = - static_cast(x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]); - - OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), - "Output", - "BatchedInput", - "fusion_lstm"); - OP_INOUT_CHECK(ctx->HasOutput("BatchedHidden"), - "Output", - "BatchedHidden", - "fusion_lstm"); - OP_INOUT_CHECK( - ctx->HasOutput("BatchedCell"), "Output", "BatchedCell", "fusion_lstm"); - OP_INOUT_CHECK( - ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0", "fusion_lstm"); - OP_INOUT_CHECK( - ctx->HasOutput("ReorderedC0"), "Output", "ReorderedC0", "fusion_lstm"); - - ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchedHidden", out_dims); - ctx->SetOutputDim("BatchedCell", out_dims); - } - ctx->SetOutputDim("XX", {x_dims[0], xx_width}); - ctx->ShareLoD("X", "XX"); -} - -phi::KernelKey FusionLSTMOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(data_type, ctx.GetPlace()); -} - -void FusionLSTMOpMaker::Make() { - AddInput( - "X", - "(phi::DenseTensor) the input is a LodTensor, which support " - "variable-time length input sequence. The underlying tensor in " - "this phi::DenseTensor is a matrix with shape (T X M), where T is the " - "total time steps in this mini-batch, M is the dim size of x."); - AddInput("WeightX", - "(phi::DenseTensor) the learnable weights of X." - " - The shape is (M x 4D), where M is the dim size of x, D is the " - "hidden size. " - " - Weight = {W_cx, W_ix, W_fx, W_ox}"); - AddInput( - "WeightH", - "(phi::DenseTensor) same as LSTMOp, the learnable hidden-hidden weights." - " - The shape is (D x 4D), where D is the hidden size. " - " - Weight = {W_ch, W_ih, W_fh, W_oh}"); - AddInput("Bias", - "(phi::DenseTensor) the learnable weights. Almost same as LSTMOp" - "Note: we should add the fc bias into this (1x4D) in bias." - "input-hidden bias weight and peephole connections weight if " - "setting `use_peepholes` True. " - "1. `use_peepholes = False` " - " - The shape is (1 x 4D). " - " - Bias = {b_c, b_i, b_f, b_o}." - "2. `use_peepholes = True` " - " - The shape is (1 x 7D). " - " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); - AddInput("H0", - "(phi::DenseTensor, optional) (same as LSTMOp) the initial hidden " - "state is an " - "optional " - "input. This is a tensor with shape (N x D), where N is the " - "batch size and D is the hidden size.") - .AsDispensable(); - AddInput("C0", - "(phi::DenseTensor, optional) (same as LSTMOp) (the initial cell " - "state is an " - "optional " - "input. This is a tensor with shape (N x D), where N is the " - "batch size. `H0` and `C0` can be NULL but only at the same time.") - .AsDispensable(); - AddOutput( - "Hidden", - "(phi::DenseTensor) (same as LSTMOp) the hidden state of LSTM operator. " - "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput( - "Cell", - "(phi::DenseTensor) (same as LSTMOp) the cell state of LSTM operator. " - "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput("XX", - "(phi::DenseTensor) the result after X * WeightX (size is T x 4D)" - " or batched_X (size is T x M), this will be automatically chosen," - " where T is the total time steps in this mini-batch," - " D is the hidden size, M is the dim size of x input.") - .AsIntermediate(); - AddOutput("BatchedInput", "(phi::DenseTensor) (T x 4D).").AsIntermediate(); - AddOutput("BatchedHidden", "(phi::DenseTensor) (T x D).").AsIntermediate(); - AddOutput("BatchedCell", "(phi::DenseTensor) (T x D).").AsIntermediate(); - AddOutput("ReorderedH0", "(phi::DenseTensor) (N x D).").AsIntermediate(); - AddOutput("ReorderedC0", "(phi::DenseTensor) (N x D).").AsIntermediate(); - AddOutput("CheckedCell", "(phi::DenseTensor) (2 x D) only for peephole.") - .AsIntermediate(); - AddAttr("use_peepholes", - "(bool, default: True) " - "whether to enable diagonal/peephole connections.") - .SetDefault(true); - AddAttr("is_reverse", - "(bool, default: False) " - "whether to compute reversed LSTM.") - .SetDefault(false); - AddAttr("use_seq", - "(bool, default: True) " - "whether to use seq mode to compute.") - .SetDefault(true); - AddAttr("gate_activation", - "(string, default: sigmoid)" - "The activation for input gate, forget gate and output " - "gate, `sigmoid` by default.") - .SetDefault("sigmoid") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddAttr("cell_activation", - "(string, default: tanh)" - "The activation for cell output, `tanh` by default.") - .SetDefault("tanh") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddAttr("candidate_activation", - "(string, default: tanh)" - "The activation for candidate hidden state, " - "`tanh` by default.") - .SetDefault("tanh") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddAttr("Scale_data", - "Scale to be used for int8 input/output data." - "Only used with MKL-DNN INT8.") - .SetDefault(1.0f); - AddAttr("Shift_data", - "Shift to be used for int8 input/output data." - "Only used with MKL-DNN INT8.") - .SetDefault(0.0f); - AddAttr>("Scale_weights", - "Scale_weights to be used for int8 weights data." - "Only used with MKL-DNN INT8.") - .SetDefault({1.0f}); - AddAttr("force_fp32_output", - "(bool, default false) Force INT8 kernel output FP32, only " - "used in MKL-DNN INT8") - .SetDefault(false); - AddComment(R"DOC( -Fusion Long-Short Term Memory (LSTM) Operator. -This operator fuse the X into LSTM, more details can refer to LSTM op. -)DOC"); -} - -template -class FusionLSTMKernel : public framework::OpKernel { - public: -#define INIT_BASE_DEFINES \ - auto* x = ctx.Input("X"); \ - auto* h0 = ctx.Input("H0"); \ - auto* c0 = ctx.Input("C0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - auto* cell_out = ctx.Output("Cell"); \ - bool is_reverse = ctx.Attr("is_reverse"); \ - bool use_peepholes = ctx.Attr("use_peepholes"); \ - auto x_dims = x->dims(); /* T x M*/ \ - auto wh_dims = wh->dims(); /* D x 4D*/ \ - const int M = x_dims[1]; \ - const int D = wh_dims[0]; \ - const int D4 = wh_dims[1] - -#define INIT_OTHER_DEFINES \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - /* diagonal weight*/ \ - const T* wp_data = bias->data() + D4; \ - /* for peephole only*/ \ - T* checked_cell_data = nullptr; \ - auto place = ctx.GetPlace(); \ - if (use_peepholes) { \ - /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ - auto* checked_cell = ctx.Output("CheckedCell"); \ - checked_cell_data = checked_cell->mutable_data(place); \ - } \ - const phi::jit::lstm_attr_t attr( \ - D, \ - phi::jit::to_kerneltype(ctx.Attr("gate_activation")), \ - phi::jit::to_kerneltype(ctx.Attr("candidate_activation")), \ - phi::jit::to_kerneltype(ctx.Attr("cell_activation")), \ - use_peepholes); \ - phi::jit::lstm_t one_step; \ - one_step.wp = wp_data; \ - one_step.checked = checked_cell_data; \ - auto ComputeC1H1 = phi::jit::KernelFuncs, \ - phi::CPUPlace>::Cache() \ - .At(attr); \ - auto ComputeCtHt = phi::jit::KernelFuncs, \ - phi::CPUPlace>::Cache() \ - .At(attr) - -// Wh GEMM -#define GEMM_WH_ADDON(bs, prev, out) \ - blas.GEMM(CblasNoTrans, \ - CblasNoTrans, \ - bs, \ - D4, \ - D, \ - static_cast(1), \ - prev, \ - D, \ - wh_data, \ - D4, \ - static_cast(1), \ - out, \ - D4) - - void SeqCompute(const framework::ExecutionContext& ctx) const { - INIT_BASE_DEFINES; - INIT_OTHER_DEFINES; - auto x_lod = x->lod(); - const int total_T = static_cast(x_dims[0]); - const int N = static_cast(x_lod[0].size() - 1); - const T* h0_data = h0 ? h0->data() : nullptr; - const T* c0_data = c0 ? c0->data() : nullptr; - T* xx_data = xx->mutable_data(place); - T* h_out_data = hidden_out->mutable_data(place); - T* c_out_data = cell_out->mutable_data(place); - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - - phi::funcs::FCFunctor fc; - fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); - - int xx_offset = D4; - int gate_offset = D; - if (is_reverse) { - const int offset = (total_T - 1) * D; - xx_data = xx_data + offset * 4; - h_out_data = h_out_data + offset; - c_out_data = c_out_data + offset; - xx_offset = -D4; - gate_offset = -D; - } - - for (int i = 0; i < N; ++i) { - int bid = is_reverse ? N - 1 - i : i; - int seq_len = static_cast(x_lod[0][bid + 1] - x_lod[0][bid]); - const T* prev_c_data = nullptr; - const T* prev_h_data = nullptr; - int tstart = 0; - if (h0_data) { - prev_h_data = h0_data + bid * D; - prev_c_data = c0_data + bid * D; - } else { - one_step.gates = xx_data; - one_step.ct = c_out_data; - one_step.ht = h_out_data; - ComputeC1H1(&one_step, &attr); - tstart = 1; - // move one step - prev_h_data = h_out_data; - prev_c_data = c_out_data; - xx_data = xx_data + xx_offset; - h_out_data = h_out_data + gate_offset; - c_out_data = c_out_data + gate_offset; - } - for (int step = tstart; step < seq_len; ++step) { - GEMM_WH_ADDON(1, prev_h_data, xx_data); - - one_step.gates = xx_data; - one_step.ct_1 = prev_c_data; - one_step.ct = c_out_data; - one_step.ht = h_out_data; - ComputeCtHt(&one_step, &attr); - // move one step - prev_h_data = h_out_data; - prev_c_data = c_out_data; - xx_data = xx_data + xx_offset; - h_out_data = h_out_data + gate_offset; - c_out_data = c_out_data + gate_offset; - } - } - } - - void BatchCompute(const framework::ExecutionContext& ctx) const { - INIT_BASE_DEFINES; - if (x->lod()[0].size() == 2) { - xx->Resize({x_dims[0], D4}); - SeqCompute(ctx); - return; - } - INIT_OTHER_DEFINES; - - auto* reordered_h0 = ctx.Output("ReorderedH0"); - auto* reordered_c0 = ctx.Output("ReorderedC0"); - auto* batched_input = ctx.Output("BatchedInput"); - auto* batched_c_out = ctx.Output("BatchedCell"); - auto* batched_h_out = ctx.Output("BatchedHidden"); - T* xx_data = xx->mutable_data(place); - T* batched_input_data = batched_input->mutable_data(place); - T* batched_c_out_data = batched_c_out->mutable_data(place); - T* batched_h_out_data = batched_h_out->mutable_data(place); - hidden_out->mutable_data(place); - cell_out->mutable_data(place); - - phi::funcs::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - phi::funcs::FCFunctor fc; - if (M > D4) { - fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data()); - to_batch(dev_ctx, *xx, batched_input, true, is_reverse); - } else { - to_batch(dev_ctx, *x, xx, true, is_reverse); - batched_input->set_lod(xx->lod()); - fc(dev_ctx, - x_dims[0], - D4, - M, - xx_data, - wx_data, - batched_input_data, - bias->data()); - } - - auto batched_lod = batched_input->lod(); - const auto& seq_order = batched_lod[2]; - const int max_bs = static_cast(seq_order.size()); - reordered_h0->Resize({max_bs, D}); - reordered_c0->Resize({max_bs, D}); - - int tstart = 0; - T* prev_h_data = nullptr; - T* prev_c_data = nullptr; - if (h0) { - // reorder h0, c0 - T* reordered_h0_data = reordered_h0->mutable_data(place); - T* reordered_c0_data = reordered_c0->mutable_data(place); - const T* h0_data = h0->data(); - const T* c0_data = c0->data(); - prev_h_data = reordered_h0_data; - prev_c_data = reordered_c0_data; - size_t sz = D; - for (int i = 0; i < max_bs; ++i) { - blas.VCOPY(sz, h0_data + seq_order[i] * D, reordered_h0_data); - blas.VCOPY(sz, c0_data + seq_order[i] * D, reordered_c0_data); - reordered_h0_data += D; - reordered_c0_data += D; - } - } else { - // compute without h0, c0 - T* cur_in_data = batched_input_data; - T* cur_h_out_data = batched_h_out_data; - T* cur_c_out_data = batched_c_out_data; - for (int i = 0; i < max_bs; ++i) { - one_step.gates = cur_in_data; - one_step.ct = cur_c_out_data; - one_step.ht = cur_h_out_data; - ComputeC1H1(&one_step, &attr); - - cur_in_data += D4; - cur_c_out_data += D; - cur_h_out_data += D; - } - tstart = 1; - prev_h_data = batched_h_out_data; - prev_c_data = batched_c_out_data; - } - - // compute kernel part - const auto& batch_starts = batched_lod[0]; - const int max_seq_len = static_cast(batch_starts.size() - 1); - const int offset = tstart * max_bs * D; - batched_input_data = batched_input_data + offset * 4; - batched_h_out_data = batched_h_out_data + offset; - batched_c_out_data = batched_c_out_data + offset; - for (int step = tstart; step < max_seq_len; ++step) { - const int cur_bs = - static_cast(batch_starts[step + 1] - batch_starts[step]); - GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); - T* cur_in_data = batched_input_data; - T* cur_prev_c_data = prev_c_data; - T* cur_c_out_data = batched_c_out_data; - T* cur_h_out_data = batched_h_out_data; - for (int i = 0; i < cur_bs; ++i) { - one_step.gates = cur_in_data; - one_step.ct_1 = cur_prev_c_data; - one_step.ct = cur_c_out_data; - one_step.ht = cur_h_out_data; - ComputeCtHt(&one_step, &attr); - - // move one batch - cur_in_data += D4; - cur_prev_c_data += D; - cur_c_out_data += D; - cur_h_out_data += D; - } - // move one step - prev_c_data = batched_c_out_data; - prev_h_data = batched_h_out_data; - batched_c_out_data = cur_c_out_data; - batched_h_out_data = cur_h_out_data; - batched_input_data = cur_in_data; - } - - phi::funcs::Batch2LoDTensorFunctor to_seq; - batched_h_out->set_lod(batched_lod); - to_seq(dev_ctx, *batched_h_out, hidden_out); - batched_c_out->set_lod(batched_lod); - to_seq(dev_ctx, *batched_c_out, cell_out); - } - - void Compute(const framework::ExecutionContext& ctx) const override { - if (ctx.Attr("use_seq")) { - SeqCompute(ctx); - } else { - BatchCompute(ctx); - } - } - -#undef GEMM_WH_ADDON -#undef INIT_OTHER_DEFINES -#undef INIT_BASE_DEFINES -}; - -} // namespace paddle::operators - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker); - -PD_REGISTER_STRUCT_KERNEL( - fusion_lstm, CPU, ALL_LAYOUT, ops::FusionLSTMKernel, float, double) {} diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.h b/paddle/fluid/operators/fused/fusion_lstm_op.h deleted file mode 100644 index c62060d7c225c..0000000000000 --- a/paddle/fluid/operators/fused/fusion_lstm_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class FusionLSTMOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - -class FusionLSTMOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/onednn/fusion_lstm_onednn_op.cc b/paddle/fluid/operators/fused/onednn/fusion_lstm_onednn_op.cc deleted file mode 100644 index 05c517fd9ac09..0000000000000 --- a/paddle/fluid/operators/fused/onednn/fusion_lstm_onednn_op.cc +++ /dev/null @@ -1,476 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/operators/fused/fusion_lstm_op.h" -#include "paddle/fluid/operators/fused/onednn/fusion_rnn_onednn.h" -#include "paddle/phi/core/expect.h" - -namespace paddle { -namespace operators { - -using phi::OneDNNContext; -using phi::funcs::OneDNNGetDataType; -using phi::funcs::OneDNNMemDesc; -using phi::funcs::RNNReorderType; -using OneDNNMemoryFormat = dnnl::memory::format_tag; - -template -class LSTMMKLDNNHandler - : public RNNMKLDNNHandler { - public: - LSTMMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, - const OneDNNContext& dev_ctx, - const dnnl::engine onednn_engine, - phi::Place cpu_place UNUSED, - const phi::DenseTensor* input, - const phi::DenseTensor* weight_h, - const phi::DenseTensor* h0, - const phi::DenseTensor* c0 UNUSED, - const bool is_reverse, - const int64_t N, - const int64_t Ti, - const int64_t IC, - const int64_t OC, - const std::string& unique_name UNUSED) - : RNNMKLDNNHandler( - ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - input, - weight_h, - h0, - is_reverse, - N, - Ti, - IC, - OC, - 4, - ctx.InputName("X") + ctx.InputName("WeightH")) { - if (unlikely(!this->isCached())) { - const bool is_INT8 = std::is_same::value; - const bool use_peepholes = ctx.Attr("use_peepholes"); - // oneDNN kernel has hardcoded activation functions - PADDLE_ENFORCE_EQ( - ctx.Attr("gate_activation"), - "sigmoid", - phi::errors::Unimplemented("oneDNN fusion_lstm supports only " - "sigmoid as a gate activation.")); - PADDLE_ENFORCE_EQ( - ctx.Attr("cell_activation"), - "tanh", - phi::errors::Unimplemented( - "oneDNN fusion_lstm supports only tanh as a cell activation.")); - PADDLE_ENFORCE_EQ( - ctx.Attr("candidate_activation"), - "tanh", - phi::errors::Unimplemented( - "oneDNN fusion_lstm supports only tanh a candidate activation.")); - - // Weights for int8 kernel are of a type s8 - const auto weights_dt = - is_INT8 ? dnnl::memory::data_type::s8 : OneDNNGetDataType(); - - // oneDNN RNN dimensions - const int64_t D = 1; // Directions - const int64_t L = 1; // Layers (PP supports only 1 stacked layer) - const int64_t G = 4; // Number of Gates, 4 for LSTM - - // Create memory descriptors - auto input_md = OneDNNMemDesc( - {Ti, N, IC}, OneDNNGetDataType(), OneDNNMemoryFormat::tnc); - auto weight_x_md = - OneDNNMemDesc({L, D, IC, G, OC}, weights_dt, OneDNNMemoryFormat::any); - auto weight_h_md = - OneDNNMemDesc({L, D, OC, G, OC}, weights_dt, OneDNNMemoryFormat::any); - auto bias_md = OneDNNMemDesc( - {L, D, G, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldgo); - auto hidden_md = OneDNNMemDesc( - {Ti, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::any); - - auto h0_md = OneDNNMemDesc( - {L, D, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::any); - auto c0_md = OneDNNMemDesc( - {L, D, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::any); - - // Create LSTM oneDNN primitive - const auto direction = - is_reverse ? dnnl::rnn_direction::unidirectional_right2left - : dnnl::rnn_direction::unidirectional_left2right; - if (!use_peepholes) { - this->AcquireForwardPrimitiveDescriptor( - this->attr_, - dnnl::prop_kind::forward_inference, - direction, - input_md, - h0_md, - c0_md, - weight_x_md, - weight_h_md, - bias_md, - hidden_md, - dnnl::memory::desc(), - dnnl::memory::desc()); - } else { - auto weight_peephole_md = OneDNNMemDesc({L, D, 3, OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldgo); - this->AcquireForwardPrimitiveDescriptor( - this->attr_, - dnnl::prop_kind::forward_inference, - direction, - input_md, - h0_md, - c0_md, - weight_x_md, - weight_h_md, - weight_peephole_md, - bias_md, - hidden_md, - dnnl::memory::desc(), - dnnl::memory::desc()); - } - } - } - - // PaddlePaddle has different order of weights than oneDNN, so a reorder is - // needed - // PaddlePaddle: {c, i, f, o} - // oneDNN: {i, f, c, o} - template - void ReorderGates(U* weights, int64_t I) { - size_t inner_block_size = this->OC; - size_t block_size = inner_block_size * this->G; - for (size_t i = 0; i < (size_t)I; ++i) { // NOLINT - size_t offset = i * block_size; - - U* base_pos = weights + offset; - std::swap_ranges(base_pos, - base_pos + inner_block_size, - base_pos + inner_block_size); // c <-> i - std::swap_ranges(base_pos + inner_block_size, - base_pos + 2 * inner_block_size, - base_pos + 2 * inner_block_size); // c <-> f - } - } - - template - std::shared_ptr AcquireWeightXMemory( - const phi::DenseTensor* weight_x) { - const std::string wx_key = this->memory_key_ + "@weight_x"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); - - if (!memory_p) { - auto user_md = OneDNNMemDesc({1, 1, this->IC, this->G, this->OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldigo); - auto user_memory = dnnl::memory(user_md, this->engine_); - - auto* weight_x_data = reinterpret_cast(user_memory.get_data_handle()); - memcpy(weight_x_data, - weight_x->data(), - sizeof(U) * this->IC * this->G * this->OC); - - ReorderGates(weight_x_data, this->IC); - - memory_p = std::make_shared( - this->fwd_pd_->weights_layer_desc(), this->engine_); - - auto& astream = OneDNNContext::tls().get_stream(); - dnnl::reorder(user_memory, *memory_p, this->attr_) - .execute(astream, user_memory, *memory_p); - - this->dev_ctx_.SetBlob(wx_key, memory_p); - } - return memory_p; - } - - template - std::shared_ptr AcquireWeightHMemory( - const phi::DenseTensor* weight_h) { - const std::string wh_key = this->memory_key_ + "@weight_h"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); - - if (!memory_p) { - auto user_md = OneDNNMemDesc({1, 1, this->OC, this->G, this->OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldigo); - auto user_memory = dnnl::memory(user_md, this->engine_); - - auto* weight_h_data = reinterpret_cast(user_memory.get_data_handle()); - memcpy(weight_h_data, - weight_h->data(), - sizeof(U) * this->OC * this->G * this->OC); - - ReorderGates(weight_h_data, this->OC); - - memory_p = std::make_shared( - this->fwd_pd_->weights_iter_desc(), this->engine_); - - auto& astream = OneDNNContext::tls().get_stream(); - dnnl::reorder(user_memory, *memory_p, this->attr_) - .execute(astream, user_memory, *memory_p); - - this->dev_ctx_.SetBlob(wh_key, memory_p); - } - return memory_p; - } - - std::shared_ptr AcquireBiasMemory( - const phi::DenseTensor* bias) { - const std::string bias_key = this->memory_key_ + "@bias"; - auto memory_p = std::static_pointer_cast( - this->dev_ctx_.GetBlob(bias_key)); - - if (!memory_p) { - memory_p = std::make_shared(this->fwd_pd_->bias_desc(), - this->engine_); - auto* bias_data = reinterpret_cast(memory_p->get_data_handle()); - if (bias) { - const float* user_bias_data = - bias->data(); // Bias in oneDNN is always float - - memcpy(bias_data, user_bias_data, sizeof(float) * this->G * this->OC); - - ReorderGates(bias_data, 1); - } else { - // oneDNN always need bias memory, if it's not provided in PP, let - // oneDNN allocate memory and set it to 0 - memset(bias_data, 0, sizeof(float) * this->G * this->OC); - } - - this->dev_ctx_.SetBlob(bias_key, memory_p); - } - return memory_p; - } - - std::shared_ptr AcquirePeepholeWeights( - const phi::DenseTensor* bias) { - const std::string peepholes_key = this->memory_key_ + "@peepholes_weights"; - auto memory_p = std::static_pointer_cast( - this->dev_ctx_.GetBlob(peepholes_key)); - - if (!memory_p) { - auto user_md = OneDNNMemDesc({1, 1, 3, this->OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldgo); - auto user_memory = dnnl::memory(user_md, this->engine_); - memory_p = std::make_shared( - this->fwd_pd_->weights_peephole_desc(), this->engine_); - auto* peephole_weights_data = - reinterpret_cast(memory_p->get_data_handle()); - - const float* user_bias_data = - bias->data(); // Bias in oneDNN is always float - memcpy(peephole_weights_data, - user_bias_data + 4 * this->OC, - sizeof(float) * 3 * this->OC); - - this->dev_ctx_.SetBlob(peepholes_key, memory_p); - } - return memory_p; - } - - std::shared_ptr AcquireC0Memory(const phi::DenseTensor* c0) { - const std::string c0_key = this->memory_key_ + "@c0"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(c0_key)); - - if (!memory_p) { - auto user_c0_memory = dnnl::memory(); - if (c0) { - user_c0_memory = - dnnl::memory({{1, 1, this->N, this->OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldnc}, - this->engine_, - phi::funcs::to_void_cast(c0->data())); - } else { - user_c0_memory = dnnl::memory({{1, 1, this->N, this->OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldnc}, - this->engine_); - memset(user_c0_memory.get_data_handle(), - 0, - sizeof(float) * this->N * this->OC); - } - memory_p = std::make_shared( - this->fwd_pd_->src_iter_c_desc(), this->engine_); - - auto& astream = OneDNNContext::tls().get_stream(); - dnnl::reorder(user_c0_memory, *memory_p) - .execute(astream, user_c0_memory, *memory_p); - - this->dev_ctx_.SetBlob(c0_key, memory_p); - } - return memory_p; - } -}; - -template -class FusionLSTMMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const bool is_bf16 = std::is_same::value; - const bool force_fp32_output = ctx.Attr("force_fp32_output"); - - // BF16 does not support force output - if (!is_bf16 && force_fp32_output) { // NOLINT - RunKernel(ctx); - } else { - RunKernel(ctx); - } - } - - template - void RunKernel(const framework::ExecutionContext& ctx) const { - auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - // Get Tensors - const auto* input = ctx.Input("X"); - const auto* h0 = ctx.Input("H0"); - const auto* c0 = ctx.Input("C0"); - const auto* weight_x = ctx.Input("WeightX"); - const auto* weight_h = ctx.Input("WeightH"); - const auto* bias = ctx.Input("Bias"); - auto* hidden = ctx.Output("Hidden"); - auto x_dims = input->dims(); - auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) - ? common::flatten_to_2d(x_dims, 1) - : x_dims; - // Get attributes - const bool is_reverse = ctx.Attr("is_reverse"); - const bool use_peepholes = ctx.Attr("use_peepholes"); - - // Get tensor dimensions - const auto x_mat_dims_vec = common::vectorize(x_mat_dims); - const auto weight_h_dims = common::vectorize(weight_h->dims()); - const auto& input_lod = input->lod()[0]; - - // Calculate RNN dimensions - const int64_t N = input_lod.size() - 1; // Number of sentences (batches) - const int64_t Ti = // Max length of the sentence in a batch - [&input_lod]() { - size_t res = 0; - for (size_t i = 0; i < (input_lod.size() - 1); ++i) { - res = std::max(res, input_lod[i + 1] - input_lod[i]); - } - return res; - }(); - const int64_t IC = x_mat_dims_vec[1]; // Input channels - const int64_t OC = weight_h_dims[0]; // Output channels - - LSTMMKLDNNHandler handler( - ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - input, - weight_h, - h0, - c0, - is_reverse, - N, - Ti, - IC, - OC, - ctx.InputName("X") + ctx.InputName("WeightH")); - - auto input_memory_p = - handler.AcquireInputMemoryWithReorder(input, is_reverse); - auto c0_memory_p = handler.AcquireC0Memory(c0); - - std::shared_ptr h0_memory_p, weight_h_memory_p, - weight_x_memory_p; - - if (weight_h->dtype() == phi::DataType::FLOAT32) { - h0_memory_p = handler.template AcquireH0Memory(h0); - weight_x_memory_p = - handler.template AcquireWeightXMemory(weight_x); - weight_h_memory_p = - handler.template AcquireWeightHMemory(weight_h); - } else if (weight_h->dtype() == phi::DataType::BFLOAT16) { - h0_memory_p = handler.template AcquireH0Memory(h0); - weight_x_memory_p = - handler.template AcquireWeightXMemory(weight_x); - weight_h_memory_p = - handler.template AcquireWeightHMemory(weight_h); - } else { - h0_memory_p = handler.template AcquireH0Memory(h0); - weight_x_memory_p = - handler.template AcquireWeightXMemory(weight_x); - weight_h_memory_p = - handler.template AcquireWeightHMemory(weight_h); - } - - auto bias_memory_p = handler.AcquireBiasMemory(bias); - auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); - - std::unordered_map lstm_args = { - {DNNL_ARG_SRC_LAYER, *input_memory_p}, - {DNNL_ARG_SRC_ITER, *h0_memory_p}, - {DNNL_ARG_SRC_ITER_C, *c0_memory_p}, - {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, - {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, - {DNNL_ARG_BIAS, *bias_memory_p}, - {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; - - if (use_peepholes) { - auto peephole_weight_p = handler.AcquirePeepholeWeights(bias); - std::pair peepholes_weights(DNNL_ARG_WEIGHTS_PEEPHOLE, - *peephole_weight_p); - lstm_args.insert(peepholes_weights); - } - - auto lstm_forward_p = handler.AcquireForwardPrimitive(); - - auto& astream = OneDNNContext::tls().get_stream(); - lstm_forward_p->execute(astream, lstm_args); - astream.wait(); - - auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle(); - auto* hidden_data = - phi::funcs::to_void_cast(hidden->mutable_data(ctx.GetPlace())); - if (handler.is_NTC()) { - handler.reorderRNNdata(hidden_onednn_data, - hidden_data, - input_lod, - is_reverse, - RNNReorderType::NTC_PP); - } else { - handler.reorderRNNdata(hidden_onednn_data, - hidden_data, - input_lod, - is_reverse, - RNNReorderType::TNC_PP); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL(fusion_lstm, - OneDNN, - ONEDNN, - ops::FusionLSTMMKLDNNKernel, - float, - uint8_t, - phi::dtype::bfloat16) {} diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 504a61cc2cbc5..1eb784ed8c0e9 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -191,6 +191,7 @@ 'fused_elementwise_mul', 'fused_elementwise_sub', 'fusion_group', + 'fusion_lstm', 'fusion_seqpool_cvm_concat', 'nce', 'lars_momentum', diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 870a8727289d4..3552cf88a0765 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -83,7 +83,6 @@ const std::unordered_set LegacyOpList = { paddle::onednn::dialect::LrnOp::name(), paddle::onednn::dialect::LrnGradOp::name(), paddle::onednn::dialect::MultiGruOp::name(), - paddle::onednn::dialect::FusionLstmOp::name(), #endif CReduceAvgOp::name(), CReduceAvg_Op::name(), diff --git a/paddle/phi/kernels/fusion/cpu/fusion_lstm_kernel.cc b/paddle/phi/kernels/fusion/cpu/fusion_lstm_kernel.cc new file mode 100644 index 0000000000000..522d7b77b559c --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fusion_lstm_kernel.cc @@ -0,0 +1,443 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" +#include "paddle/phi/kernels/funcs/jit/kernels.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" + +namespace phi { + +#define INIT_BASE_DEFINES \ + auto *x = &x_in; \ + auto *h0 = h0_in.get_ptr(); \ + auto *c0 = c0_in.get_ptr(); \ + auto *wx = &weight_x_in; \ + auto *wh = &weight_h_in; \ + auto *bias = &bias_in; \ + auto *hidden_out = hidden; \ + auto *cell_out = cell; \ + auto x_dims = x->dims(); /* T x M*/ \ + auto wh_dims = wh->dims(); /* D x 4D*/ \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D4 = wh_dims[1] + +#define INIT_OTHER_DEFINES \ + const T *x_data = x->data(); \ + const T *wx_data = wx->data(); \ + const T *wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T *wp_data = bias->data() + D4; \ + /* for peephole only*/ \ + T *checked_cell_data = nullptr; \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + checked_cell_data = dev_ctx.template Alloc(checked_cell); \ + } \ + const phi::jit::lstm_attr_t attr( \ + D, \ + phi::jit::to_kerneltype(gate_activation), \ + phi::jit::to_kerneltype(candidate_activation), \ + phi::jit::to_kerneltype(cell_activation), \ + use_peepholes); \ + phi::jit::lstm_t one_step; \ + one_step.wp = wp_data; \ + one_step.checked = checked_cell_data; \ + auto ComputeC1H1 = phi::jit::KernelFuncs, \ + phi::CPUPlace>::Cache() \ + .At(attr); \ + auto ComputeCtHt = phi::jit::KernelFuncs, \ + phi::CPUPlace>::Cache() \ + .At(attr) + +// Wh GEMM +#define GEMM_WH_ADDON(bs, prev, out) \ + blas.GEMM(CblasNoTrans, \ + CblasNoTrans, \ + bs, \ + D4, \ + D, \ + static_cast(1), \ + prev, \ + D, \ + wh_data, \ + D4, \ + static_cast(1), \ + out, \ + D4) + +template +void SeqCompute(const Context &dev_ctx, + const DenseTensor &x_in, + const DenseTensor &weight_x_in, + const DenseTensor &weight_h_in, + const DenseTensor &bias_in, + const paddle::optional &h0_in, + const paddle::optional &c0_in, + bool use_peepholes, + bool is_reverse, + bool use_seq, + const std::string &gate_activation, + const std::string &cell_activation, + const std::string &candidate_activation, + float scale_data, + float shift_data, + const std::vector &scale_weights, + bool force_fp32_output, + DenseTensor *hidden, + DenseTensor *cell, + DenseTensor *xx, + DenseTensor *batched_input, + DenseTensor *batched_hidden, + DenseTensor *batched_cell, + DenseTensor *reordered_h0, + DenseTensor *reordered_c0, + DenseTensor *checked_cell) { + INIT_BASE_DEFINES; + INIT_OTHER_DEFINES; + auto x_lod = x->lod(); + const int total_T = static_cast(x_dims[0]); + const int N = static_cast(x_lod[0].size() - 1); + const T *h0_data = h0 ? h0->data() : nullptr; + const T *c0_data = c0 ? c0->data() : nullptr; + T *xx_data = dev_ctx.template Alloc(xx); + T *h_out_data = dev_ctx.template Alloc(hidden_out); + T *c_out_data = dev_ctx.template Alloc(cell_out); + auto blas = phi::funcs::GetBlas(dev_ctx); + + phi::funcs::FCFunctor fc; + fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); + + int xx_offset = D4; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 4; + h_out_data = h_out_data + offset; + c_out_data = c_out_data + offset; + xx_offset = -D4; + gate_offset = -D; + } + + for (int i = 0; i < N; ++i) { + int bid = is_reverse ? N - 1 - i : i; + int seq_len = static_cast(x_lod[0][bid + 1] - x_lod[0][bid]); + const T *prev_c_data = nullptr; + const T *prev_h_data = nullptr; + int tstart = 0; + if (h0_data) { + prev_h_data = h0_data + bid * D; + prev_c_data = c0_data + bid * D; + } else { + one_step.gates = xx_data; + one_step.ct = c_out_data; + one_step.ht = h_out_data; + ComputeC1H1(&one_step, &attr); + tstart = 1; + // move one step + prev_h_data = h_out_data; + prev_c_data = c_out_data; + xx_data = xx_data + xx_offset; + h_out_data = h_out_data + gate_offset; + c_out_data = c_out_data + gate_offset; + } + for (int step = tstart; step < seq_len; ++step) { + GEMM_WH_ADDON(1, prev_h_data, xx_data); + + one_step.gates = xx_data; + one_step.ct_1 = prev_c_data; + one_step.ct = c_out_data; + one_step.ht = h_out_data; + ComputeCtHt(&one_step, &attr); + // move one step + prev_h_data = h_out_data; + prev_c_data = c_out_data; + xx_data = xx_data + xx_offset; + h_out_data = h_out_data + gate_offset; + c_out_data = c_out_data + gate_offset; + } + } +} + +template +void BatchCompute(const Context &dev_ctx, + const DenseTensor &x_in, + const DenseTensor &weight_x_in, + const DenseTensor &weight_h_in, + const DenseTensor &bias_in, + const paddle::optional &h0_in, + const paddle::optional &c0_in, + bool use_peepholes, + bool is_reverse, + bool use_seq, + const std::string &gate_activation, + const std::string &cell_activation, + const std::string &candidate_activation, + float scale_data, + float shift_data, + const std::vector &scale_weights, + bool force_fp32_output, + DenseTensor *hidden, + DenseTensor *cell, + DenseTensor *xx, + DenseTensor *batched_input, + DenseTensor *batched_hidden, + DenseTensor *batched_cell, + DenseTensor *reordered_h0, + DenseTensor *reordered_c0, + DenseTensor *checked_cell) { + INIT_BASE_DEFINES; + if (x->lod()[0].size() == 2) { + xx->Resize({x_dims[0], D4}); + SeqCompute(dev_ctx, + x_in, + weight_x_in, + weight_h_in, + bias_in, + h0_in, + c0_in, + use_peepholes, + is_reverse, + use_seq, + gate_activation, + cell_activation, + candidate_activation, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + hidden, + cell, + xx, + batched_input, + batched_hidden, + batched_cell, + reordered_h0, + reordered_c0, + checked_cell); + return; + } + INIT_OTHER_DEFINES; + + auto *batched_c_out = batched_cell; + auto *batched_h_out = batched_hidden; + T *xx_data = dev_ctx.template Alloc(xx); + T *batched_input_data = dev_ctx.template Alloc(batched_input); + T *batched_c_out_data = dev_ctx.template Alloc(batched_c_out); + T *batched_h_out_data = dev_ctx.template Alloc(batched_h_out); + dev_ctx.template Alloc(hidden_out); + dev_ctx.template Alloc(cell_out); + + phi::funcs::LoDTensor2BatchFunctor to_batch; + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::FCFunctor fc; + if (M > D4) { + fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data()); + to_batch(dev_ctx, *xx, batched_input, true, is_reverse); + } else { + to_batch(dev_ctx, *x, xx, true, is_reverse); + batched_input->set_lod(xx->lod()); + fc(dev_ctx, + x_dims[0], + D4, + M, + xx_data, + wx_data, + batched_input_data, + bias->data()); + } + + auto batched_lod = batched_input->lod(); + const auto &seq_order = batched_lod[2]; + const int max_bs = static_cast(seq_order.size()); + reordered_h0->Resize({max_bs, D}); + reordered_c0->Resize({max_bs, D}); + + int tstart = 0; + T *prev_h_data = nullptr; + T *prev_c_data = nullptr; + if (h0) { + // reorder h0, c0 + T *reordered_h0_data = dev_ctx.template Alloc(reordered_h0); + T *reordered_c0_data = dev_ctx.template Alloc(reordered_c0); + const T *h0_data = h0->data(); + const T *c0_data = c0->data(); + prev_h_data = reordered_h0_data; + prev_c_data = reordered_c0_data; + size_t sz = D; + for (int i = 0; i < max_bs; ++i) { + blas.VCOPY(sz, h0_data + seq_order[i] * D, reordered_h0_data); + blas.VCOPY(sz, c0_data + seq_order[i] * D, reordered_c0_data); + reordered_h0_data += D; + reordered_c0_data += D; + } + } else { + // compute without h0, c0 + T *cur_in_data = batched_input_data; + T *cur_h_out_data = batched_h_out_data; + T *cur_c_out_data = batched_c_out_data; + for (int i = 0; i < max_bs; ++i) { + one_step.gates = cur_in_data; + one_step.ct = cur_c_out_data; + one_step.ht = cur_h_out_data; + ComputeC1H1(&one_step, &attr); + + cur_in_data += D4; + cur_c_out_data += D; + cur_h_out_data += D; + } + tstart = 1; + prev_h_data = batched_h_out_data; + prev_c_data = batched_c_out_data; + } + + // compute kernel part + const auto &batch_starts = batched_lod[0]; + const int max_seq_len = static_cast(batch_starts.size() - 1); + const int offset = tstart * max_bs * D; + batched_input_data = batched_input_data + offset * 4; + batched_h_out_data = batched_h_out_data + offset; + batched_c_out_data = batched_c_out_data + offset; + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = + static_cast(batch_starts[step + 1] - batch_starts[step]); + GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); + T *cur_in_data = batched_input_data; + T *cur_prev_c_data = prev_c_data; + T *cur_c_out_data = batched_c_out_data; + T *cur_h_out_data = batched_h_out_data; + for (int i = 0; i < cur_bs; ++i) { + one_step.gates = cur_in_data; + one_step.ct_1 = cur_prev_c_data; + one_step.ct = cur_c_out_data; + one_step.ht = cur_h_out_data; + ComputeCtHt(&one_step, &attr); + + // move one batch + cur_in_data += D4; + cur_prev_c_data += D; + cur_c_out_data += D; + cur_h_out_data += D; + } + // move one step + prev_c_data = batched_c_out_data; + prev_h_data = batched_h_out_data; + batched_c_out_data = cur_c_out_data; + batched_h_out_data = cur_h_out_data; + batched_input_data = cur_in_data; + } + + phi::funcs::Batch2LoDTensorFunctor to_seq; + batched_h_out->set_lod(batched_lod); + to_seq(dev_ctx, *batched_h_out, hidden_out); + batched_c_out->set_lod(batched_lod); + to_seq(dev_ctx, *batched_c_out, cell_out); +} + +template +void FusionLSTMKernel(const Context &dev_ctx, + const DenseTensor &x_in, + const DenseTensor &weight_x_in, + const DenseTensor &weight_h_in, + const DenseTensor &bias_in, + const paddle::optional &h0_in, + const paddle::optional &c0_in, + bool use_peepholes, + bool is_reverse, + bool use_seq, + const std::string &gate_activation, + const std::string &cell_activation, + const std::string &candidate_activation, + float scale_data, + float shift_data, + const std::vector &scale_weights, + bool force_fp32_output, + DenseTensor *hidden, + DenseTensor *cell, + DenseTensor *xx, + DenseTensor *batched_input, + DenseTensor *batched_hidden, + DenseTensor *batched_cell, + DenseTensor *reordered_h0, + DenseTensor *reordered_c0, + DenseTensor *checked_cell) { + if (use_seq) { + SeqCompute(dev_ctx, + x_in, + weight_x_in, + weight_h_in, + bias_in, + h0_in, + c0_in, + use_peepholes, + is_reverse, + use_seq, + gate_activation, + cell_activation, + candidate_activation, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + hidden, + cell, + xx, + batched_input, + batched_hidden, + batched_cell, + reordered_h0, + reordered_c0, + checked_cell); + } else { + BatchCompute(dev_ctx, + x_in, + weight_x_in, + weight_h_in, + bias_in, + h0_in, + c0_in, + use_peepholes, + is_reverse, + use_seq, + gate_activation, + cell_activation, + candidate_activation, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + hidden, + cell, + xx, + batched_input, + batched_hidden, + batched_cell, + reordered_h0, + reordered_c0, + checked_cell); + } +} + +#undef GEMM_WH_ADDON +#undef INIT_OTHER_DEFINES +#undef INIT_BASE_DEFINES + +} // namespace phi + +PD_REGISTER_KERNEL( + fusion_lstm, CPU, ALL_LAYOUT, phi::FusionLSTMKernel, float, double) {} diff --git a/paddle/phi/kernels/fusion/onednn/fusion_lstm_kernel.cc b/paddle/phi/kernels/fusion/onednn/fusion_lstm_kernel.cc new file mode 100644 index 0000000000000..02a3fd7fc3fb9 --- /dev/null +++ b/paddle/phi/kernels/fusion/onednn/fusion_lstm_kernel.cc @@ -0,0 +1,573 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/expect.h" +#include "paddle/phi/kernels/fusion/onednn/fusion_rnn_onednn.h" +#include "paddle/utils/optional.h" + +namespace phi { +namespace fusion { + +using phi::OneDNNContext; +using phi::funcs::OneDNNGetDataType; +using phi::funcs::OneDNNMemDesc; +using phi::funcs::RNNReorderType; +using OneDNNMemoryFormat = dnnl::memory::format_tag; + +template +class LSTMMKLDNNHandler + : public RNNMKLDNNHandler { + public: + LSTMMKLDNNHandler(const OneDNNContext& dev_ctx, + const dnnl::engine onednn_engine, + phi::Place cpu_place UNUSED, + const phi::DenseTensor* input, + const phi::DenseTensor* weight_h, + const phi::DenseTensor* h0, + const phi::DenseTensor* c0 UNUSED, + const bool is_reverse, + const int64_t N, + const int64_t Ti, + const int64_t IC, + const int64_t OC, + const std::string& unique_name UNUSED, + float scale_data, + float shift_data, + std::vector scale_weights, + bool use_peepholes, + std::string gate_activation, + std::string cell_activation, + std::string candidate_activation) + : RNNMKLDNNHandler(dev_ctx, + onednn_engine, + dev_ctx.GetPlace(), + input, + weight_h, + h0, + is_reverse, + N, + Ti, + IC, + OC, + 4, + "x_weight_h", + scale_data, + shift_data, + scale_weights) { + if (unlikely(!this->isCached())) { + const bool is_INT8 = std::is_same::value; + // oneDNN kernel has hardcoded activation functions + PADDLE_ENFORCE_EQ( + gate_activation, + "sigmoid", + phi::errors::Unimplemented("oneDNN fusion_lstm supports only " + "sigmoid as a gate activation.")); + PADDLE_ENFORCE_EQ( + cell_activation, + "tanh", + phi::errors::Unimplemented( + "oneDNN fusion_lstm supports only tanh as a cell activation.")); + PADDLE_ENFORCE_EQ( + candidate_activation, + "tanh", + phi::errors::Unimplemented( + "oneDNN fusion_lstm supports only tanh a candidate activation.")); + + // Weights for int8 kernel are of a type s8 + const auto weights_dt = + is_INT8 ? dnnl::memory::data_type::s8 : OneDNNGetDataType(); + + // oneDNN RNN dimensions + const int64_t D = 1; // Directions + const int64_t L = 1; // Layers (PP supports only 1 stacked layer) + const int64_t G = 4; // Number of Gates, 4 for LSTM + + // Create memory descriptors + auto input_md = OneDNNMemDesc( + {Ti, N, IC}, OneDNNGetDataType(), OneDNNMemoryFormat::tnc); + auto weight_x_md = + OneDNNMemDesc({L, D, IC, G, OC}, weights_dt, OneDNNMemoryFormat::any); + auto weight_h_md = + OneDNNMemDesc({L, D, OC, G, OC}, weights_dt, OneDNNMemoryFormat::any); + auto bias_md = OneDNNMemDesc( + {L, D, G, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldgo); + auto hidden_md = OneDNNMemDesc( + {Ti, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::any); + + auto h0_md = OneDNNMemDesc( + {L, D, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::any); + auto c0_md = OneDNNMemDesc( + {L, D, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::any); + + // Create LSTM oneDNN primitive + const auto direction = + is_reverse ? dnnl::rnn_direction::unidirectional_right2left + : dnnl::rnn_direction::unidirectional_left2right; + if (!use_peepholes) { + this->AcquireForwardPrimitiveDescriptor( + this->attr_, + dnnl::prop_kind::forward_inference, + direction, + input_md, + h0_md, + c0_md, + weight_x_md, + weight_h_md, + bias_md, + hidden_md, + dnnl::memory::desc(), + dnnl::memory::desc()); + } else { + auto weight_peephole_md = OneDNNMemDesc({L, D, 3, OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldgo); + this->AcquireForwardPrimitiveDescriptor( + this->attr_, + dnnl::prop_kind::forward_inference, + direction, + input_md, + h0_md, + c0_md, + weight_x_md, + weight_h_md, + weight_peephole_md, + bias_md, + hidden_md, + dnnl::memory::desc(), + dnnl::memory::desc()); + } + } + } + + // PaddlePaddle has different order of weights than oneDNN, so a reorder is + // needed + // PaddlePaddle: {c, i, f, o} + // oneDNN: {i, f, c, o} + template + void ReorderGates(U* weights, int64_t I) { + size_t inner_block_size = this->OC; + size_t block_size = inner_block_size * this->G; + for (size_t i = 0; i < (size_t)I; ++i) { // NOLINT + size_t offset = i * block_size; + + U* base_pos = weights + offset; + std::swap_ranges(base_pos, + base_pos + inner_block_size, + base_pos + inner_block_size); // c <-> i + std::swap_ranges(base_pos + inner_block_size, + base_pos + 2 * inner_block_size, + base_pos + 2 * inner_block_size); // c <-> f + } + } + + template + std::shared_ptr AcquireWeightXMemory( + const phi::DenseTensor* weight_x) { + const std::string wx_key = this->memory_key_ + "@weight_x"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); + + if (!memory_p) { + auto user_md = OneDNNMemDesc({1, 1, this->IC, this->G, this->OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + auto* weight_x_data = reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_x_data, + weight_x->data(), + sizeof(U) * this->IC * this->G * this->OC); + + ReorderGates(weight_x_data, this->IC); + + memory_p = std::make_shared( + this->fwd_pd_->weights_layer_desc(), this->engine_); + + auto& astream = OneDNNContext::tls().get_stream(); + dnnl::reorder(user_memory, *memory_p, this->attr_) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wx_key, memory_p); + } + return memory_p; + } + + template + std::shared_ptr AcquireWeightHMemory( + const phi::DenseTensor* weight_h) { + const std::string wh_key = this->memory_key_ + "@weight_h"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); + + if (!memory_p) { + auto user_md = OneDNNMemDesc({1, 1, this->OC, this->G, this->OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + auto* weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_h_data, + weight_h->data(), + sizeof(U) * this->OC * this->G * this->OC); + + ReorderGates(weight_h_data, this->OC); + + memory_p = std::make_shared( + this->fwd_pd_->weights_iter_desc(), this->engine_); + + auto& astream = OneDNNContext::tls().get_stream(); + dnnl::reorder(user_memory, *memory_p, this->attr_) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wh_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireBiasMemory( + const phi::DenseTensor* bias) { + const std::string bias_key = this->memory_key_ + "@bias"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(bias_key)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->bias_desc(), + this->engine_); + auto* bias_data = reinterpret_cast(memory_p->get_data_handle()); + if (bias) { + const float* user_bias_data = + bias->data(); // Bias in oneDNN is always float + + memcpy(bias_data, user_bias_data, sizeof(float) * this->G * this->OC); + + ReorderGates(bias_data, 1); + } else { + // oneDNN always need bias memory, if it's not provided in PP, let + // oneDNN allocate memory and set it to 0 + memset(bias_data, 0, sizeof(float) * this->G * this->OC); + } + + this->dev_ctx_.SetBlob(bias_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquirePeepholeWeights( + const phi::DenseTensor* bias) { + const std::string peepholes_key = this->memory_key_ + "@peepholes_weights"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(peepholes_key)); + + if (!memory_p) { + auto user_md = OneDNNMemDesc({1, 1, 3, this->OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldgo); + auto user_memory = dnnl::memory(user_md, this->engine_); + memory_p = std::make_shared( + this->fwd_pd_->weights_peephole_desc(), this->engine_); + auto* peephole_weights_data = + reinterpret_cast(memory_p->get_data_handle()); + + const float* user_bias_data = + bias->data(); // Bias in oneDNN is always float + memcpy(peephole_weights_data, + user_bias_data + 4 * this->OC, + sizeof(float) * 3 * this->OC); + + this->dev_ctx_.SetBlob(peepholes_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireC0Memory(const phi::DenseTensor* c0) { + const std::string c0_key = this->memory_key_ + "@c0"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(c0_key)); + + if (!memory_p) { + auto user_c0_memory = dnnl::memory(); + if (c0) { + user_c0_memory = + dnnl::memory({{1, 1, this->N, this->OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldnc}, + this->engine_, + phi::funcs::to_void_cast(c0->data())); + } else { + user_c0_memory = dnnl::memory({{1, 1, this->N, this->OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldnc}, + this->engine_); + memset(user_c0_memory.get_data_handle(), + 0, + sizeof(float) * this->N * this->OC); + } + memory_p = std::make_shared( + this->fwd_pd_->src_iter_c_desc(), this->engine_); + + auto& astream = OneDNNContext::tls().get_stream(); + dnnl::reorder(user_c0_memory, *memory_p) + .execute(astream, user_c0_memory, *memory_p); + + this->dev_ctx_.SetBlob(c0_key, memory_p); + } + return memory_p; + } +}; + +template +void RunKernel(const Context& dev_ctx, + const DenseTensor& x_in, + const DenseTensor& weight_x_in, + const DenseTensor& weight_h_in, + const DenseTensor& bias_in, + const paddle::optional& h0_in, + const paddle::optional& c0_in, + bool use_peepholes, + bool is_reverse, + bool use_seq, + const std::string& gate_activation, + const std::string& cell_activation, + const std::string& candidate_activation, + float scale_data, + float shift_data, + const std::vector& scale_weights, + bool force_fp32_output, + DenseTensor* hidden, + DenseTensor* cell, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_hidden, + DenseTensor* batched_cell, + DenseTensor* reordered_h0, + DenseTensor* reordered_c0, + DenseTensor* checked_cell) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + // Get Tensors + const auto* input = &x_in; + const auto* h0 = h0_in.get_ptr(); + const auto* c0 = c0_in.get_ptr(); + const auto* weight_x = &weight_x_in; + const auto* weight_h = &weight_h_in; + const auto* bias = &bias_in; + + auto x_dims = input->dims(); + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) + ? common::flatten_to_2d(x_dims, 1) + : x_dims; + + // Get tensor dimensions + const auto x_mat_dims_vec = common::vectorize(x_mat_dims); + const auto weight_h_dims = common::vectorize(weight_h->dims()); + const auto& input_lod = input->lod()[0]; + + // Calculate RNN dimensions + const int64_t N = input_lod.size() - 1; // Number of sentences (batches) + const int64_t Ti = // Max length of the sentence in a batch + [&input_lod]() { + size_t res = 0; + for (size_t i = 0; i < (input_lod.size() - 1); ++i) { + res = std::max(res, input_lod[i + 1] - input_lod[i]); + } + return res; + }(); + const int64_t IC = x_mat_dims_vec[1]; // Input channels + const int64_t OC = weight_h_dims[0]; // Output channels + + LSTMMKLDNNHandler handler(dev_ctx, + onednn_engine, + dev_ctx.GetPlace(), + input, + weight_h, + h0, + c0, + is_reverse, + N, + Ti, + IC, + OC, + "x_weight_h", + scale_data, + shift_data, + scale_weights, + use_peepholes, + gate_activation, + cell_activation, + candidate_activation); + + auto input_memory_p = + handler.AcquireInputMemoryWithReorder(input, is_reverse); + auto c0_memory_p = handler.AcquireC0Memory(c0); + + std::shared_ptr h0_memory_p, weight_h_memory_p, + weight_x_memory_p; + + if (weight_h->dtype() == phi::DataType::FLOAT32) { + h0_memory_p = handler.template AcquireH0Memory(h0); + weight_x_memory_p = handler.template AcquireWeightXMemory(weight_x); + weight_h_memory_p = handler.template AcquireWeightHMemory(weight_h); + } else if (weight_h->dtype() == phi::DataType::BFLOAT16) { + h0_memory_p = handler.template AcquireH0Memory(h0); + weight_x_memory_p = + handler.template AcquireWeightXMemory(weight_x); + weight_h_memory_p = + handler.template AcquireWeightHMemory(weight_h); + } else { + h0_memory_p = handler.template AcquireH0Memory(h0); + weight_x_memory_p = handler.template AcquireWeightXMemory(weight_x); + weight_h_memory_p = handler.template AcquireWeightHMemory(weight_h); + } + + auto bias_memory_p = handler.AcquireBiasMemory(bias); + auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); + + std::unordered_map lstm_args = { + {DNNL_ARG_SRC_LAYER, *input_memory_p}, + {DNNL_ARG_SRC_ITER, *h0_memory_p}, + {DNNL_ARG_SRC_ITER_C, *c0_memory_p}, + {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, + {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, + {DNNL_ARG_BIAS, *bias_memory_p}, + {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; + + if (use_peepholes) { + auto peephole_weight_p = handler.AcquirePeepholeWeights(bias); + std::pair peepholes_weights(DNNL_ARG_WEIGHTS_PEEPHOLE, + *peephole_weight_p); + lstm_args.insert(peepholes_weights); + } + + auto lstm_forward_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + lstm_forward_p->execute(astream, lstm_args); + astream.wait(); + + auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle(); + auto* hidden_data = + phi::funcs::to_void_cast(dev_ctx.template Alloc(hidden)); + if (handler.is_NTC()) { + handler.reorderRNNdata(hidden_onednn_data, + hidden_data, + input_lod, + is_reverse, + RNNReorderType::NTC_PP); + } else { + handler.reorderRNNdata(hidden_onednn_data, + hidden_data, + input_lod, + is_reverse, + RNNReorderType::TNC_PP); + } +} + +template +void FusionLSTMMKLDNNKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const DenseTensor& bias, + const paddle::optional& h0, + const paddle::optional& c0, + bool use_peepholes, + bool is_reverse, + bool use_seq, + const std::string& gate_activation, + const std::string& cell_activation, + const std::string& candidate_activation, + float scale_data, + float shift_data, + const std::vector& scale_weights, + bool force_fp32_output, + DenseTensor* hidden, + DenseTensor* cell, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_hidden, + DenseTensor* batched_cell, + DenseTensor* reordered_h0, + DenseTensor* reordered_c0, + DenseTensor* checked_cell) { + const bool is_bf16 = std::is_same::value; + + // BF16 does not support force output + if (!is_bf16 && force_fp32_output) { // NOLINT + RunKernel(dev_ctx, + x, + weight_x, + weight_h, + bias, + h0, + c0, + use_peepholes, + is_reverse, + use_seq, + gate_activation, + cell_activation, + candidate_activation, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + hidden, + cell, + xx, + batched_input, + batched_hidden, + batched_cell, + reordered_h0, + reordered_c0, + checked_cell); + } else { + RunKernel(dev_ctx, + x, + weight_x, + weight_h, + bias, + h0, + c0, + use_peepholes, + is_reverse, + use_seq, + gate_activation, + cell_activation, + candidate_activation, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + hidden, + cell, + xx, + batched_input, + batched_hidden, + batched_cell, + reordered_h0, + reordered_c0, + checked_cell); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_lstm, + OneDNN, + ONEDNN, + phi::fusion::FusionLSTMMKLDNNKernel, + float, + uint8_t, + phi::dtype::bfloat16) {} diff --git a/paddle/fluid/operators/fused/onednn/fusion_rnn_onednn.h b/paddle/phi/kernels/fusion/onednn/fusion_rnn_onednn.h similarity index 87% rename from paddle/fluid/operators/fused/onednn/fusion_rnn_onednn.h rename to paddle/phi/kernels/fusion/onednn/fusion_rnn_onednn.h index c04dd0cebeec0..d429f0b3944bb 100644 --- a/paddle/fluid/operators/fused/onednn/fusion_rnn_onednn.h +++ b/paddle/phi/kernels/fusion/onednn/fusion_rnn_onednn.h @@ -1,24 +1,24 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" -namespace paddle { -namespace operators { +namespace phi { +namespace fusion { using phi::funcs::CreateKey; using phi::funcs::OneDNNGetDataType; @@ -28,8 +28,7 @@ using OneDNNMemoryFormat = dnnl::memory::format_tag; template class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT { public: - RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, - const phi::OneDNNContext& dev_ctx, + RNNMKLDNNHandler(const phi::OneDNNContext& dev_ctx, const dnnl::engine onednn_engine UNUSED, phi::Place cpu_place, const phi::DenseTensor* input UNUSED, @@ -41,7 +40,10 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT { const int64_t IC, const int64_t OC, const int64_t G, - const std::string& unique_name) + const std::string& unique_name, + float scale_data, + float shift_data, + std::vector scale_weights) : phi::funcs::OneDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), @@ -62,9 +64,6 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT { if (is_INT8) { // Int8 attributes - const float scale_data = ctx.Attr("Scale_data"); - const float shift_data = ctx.Attr("Shift_data"); - const auto scale_weights = ctx.Attr>("Scale_weights"); const int weights_scale_mask = 0 + @@ -237,5 +236,5 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT { std::string memory_key_; dnnl::primitive_attr attr_; }; -} // namespace operators -} // namespace paddle +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/ops/yaml/fused_ops.yaml b/paddle/phi/ops/yaml/fused_ops.yaml index a96fe7416facf..3c244b6f4625d 100644 --- a/paddle/phi/ops/yaml/fused_ops.yaml +++ b/paddle/phi/ops/yaml/fused_ops.yaml @@ -457,6 +457,17 @@ optional : h0, bias intermediate : reordered_h0, xx, batched_input, batched_out +- op : fusion_lstm + args : (Tensor x, Tensor weight_x, Tensor weight_h, Tensor bias, Tensor h0, Tensor c0, bool use_peepholes=true, bool is_reverse=false, bool use_seq=true, str gate_activation="sigmoid", str cell_activation="tanh", str candidate_activation="tanh", float scale_data=1.0, float shift_data=0.0, float[] scale_weights={1.0}, bool force_fp32_output=false) + output : Tensor(hidden), Tensor(cell), Tensor(xx), Tensor(batched_input), Tensor(batched_hidden), Tensor(batched_cell), Tensor(reordered_h0), Tensor(reordered_c0), Tensor(checked_cell) + infer_meta : + func : FusionLstmInferMeta + kernel : + func : fusion_lstm + data_type : x + optional : h0, c0 + intermediate : xx, batched_input, batched_hidden, batched_cell, reordered_h0, reordered_c0, checked_cell + - op : fusion_repeated_fc_relu args : (Tensor x, Tensor[] w, Tensor[] bias) output : Tensor[](relu_out){w.size()-1}, Tensor(out) diff --git a/paddle/phi/ops/yaml/inconsistent/onednn_static.yaml b/paddle/phi/ops/yaml/inconsistent/onednn_static.yaml index 282dd35cb3453..386eadf0c1dc6 100644 --- a/paddle/phi/ops/yaml/inconsistent/onednn_static.yaml +++ b/paddle/phi/ops/yaml/inconsistent/onednn_static.yaml @@ -91,17 +91,6 @@ kernel : func : fused_transpose -- op : fusion_lstm - args : (Tensor x, Tensor weight_x, Tensor weight_h, Tensor bias, Tensor h0, Tensor c0, bool use_peepholes=true, bool is_reverse=false, bool use_seq=true, str gate_activation="sigmoid", str cell_activation="tanh", str candidate_activation="tanh", float scale_data=1.0, float shift_data=0.0, float[] scale_weights={1.0}, bool force_fp32_output=false) - output : Tensor(hidden), Tensor(cell), Tensor(xx), Tensor(batched_input), Tensor(batched_hidden), Tensor(batched_cell), Tensor(reordered_h0), Tensor(reordered_c0), Tensor(checked_cell) - infer_meta : - func : FusionLstmInferMeta - kernel : - func : fusion_lstm - data_type : x - optional : h0, c0 - intermediate : xx, batched_input, batched_hidden, batched_cell, reordered_h0, reordered_c0, checked_cell - - op: multi_gru args: (Tensor x, Tensor[] weight_x, Tensor[] weight_h, Tensor[] bias, Tensor[] scale_weights, str activation="tanh", str gate_activation="sigmoid", int layers=1, bool origin_mode=false, str mkldnn_data_type="float32", float scale_data=1.0, float shift_data=1.0, bool force_fp32_output=false) output: Tensor(hidden)