diff --git a/paddle/fluid/operators/lu_op.cc b/paddle/fluid/operators/lu_op.cc index 67bc9ba4fe774..c6831f975c47a 100644 --- a/paddle/fluid/operators/lu_op.cc +++ b/paddle/fluid/operators/lu_op.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/lu_op.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h deleted file mode 100644 index 1122937b6efe3..0000000000000 --- a/paddle/fluid/operators/lu_op.h +++ /dev/null @@ -1,528 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/phi_utils.h" -#include "paddle/fluid/operators/set_value_op.h" -#include "paddle/fluid/operators/svd_helper.h" -#include "paddle/phi/kernels/elementwise_add_kernel.h" -#include "paddle/phi/kernels/elementwise_subtract_kernel.h" -#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" -#include "paddle/phi/kernels/funcs/tril_triu_compute.h" -#include "paddle/phi/kernels/triangular_solve_kernel.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensorArray = framework::LoDTensorArray; - -template -void SetValueCompute(const framework::ExecutionContext& ctx, - framework::Tensor* in, - framework::Tensor* value_tensor, - framework::Tensor* out, - const std::vector& axes, - std::vector* starts, - std::vector* ends, - const std::vector& shape) { - std::vector steps = {1, 1}; - std::vector decrease_axes = {}; - std::vector none_axes = {}; - - auto dtype = framework::TransToProtoVarType(in->dtype()); - - auto in_dims = in->dims(); - phi::funcs::CheckAndUpdateSliceAttrs( - in_dims, axes, starts, ends, &steps); - auto slice_dims = - phi::funcs::GetSliceDims(in_dims, axes, *starts, *ends, &steps); - auto decrease_slice_dims = - phi::funcs::GetDecreasedDims(slice_dims, decrease_axes); - - auto slice_dims_for_assign = decrease_slice_dims; - if (!none_axes.empty()) { - std::vector slice_dims_with_none; - - size_t none_axes_cur = 0, decrease_axes_cur = 0; - for (int i = 0; i < slice_dims.size(); ++i) { - while (none_axes_cur < none_axes.size() && - none_axes[none_axes_cur] <= i) { - slice_dims_with_none.push_back(1); - none_axes_cur++; - } - if (decrease_axes_cur < decrease_axes.size() && - decrease_axes[decrease_axes_cur] == i) { - decrease_axes_cur++; - } else { - slice_dims_with_none.push_back(slice_dims[i]); - } - } - while (none_axes_cur < none_axes.size()) { - slice_dims_with_none.push_back(1); - none_axes_cur++; - } - - slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); - } - - auto place = ctx.GetPlace(); - auto& eigen_place = - *ctx.template device_context().eigen_device(); - - // Here copy data from input to avoid data loss at PE and Graph level. - // TODO(liym27): Speed up in the future version. - // - Q: Why don't call ShareDataWith to speed up? - // - A: Because it's not supported to ShareDataWith on OP's input and output - // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP - // - Q: Why don't delete Input, after all, the input and output are the same - // Tensor at program level? - // - A: If deleting Input, the graph will be complex, such as there will - // be two ops points to the output in graph: op1 -> output <- set_value. - // In this case, we have to find a way to handle the running order of - // set_value is what we want. - paddle::framework::TensorCopy(*in, place, out); - - Tensor slice_tensor(framework::TransToPhiDataType(dtype)), - pad_tensor(framework::TransToPhiDataType(dtype)); - slice_tensor.mutable_data(slice_dims, place); - pad_tensor.mutable_data(in_dims, place); - - auto pad_e = framework::EigenTensor::From(pad_tensor, in_dims); - auto out_e = framework::EigenTensor::From(*out); - auto slice_e = framework::EigenTensor::From(slice_tensor, slice_dims); - - // Step 1: Set the value of out at `_index` to zero - slice_e.device(eigen_place) = slice_e.constant(T(0)); - - auto starts_indices = Eigen::DSizes(); - auto ends_indices = Eigen::DSizes(); - auto strides_indices = Eigen::DSizes(); - - for (size_t i = 0; i < D; ++i) { - starts_indices[i] = 0; - ends_indices[i] = slice_dims[i]; - strides_indices[i] = 1; - } - for (size_t i = 0; i < axes.size(); i++) { - int axis_index = axes[i]; - starts_indices[axis_index] = (*starts)[i]; - ends_indices[axis_index] = (*ends)[i]; - strides_indices[axis_index] = steps[i]; - if ((*starts)[i] == - (*ends)[i]) { // slice is empty, data will not be changed - return; - } - } - - out_e.stridedSlice(starts_indices, ends_indices, strides_indices) - .device(eigen_place) = slice_e; - - // Step 2: Set a tensor with the same shape as out tensor. And its data at - // '_index' is the same as value_tensor, and data out of '_index' to zero - - // - Step 2.1 Set slice tensor with value - - // NOTE(liym27): [ Why resize slice_tensor here? ] - // A: When do broadcasting on slice_tensor and value_tensor, the shape of - // slice_tensor should be decreased dims. - // e.g. - // x[:,0] = value_tensor - // x's shape = [3, 4], value_tensor's shape = [3] - // We get slice_dims = [3, 1], decrease_slice_dims = [3] - // If do broadcasting on Tensor with shape [3, 1] and [3], the result's - // shape is [3, 3], which cross the border; - // If do broadcasting on Tensor with shape [3] and [3], the result's shape - // is [3], which is right. - - slice_tensor.Resize(slice_dims_for_assign); - if (value_tensor != nullptr) { - CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims()); - // ElementwiseComputeEx can do broadcasting - ElementwiseComputeEx, DeviceContext, T>( - ctx, &slice_tensor, value_tensor, -1, SubFunctor(), &slice_tensor); - } else { - Tensor value_t(framework::TransToPhiDataType(dtype)); - auto value_dims = phi::make_ddim(shape); - CheckIsDimsMatch(slice_dims_for_assign, value_dims); - - value_t.mutable_data(value_dims, place); - auto value_name = GetValueName(dtype); - CopyVectorToTensor(value_name.c_str(), &value_t, ctx); - value_t.Resize(value_dims); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &slice_tensor, &value_t, -1, SubFunctor(), &slice_tensor); - } - slice_tensor.Resize(slice_dims); - - // - Step 2.2 Pad slice tensor with 0 - pad_e.device(eigen_place) = pad_e.constant(T(0)); - pad_e.stridedSlice(starts_indices, ends_indices, strides_indices) - .device(eigen_place) = slice_e; - - // Step 3: Set out tensor with value_tensor - out_e.device(eigen_place) = out_e - pad_e; -} - -template -void SetValueCompute_dispatch(const framework::ExecutionContext& ctx, - framework::Tensor* in, - framework::Tensor* value_tensor, - framework::Tensor* out, - const std::vector& axes, - std::vector* starts, - std::vector* ends, - const std::vector& shape, - int rank) { - switch (rank) { - case 1: - SetValueCompute( - ctx, in, value_tensor, out, axes, starts, ends, shape); - break; - case 2: - SetValueCompute( - ctx, in, value_tensor, out, axes, starts, ends, shape); - break; - case 3: - SetValueCompute( - ctx, in, value_tensor, out, axes, starts, ends, shape); - break; - case 4: - SetValueCompute( - ctx, in, value_tensor, out, axes, starts, ends, shape); - break; - case 5: - SetValueCompute( - ctx, in, value_tensor, out, axes, starts, ends, shape); - break; - case 6: - SetValueCompute( - ctx, in, value_tensor, out, axes, starts, ends, shape); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "The rank of input should be less than 7, but received %d.", rank)); - } -} - -template -void Tensor_Conj(const DeviceContext& dev_ctx, - const framework::Tensor& tensor, - framework::Tensor* out) { - out->Resize(tensor.dims()); - platform::ForRange out_for_range(dev_ctx, tensor.numel()); - phi::funcs::ConjFunctor out_functor( - tensor.data(), - tensor.numel(), - out->mutable_data(dev_ctx.GetPlace())); - out_for_range(out_functor); -} - -template -void Tensor_Add(const DeviceContext& dev_ctx, - const framework::Tensor& src1, - const framework::Tensor& src2, - framework::Tensor* out) { - out->Resize(src1.dims()); - out->mutable_data(dev_ctx.GetPlace()); - - phi::AddRawKernel< - T, - typename paddle::framework::ConvertToPhiContext::TYPE>( - static_cast::TYPE&>(dev_ctx), - src1, - src2, - -1, - out); -} - -template -void Tensor_Sub(const DeviceContext& dev_ctx, - const framework::Tensor& src1, - const framework::Tensor& src2, - framework::Tensor* out) { - out->Resize(src1.dims()); - out->mutable_data(dev_ctx.GetPlace()); - - phi::SubtractRawKernel< - T, - typename paddle::framework::ConvertToPhiContext::TYPE>( - static_cast::TYPE&>(dev_ctx), - src1, - src2, - -1, - out); -} - -template -void SliceCompute(const framework::ExecutionContext& ctx, - const framework::Tensor* in, - framework::Tensor* out, - const std::vector& axes_int, - const std::vector& starts_int, - const std::vector& ends_int) { - std::vector axes(axes_int.begin(), axes_int.end()); - std::vector starts(starts_int.begin(), starts_int.end()); - std::vector ends(ends_int.begin(), ends_int.end()); - - std::vector decrease_axis = {}; - std::vector infer_flags = {}; - - PADDLE_ENFORCE_EQ( - starts.size(), - axes.size(), - platform::errors::InvalidArgument( - "The size of starts must be equal to the size of axes.")); - PADDLE_ENFORCE_EQ(ends.size(), - axes.size(), - platform::errors::InvalidArgument( - "The size of ends must be equal to the size of axes.")); - - // Step 2: Compute output - - auto in_dims = in->dims(); - auto out_dims = out->dims(); - auto slice_dims = out_dims; - - // 2.1 Infer output dims - for (size_t i = 0; i < axes.size(); ++i) { - // when start == -1 && end == start+1 - if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) { - auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); - if (ret != decrease_axis.end()) { - ends[i] = in_dims[axes[i]]; - } - } - } - - phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); - slice_dims = phi::funcs::GetSliceDims( - in_dims, axes, starts, ends, nullptr, nullptr); - out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis); - - // 2.2 Get output - auto offsets = Eigen::DSizes(); - auto extents = Eigen::DSizes(); - - for (size_t i = 0; i < D; ++i) { - offsets[i] = 0; - extents[i] = slice_dims[i]; - } - for (size_t i = 0; i < axes.size(); ++i) { - offsets[axes[i]] = starts[i]; - } - - out->Resize(slice_dims); - out->mutable_data(ctx.GetPlace()); - - auto in_t = framework::EigenTensor::From(*in, in_dims); - auto out_t = framework::EigenTensor::From(*out, slice_dims); - auto& eigen_place = - *ctx.template device_context().eigen_device(); - - if (in->numel() <= Eigen::NumTraits::highest()) { - // similar to tf.slice: - // if element number less than INT_MAX, change the type of index to int - Eigen::DSizes offsets_32bit, extents_32bit; - for (size_t i = 0; i < D; i++) { - offsets_32bit[i] = offsets[i]; - extents_32bit[i] = extents[i]; - } - EigenSlice, T, D>::Eval( - eigen_place, - framework::To32BitIndex(out_t), - framework::To32BitIndex(in_t), - offsets_32bit, - extents_32bit); - } else { - EigenSlice, T, D>::Eval( - eigen_place, out_t, in_t, offsets, extents); - } - - out->Resize(out_dims); - out->mutable_data(ctx.GetPlace()); -} - -template -void Tensor_narrow(const framework::ExecutionContext& ctx, - const framework::Tensor* src, - framework::Tensor* out, - int row_s, - int row_e, - int col_s, - int col_e) { - auto rank = src->dims().size(); - std::vector axes_int = {rank - 2, rank - 1}; - std::vector starts_int = {row_s, col_s}; - std::vector ends_int = {row_e, col_e}; - switch (rank) { - case 1: - SliceCompute( - ctx, src, out, axes_int, starts_int, ends_int); - break; - case 2: - SliceCompute( - ctx, src, out, axes_int, starts_int, ends_int); - break; - case 3: - SliceCompute( - ctx, src, out, axes_int, starts_int, ends_int); - break; - case 4: - SliceCompute( - ctx, src, out, axes_int, starts_int, ends_int); - break; - case 5: - SliceCompute( - ctx, src, out, axes_int, starts_int, ends_int); - break; - case 6: - SliceCompute( - ctx, src, out, axes_int, starts_int, ends_int); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "The rank of input should be less than 7, but received %d.", rank)); - } -} - -template -void arange(const DeviceContext& dev_ctx, - framework::Tensor* tmp, - int w, - int batchsize = 1, - int h = 1) { - tmp->Resize(phi::make_ddim({batchsize * w})); - platform::CPUPlace cpu; - auto tmpdata = tmp->mutable_data(cpu); - for (int b = 0; b < batchsize; b++) { - for (int i = 0; i < w; i++) { - tmpdata[b * w + i] = static_cast(b * h + i); - } - } -} - -template -struct OneFunctor { - OneFunctor(T* output, int* idtptr, int w, int dim) - : output_(output), idtptr_(idtptr), w_(w), dim_(dim) {} - - HOSTDEVICE void operator()(size_t idx) const { - output_[w_ * idtptr_[idx] + idx % dim_] = static_cast(1); - } - - T* output_; - int* idtptr_; - int w_; - int dim_; -}; - -template -void LU_Unpack(const DeviceContext& dev_ctx, - const framework::Tensor* LU, - framework::Tensor* L, - framework::Tensor* U) { - const auto udims = LU->dims(); - L->Resize(udims); - U->Resize(udims); - const auto H = udims[udims.size() - 2]; - const auto W = udims[udims.size() - 1]; - auto L_dataptr = L->mutable_data(dev_ctx.GetPlace()); - platform::ForRange x_for_range(dev_ctx, LU->numel()); - phi::funcs::TrilTriuCompute tril_computer( - LU->data(), -1, true, H, W, L_dataptr); - x_for_range(tril_computer); - - phi::funcs::TrilTriuCompute triu_computer( - LU->data(), 0, false, H, W, U->mutable_data(dev_ctx.GetPlace())); - x_for_range(triu_computer); - - // set L's diagonal 1 - auto dim = std::min(H, W); - framework::Tensor rowtensor, rt_dev; - auto batchsize = product(phi::slice_ddim(udims, 0, udims.size() - 2)); - batchsize = std::max(static_cast(batchsize), 1); - arange(dev_ctx, &rowtensor, dim, batchsize, H); - auto idtptr = rowtensor.data(); - if (platform::is_gpu_place(dev_ctx.GetPlace())) { - framework::TensorCopy(rowtensor, dev_ctx.GetPlace(), &rt_dev); - idtptr = rt_dev.data(); - } - - platform::ForRange for_range(dev_ctx, rowtensor.numel()); - OneFunctor functor(L_dataptr, idtptr, W, dim); - for_range(functor); -} - -template -void scatterpivot(const DeviceContext& dev_ctx, - T* out_data, - framework::Tensor* idlst, - int w, - int dim) { - framework::Tensor idlst_tmp; - idlst_tmp.Resize(idlst->dims()); - idlst_tmp.mutable_data(dev_ctx.GetPlace()); - framework::TensorCopy(*idlst, dev_ctx.GetPlace(), &idlst_tmp); - auto idtptr = idlst_tmp.data(); - - platform::ForRange for_range(dev_ctx, idlst_tmp.numel()); - OneFunctor functor(out_data, idtptr, w, dim); - for_range(functor); -} - -template -void Unpack_Pivot(const DeviceContext& dev_ctx, - const framework::Tensor& Pivot, - framework::Tensor* P, - int h, - int w) { - auto dims = Pivot.dims(); - auto Pdimvec = vectorize(dims); - auto prank = Pdimvec.size(); - auto Pnum = dims[prank - 1]; - framework::Tensor Pivot_cpu; - platform::CPUPlace cpu; - framework::TensorCopy(Pivot, cpu, &Pivot_cpu); - auto pdataptr = Pivot_cpu.data(); - Pdimvec[prank - 1] = h; - Pdimvec.emplace_back(h); - auto Pdim = phi::make_ddim(Pdimvec); - P->Resize(Pdim); - auto pdata = P->mutable_data(dev_ctx.GetPlace()); - phi::funcs::SetConstant setter; - setter(dev_ctx, P, static_cast(0)); - - auto batchsize = product(phi::slice_ddim(dims, 0, prank - 1)); - batchsize = std::max(static_cast(batchsize), 1); - framework::Tensor idt; - for (int i = 0; i < batchsize; i++) { - arange(dev_ctx, &idt, h); - auto idlst = idt.data(); - for (int j = 0; j < Pnum; j++) { - if (idlst[pdataptr[i * Pnum + j] - 1] == idlst[j]) continue; - auto temp = idlst[j]; - idlst[j] = idlst[pdataptr[i * Pnum + j] - 1]; - idlst[pdataptr[i * Pnum + j] - 1] = temp; - } - scatterpivot(dev_ctx, &(pdata[i * h * h]), &idt, h, h); - } -} - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/lu_unpack_op.cc b/paddle/fluid/operators/lu_unpack_op.cc index 4c6b37ed3e55e..988cba43989e9 100644 --- a/paddle/fluid/operators/lu_unpack_op.cc +++ b/paddle/fluid/operators/lu_unpack_op.cc @@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/lu_unpack_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" + +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -42,44 +46,6 @@ class LU_UnpackOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "LU_Unpack"); - OP_INOUT_CHECK(context->HasInput("Pivots"), "Input", "Pivots", "LU_Unpack"); - OP_INOUT_CHECK(context->HasOutput("L"), "Output", "L", "LU_Unpack"); - OP_INOUT_CHECK(context->HasOutput("U"), "Output", "U", "LU_Unpack"); - OP_INOUT_CHECK(context->HasOutput("Pmat"), "Output", "Pmat", "LU_Unpack"); - bool unpack_ludata = context->Attrs().Get("unpack_ludata"); - bool unpack_pivots = context->Attrs().Get("unpack_pivots"); - - auto x_dims = context->GetInputDim("X"); - int x_rank = x_dims.size(); - PADDLE_ENFORCE_GE(x_rank, - 2, - platform::errors::InvalidArgument( - "the rank of input must greater than 2")); - - // context->SetOutputDim("Out", x_dims); - int m = x_dims[x_rank - 1]; - int n = x_dims[x_rank - 2]; - int min_mn = std::min(m, n); - if (unpack_ludata) { - auto ldims = x_dims; - auto udims = x_dims; - if (m >= n) { - udims[x_rank - 2] = min_mn; - } else { - ldims[x_rank - 1] = min_mn; - } - context->SetOutputDim("U", udims); - context->SetOutputDim("L", ldims); - } - if (unpack_pivots) { - auto pdims = x_dims; - pdims[x_rank - 1] = m; - context->SetOutputDim("Pmat", pdims); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -143,25 +109,6 @@ class LU_UnpackGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lu_unpack"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("L")), - "Input", - "L@GRAD", - "lu_unpack"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("U")), - "Input", - "U@GRAD", - "lu_unpack"); - - auto x_dims = ctx->GetInputDim("X"); - auto x_grad_name = framework::GradVarName("X"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -175,19 +122,21 @@ class LU_UnpackGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(lu_unpack, + LUUnpackInferMetaFunctor, + PD_INFER_META(phi::LUUnpackInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(lu_unpack_grad, + LUUnpackGradInferMetaFunctor, + PD_INFER_META(phi::LUUnpackGradInferMeta)); + REGISTER_OPERATOR(lu_unpack, ops::LU_UnpackOp, ops::LU_UnpackOpMaker, ops::LU_UnpackOpVarTypeInference, ops::LU_UnpackOpGradMaker, - ops::LU_UnpackOpGradMaker); + ops::LU_UnpackOpGradMaker, + LUUnpackInferMetaFunctor); REGISTER_OPERATOR(lu_unpack_grad, ops::LU_UnpackGradOp, - ops::LU_UnpackGradOpVarTypeInference); - -REGISTER_OP_CPU_KERNEL(lu_unpack, - ops::LU_UnpackKernel, - ops::LU_UnpackKernel); -REGISTER_OP_CPU_KERNEL(lu_unpack_grad, - ops::LU_UnpackGradKernel, - ops::LU_UnpackGradKernel); + ops::LU_UnpackGradOpVarTypeInference, + LUUnpackGradInferMetaFunctor); diff --git a/paddle/fluid/operators/lu_unpack_op.cu b/paddle/fluid/operators/lu_unpack_op.cu deleted file mode 100644 index 18d9c13eceea6..0000000000000 --- a/paddle/fluid/operators/lu_unpack_op.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/lu_unpack_op.h" -#include "paddle/fluid/memory/memory.h" - -namespace paddle { -namespace operators {} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(lu_unpack, - ops::LU_UnpackKernel, - ops::LU_UnpackKernel); -REGISTER_OP_CUDA_KERNEL( - lu_unpack_grad, - ops::LU_UnpackGradKernel, - ops::LU_UnpackGradKernel); diff --git a/paddle/fluid/operators/lu_unpack_op.h b/paddle/fluid/operators/lu_unpack_op.h deleted file mode 100644 index 559c13c9ee6e2..0000000000000 --- a/paddle/fluid/operators/lu_unpack_op.h +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/lu_op.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/tril_triu_compute.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensorArray = framework::LoDTensorArray; - -template -class LU_UnpackKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto xin = ctx.Input("X"); - auto P = ctx.Input("Pivots"); - - auto ltensor = ctx.Output("L"); - auto utensor = ctx.Output("U"); - auto ptensor = ctx.Output("Pmat"); - - auto unpack_ludata = ctx.Attr("unpack_ludata"); - auto unpack_pivots = ctx.Attr("unpack_pivots"); - - const auto& dev_ctx = ctx.template device_context(); - - auto xdims = xin->dims(); - int xrank = xdims.size(); - int64_t m = xdims[xrank - 2]; - int64_t n = xdims[xrank - 1]; - int64_t k = std::min(m, n); - - if (unpack_ludata) { - ltensor->mutable_data(ctx.GetPlace()); - utensor->mutable_data(ctx.GetPlace()); - - framework::Tensor L, U; - LU_Unpack(dev_ctx, xin, &L, &U); - - if (m >= n) { - framework::TensorCopy(L, ctx.GetPlace(), ltensor); - Tensor_narrow(ctx, &U, utensor, 0, k, 0, k); - } else { - framework::TensorCopy(U, ctx.GetPlace(), utensor); - Tensor_narrow(ctx, &L, ltensor, 0, k, 0, k); - } - } - - if (unpack_pivots) { - ptensor->mutable_data(ctx.GetPlace()); - Unpack_Pivot(dev_ctx, *P, ptensor, m, k); - } - } -}; - -template -class LU_UnpackGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto dl = ctx.Input(framework::GradVarName("L")); - auto du = ctx.Input(framework::GradVarName("U")); - auto dx = ctx.Output(framework::GradVarName("X")); - dx->mutable_data(ctx.GetPlace()); - - const auto& dev_ctx = ctx.template device_context(); - - framework::Tensor dl_tril, du_triu; - const auto ldims = dl->dims(); - dl_tril.Resize(ldims); - auto H = ldims[ldims.size() - 2]; - auto W = ldims[ldims.size() - 1]; - auto L_dataptr = dl_tril.mutable_data(dev_ctx.GetPlace()); - platform::ForRange l_for_range(dev_ctx, dl->numel()); - phi::funcs::TrilTriuCompute tril_computer( - dl->data(), -1, true, H, W, L_dataptr); - l_for_range(tril_computer); - - const auto udims = du->dims(); - du_triu.Resize(udims); - H = udims[udims.size() - 2]; - W = udims[udims.size() - 1]; - auto U_dataptr = du_triu.mutable_data(dev_ctx.GetPlace()); - platform::ForRange u_for_range(dev_ctx, du->numel()); - phi::funcs::TrilTriuCompute triu_computer( - du->data(), 0, false, H, W, U_dataptr); - u_for_range(triu_computer); - - auto xdims = dx->dims(); - int xrank = xdims.size(); - int64_t m = xdims[xrank - 2]; - int64_t n = xdims[xrank - 1]; - int64_t k = std::min(m, n); - - std::vector axes = {xrank - 2, xrank - 1}; - std::vector slice_starts(2, 0); - std::vector slice_ends(2, 0); - auto valuedims = vectorize(xdims); - - phi::funcs::SetConstant setter; - setter(dev_ctx, dx, static_cast(0)); - if (m <= n) { - slice_starts[0] = 0; - slice_starts[1] = 0; - slice_ends[0] = k; - slice_ends[1] = k; - valuedims[xrank - 2] = k; - valuedims[xrank - 1] = k; - SetValueCompute_dispatch(ctx, - dx, - &dl_tril, - dx, - axes, - &slice_starts, - &slice_ends, - valuedims, - xrank); - - Tensor_Add(dev_ctx, *dx, du_triu, dx); - } else { - slice_starts[0] = 0; - slice_starts[1] = 0; - slice_ends[0] = k; - slice_ends[1] = k; - valuedims[xrank - 2] = k; - valuedims[xrank - 1] = k; - SetValueCompute_dispatch(ctx, - dx, - &du_triu, - dx, - axes, - &slice_starts, - &slice_ends, - valuedims, - xrank); - - Tensor_Add(dev_ctx, *dx, dl_tril, dx); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 2b83f79055a34..99e21d1393b42 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1434,6 +1434,16 @@ func : lu backward : lu_grad +- api : lu_unpack + args : (Tensor x, Tensor pivots, bool unpack_ludata, bool unpack_pivots) + output : Tensor(pmat), Tensor(l), Tensor(u) + infer_meta : + func : LUUnpackInferMeta + kernel : + func : lu_unpack + data_type : x + backward : lu_unpack_grad + # masked_select - api : masked_select args : (Tensor x, Tensor mask) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 39addc9421d80..02c7488dcf5a9 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1254,6 +1254,15 @@ kernel : func : lu_grad +- backward_api : lu_unpack_grad + forward : lu_unpack (Tensor x, Tensor pivots, bool unpack_ludata, bool unpack_pivots) -> Tensor(pmat), Tensor(l), Tensor(u) + args : (Tensor x, Tensor pivots, Tensor l, Tensor u, Tensor pmat, Tensor l_grad, Tensor u_grad, bool unpack_ludata, bool unpack_pivots) + output : Tensor(x_grad) + infer_meta : + func : LUUnpackGradInferMeta + kernel : + func : lu_unpack_grad + - backward_api : masked_select_grad forward : masked_select (Tensor x, Tensor mask) -> Tensor(out) args : (Tensor x, Tensor mask, Tensor out_grad) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a0f71f3b689b1..26e578107206e 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -456,6 +456,24 @@ void LUGradInferMeta(const MetaTensor& x, } } +void LUUnpackGradInferMeta(const MetaTensor& x, + const MetaTensor& pivots, + const MetaTensor& l, + const MetaTensor& u, + const MetaTensor& pmat, + const MetaTensor& l_grad, + const MetaTensor& u_grad, + bool unpack_ludata, + bool unpack_pivots, + MetaTensor* x_grad) { + auto x_dims = x.dims(); + + if (x_grad) { + x_grad->set_dims(x_dims); + x_grad->set_dtype(x.dtype()); + } +} + void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const MetaTensor& mask, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index ae04764d105a8..bc89d84cf2203 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -207,6 +207,17 @@ void LUGradInferMeta(const MetaTensor& x, bool pivot, MetaTensor* x_grad); +void LUUnpackGradInferMeta(const MetaTensor& x, + const MetaTensor& pivots, + const MetaTensor& l, + const MetaTensor& u, + const MetaTensor& pmat, + const MetaTensor& l_grad, + const MetaTensor& u_grad, + bool unpack_ludata, + bool unpack_pivots, + MetaTensor* x_grad); + void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const MetaTensor& mask, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 1463296664b46..52d12d1bb0fcf 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1486,6 +1486,52 @@ void LogLossInferMeta(const MetaTensor& input, out->share_lod(input); } +void LUUnpackInferMeta(const MetaTensor& x, + const MetaTensor& pivots, + bool unpack_ludata, + bool unpack_pivots, + MetaTensor* pmat, + MetaTensor* l, + MetaTensor* u) { + PADDLE_ENFORCE_NOT_NULL( + pmat, + phi::errors::InvalidArgument("Output(Pmat) should not be nullptr.")); + PADDLE_ENFORCE_NOT_NULL( + l, phi::errors::InvalidArgument("Output(L) should not be nullptr.")); + PADDLE_ENFORCE_NOT_NULL( + u, phi::errors::InvalidArgument("Output(U) should not be nullptr.")); + + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + PADDLE_ENFORCE_GE( + x_rank, + 2, + phi::errors::InvalidArgument("The rank of input must greater than 2.")); + + int m = x_dims[x_rank - 1]; + int n = x_dims[x_rank - 2]; + int min_mn = std::min(m, n); + if (unpack_ludata) { + auto ldims = x_dims; + auto udims = x_dims; + if (m >= n) { + udims[x_rank - 2] = min_mn; + } else { + ldims[x_rank - 1] = min_mn; + } + u->set_dims(udims); + u->set_dtype(x.dtype()); + l->set_dims(ldims); + l->set_dtype(x.dtype()); + } + if (unpack_pivots) { + auto pdims = x_dims; + pdims[x_rank - 1] = m; + pmat->set_dims(pdims); + pmat->set_dtype(x.dtype()); + } +} + void MaskedSelectInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out) { diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 85851ee705d20..563936f13609b 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -225,6 +225,14 @@ void LogLossInferMeta(const MetaTensor& input, MetaTensor* out, MetaConfig config = MetaConfig()); +void LUUnpackInferMeta(const MetaTensor& x, + const MetaTensor& pivots, + bool unpack_ludata, + bool unpack_pivots, + MetaTensor* pmat, + MetaTensor* l, + MetaTensor* u); + void MaskedSelectInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/lu_unpack_grad_kernel.cc b/paddle/phi/kernels/cpu/lu_unpack_grad_kernel.cc new file mode 100644 index 0000000000000..712c43e97ef1f --- /dev/null +++ b/paddle/phi/kernels/cpu/lu_unpack_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h" +#include "paddle/phi/kernels/lu_unpack_grad_kernel.h" + +PD_REGISTER_KERNEL( + lu_unpack_grad, CPU, ALL_LAYOUT, phi::LUUnpackGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/lu_unpack_kernel.cc b/paddle/phi/kernels/cpu/lu_unpack_kernel.cc new file mode 100644 index 0000000000000..bed7da328ffac --- /dev/null +++ b/paddle/phi/kernels/cpu/lu_unpack_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/lu_unpack_kernel_impl.h" +#include "paddle/phi/kernels/lu_unpack_kernel.h" + +PD_REGISTER_KERNEL( + lu_unpack, CPU, ALL_LAYOUT, phi::LUUnpackKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/lu_unpack_grad_kernel.cu b/paddle/phi/kernels/gpu/lu_unpack_grad_kernel.cu new file mode 100644 index 0000000000000..779c4f3facaf3 --- /dev/null +++ b/paddle/phi/kernels/gpu/lu_unpack_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h" +#include "paddle/phi/kernels/lu_unpack_grad_kernel.h" + +PD_REGISTER_KERNEL( + lu_unpack_grad, GPU, ALL_LAYOUT, phi::LUUnpackGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/lu_unpack_kernel.cu b/paddle/phi/kernels/gpu/lu_unpack_kernel.cu new file mode 100644 index 0000000000000..01a9212a59303 --- /dev/null +++ b/paddle/phi/kernels/gpu/lu_unpack_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/lu_unpack_kernel_impl.h" +#include "paddle/phi/kernels/lu_unpack_kernel.h" + +PD_REGISTER_KERNEL( + lu_unpack, GPU, ALL_LAYOUT, phi::LUUnpackKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h b/paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h new file mode 100644 index 0000000000000..648e12bb26a48 --- /dev/null +++ b/paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h @@ -0,0 +1,110 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/impl/lu_kernel_impl.h" + +namespace phi { + +template +void LUUnpackGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& pivots, + const DenseTensor& l, + const DenseTensor& u, + const DenseTensor& pmat, + const DenseTensor& l_grad, + const DenseTensor& u_grad, + bool unpack_ludata, + bool unpack_pivots, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + + DenseTensor dl_tril, du_triu; + const auto ldims = l_grad.dims(); + dl_tril.Resize(ldims); + auto H = ldims[ldims.size() - 2]; + auto W = ldims[ldims.size() - 1]; + dev_ctx.template Alloc(&dl_tril); + auto L_dataptr = dl_tril.data(); + phi::funcs::ForRange l_for_range(dev_ctx, l_grad.numel()); + phi::funcs::TrilTriuCompute tril_computer( + l_grad.data(), -1, true, H, W, L_dataptr); + l_for_range(tril_computer); + + const auto udims = u_grad.dims(); + du_triu.Resize(udims); + H = udims[udims.size() - 2]; + W = udims[udims.size() - 1]; + dev_ctx.template Alloc(&du_triu); + auto U_dataptr = du_triu.data(); + phi::funcs::ForRange u_for_range(dev_ctx, u_grad.numel()); + phi::funcs::TrilTriuCompute triu_computer( + u_grad.data(), 0, false, H, W, U_dataptr); + u_for_range(triu_computer); + + auto xdims = x_grad->dims(); + int xrank = xdims.size(); + int64_t m = xdims[xrank - 2]; + int64_t n = xdims[xrank - 1]; + int64_t k = std::min(m, n); + + std::vector axes = {xrank - 2, xrank - 1}; + std::vector slice_starts(2, 0); + std::vector slice_ends(2, 0); + auto valuedims = vectorize(xdims); + + phi::funcs::SetConstant setter; + setter(dev_ctx, x_grad, static_cast(0)); + if (m <= n) { + slice_starts[0] = 0; + slice_starts[1] = 0; + slice_ends[0] = k; + slice_ends[1] = k; + valuedims[xrank - 2] = k; + valuedims[xrank - 1] = k; + SetValueCompute_dispatch(dev_ctx, + x_grad, + &dl_tril, + x_grad, + axes, + &slice_starts, + &slice_ends, + valuedims, + xrank); + + Tensor_Add(dev_ctx, *x_grad, du_triu, x_grad); + } else { + slice_starts[0] = 0; + slice_starts[1] = 0; + slice_ends[0] = k; + slice_ends[1] = k; + valuedims[xrank - 2] = k; + valuedims[xrank - 1] = k; + SetValueCompute_dispatch(dev_ctx, + x_grad, + &du_triu, + x_grad, + axes, + &slice_starts, + &slice_ends, + valuedims, + xrank); + + Tensor_Add(dev_ctx, *x_grad, dl_tril, x_grad); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/lu_unpack_kernel_impl.h b/paddle/phi/kernels/impl/lu_unpack_kernel_impl.h new file mode 100644 index 0000000000000..7e77fdd171994 --- /dev/null +++ b/paddle/phi/kernels/impl/lu_unpack_kernel_impl.h @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/impl/lu_kernel_impl.h" + +namespace phi { + +template +void LUUnpackKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& pivots, + bool unpack_ludata, + bool unpack_pivots, + DenseTensor* pmat, + DenseTensor* l, + DenseTensor* u) { + auto xdims = x.dims(); + int xrank = xdims.size(); + int64_t m = xdims[xrank - 2]; + int64_t n = xdims[xrank - 1]; + int64_t k = std::min(m, n); + + if (unpack_ludata) { + dev_ctx.template Alloc(l); + dev_ctx.template Alloc(u); + + DenseTensor L, U; + LU_Unpack(dev_ctx, &x, &L, &U); + + if (m >= n) { + phi::Copy(dev_ctx, L, dev_ctx.GetPlace(), false, l); + Tensor_narrow(dev_ctx, &U, u, 0, k, 0, k); + } else { + phi::Copy(dev_ctx, U, dev_ctx.GetPlace(), false, u); + Tensor_narrow(dev_ctx, &L, l, 0, k, 0, k); + } + } + + if (unpack_pivots) { + dev_ctx.template Alloc(pmat); + Unpack_Pivot(dev_ctx, pivots, pmat, m, k); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/lu_unpack_grad_kernel.h b/paddle/phi/kernels/lu_unpack_grad_kernel.h new file mode 100644 index 0000000000000..056f2096d96e5 --- /dev/null +++ b/paddle/phi/kernels/lu_unpack_grad_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LUUnpackGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& pivots, + const DenseTensor& l, + const DenseTensor& u, + const DenseTensor& pmat, + const DenseTensor& l_grad, + const DenseTensor& u_grad, + bool unpack_ludata, + bool unpack_pivots, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/lu_unpack_kernel.h b/paddle/phi/kernels/lu_unpack_kernel.h new file mode 100644 index 0000000000000..48acc3cc566eb --- /dev/null +++ b/paddle/phi/kernels/lu_unpack_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LUUnpackKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& pivots, + bool unpack_ludata, + bool unpack_pivots, + DenseTensor* pmat, + DenseTensor* l, + DenseTensor* u); + +} // namespace phi diff --git a/paddle/phi/ops/compat/lu_unpack_sig.cc b/paddle/phi/ops/compat/lu_unpack_sig.cc new file mode 100644 index 0000000000000..8baafe4fcb23a --- /dev/null +++ b/paddle/phi/ops/compat/lu_unpack_sig.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LUUnpackOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("lu_unpack", + {"X", "Pivots"}, + {"unpack_ludata", "unpack_pivots"}, + {"Pmat", "L", "U"}); +} + +KernelSignature LUUnpackGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("lu_unpack_grad", + {"X", "Pivots", "L", "U", "Pmat", "L@GRAD", "U@GRAD"}, + {"unpack_ludata", "unpack_pivots"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(lu_unpack, phi::LUUnpackOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(lu_unpack_grad, phi::LUUnpackGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py b/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py index 97773c70e177a..246587fba7151 100644 --- a/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py +++ b/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py @@ -120,6 +120,8 @@ def set_output(self, A): def setUp(self): self.op_type = "lu_unpack" + self.python_api = paddle.tensor.linalg.lu_unpack + self.python_out_sig = ["Pmat", "L", "U"] self.config() x = np.random.random(self.x_shape).astype(self.dtype) if paddle.in_dynamic_mode(): @@ -156,10 +158,10 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], ['L', 'U']) + self.check_grad(['X'], ['L', 'U'], check_eager=True) # m = n diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 7001b6ec57426..49fabc79c4502 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2200,6 +2200,11 @@ def lu_unpack(x, y, unpack_ludata=True, unpack_pivots=True, name=None): # one can verify : X = P @ L @ U ; """ + if in_dygraph_mode(): + P, L, U = _C_ops.final_state_lu_unpack(x, y, unpack_ludata, + unpack_pivots) + return P, L, U + if paddle.in_dynamic_mode(): P, L, U = _C_ops.lu_unpack(x, y, 'unpack_ludata', unpack_ludata, 'unpack_pivots', unpack_pivots)