diff --git a/paddle/fluid/operators/partial_concat_op_xpu.cc b/paddle/fluid/operators/partial_concat_op_xpu.cc new file mode 100644 index 0000000000000..c95cfa8aa703d --- /dev/null +++ b/paddle/fluid/operators/partial_concat_op_xpu.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/partial_concat_op.h" + +namespace paddle { +namespace operators { +using LoDTensor = framework::LoDTensor; + +template +class PartialConcatXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto in_vars = ctx.MultiInput("X"); + Tensor *out = ctx.Output("Out"); + PADDLE_ENFORCE_EQ(in_vars[0] != nullptr, + true, + platform::errors::InvalidArgument( + "The input of partial concat should not be null.")); + + auto input_dim = in_vars[0]->dims(); + PADDLE_ENFORCE_EQ(input_dim.size(), + 2, + platform::errors::InvalidArgument( + "Only supports 2-D array with batch size in the 1st " + "dimension and data in the 2nd.")); + auto in_size = input_dim[1]; + // may be negative + auto start_index = ctx.Attr("start_index"); + start_index = ComputeStartIndex(start_index, in_size); + + auto partial_len = ctx.Attr("length"); + if (partial_len < 0) { + partial_len = in_size - start_index; + } + //TODO: what if partial_len > in_size + auto xpu_context = ctx.template device_context().x_context(); + + int in_num = in_vars.size(); + int batch_size = input_dim[0]; + + std::vector tmp_tensors(in_num); + std::vector tmp_tensors_data(in_num); + std::vector> tmp_outs_shape(in_num); + for (size_t i = 0; i < in_vars.size(); i++) { + tmp_tensors[i].Resize(phi::make_ddim({batch_size, partial_len})); + tmp_tensors_data[i] = tmp_tensors[i].mutable_data(ctx.GetPlace()); + + tmp_outs_shape[i] = std::vector({batch_size, partial_len}); + + const T* input_data = in_vars[i]->data(); + + std::vector xshape = phi::vectorize(in_vars[i]->dims()); + std::vector starts = {0, start_index}; + std::vector ends = {batch_size, start_index + partial_len + 1};//要截取的x的每个维度的终止坐标(不包含) + + int r = xpu::slice(xpu_context, + input_data, + const_cast(tmp_tensors_data[i]), + xshape, + starts, + ends); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The partial_concat XPU OP's slice return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } + + T* out_data = out->mutable_data(ctx.GetPlace()); + + int axis = 1; + int r = xpu::concat(xpu_context, + tmp_tensors_data, + out_data, + tmp_outs_shape, + axis); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The partial_concat XPU OP's concat return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +template +class PartialConcatGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *out_grad = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto xs_grad = ctx.MultiOutput(framework::GradVarName("X")); + + PADDLE_ENFORCE_EQ(ins[0] != nullptr, + true, + platform::errors::InvalidArgument( + "The input of partial concat should not be null.")); + // all parameters + int batch_size = ins[0]->dims()[0]; + int in_size = ins[0]->dims()[1]; + // may be negative + auto start_index = ctx.Attr("start_index"); + start_index = ComputeStartIndex(start_index, in_size); + auto partial_len = ctx.Attr("length"); + if (partial_len < 0) { + partial_len = in_size - start_index; + } + + auto in_num = ins.size(); + + auto xpu_context = ctx.template device_context().x_context(); + + std::vector tmp_tensors(in_num); + std::vector tmp_tensors_data(in_num); + + const T* out_grad_data = out_grad->data(); + for (size_t i = 0; i < in_num; i++) { + tmp_tensors[i].Resize(phi::make_ddim({batch_size, partial_len})); + tmp_tensors_data[i] = tmp_tensors[i].mutable_data(ctx.GetPlace()); + + std::vector xshape = phi::vectorize(out_grad->dims()); + std::vector starts = {0, int(partial_len * i)}; + std::vector ends = {batch_size, int(partial_len * i + partial_len + 1)};//要截取的x的每个维度的终止坐标(不包含) + + int r = xpu::slice(xpu_context, + out_grad_data, + const_cast(tmp_tensors_data[i]), + xshape, + starts, + ends); + PADDLE_ENFORCE_EQ( + r, + xpu::Error_t::SUCCESS, + platform::errors::External("The partial_concat_grad XPU OP's slice " + "return wrong value[%d %s]", + r, + XPUAPIErrorMsg[r])); + + std::vector tmp_shape = {batch_size, partial_len}; + std::vector pad_left = {0, start_index}; + std::vector pad_right = {0, in_size - start_index - partial_len}; + T* xs_grad_data = xs_grad[i]->mutable_data(ctx.GetPlace()); + + r = xpu::pad(xpu_context, + tmp_tensors_data[i], + xs_grad_data, + tmp_shape, + pad_left, + pad_right, + T(0)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The partial_concat_grad XPU OP's pad return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL(partial_concat, ops::PartialConcatXPUKernel) +REGISTER_OP_XPU_KERNEL(partial_concat_grad, ops::PartialConcatGradXPUKernel) diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index df5585b3128ce..d438f0e8d2a2d 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -589,6 +589,10 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"index_select_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"partial_concat", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"partial_concat_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, }; return s_xpu2_kernels; }