diff --git a/paddle/fluid/operators/lu_op.cc b/paddle/fluid/operators/lu_op.cc index 1f569950dad52..67bc9ba4fe774 100644 --- a/paddle/fluid/operators/lu_op.cc +++ b/paddle/fluid/operators/lu_op.cc @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/lu_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" + +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -39,39 +44,6 @@ class LUOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "LU"); - OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "LU"); - bool pivots = context->Attrs().Get("pivots"); - auto x_dims = context->GetInputDim("X"); - int x_rank = x_dims.size(); - PADDLE_ENFORCE_GE(x_rank, - 2, - platform::errors::InvalidArgument( - "the rank of input must greater than 2")); - context->SetOutputDim("Out", x_dims); - int m = x_dims[x_rank - 1]; - int n = x_dims[x_rank - 2]; - int min_mn = std::min(m, n); - auto dims_vec = phi::vectorize(x_dims); - OP_INOUT_CHECK(context->HasOutput("Infos"), "Output", "Infos", "LU"); - if (x_rank == 2) { - auto Infos_dim = std::vector(1); - context->SetOutputDim("Infos", phi::make_ddim(Infos_dim)); - } else { - auto Infos_dim = - std::vector(dims_vec.begin(), dims_vec.begin() + x_rank - 2); - context->SetOutputDim("Infos", phi::make_ddim(Infos_dim)); - } - if (pivots) { - OP_INOUT_CHECK(context->HasOutput("Pivots"), "Output", "Pivots", "LU"); - auto Pivots_dim = - std::vector(dims_vec.begin(), dims_vec.begin() + x_rank - 1); - Pivots_dim[x_rank - 2] = min_mn; - context->SetOutputDim("Pivots", phi::make_ddim(Pivots_dim)); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -99,57 +71,6 @@ class LUOpVarTypeInference : public framework::VarTypeInference { } }; -template -class LUKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto pivots = ctx.Attr("pivots"); - auto *xin = ctx.Input("X"); - auto *out = ctx.Output("Out"); - auto *IpivT = ctx.Output("Pivots"); - auto *InfoT = ctx.Output("Infos"); - PADDLE_ENFORCE_EQ(pivots, - true, - platform::errors::InvalidArgument( - "lu without pivoting is not implemented on the CPU, " - "but got pivots=False")); - - math::DeviceIndependenceTensorOperations helper(ctx); - *out = helper.Transpose(*xin); - - auto outdims = out->dims(); - auto outrank = outdims.size(); - - int m = static_cast(outdims[outrank - 1]); - int n = static_cast(outdims[outrank - 2]); - int lda = std::max(1, m); - - auto ipiv_dims = phi::slice_ddim(outdims, 0, outrank - 1); - ipiv_dims[outrank - 2] = std::min(m, n); - IpivT->Resize(ipiv_dims); - auto ipiv_data = IpivT->mutable_data(ctx.GetPlace()); - - auto info_dims = phi::slice_ddim(outdims, 0, outrank - 2); - if (info_dims.size() == 0) { - info_dims = phi::make_ddim({1}); - } - InfoT->Resize(info_dims); - auto info_data = InfoT->mutable_data(ctx.GetPlace()); - - auto batchsize = product(info_dims); - batchsize = std::max(static_cast(batchsize), 1); - auto out_data = out->mutable_data(ctx.GetPlace()); - for (int b = 0; b < batchsize; b++) { - auto out_data_item = &out_data[b * m * n]; - int *info_data_item = &info_data[b]; - int *ipiv_data_item = &ipiv_data[b * std::min(m, n)]; - phi::funcs::lapackLu( - m, n, out_data_item, lda, ipiv_data_item, info_data_item); - } - *out = helper.Transpose(*out); - } -}; - template class LUOpGradMaker : public framework::SingleGradOpMaker { public: @@ -184,23 +105,6 @@ class LUGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lu"); - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "lu"); - OP_INOUT_CHECK(ctx->HasInput("Pivots"), "Input", "Pivots", "lu"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "lu"); - - auto x_dims = ctx->GetInputDim("X"); - auto x_grad_name = framework::GradVarName("X"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -219,19 +123,21 @@ DECLARE_INPLACE_OP_INFERER(LUGradOpInplaceInferer, namespace ops = paddle::operators; namespace plat = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(lu, + LUInferMetaFunctor, + PD_INFER_META(phi::LUInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(lu_grad, + LUGradInferMetaFunctor, + PD_INFER_META(phi::LUGradInferMeta)); + REGISTER_OPERATOR(lu, ops::LUOp, ops::LUOpMaker, ops::LUOpVarTypeInference, ops::LUOpGradMaker, ops::LUOpGradMaker, - ops::LUOpInplaceInferer); + LUInferMetaFunctor); REGISTER_OPERATOR(lu_grad, ops::LUGradOp, ops::LUGradOpVarTypeInference, - ops::LUGradOpInplaceInferer); - -REGISTER_OP_CPU_KERNEL(lu, ops::LUKernel, ops::LUKernel); -REGISTER_OP_CPU_KERNEL(lu_grad, - ops::LUGradKernel, - ops::LUGradKernel); + LUGradInferMetaFunctor); diff --git a/paddle/fluid/operators/lu_op.cu b/paddle/fluid/operators/lu_op.cu deleted file mode 100644 index 35194b125abed..0000000000000 --- a/paddle/fluid/operators/lu_op.cu +++ /dev/null @@ -1,194 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef PADDLE_WITH_HIP -// HIP not support cusolver - -#include "paddle/fluid/operators/lu_op.h" -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/platform/dynload/cusolver.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using CUDADeviceContext = paddle::platform::CUDADeviceContext; - -template -void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH, - int m, - int n, - T* d_A, - int lda, - int* lwork); -template -void cusolver_getrf(const cusolverDnHandle_t& cusolverH, - int m, - int n, - T* d_A, - int lda, - T* d_work, - int* d_Ipiv, - int* d_info); - -template <> -void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH, - int m, - int n, - float* d_A, - int lda, - int* lwork) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgetrf_bufferSize( - cusolverH, m, n, d_A, lda, lwork)); -} - -template <> -void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH, - int m, - int n, - double* d_A, - int lda, - int* lwork) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgetrf_bufferSize( - cusolverH, m, n, d_A, lda, lwork)); -} - -template <> -void cusolver_getrf(const cusolverDnHandle_t& cusolverH, - int m, - int n, - float* d_A, - int lda, - float* d_work, - int* d_Ipiv, - int* d_info) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgetrf( - cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info)); -} - -template <> -void cusolver_getrf(const cusolverDnHandle_t& cusolverH, - int m, - int n, - double* d_A, - int lda, - double* d_work, - int* d_Ipiv, - int* d_info) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgetrf( - cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info)); -} - -template -void lu_decomposed_kernel(int m, - int n, - T* d_A, - int lda, - int* d_Ipiv, - int* d_info, - const framework::ExecutionContext& ctx) { - /* step 1: get cusolver handle*/ - auto& dev_ctx = ctx.template device_context(); - auto cusolverH = dev_ctx.cusolver_dn_handle(); - - /* step 2: query working space of getrf */ - int lwork; - cusolver_bufferSize(cusolverH, m, n, d_A, lda, &lwork); - - auto work_buff = memory::Alloc(dev_ctx, lwork * sizeof(T)); - T* d_work = reinterpret_cast(work_buff->ptr()); - - /* step 3: LU factorization */ - if (d_Ipiv) { - cusolver_getrf(cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info); - } else { - cusolver_getrf(cusolverH, m, n, d_A, lda, d_work, NULL, d_info); - } - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); -} - -template -class LUCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { -#ifdef __HIPCC__ - const int64_t kMaxBlockDim = 256; -#else - const int64_t kMaxBlockDim = 512; -#endif - auto* xin = ctx.Input("X"); - auto* out = ctx.Output("Out"); - auto* IpivT = ctx.Output("Pivots"); - auto* InfoT = ctx.Output("Infos"); - auto pivots = ctx.Attr("pivots"); - - math::DeviceIndependenceTensorOperations< - paddle::platform::CUDADeviceContext, - T> - helper(ctx); - *out = helper.Transpose(*xin); - - auto outdims = out->dims(); - auto outrank = outdims.size(); - - int m = static_cast(outdims[outrank - 1]); - int n = static_cast(outdims[outrank - 2]); - int lda = std::max(1, m); - if (pivots) { - auto ipiv_dims = phi::slice_ddim(outdims, 0, outrank - 1); - ipiv_dims[outrank - 2] = std::min(m, n); - IpivT->Resize(ipiv_dims); - } - auto ipiv_data = IpivT->mutable_data(ctx.GetPlace()); - - auto info_dims = phi::slice_ddim(outdims, 0, outrank - 2); - if (info_dims.size() == 0) { - info_dims = phi::make_ddim({1}); - } - InfoT->Resize(info_dims); - auto info_data = InfoT->mutable_data(ctx.GetPlace()); - - auto batchsize = product(info_dims); - batchsize = std::max(static_cast(batchsize), 1); - auto out_data = out->mutable_data(ctx.GetPlace()); - for (int b = 0; b < batchsize; b++) { - auto out_data_item = &out_data[b * m * n]; - int* info_data_item = &info_data[b]; - if (pivots) { - auto ipiv_data_item = &ipiv_data[b * std::min(m, n)]; - lu_decomposed_kernel( - m, n, out_data_item, lda, ipiv_data_item, info_data_item, ctx); - } else { - lu_decomposed_kernel( - m, n, out_data_item, lda, NULL, info_data_item, ctx); - } - } - *out = helper.Transpose(*out); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(lu, - ops::LUCUDAKernel, - ops::LUCUDAKernel); -REGISTER_OP_CUDA_KERNEL(lu_grad, - ops::LUGradKernel, - ops::LUGradKernel); - -#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h index 4215074aeff01..1122937b6efe3 100644 --- a/paddle/fluid/operators/lu_op.h +++ b/paddle/fluid/operators/lu_op.h @@ -524,305 +524,5 @@ void Unpack_Pivot(const DeviceContext& dev_ctx, } } -template -class LUGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto xin = ctx.Input("X"); - auto out = ctx.Input("Out"); - auto P = ctx.Input("Pivots"); - auto dout = ctx.Input(framework::GradVarName("Out")); - auto dx = ctx.Output(framework::GradVarName("X")); - dx->mutable_data(ctx.GetPlace()); - - const auto& dev_ctx = ctx.template device_context(); - math::DeviceIndependenceTensorOperations helper(ctx); - auto blas = phi::funcs::GetBlas(ctx); - - auto xdims = xin->dims(); - int xrank = xdims.size(); - int64_t m = xdims[xrank - 2]; - int64_t n = xdims[xrank - 1]; - int64_t k = std::min(m, n); - - framework::Tensor L, U, L_narrow, U_narrow, L_narrow_mH, U_narrow_mH, - grad_narrow; - LU_Unpack(dev_ctx, out, &L, &U); - - Tensor_narrow(ctx, &L, &L_narrow, 0, k, 0, k); - Tensor_narrow(ctx, &U, &U_narrow, 0, k, 0, k); - Tensor_narrow(ctx, dout, &grad_narrow, 0, k, 0, k); - auto graddims = grad_narrow.dims(); - - Tensor_Conj(dev_ctx, L_narrow, &L_narrow_mH); - Tensor_Conj(dev_ctx, U_narrow, &U_narrow_mH); - L_narrow_mH = helper.Transpose(L_narrow_mH); - U_narrow_mH = helper.Transpose(U_narrow_mH); - - auto LmHdims = L_narrow_mH.dims(); - auto UmHdims = U_narrow_mH.dims(); - - framework::Tensor phi_L, phi_U, phi, psi; - phi_L.Resize(LmHdims); - phi_L.mutable_data(ctx.GetPlace()); - phi_U.Resize(UmHdims); - phi_U.mutable_data(ctx.GetPlace()); - auto mat_dim_l = phi::funcs::CreateMatrixDescriptor(LmHdims, 0, false); - auto mat_dim_u = phi::funcs::CreateMatrixDescriptor(UmHdims, 0, false); - auto mat_dim_g = phi::funcs::CreateMatrixDescriptor(graddims, 0, false); - blas.MatMul(L_narrow_mH, - mat_dim_l, - grad_narrow, - mat_dim_g, - static_cast(1), - &phi_L, - static_cast(0)); - - blas.MatMul(grad_narrow, - mat_dim_g, - U_narrow_mH, - mat_dim_u, - static_cast(1), - &phi_U, - static_cast(0)); - - auto phil_rank = LmHdims.size(); - auto phiu_rank = UmHdims.size(); - platform::ForRange l_for_range(dev_ctx, phi_L.numel()); - phi::funcs::TrilTriuCompute tril_computer(phi_L.data(), - -1, - true, - LmHdims[phil_rank - 2], - LmHdims[phil_rank - 1], - phi_L.data()); - l_for_range(tril_computer); - - platform::ForRange u_for_range(dev_ctx, phi_U.numel()); - phi::funcs::TrilTriuCompute triu_computer(phi_U.data(), - 0, - false, - UmHdims[phiu_rank - 2], - UmHdims[phiu_rank - 1], - phi_U.data()); - u_for_range(triu_computer); - - Tensor_Add(dev_ctx, phi_L, phi_U, &phi); - psi.Resize(xdims); - psi.mutable_data(ctx.GetPlace()); - phi::funcs::SetConstant setter; - setter(dev_ctx, &psi, static_cast(0)); - - std::vector axes = {xrank - 2, xrank - 1}; - std::vector slice_starts(2, 0); - std::vector slice_ends(2, 0); - auto valuedims = vectorize(xdims); - - framework::Tensor Pmat; - Unpack_Pivot(dev_ctx, *P, &Pmat, m, k); - - using Context = - typename framework::ConvertToPhiContext::TYPE; - auto& phi_dev_ctx = static_cast(dev_ctx); - - if (m <= n) { - if (k < n) { - framework::Tensor U_complement, U_grad_complement, phi_complement, - phi_complement_l; - Tensor_narrow(ctx, &U, &U_complement, 0, k, k, n); - Tensor_narrow( - ctx, dout, &U_grad_complement, 0, k, k, n); - framework::Tensor U_complement_mH = helper.Transpose(U_complement); - - Tensor_Conj( - dev_ctx, U_complement_mH, &U_complement_mH); - - auto mat_dim_g = phi::funcs::CreateMatrixDescriptor( - U_grad_complement.dims(), 0, false); - auto mat_dim_u = phi::funcs::CreateMatrixDescriptor( - U_complement_mH.dims(), 0, false); - auto phidims = UmHdims; - phidims[UmHdims.size() - 2] = k; - phidims[UmHdims.size() - 1] = k; - phi_complement.Resize(phidims); - phi_complement.mutable_data(ctx.GetPlace()); - blas.MatMul(U_grad_complement, - mat_dim_g, - U_complement_mH, - mat_dim_u, - static_cast(1), - &phi_complement, - static_cast(0)); - - phi_complement_l.Resize(phidims); - phi_complement_l.mutable_data(ctx.GetPlace()); - const auto H = phidims[phidims.size() - 2]; - const auto W = phidims[phidims.size() - 1]; - platform::ForRange x_for_range(dev_ctx, - phi_complement.numel()); - phi::funcs::TrilTriuCompute tril_computer( - phi_complement.data(), - -1, - true, - H, - W, - phi_complement_l.data()); - x_for_range(tril_computer); - - Tensor_Sub(dev_ctx, phi, phi_complement_l, &phi); - - slice_starts[0] = 0; - slice_starts[1] = k; - slice_ends[0] = k; - slice_ends[1] = n; - valuedims[xrank - 2] = k; - valuedims[xrank - 1] = n - k; - SetValueCompute_dispatch(ctx, - &psi, - &U_grad_complement, - &psi, - axes, - &slice_starts, - &slice_ends, - valuedims, - xrank); - } - - framework::Tensor psi_principal, phi_mH, psi_tmp; - Tensor_Conj(dev_ctx, phi, &phi_mH); - phi_mH = helper.Transpose(phi_mH); - - phi::TriangularSolveKernel( - phi_dev_ctx, U_narrow, phi_mH, true, false, false, &psi_principal); - - Tensor_Conj(dev_ctx, psi_principal, &psi_principal); - psi_principal = helper.Transpose(psi_principal); - slice_starts[0] = 0; - slice_starts[1] = 0; - slice_ends[0] = k; - slice_ends[1] = k; - valuedims[xrank - 2] = k; - valuedims[xrank - 1] = k; - - SetValueCompute_dispatch(ctx, - &psi, - &psi_principal, - &psi, - axes, - &slice_starts, - &slice_ends, - valuedims, - xrank); - - phi::TriangularSolveKernel( - phi_dev_ctx, L_narrow_mH, psi, true, false, true, &psi_tmp); - - auto mat_dim_p = - phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false); - auto mat_dim_b = - phi::funcs::CreateMatrixDescriptor(psi_tmp.dims(), 0, false); - blas.MatMul(Pmat, - mat_dim_p, - psi_tmp, - mat_dim_b, - static_cast(1), - dx, - static_cast(0)); - } else { - framework::Tensor L_complement, L_grad_complement, phi_complement, - phi_complement_u; - Tensor_narrow(ctx, &L, &L_complement, k, m, 0, k); - Tensor_narrow( - ctx, dout, &L_grad_complement, k, m, 0, k); - framework::Tensor L_complement_mH = helper.Transpose(L_complement); - Tensor_Conj(dev_ctx, L_complement_mH, &L_complement_mH); - - auto mat_dim_g = phi::funcs::CreateMatrixDescriptor( - L_grad_complement.dims(), 0, false); - auto mat_dim_u = - phi::funcs::CreateMatrixDescriptor(L_complement_mH.dims(), 0, false); - auto phidims = LmHdims; - phidims[LmHdims.size() - 2] = k; - phidims[LmHdims.size() - 1] = k; - phi_complement.Resize(phidims); - phi_complement.mutable_data(ctx.GetPlace()); - blas.MatMul(L_complement_mH, - mat_dim_u, - L_grad_complement, - mat_dim_g, - static_cast(1), - &phi_complement, - static_cast(0)); - - phi_complement_u.Resize(phidims); - phi_complement_u.mutable_data(ctx.GetPlace()); - const auto H = phidims[phidims.size() - 2]; - const auto W = phidims[phidims.size() - 1]; - platform::ForRange x_for_range(dev_ctx, - phi_complement.numel()); - phi::funcs::TrilTriuCompute triu_computer( - phi_complement.data(), 0, false, H, W, phi_complement_u.data()); - x_for_range(triu_computer); - - Tensor_Sub(dev_ctx, phi, phi_complement_u, &phi); - - slice_starts[0] = k; - slice_starts[1] = 0; - slice_ends[0] = m; - slice_ends[1] = k; - valuedims[xrank - 2] = m - k; - valuedims[xrank - 1] = k; - SetValueCompute_dispatch(ctx, - &psi, - &L_grad_complement, - &psi, - axes, - &slice_starts, - &slice_ends, - valuedims, - xrank); - framework::Tensor psi_principal, phi_mH, psi_tmp, U_narrow_mH; - - phi::TriangularSolveKernel( - phi_dev_ctx, L_narrow_mH, phi, true, false, true, &psi_principal); - - slice_starts[0] = 0; - slice_starts[1] = 0; - slice_ends[0] = k; - slice_ends[1] = k; - valuedims[xrank - 2] = k; - valuedims[xrank - 1] = k; - - SetValueCompute_dispatch(ctx, - &psi, - &psi_principal, - &psi, - axes, - &slice_starts, - &slice_ends, - valuedims, - xrank); - - psi_tmp.Resize(psi.dims()); - psi_tmp.mutable_data(ctx.GetPlace()); - auto mat_dim_p = - phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(psi.dims(), 0, false); - blas.MatMul(Pmat, - mat_dim_p, - psi, - mat_dim_b, - static_cast(1), - &psi_tmp, - static_cast(0)); - psi_tmp = helper.Transpose(psi_tmp); - - Tensor_Conj(dev_ctx, U_narrow, &U_narrow_mH); - phi::TriangularSolveKernel( - phi_dev_ctx, U_narrow_mH, psi_tmp, true, false, false, &psi); - *dx = helper.Transpose(psi); - } - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 00a68bb0c44f6..d2d1bb0ecc698 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1357,6 +1357,15 @@ func : logsumexp backward : logsumexp_grad +- api : lu + args : (Tensor x, bool pivot) + output : Tensor(out), Tensor(pivots), Tensor(infos) + infer_meta : + func : LUInferMeta + kernel : + func : lu + backward : lu_grad + # masked_select - api : masked_select args : (Tensor x, Tensor mask) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index abbb23cc253fa..7baa2b962a0c9 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1203,6 +1203,15 @@ kernel : func : logsumexp_grad +- backward_api : lu_grad + forward : lu (Tensor x, bool pivot) -> Tensor(out), Tensor(pivots), Tensor(infos) + args : (Tensor x, Tensor out, Tensor pivots, Tensor out_grad, bool pivot) + output : Tensor(x_grad) + infer_meta : + func : LUGradInferMeta + kernel : + func : lu_grad + - backward_api : masked_select_grad forward : masked_select (Tensor x, Tensor mask) -> Tensor(out) args : (Tensor x, Tensor mask, Tensor out_grad) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 1eca092a5f22f..21e165b0c0b9e 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -429,6 +429,20 @@ void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) { dx->share_lod(xshape); } +void LUGradInferMeta(const MetaTensor& x, + const MetaTensor& out, + const MetaTensor& pivots, + const MetaTensor& out_grad, + bool pivot, + MetaTensor* x_grad) { + auto x_dims = x.dims(); + + if (x_grad) { + x_grad->set_dims(x_dims); + x_grad->set_dtype(x.dtype()); + } +} + void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const MetaTensor& mask, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 5551b6bcbf183..1b1c27cc46dd3 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -195,6 +195,13 @@ void InverseGradInferMeta(const MetaTensor& out, void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx); +void LUGradInferMeta(const MetaTensor& x, + const MetaTensor& out, + const MetaTensor& pivots, + const MetaTensor& out_grad, + bool pivot, + MetaTensor* x_grad); + void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const MetaTensor& mask, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index edc455225e4dc..2f2be2fb39594 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1240,15 +1240,59 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) { out->set_dtype(x.dtype()); } +void LUInferMeta(const MetaTensor& x, + bool pivot, + MetaTensor* out, + MetaTensor* pivots, + MetaTensor* infos) { + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + + PADDLE_ENFORCE_NOT_NULL( + out, phi::errors::InvalidArgument("Output(Out) should not be nullptr.")); + PADDLE_ENFORCE_GE( + x_rank, + 2, + phi::errors::InvalidArgument("The rank of input must greater than 2.")); + out->set_dims(x_dims); + out->set_dtype(x.dtype()); + int m = x_dims[x_rank - 1]; + int n = x_dims[x_rank - 2]; + int min_mn = std::min(m, n); + auto dims_vec = phi::vectorize(x_dims); + PADDLE_ENFORCE_NOT_NULL( + infos, + phi::errors::InvalidArgument("Output(Infos) should not be nullptr.")); + if (x_rank == 2) { + auto Infos_dim = std::vector(1); + infos->set_dims(phi::make_ddim(Infos_dim)); + } else { + auto Infos_dim = + std::vector(dims_vec.begin(), dims_vec.begin() + x_rank - 2); + infos->set_dims(phi::make_ddim(Infos_dim)); + } + infos->set_dtype(DataType::INT32); + if (pivot) { + PADDLE_ENFORCE_NOT_NULL( + pivots, + phi::errors::InvalidArgument("Output(Pivots) should not be nullptr.")); + auto Pivots_dim = + std::vector(dims_vec.begin(), dims_vec.begin() + x_rank - 1); + Pivots_dim[x_rank - 2] = min_mn; + pivots->set_dims(phi::make_ddim(Pivots_dim)); + pivots->set_dtype(DataType::INT32); + } +} + void MatrixRankInferMeta(const MetaTensor& x, bool use_default_tol, bool hermitian, MetaTensor* out) { auto dim_x = x.dims(); - PADDLE_ENFORCE_GE( - dim_x.size(), - 2, - phi::errors::InvalidArgument("The dims of input must be greater than 2")); + PADDLE_ENFORCE_GE(dim_x.size(), + 2, + phi::errors::InvalidArgument( + "The dims of input must be greater than 2.")); if (hermitian) { int rows = dim_x[dim_x.size() - 2]; @@ -1279,11 +1323,11 @@ void MaxOutInferMeta(const MetaTensor& x, axis == 1 || axis == -1 || axis == 3, true, phi::errors::InvalidArgument( - "axis only supported 1, -1 or 3, but recevied axis is: %d", axis)); + "axis only supported 1, -1 or 3, but recevied axis is: %d.", axis)); PADDLE_ENFORCE_EQ(in_x_dims.size(), 4, phi::errors::InvalidArgument( - "x's dims should be 4, but received x's dims is: %d", + "x's dims should be 4, but received x's dims is: %d.", in_x_dims.size())); if (axis < 0) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 1449e8cfe197d..48be4d8e25298 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -166,6 +166,12 @@ void LogsumexpInferMeta(const MetaTensor& input, bool reduce_all, MetaTensor* out); +void LUInferMeta(const MetaTensor& x, + bool pivot, + MetaTensor* out, + MetaTensor* pivots, + MetaTensor* infos); + void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out); void MatrixRankInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/cpu/lu_grad_kernel.cc b/paddle/phi/kernels/cpu/lu_grad_kernel.cc new file mode 100644 index 0000000000000..6443fb66428d0 --- /dev/null +++ b/paddle/phi/kernels/cpu/lu_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/lu_grad_kernel_impl.h" +#include "paddle/phi/kernels/lu_grad_kernel.h" + +PD_REGISTER_KERNEL(lu_grad, CPU, ALL_LAYOUT, phi::LUGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/cpu/lu_kernel.cc b/paddle/phi/kernels/cpu/lu_kernel.cc new file mode 100644 index 0000000000000..14cbab53663a2 --- /dev/null +++ b/paddle/phi/kernels/cpu/lu_kernel.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" + +#include "paddle/phi/kernels/impl/lu_kernel_impl.h" +#include "paddle/phi/kernels/lu_kernel.h" + +namespace phi { + +template +void LUKernel(const Context& dev_ctx, + const DenseTensor& x, + bool pivot, + DenseTensor* out, + DenseTensor* pivots, + DenseTensor* infos) { + PADDLE_ENFORCE_EQ(pivot, + true, + errors::InvalidArgument( + "lu without pivoting is not implemented on the CPU, " + "but got pivots=False")); + + *out = Transpose2DTo6D(dev_ctx, x); + + auto outdims = out->dims(); + auto outrank = outdims.size(); + + int m = static_cast(outdims[outrank - 1]); + int n = static_cast(outdims[outrank - 2]); + int lda = std::max(1, m); + + auto ipiv_dims = phi::slice_ddim(outdims, 0, outrank - 1); + ipiv_dims[outrank - 2] = std::min(m, n); + pivots->Resize(ipiv_dims); + dev_ctx.template Alloc(pivots); + auto ipiv_data = pivots->data(); + + auto info_dims = phi::slice_ddim(outdims, 0, outrank - 2); + if (info_dims.size() == 0) { + info_dims = phi::make_ddim({1}); + } + infos->Resize(info_dims); + dev_ctx.template Alloc(infos); + auto info_data = infos->data(); + + auto batchsize = product(info_dims); + batchsize = std::max(static_cast(batchsize), 1); + dev_ctx.template Alloc(out); + auto out_data = out->data(); + for (int b = 0; b < batchsize; b++) { + auto out_data_item = &out_data[b * m * n]; + int* info_data_item = &info_data[b]; + int* ipiv_data_item = &ipiv_data[b * std::min(m, n)]; + phi::funcs::lapackLu( + m, n, out_data_item, lda, ipiv_data_item, info_data_item); + } + *out = Transpose2DTo6D(dev_ctx, *out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(lu, CPU, ALL_LAYOUT, phi::LUKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/lu_grad_kernel.cu b/paddle/phi/kernels/gpu/lu_grad_kernel.cu new file mode 100644 index 0000000000000..88a15606d2261 --- /dev/null +++ b/paddle/phi/kernels/gpu/lu_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/lu_grad_kernel_impl.h" +#include "paddle/phi/kernels/lu_grad_kernel.h" + +PD_REGISTER_KERNEL(lu_grad, GPU, ALL_LAYOUT, phi::LUGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/gpu/lu_kernel.cu b/paddle/phi/kernels/gpu/lu_kernel.cu new file mode 100644 index 0000000000000..7f6070a805c8a --- /dev/null +++ b/paddle/phi/kernels/gpu/lu_kernel.cu @@ -0,0 +1,185 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include "paddle/fluid/memory/malloc.h" +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/lu_kernel_impl.h" +#include "paddle/phi/kernels/lu_kernel.h" + +namespace phi { + +template +void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH, + int m, + int n, + T* d_A, + int lda, + int* lwork); +template +void cusolver_getrf(const cusolverDnHandle_t& cusolverH, + int m, + int n, + T* d_A, + int lda, + T* d_work, + int* d_Ipiv, + int* d_info); + +template <> +void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH, + int m, + int n, + float* d_A, + int lda, + int* lwork) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnSgetrf_bufferSize(cusolverH, m, n, d_A, lda, lwork)); +} + +template <> +void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH, + int m, + int n, + double* d_A, + int lda, + int* lwork) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnDgetrf_bufferSize(cusolverH, m, n, d_A, lda, lwork)); +} + +template <> +void cusolver_getrf(const cusolverDnHandle_t& cusolverH, + int m, + int n, + float* d_A, + int lda, + float* d_work, + int* d_Ipiv, + int* d_info) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSgetrf( + cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info)); +} + +template <> +void cusolver_getrf(const cusolverDnHandle_t& cusolverH, + int m, + int n, + double* d_A, + int lda, + double* d_work, + int* d_Ipiv, + int* d_info) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDgetrf( + cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info)); +} + +template +void lu_decomposed_kernel(const Context& dev_ctx, + int m, + int n, + T* d_A, + int lda, + int* d_Ipiv, + int* d_info) { + /* step 1: get cusolver handle*/ + auto cusolverH = dev_ctx.cusolver_dn_handle(); + + /* step 2: query working space of getrf */ + int lwork; + cusolver_bufferSize(cusolverH, m, n, d_A, lda, &lwork); + + auto work_buff = paddle::memory::Alloc(dev_ctx, lwork * sizeof(T)); + T* d_work = reinterpret_cast(work_buff->ptr()); + + /* step 3: LU factorization */ + if (d_Ipiv) { + cusolver_getrf(cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info); + } else { + cusolver_getrf(cusolverH, m, n, d_A, lda, d_work, NULL, d_info); + } + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +} + +template +void LUKernel(const Context& dev_ctx, + const DenseTensor& x, + bool pivot, + DenseTensor* out, + DenseTensor* pivots, + DenseTensor* infos) { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + + *out = Transpose2DTo6D(dev_ctx, x); + + auto outdims = out->dims(); + auto outrank = outdims.size(); + + int m = static_cast(outdims[outrank - 1]); + int n = static_cast(outdims[outrank - 2]); + int lda = std::max(1, m); + if (pivot) { + auto ipiv_dims = phi::slice_ddim(outdims, 0, outrank - 1); + ipiv_dims[outrank - 2] = std::min(m, n); + pivots->Resize(ipiv_dims); + } + dev_ctx.template Alloc(pivots); + auto ipiv_data = pivots->data(); + + auto info_dims = phi::slice_ddim(outdims, 0, outrank - 2); + if (info_dims.size() == 0) { + info_dims = phi::make_ddim({1}); + } + infos->Resize(info_dims); + dev_ctx.template Alloc(infos); + auto info_data = infos->data(); + + auto batchsize = product(info_dims); + batchsize = std::max(static_cast(batchsize), 1); + dev_ctx.template Alloc(out); + auto out_data = out->data(); + for (int b = 0; b < batchsize; b++) { + auto out_data_item = &out_data[b * m * n]; + int* info_data_item = &info_data[b]; + if (pivot) { + auto ipiv_data_item = &ipiv_data[b * std::min(m, n)]; + lu_decomposed_kernel( + dev_ctx, m, n, out_data_item, lda, ipiv_data_item, info_data_item); + } else { + lu_decomposed_kernel( + dev_ctx, m, n, out_data_item, lda, NULL, info_data_item); + } + } + *out = Transpose2DTo6D(dev_ctx, *out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(lu, // cuda_only + GPU, + ALL_LAYOUT, + phi::LUKernel, + float, + double) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/impl/lu_grad_kernel_impl.h b/paddle/phi/kernels/impl/lu_grad_kernel_impl.h new file mode 100644 index 0000000000000..523e33f29f8c1 --- /dev/null +++ b/paddle/phi/kernels/impl/lu_grad_kernel_impl.h @@ -0,0 +1,309 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/triangular_solve_kernel.h" + +#include "paddle/phi/kernels/impl/lu_kernel_impl.h" + +namespace phi { + +template +void LUGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& pivots, + const DenseTensor& out_grad, + bool pivot, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + + auto blas = phi::funcs::GetBlas(dev_ctx); + + auto xdims = x.dims(); + int xrank = xdims.size(); + int64_t m = xdims[xrank - 2]; + int64_t n = xdims[xrank - 1]; + int64_t k = std::min(m, n); + + DenseTensor L, U, L_narrow, U_narrow, L_narrow_mH, U_narrow_mH, grad_narrow; + LU_Unpack(dev_ctx, &out, &L, &U); + + Tensor_narrow(dev_ctx, &L, &L_narrow, 0, k, 0, k); + Tensor_narrow(dev_ctx, &U, &U_narrow, 0, k, 0, k); + Tensor_narrow(dev_ctx, &out_grad, &grad_narrow, 0, k, 0, k); + auto graddims = grad_narrow.dims(); + + Tensor_Conj(dev_ctx, L_narrow, &L_narrow_mH); + Tensor_Conj(dev_ctx, U_narrow, &U_narrow_mH); + L_narrow_mH = Transpose2DTo6D(dev_ctx, L_narrow_mH); + U_narrow_mH = Transpose2DTo6D(dev_ctx, U_narrow_mH); + + auto LmHdims = L_narrow_mH.dims(); + auto UmHdims = U_narrow_mH.dims(); + + DenseTensor phi_L, phi_U, phi, psi; + phi_L.Resize(LmHdims); + dev_ctx.template Alloc(&phi_L); + phi_U.Resize(UmHdims); + dev_ctx.template Alloc(&phi_U); + auto mat_dim_l = phi::funcs::CreateMatrixDescriptor(LmHdims, 0, false); + auto mat_dim_u = phi::funcs::CreateMatrixDescriptor(UmHdims, 0, false); + auto mat_dim_g = phi::funcs::CreateMatrixDescriptor(graddims, 0, false); + blas.MatMul(L_narrow_mH, + mat_dim_l, + grad_narrow, + mat_dim_g, + static_cast(1), + &phi_L, + static_cast(0)); + + blas.MatMul(grad_narrow, + mat_dim_g, + U_narrow_mH, + mat_dim_u, + static_cast(1), + &phi_U, + static_cast(0)); + + auto phil_rank = LmHdims.size(); + auto phiu_rank = UmHdims.size(); + phi::funcs::ForRange l_for_range(dev_ctx, phi_L.numel()); + phi::funcs::TrilTriuCompute tril_computer(phi_L.data(), + -1, + true, + LmHdims[phil_rank - 2], + LmHdims[phil_rank - 1], + phi_L.data()); + l_for_range(tril_computer); + + phi::funcs::ForRange u_for_range(dev_ctx, phi_U.numel()); + phi::funcs::TrilTriuCompute triu_computer(phi_U.data(), + 0, + false, + UmHdims[phiu_rank - 2], + UmHdims[phiu_rank - 1], + phi_U.data()); + u_for_range(triu_computer); + + Tensor_Add(dev_ctx, phi_L, phi_U, &phi); + psi.Resize(xdims); + dev_ctx.template Alloc(&psi); + phi::funcs::SetConstant setter; + setter(dev_ctx, &psi, static_cast(0)); + + std::vector axes = {xrank - 2, xrank - 1}; + std::vector slice_starts(2, 0); + std::vector slice_ends(2, 0); + auto valuedims = vectorize(xdims); + + DenseTensor Pmat; + Unpack_Pivot(dev_ctx, pivots, &Pmat, m, k); + + if (m <= n) { + if (k < n) { + DenseTensor U_complement, U_grad_complement, phi_complement, + phi_complement_l; + Tensor_narrow(dev_ctx, &U, &U_complement, 0, k, k, n); + Tensor_narrow( + dev_ctx, &out_grad, &U_grad_complement, 0, k, k, n); + DenseTensor U_complement_mH = + Transpose2DTo6D(dev_ctx, U_complement); + + Tensor_Conj(dev_ctx, U_complement_mH, &U_complement_mH); + + auto mat_dim_g = phi::funcs::CreateMatrixDescriptor( + U_grad_complement.dims(), 0, false); + auto mat_dim_u = + phi::funcs::CreateMatrixDescriptor(U_complement_mH.dims(), 0, false); + auto phidims = UmHdims; + phidims[UmHdims.size() - 2] = k; + phidims[UmHdims.size() - 1] = k; + phi_complement.Resize(phidims); + dev_ctx.template Alloc(&phi_complement); + blas.MatMul(U_grad_complement, + mat_dim_g, + U_complement_mH, + mat_dim_u, + static_cast(1), + &phi_complement, + static_cast(0)); + + phi_complement_l.Resize(phidims); + dev_ctx.template Alloc(&phi_complement_l); + const auto H = phidims[phidims.size() - 2]; + const auto W = phidims[phidims.size() - 1]; + phi::funcs::ForRange x_for_range(dev_ctx, + phi_complement.numel()); + phi::funcs::TrilTriuCompute tril_computer( + phi_complement.data(), -1, true, H, W, phi_complement_l.data()); + x_for_range(tril_computer); + + Tensor_Sub(dev_ctx, phi, phi_complement_l, &phi); + + slice_starts[0] = 0; + slice_starts[1] = k; + slice_ends[0] = k; + slice_ends[1] = n; + valuedims[xrank - 2] = k; + valuedims[xrank - 1] = n - k; + SetValueCompute_dispatch(dev_ctx, + &psi, + &U_grad_complement, + &psi, + axes, + &slice_starts, + &slice_ends, + valuedims, + xrank); + } + + DenseTensor psi_principal, phi_mH, psi_tmp; + Tensor_Conj(dev_ctx, phi, &phi_mH); + phi_mH = Transpose2DTo6D(dev_ctx, phi_mH); + + phi::TriangularSolveKernel( + dev_ctx, U_narrow, phi_mH, true, false, false, &psi_principal); + + Tensor_Conj(dev_ctx, psi_principal, &psi_principal); + psi_principal = Transpose2DTo6D(dev_ctx, psi_principal); + slice_starts[0] = 0; + slice_starts[1] = 0; + slice_ends[0] = k; + slice_ends[1] = k; + valuedims[xrank - 2] = k; + valuedims[xrank - 1] = k; + + SetValueCompute_dispatch(dev_ctx, + &psi, + &psi_principal, + &psi, + axes, + &slice_starts, + &slice_ends, + valuedims, + xrank); + + phi::TriangularSolveKernel( + dev_ctx, L_narrow_mH, psi, true, false, true, &psi_tmp); + + auto mat_dim_p = phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false); + auto mat_dim_b = + phi::funcs::CreateMatrixDescriptor(psi_tmp.dims(), 0, false); + blas.MatMul(Pmat, + mat_dim_p, + psi_tmp, + mat_dim_b, + static_cast(1), + x_grad, + static_cast(0)); + } else { + DenseTensor L_complement, L_grad_complement, phi_complement, + phi_complement_u; + Tensor_narrow(dev_ctx, &L, &L_complement, k, m, 0, k); + Tensor_narrow( + dev_ctx, &out_grad, &L_grad_complement, k, m, 0, k); + DenseTensor L_complement_mH = + Transpose2DTo6D(dev_ctx, L_complement); + Tensor_Conj(dev_ctx, L_complement_mH, &L_complement_mH); + + auto mat_dim_g = + phi::funcs::CreateMatrixDescriptor(L_grad_complement.dims(), 0, false); + auto mat_dim_u = + phi::funcs::CreateMatrixDescriptor(L_complement_mH.dims(), 0, false); + auto phidims = LmHdims; + phidims[LmHdims.size() - 2] = k; + phidims[LmHdims.size() - 1] = k; + phi_complement.Resize(phidims); + dev_ctx.template Alloc(&phi_complement); + blas.MatMul(L_complement_mH, + mat_dim_u, + L_grad_complement, + mat_dim_g, + static_cast(1), + &phi_complement, + static_cast(0)); + + phi_complement_u.Resize(phidims); + dev_ctx.template Alloc(&phi_complement_u); + const auto H = phidims[phidims.size() - 2]; + const auto W = phidims[phidims.size() - 1]; + phi::funcs::ForRange x_for_range(dev_ctx, phi_complement.numel()); + phi::funcs::TrilTriuCompute triu_computer( + phi_complement.data(), 0, false, H, W, phi_complement_u.data()); + x_for_range(triu_computer); + + Tensor_Sub(dev_ctx, phi, phi_complement_u, &phi); + + slice_starts[0] = k; + slice_starts[1] = 0; + slice_ends[0] = m; + slice_ends[1] = k; + valuedims[xrank - 2] = m - k; + valuedims[xrank - 1] = k; + SetValueCompute_dispatch(dev_ctx, + &psi, + &L_grad_complement, + &psi, + axes, + &slice_starts, + &slice_ends, + valuedims, + xrank); + DenseTensor psi_principal, phi_mH, psi_tmp, U_narrow_mH; + + phi::TriangularSolveKernel( + dev_ctx, L_narrow_mH, phi, true, false, true, &psi_principal); + + slice_starts[0] = 0; + slice_starts[1] = 0; + slice_ends[0] = k; + slice_ends[1] = k; + valuedims[xrank - 2] = k; + valuedims[xrank - 1] = k; + + SetValueCompute_dispatch(dev_ctx, + &psi, + &psi_principal, + &psi, + axes, + &slice_starts, + &slice_ends, + valuedims, + xrank); + + psi_tmp.Resize(psi.dims()); + dev_ctx.template Alloc(&psi_tmp); + auto mat_dim_p = phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(psi.dims(), 0, false); + blas.MatMul(Pmat, + mat_dim_p, + psi, + mat_dim_b, + static_cast(1), + &psi_tmp, + static_cast(0)); + psi_tmp = Transpose2DTo6D(dev_ctx, psi_tmp); + + Tensor_Conj(dev_ctx, U_narrow, &U_narrow_mH); + phi::TriangularSolveKernel( + dev_ctx, U_narrow_mH, psi_tmp, true, false, false, &psi); + *x_grad = Transpose2DTo6D(dev_ctx, psi); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/lu_kernel_impl.h b/paddle/phi/kernels/impl/lu_kernel_impl.h new file mode 100644 index 0000000000000..ed3cc0801d9af --- /dev/null +++ b/paddle/phi/kernels/impl/lu_kernel_impl.h @@ -0,0 +1,563 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" +#include "paddle/phi/kernels/funcs/tril_triu_compute.h" +#include "paddle/phi/kernels/impl/set_value_kernel_impl.h" + +namespace phi { + +template +using SubFunctor = phi::funcs::SubtractFunctor; + +template +void SetValueCompute(const Context& dev_ctx, + DenseTensor* in, + DenseTensor* value_tensor, + DenseTensor* out, + const std::vector& axes, + std::vector* starts, + std::vector* ends, + const std::vector& shape) { + std::vector steps = {1, 1}; + std::vector decrease_axes = {}; + std::vector none_axes = {}; + + auto dtype = in->dtype(); + + auto in_dims = in->dims(); + phi::funcs::CheckAndUpdateSliceAttrs( + in_dims, axes, starts, ends, &steps); + auto slice_dims = + phi::funcs::GetSliceDims(in_dims, axes, *starts, *ends, &steps); + auto decrease_slice_dims = + phi::funcs::GetDecreasedDims(slice_dims, decrease_axes); + + auto slice_dims_for_assign = decrease_slice_dims; + if (!none_axes.empty()) { + std::vector slice_dims_with_none; + + size_t none_axes_cur = 0, decrease_axes_cur = 0; + for (int i = 0; i < slice_dims.size(); ++i) { + while (none_axes_cur < none_axes.size() && + none_axes[none_axes_cur] <= i) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + if (decrease_axes_cur < decrease_axes.size() && + decrease_axes[decrease_axes_cur] == i) { + decrease_axes_cur++; + } else { + slice_dims_with_none.push_back(slice_dims[i]); + } + } + while (none_axes_cur < none_axes.size()) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + + slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); + } + + auto place = dev_ctx.GetPlace(); + auto& eigen_place = *dev_ctx.eigen_device(); + + // Here copy data from input to avoid data loss at PE and Graph level. + // TODO(liym27): Speed up in the future version. + // - Q: Why don't call ShareDataWith to speed up? + // - A: Because it's not supported to ShareDataWith on OP's input and output + // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP + // - Q: Why don't delete Input, after all, the input and output are the same + // Tensor at program level? + // - A: If deleting Input, the graph will be complex, such as there will + // be two ops points to the output in graph: op1 -> output <- set_value. + // In this case, we have to find a way to handle the running order of + // set_value is what we want. + phi::Copy(dev_ctx, *in, place, false, out); + + DenseTensor slice_tensor(dtype), pad_tensor(dtype); + slice_tensor.Resize(slice_dims); + dev_ctx.template Alloc(&slice_tensor); + pad_tensor.Resize(in_dims); + dev_ctx.template Alloc(&pad_tensor); + + auto pad_e = EigenTensor::From(pad_tensor, in_dims); + auto out_e = EigenTensor::From(*out); + auto slice_e = EigenTensor::From(slice_tensor, slice_dims); + + // Step 1: Set the value of out at `_index` to zero + slice_e.device(eigen_place) = slice_e.constant(T(0)); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + + for (size_t i = 0; i < D; ++i) { + starts_indices[i] = 0; + ends_indices[i] = slice_dims[i]; + strides_indices[i] = 1; + } + for (size_t i = 0; i < axes.size(); i++) { + int axis_index = axes[i]; + starts_indices[axis_index] = (*starts)[i]; + ends_indices[axis_index] = (*ends)[i]; + strides_indices[axis_index] = steps[i]; + if ((*starts)[i] == + (*ends)[i]) { // slice is empty, data will not be changed + return; + } + } + + out_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; + + // Step 2: Set a tensor with the same shape as out tensor. And its data at + // '_index' is the same as value_tensor, and data out of '_index' to zero + + // - Step 2.1 Set slice tensor with value + + // NOTE(liym27): [ Why resize slice_tensor here? ] + // A: When do broadcasting on slice_tensor and value_tensor, the shape of + // slice_tensor should be decreased dims. + // e.g. + // x[:,0] = value_tensor + // x's shape = [3, 4], value_tensor's shape = [3] + // We get slice_dims = [3, 1], decrease_slice_dims = [3] + // If do broadcasting on Tensor with shape [3, 1] and [3], the result's + // shape is [3, 3], which cross the border; + // If do broadcasting on Tensor with shape [3] and [3], the result's shape + // is [3], which is right. + + slice_tensor.Resize(slice_dims_for_assign); + if (value_tensor != nullptr) { + CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims()); + phi::funcs::ElementwiseCompute, T, T>(dev_ctx, + slice_tensor, + *value_tensor, + -1, + SubFunctor(), + &slice_tensor); + } else { + DenseTensor value_t(dtype); + auto value_dims = phi::make_ddim(shape); + CheckIsDimsMatch(slice_dims_for_assign, value_dims); + + value_t.Resize(value_dims); + dev_ctx.template Alloc(&value_t); + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, slice_tensor, value_t, -1, SubFunctor(), &slice_tensor); + } + slice_tensor.Resize(slice_dims); + + // - Step 2.2 Pad slice tensor with 0 + pad_e.device(eigen_place) = pad_e.constant(T(0)); + pad_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; + + // Step 3: Set out tensor with value_tensor + out_e.device(eigen_place) = out_e - pad_e; +} + +template +void SetValueCompute_dispatch(const Context& dev_ctx, + DenseTensor* in, + DenseTensor* value_tensor, + DenseTensor* out, + const std::vector& axes, + std::vector* starts, + std::vector* ends, + const std::vector& shape, + int rank) { + switch (rank) { + case 1: + SetValueCompute( + dev_ctx, in, value_tensor, out, axes, starts, ends, shape); + break; + case 2: + SetValueCompute( + dev_ctx, in, value_tensor, out, axes, starts, ends, shape); + break; + case 3: + SetValueCompute( + dev_ctx, in, value_tensor, out, axes, starts, ends, shape); + break; + case 4: + SetValueCompute( + dev_ctx, in, value_tensor, out, axes, starts, ends, shape); + break; + case 5: + SetValueCompute( + dev_ctx, in, value_tensor, out, axes, starts, ends, shape); + break; + case 6: + SetValueCompute( + dev_ctx, in, value_tensor, out, axes, starts, ends, shape); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); + } +} + +template +void Tensor_Conj(const Context& dev_ctx, + const DenseTensor& tensor, + DenseTensor* out) { + out->Resize(tensor.dims()); + phi::funcs::ForRange out_for_range(dev_ctx, tensor.numel()); + dev_ctx.template Alloc(out); + phi::funcs::ConjFunctor out_functor( + tensor.data(), tensor.numel(), out->data()); + out_for_range(out_functor); +} + +template +void Tensor_Add(const Context& dev_ctx, + const DenseTensor& src1, + const DenseTensor& src2, + DenseTensor* out) { + out->Resize(src1.dims()); + dev_ctx.template Alloc(out); + + phi::AddRawKernel(dev_ctx, src1, src2, -1, out); +} + +template +void Tensor_Sub(const Context& dev_ctx, + const DenseTensor& src1, + const DenseTensor& src2, + DenseTensor* out) { + out->Resize(src1.dims()); + dev_ctx.template Alloc(out); + + phi::SubtractRawKernel(dev_ctx, src1, src2, -1, out); +} + +template +void SliceCompute(const Context& dev_ctx, + const DenseTensor* in, + DenseTensor* out, + const std::vector& axes_int, + const std::vector& starts_int, + const std::vector& ends_int) { + std::vector axes(axes_int.begin(), axes_int.end()); + std::vector starts(starts_int.begin(), starts_int.end()); + std::vector ends(ends_int.begin(), ends_int.end()); + + std::vector decrease_axis = {}; + std::vector infer_flags = {}; + + PADDLE_ENFORCE_EQ( + starts.size(), + axes.size(), + phi::errors::InvalidArgument( + "The size of starts must be equal to the size of axes.")); + PADDLE_ENFORCE_EQ(ends.size(), + axes.size(), + phi::errors::InvalidArgument( + "The size of ends must be equal to the size of axes.")); + + // Step 2: Compute output + + auto in_dims = in->dims(); + auto out_dims = out->dims(); + auto slice_dims = out_dims; + + // 2.1 Infer output dims + for (size_t i = 0; i < axes.size(); ++i) { + // when start == -1 && end == start+1 + if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) { + auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); + if (ret != decrease_axis.end()) { + ends[i] = in_dims[axes[i]]; + } + } + } + + phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + slice_dims = phi::funcs::GetSliceDims( + in_dims, axes, starts, ends, nullptr, nullptr); + out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis); + + // 2.2 Get output + auto offsets = Eigen::DSizes(); + auto extents = Eigen::DSizes(); + + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = slice_dims[i]; + } + for (size_t i = 0; i < axes.size(); ++i) { + offsets[axes[i]] = starts[i]; + } + + out->Resize(slice_dims); + dev_ctx.template Alloc(out); + + auto in_t = EigenTensor::From(*in, in_dims); + auto out_t = EigenTensor::From(*out, slice_dims); + auto& eigen_place = *dev_ctx.eigen_device(); + + if (in->numel() <= Eigen::NumTraits::highest()) { + // similar to tf.slice: + // if element number less than INT_MAX, change the type of index to int + Eigen::DSizes offsets_32bit, extents_32bit; + for (size_t i = 0; i < D; i++) { + offsets_32bit[i] = offsets[i]; + extents_32bit[i] = extents[i]; + } + funcs::EigenSlice, T, D>::Eval( + eigen_place, + To32BitIndex(out_t), + To32BitIndex(in_t), + offsets_32bit, + extents_32bit); + } else { + funcs::EigenSlice, T, D>::Eval( + eigen_place, out_t, in_t, offsets, extents); + } + + out->Resize(out_dims); + dev_ctx.template Alloc(out); +} + +template +void Tensor_narrow(const Context& dev_ctx, + const DenseTensor* src, + DenseTensor* out, + int row_s, + int row_e, + int col_s, + int col_e) { + auto rank = src->dims().size(); + std::vector axes_int = {rank - 2, rank - 1}; + std::vector starts_int = {row_s, col_s}; + std::vector ends_int = {row_e, col_e}; + switch (rank) { + case 1: + SliceCompute( + dev_ctx, src, out, axes_int, starts_int, ends_int); + break; + case 2: + SliceCompute( + dev_ctx, src, out, axes_int, starts_int, ends_int); + break; + case 3: + SliceCompute( + dev_ctx, src, out, axes_int, starts_int, ends_int); + break; + case 4: + SliceCompute( + dev_ctx, src, out, axes_int, starts_int, ends_int); + break; + case 5: + SliceCompute( + dev_ctx, src, out, axes_int, starts_int, ends_int); + break; + case 6: + SliceCompute( + dev_ctx, src, out, axes_int, starts_int, ends_int); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); + } +} + +template +void arange(const Context& dev_ctx, + DenseTensor* tmp, + int w, + int batchsize = 1, + int h = 1) { + tmp->Resize(phi::make_ddim({batchsize * w})); + dev_ctx.template HostAlloc(tmp); + auto tmpdata = tmp->data(); + for (int b = 0; b < batchsize; b++) { + for (int i = 0; i < w; i++) { + tmpdata[b * w + i] = static_cast(b * h + i); + } + } +} + +template +struct OneFunctor { + OneFunctor(T* output, int* idtptr, int w, int dim) + : output_(output), idtptr_(idtptr), w_(w), dim_(dim) {} + + HOSTDEVICE void operator()(size_t idx) const { + output_[w_ * idtptr_[idx] + idx % dim_] = static_cast(1); + } + + T* output_; + int* idtptr_; + int w_; + int dim_; +}; + +template +void LU_Unpack(const Context& dev_ctx, + const DenseTensor* LU, + DenseTensor* L, + DenseTensor* U) { + const auto udims = LU->dims(); + L->Resize(udims); + U->Resize(udims); + const auto H = udims[udims.size() - 2]; + const auto W = udims[udims.size() - 1]; + dev_ctx.template Alloc(L); + auto L_dataptr = L->data(); + phi::funcs::ForRange x_for_range(dev_ctx, LU->numel()); + phi::funcs::TrilTriuCompute tril_computer( + LU->data(), -1, true, H, W, L_dataptr); + x_for_range(tril_computer); + + dev_ctx.template Alloc(U); + phi::funcs::TrilTriuCompute triu_computer( + LU->data(), 0, false, H, W, U->data()); + x_for_range(triu_computer); + + // set L's diagonal 1 + auto dim = std::min(H, W); + DenseTensor rowtensor, rt_dev; + auto batchsize = product(phi::slice_ddim(udims, 0, udims.size() - 2)); + batchsize = std::max(static_cast(batchsize), 1); + arange(dev_ctx, &rowtensor, dim, batchsize, H); + auto idtptr = rowtensor.data(); + if (phi::AllocationType::GPU == dev_ctx.GetPlace().GetType()) { + phi::Copy(dev_ctx, rowtensor, dev_ctx.GetPlace(), false, &rt_dev); + idtptr = rt_dev.data(); + } + + phi::funcs::ForRange for_range(dev_ctx, rowtensor.numel()); + OneFunctor functor(L_dataptr, idtptr, W, dim); + for_range(functor); +} + +template +void scatterpivot( + const Context& dev_ctx, T* out_data, DenseTensor* idlst, int w, int dim) { + DenseTensor idlst_tmp; + idlst_tmp.Resize(idlst->dims()); + dev_ctx.template Alloc(&idlst_tmp); + phi::Copy(dev_ctx, *idlst, dev_ctx.GetPlace(), false, &idlst_tmp); + auto idtptr = idlst_tmp.data(); + + phi::funcs::ForRange for_range(dev_ctx, idlst_tmp.numel()); + OneFunctor functor(out_data, idtptr, w, dim); + for_range(functor); +} + +template +void Unpack_Pivot(const Context& dev_ctx, + const DenseTensor& Pivot, + DenseTensor* P, + int h, + int w) { + auto dims = Pivot.dims(); + auto Pdimvec = vectorize(dims); + auto prank = Pdimvec.size(); + auto Pnum = dims[prank - 1]; + DenseTensor Pivot_cpu; + phi::CPUPlace cpu; + phi::Copy(dev_ctx, Pivot, cpu, false, &Pivot_cpu); + auto pdataptr = Pivot_cpu.data(); + Pdimvec[prank - 1] = h; + Pdimvec.emplace_back(h); + auto Pdim = phi::make_ddim(Pdimvec); + P->Resize(Pdim); + dev_ctx.template Alloc(P); + auto pdata = P->data(); + phi::funcs::SetConstant setter; + setter(dev_ctx, P, static_cast(0)); + + auto batchsize = product(phi::slice_ddim(dims, 0, prank - 1)); + batchsize = std::max(static_cast(batchsize), 1); + DenseTensor idt; + for (int i = 0; i < batchsize; i++) { + arange(dev_ctx, &idt, h); + auto idlst = idt.data(); + for (int j = 0; j < Pnum; j++) { + if (idlst[pdataptr[i * Pnum + j] - 1] == idlst[j]) continue; + auto temp = idlst[j]; + idlst[j] = idlst[pdataptr[i * Pnum + j] - 1]; + idlst[pdataptr[i * Pnum + j] - 1] = temp; + } + scatterpivot(dev_ctx, &(pdata[i * h * h]), &idt, h, h); + } +} + +template +DenseTensor Transpose2DTo6D(const Context& dev_ctx, const DenseTensor& x) { + // transpose the last two dimision + DenseTensor ret; + auto x_dim = x.dims(); + auto x_vec = phi::vectorize(x_dim); + int rank = x_vec.size(); + std::swap(x_vec[rank - 1], x_vec[rank - 2]); + std::vector out_shape = x_vec; + std::vector axis(rank); + for (int i = 0; i < rank; ++i) { + axis[i] = i; + } + std::swap(axis[rank - 1], axis[rank - 2]); + ret.Resize(phi::make_ddim(x_vec)); + dev_ctx.template Alloc(&ret); + switch (rank) { + case 2: { + phi::funcs::Transpose trans; + trans(dev_ctx, x, &ret, axis); + break; + } + case 3: { + phi::funcs::Transpose trans; + trans(dev_ctx, x, &ret, axis); + break; + } + case 4: { + phi::funcs::Transpose trans; + trans(dev_ctx, x, &ret, axis); + break; + } + case 5: { + phi::funcs::Transpose trans; + trans(dev_ctx, x, &ret, axis); + break; + } + case 6: { + phi::funcs::Transpose trans; + trans(dev_ctx, x, &ret, axis); + break; + } + default: { + PADDLE_THROW(phi::errors::InvalidArgument( + "Invalid Rank number, " + "currently only support rank between 2~6")); + } + } + return ret; +} + +} // namespace phi diff --git a/paddle/phi/kernels/lu_grad_kernel.h b/paddle/phi/kernels/lu_grad_kernel.h new file mode 100644 index 0000000000000..23d4ac26840fa --- /dev/null +++ b/paddle/phi/kernels/lu_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LUGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& pivots, + const DenseTensor& out_grad, + bool pivot, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/lu_kernel.h b/paddle/phi/kernels/lu_kernel.h new file mode 100644 index 0000000000000..636bfd0683f7f --- /dev/null +++ b/paddle/phi/kernels/lu_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LUKernel(const Context& dev_ctx, + const DenseTensor& x, + bool pivot, + DenseTensor* out, + DenseTensor* pivots, + DenseTensor* infos); + +} // namespace phi diff --git a/paddle/phi/ops/compat/lu_sig.cc b/paddle/phi/ops/compat/lu_sig.cc new file mode 100644 index 0000000000000..84ae337cae642 --- /dev/null +++ b/paddle/phi/ops/compat/lu_sig.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LUOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("lu", {"X"}, {"pivots"}, {"Out", "Pivots", "Infos"}); +} + +KernelSignature LUGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "lu_grad", {"X", "Out", "Pivots", "Out@GRAD"}, {"pivots"}, {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(lu, phi::LUOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(lu_grad, phi::LUGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_lu_op.py b/python/paddle/fluid/tests/unittests/test_lu_op.py index 2989a0307400a..414dc66f84176 100644 --- a/python/paddle/fluid/tests/unittests/test_lu_op.py +++ b/python/paddle/fluid/tests/unittests/test_lu_op.py @@ -128,6 +128,8 @@ def set_output(self): def setUp(self): self.op_type = "lu" + self.python_api = paddle.tensor.linalg.lu + self.python_out_sig = ["Out", "Pivots"] self.config() self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)} @@ -140,10 +142,10 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], ['Out']) + self.check_grad(['X'], ['Out'], check_eager=True) # m = n 2D diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index a77d6b5a2ad92..01c36a22c3757 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2099,27 +2099,27 @@ def lu(x, pivot=True, get_infos=False, name=None): # one can verify : X = P @ L @ U ; """ - if paddle.in_dynamic_mode(): - LU, Piv, Info = _C_ops.lu(x, 'pivots', pivot) - if get_infos: - return LU, Piv, Info - else: - return LU, Piv - check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'lu') - helper = LayerHelper('lu', **locals()) - lu = helper.create_variable_for_type_inference(dtype=x.dtype) - p = helper.create_variable_for_type_inference(dtype='int') - info = helper.create_variable_for_type_inference(dtype='int') - attrs = dict() - attrs['pivots'] = pivot - helper.append_op(type='lu', - inputs={'X': x}, - outputs={ - 'Out': lu, - 'Pivots': p, - 'Infos': info - }, - attrs=attrs) + + if in_dygraph_mode(): + lu, p, info = _C_ops.final_state_lu(x, pivot) + elif paddle.in_dynamic_mode(): + lu, p, info = _C_ops.lu(x, 'pivot', pivot) + else: + check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'lu') + helper = LayerHelper('lu', **locals()) + lu = helper.create_variable_for_type_inference(dtype=x.dtype) + p = helper.create_variable_for_type_inference(dtype='int') + info = helper.create_variable_for_type_inference(dtype='int') + attrs = dict() + attrs['pivot'] = pivot + helper.append_op(type='lu', + inputs={'X': x}, + outputs={ + 'Out': lu, + 'Pivots': p, + 'Infos': info + }, + attrs=attrs) if get_infos: return lu, p, info else: