From 9e4cf0d770919fdf57ce6a9e3e00c0e56340e3b9 Mon Sep 17 00:00:00 2001 From: Zjq9409 <15205085056@163.com> Date: Wed, 22 Sep 2021 08:43:45 +0000 Subject: [PATCH 1/5] CPU forward calculation replaces Eigen with Lapack --- .../operators/math/eigen_values_vectors.h | 167 +++++++----------- .../fluid/operators/math/lapack_function.cc | 47 ++++- paddle/fluid/operators/math/lapack_function.h | 7 +- paddle/fluid/platform/dynload/lapack.h | 21 ++- 4 files changed, 135 insertions(+), 107 deletions(-) diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 3c793c8906e6e..980a0d6684b32 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -16,6 +16,7 @@ #include "Eigen/Core" #include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/math/lapack_function.h" #include "paddle/fluid/operators/svd_helper.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/dynload/cusolver.h" @@ -25,84 +26,6 @@ namespace paddle { namespace operators { namespace math { -template -using InputMatrixMap = Eigen::Map< - const Eigen::Matrix>; - -template -using OutputMatrixMap = Eigen::Map< - Eigen::Matrix>; - -template -inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data, - ValueType *eigenvalues_data, - ValueType *eigenvectors_data, - int batches, int rows, int cols, - bool has_vectors) { - int stride = rows * cols; - for (int i = 0; i < batches; i++) { - auto m = InputMatrixMap(x_data + i * stride, rows, cols); - auto eigenvalues = - OutputMatrixMap(eigenvalues_data + i * rows, 1, rows); - auto eigenvectors = - OutputMatrixMap(eigenvectors_data + i * stride, rows, cols); - - Eigen::SelfAdjointEigenSolver> - eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors - : Eigen::EigenvaluesOnly); - PADDLE_ENFORCE_EQ( - eigen_solver.info(), Eigen::Success, - platform::errors::InvalidArgument( - "Self Adjoint Eigen decomposition is not successful. " - "The %d-th input matrice might not be not be positive definite.", - i)); - - eigenvalues = eigen_solver.eigenvalues().transpose(); - if (has_vectors) { - eigenvectors = eigen_solver.eigenvectors(); - } - } -} - -template -inline void ComputeComplexEigenvaluesAndVectors(T *x_data, - ValueType *eigenvalues_data, - T *eigenvectors_data, - int batches, int rows, int cols, - bool has_vectors) { - using Complex = std::complex; - Complex *input = reinterpret_cast(x_data); - Complex *eigenvectors_data_ = reinterpret_cast(eigenvectors_data); - - int stride = rows * cols; - for (int i = 0; i < batches; i++) { - auto m = InputMatrixMap(input + i * stride, rows, cols); - auto eigenvalues = - OutputMatrixMap(eigenvalues_data + i * rows, 1, rows); - auto eigenvectors = - OutputMatrixMap(eigenvectors_data_ + i * stride, rows, cols); - - Eigen::SelfAdjointEigenSolver< - Eigen::Matrix> - eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors - : Eigen::EigenvaluesOnly); - PADDLE_ENFORCE_EQ( - eigen_solver.info(), Eigen::Success, - platform::errors::InvalidArgument( - "Self Adjoint Eigen decomposition is not successful. " - "The %d-th input matrice might not be not be positive definite.", - i)); - - eigenvalues = eigen_solver.eigenvalues().transpose(); - if (has_vectors) { - eigenvectors = eigen_solver.eigenvectors(); - } - } -} - inline int64_t GetBatchSize(framework::DDim dims) { int64_t batch_size = 1; auto dim_size = dims.size(); @@ -128,37 +51,75 @@ struct MatrixEighFunctor { void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, bool has_vectors) { - auto dims = input.dims(); - auto output_value_dim = eigen_values->dims(); + auto *out_value = eigen_values->mutable_data(ctx.GetPlace()); + auto *out_vector = eigen_vectors->mutable_data(ctx.GetPlace()); - int64_t batch_size = 1; + auto dims = input.dims(); int dim_size = dims.size(); - for (int64_t i = 0; i < dim_size - 2; i++) { - batch_size *= dims[i]; - } + int64_t batch_size = GetBatchSize(dims); + auto dito = - DeviceIndependenceTensorOperations(ctx); - Tensor input_tensor; - TensorCopy(input, ctx.GetPlace(), &input_tensor); - if (!is_lower) { - input_tensor = dito.Transpose(input); + math::DeviceIndependenceTensorOperations( + ctx); + Tensor output_v_var_trans = dito.Transpose(input); + TensorCopy(output_v_var_trans, ctx.GetPlace(), eigen_vectors); + + int vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; + int values_stride = dims[dim_size - 1]; + char uplo = is_lower ? 'L' : 'U'; + char jobz = has_vectors ? 'V' : 'N'; + auto n = dims[dim_size - 1]; + auto lda = std::max(1, n); + + int lwork = -1; + int lrwork = -1; + int liwork = -1; + int iwork_buffer = -1; + T lwork_buffer = static_cast(-1); + ValueType rwork_buffer = static_cast(-1); + + Tensor info_tensor; + auto *infos_data = info_tensor.mutable_data( + framework::make_ddim({batch_size}), ctx.GetPlace()); + + math::lapackEvd(jobz, uplo, n, out_vector, lda, out_value, + &lwork_buffer, lwork, &rwork_buffer, lrwork, + &iwork_buffer, liwork, infos_data); + + lwork = std::max(1, static_cast(lwork_buffer)); + liwork = std::max(1, iwork_buffer); + + Tensor rwork_tensor; + ValueType *rwork_data = nullptr; + + // complex type + if (framework::IsComplexType(eigen_vectors->type())) { + lrwork = std::max(1, static_cast(rwork_buffer)); + rwork_data = rwork_tensor.mutable_data( + framework::make_ddim({lrwork}), ctx.GetPlace()); } - int rows = dims[dims.size() - 2]; - auto *value_data = - eigen_values->mutable_data(output_value_dim, ctx.GetPlace()); + Tensor iwork_tensor, work_tensor; + auto *iwork_data = iwork_tensor.mutable_data( + framework::make_ddim({liwork}), ctx.GetPlace()); + auto *work_data = work_tensor.mutable_data(framework::make_ddim({lwork}), + ctx.GetPlace()); - if (framework::IsComplexType(input_tensor.type())) { - auto *x_data = input_tensor.data(); - auto *vector_data = eigen_vectors->mutable_data(dims, ctx.GetPlace()); - ComputeComplexEigenvaluesAndVectors( - x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); - } else { - auto *x_data = input_tensor.data(); - auto *vector_data = - eigen_vectors->mutable_data(dims, ctx.GetPlace()); - ComputeFloatEigenvaluesAndVectors( - x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); + for (auto i = 0; i < batch_size; i++) { + auto *value_data = out_value + i * values_stride; + auto *vector_data = out_vector + i * vector_stride; + int *info_ptr = &infos_data[i]; + math::lapackEvd(jobz, uplo, n, vector_data, lda, value_data, + work_data, lwork, rwork_data, lrwork, + iwork_data, liwork, info_ptr); + PADDLE_ENFORCE_EQ( + *info_ptr, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: the [%d] argument had an illegal value", i, + *info_ptr)); + } + if (has_vectors) { + *eigen_vectors = dito.Transpose(*eigen_vectors); } } }; diff --git a/paddle/fluid/operators/math/lapack_function.cc b/paddle/fluid/operators/math/lapack_function.cc index 54033a444a654..d4e5cd5de78af 100644 --- a/paddle/fluid/operators/math/lapack_function.cc +++ b/paddle/fluid/operators/math/lapack_function.cc @@ -21,15 +21,58 @@ namespace math { // LU (for example) template <> -void lapackLu(int m, int n, double *a, int lda, int *ipiv, int *info) { +void lapackLu(int m, int n, double* a, int lda, int* ipiv, int* info) { platform::dynload::dgetrf_(&m, &n, a, &lda, ipiv, info); } template <> -void lapackLu(int m, int n, float *a, int lda, int *ipiv, int *info) { +void lapackLu(int m, int n, float* a, int lda, int* ipiv, int* info) { platform::dynload::sgetrf_(&m, &n, a, &lda, ipiv, info); } +template <> +void lapackEvd(char jobz, char uplo, int n, float* a, int lda, + float* w, float* work, int lwork, float* rwork, + int lrwork, int* iwork, int liwork, int* info) { + (void)rwork; // unused + (void)lrwork; // unused + platform::dynload::ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, + &liwork, info); +} + +template <> +void lapackEvd(char jobz, char uplo, int n, double* a, int lda, + double* w, double* work, int lwork, + double* rwork, int lrwork, int* iwork, + int liwork, int* info) { + (void)rwork; // unused + (void)lrwork; // unused + platform::dynload::dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, + &liwork, info); +} + +template <> +void lapackEvd, float>( + char jobz, char uplo, int n, paddle::platform::complex* a, int lda, + float* w, paddle::platform::complex* work, int lwork, float* rwork, + int lrwork, int* iwork, int liwork, int* info) { + platform::dynload::cheevd_(&jobz, &uplo, &n, + reinterpret_cast*>(a), &lda, w, + reinterpret_cast*>(work), + &lwork, rwork, &lrwork, iwork, &liwork, info); +} + +template <> +void lapackEvd, double>( + char jobz, char uplo, int n, paddle::platform::complex* a, int lda, + double* w, paddle::platform::complex* work, int lwork, + double* rwork, int lrwork, int* iwork, int liwork, int* info) { + platform::dynload::zheevd_(&jobz, &uplo, &n, + reinterpret_cast*>(a), &lda, + w, reinterpret_cast*>(work), + &lwork, rwork, &lrwork, iwork, &liwork, info); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/lapack_function.h b/paddle/fluid/operators/math/lapack_function.h index 694da4603ba53..977dd12da156b 100644 --- a/paddle/fluid/operators/math/lapack_function.h +++ b/paddle/fluid/operators/math/lapack_function.h @@ -20,7 +20,12 @@ namespace math { // LU (for example) template -void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info); +void lapackLu(int m, int n, T* a, int lda, int* ipiv, int* info); + +template +void lapackEvd(char jobz, char uplo, int n, T* a, int lda, ValueType* w, + T* work, int lwork, ValueType* rwork, int lrwork, int* iwork, + int liwork, int* info); } // namespace math } // namespace operators diff --git a/paddle/fluid/platform/dynload/lapack.h b/paddle/fluid/platform/dynload/lapack.h index ffb3d3e0f6799..e926a5fa1f7d0 100644 --- a/paddle/fluid/platform/dynload/lapack.h +++ b/paddle/fluid/platform/dynload/lapack.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/port.h" @@ -26,6 +27,20 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info); extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info); +extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex *a, + int *lda, double *w, std::complex *work, + int *lwork, double *rwork, int *lrwork, int *iwork, + int *liwork, int *info); +extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex *a, + int *lda, float *w, std::complex *work, + int *lwork, float *rwork, int *lrwork, int *iwork, + int *liwork, int *info); +extern "C" void dsyevd_(char *jobz, char *uplo, int *n, double *a, int *lda, + double *w, double *work, int *lwork, int *iwork, + int *liwork, int *info); +extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, + float *w, float *work, int *lwork, int *iwork, + int *liwork, int *info); namespace paddle { namespace platform { @@ -58,7 +73,11 @@ extern void *lapack_dso_handle; #define LAPACK_ROUTINE_EACH(__macro) \ __macro(dgetrf_); \ - __macro(sgetrf_); + __macro(sgetrf_); \ + __macro(zheevd_); \ + __macro(cheevd_); \ + __macro(dsyevd_); \ + __macro(ssyevd_); LAPACK_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_LAPACK_WRAP); From 8d53c2ff037eac21288079ba34329ce28cb32363 Mon Sep 17 00:00:00 2001 From: Zjq9409 <15205085056@163.com> Date: Thu, 23 Sep 2021 02:50:50 +0000 Subject: [PATCH 2/5] Remove memory copy operation --- .../operators/math/eigen_values_vectors.h | 40 +++++++++---------- python/paddle/__init__.py | 1 - python/paddle/tensor/linalg.py | 2 +- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 980a0d6684b32..11dea089bf6cd 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -52,18 +52,17 @@ struct MatrixEighFunctor { Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, bool has_vectors) { auto *out_value = eigen_values->mutable_data(ctx.GetPlace()); + + auto dito = + math::DeviceIndependenceTensorOperations( + ctx); + *eigen_vectors = dito.Transpose(input); auto *out_vector = eigen_vectors->mutable_data(ctx.GetPlace()); auto dims = input.dims(); int dim_size = dims.size(); int64_t batch_size = GetBatchSize(dims); - auto dito = - math::DeviceIndependenceTensorOperations( - ctx); - Tensor output_v_var_trans = dito.Transpose(input); - TensorCopy(output_v_var_trans, ctx.GetPlace(), eigen_vectors); - int vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; int values_stride = dims[dim_size - 1]; char uplo = is_lower ? 'L' : 'U'; @@ -74,27 +73,27 @@ struct MatrixEighFunctor { int lwork = -1; int lrwork = -1; int liwork = -1; - int iwork_buffer = -1; - T lwork_buffer = static_cast(-1); - ValueType rwork_buffer = static_cast(-1); + int iwork_opt = -1; + T lwork_opt = static_cast(-1); + ValueType rwork_opt = static_cast(-1); Tensor info_tensor; auto *infos_data = info_tensor.mutable_data( framework::make_ddim({batch_size}), ctx.GetPlace()); math::lapackEvd(jobz, uplo, n, out_vector, lda, out_value, - &lwork_buffer, lwork, &rwork_buffer, lrwork, - &iwork_buffer, liwork, infos_data); + &lwork_opt, lwork, &rwork_opt, lrwork, + &iwork_opt, liwork, infos_data); - lwork = std::max(1, static_cast(lwork_buffer)); - liwork = std::max(1, iwork_buffer); + lwork = std::max(1, static_cast(lwork_opt)); + liwork = std::max(1, iwork_opt); Tensor rwork_tensor; ValueType *rwork_data = nullptr; // complex type if (framework::IsComplexType(eigen_vectors->type())) { - lrwork = std::max(1, static_cast(rwork_buffer)); + lrwork = std::max(1, static_cast(rwork_opt)); rwork_data = rwork_tensor.mutable_data( framework::make_ddim({lrwork}), ctx.GetPlace()); } @@ -136,6 +135,12 @@ struct MatrixEighFunctor { Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, bool has_vectors) { auto *out_value = eigen_values->mutable_data(ctx.GetPlace()); + + auto &dev_ctx = ctx.template device_context(); + auto dito = + math::DeviceIndependenceTensorOperations(ctx); + *eigen_vectors = dito.Transpose(input); auto *out_vector = eigen_vectors->mutable_data(ctx.GetPlace()); auto &dims = input.dims(); @@ -152,13 +157,6 @@ struct MatrixEighFunctor { auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; auto values_stride = dims[dim_size - 1]; - auto &dev_ctx = ctx.template device_context(); - auto dito = - math::DeviceIndependenceTensorOperations(ctx); - Tensor output_v_var_trans = dito.Transpose(input); - TensorCopy(output_v_var_trans, ctx.GetPlace(), eigen_vectors); - int lwork = 0; auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_size); auto *info_ptr = reinterpret_cast(info->ptr()); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 6bd58ee558f0b..dc9add3e0cf18 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -104,7 +104,6 @@ from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor.linalg import svd # noqa: F401 -from .tensor.linalg import eigh # noqa: F401 from .tensor.linalg import pinv # noqa: F401 from .tensor.logic import equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index c7862f61894e5..14c3b5deeaac7 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1656,7 +1656,7 @@ def eigh(x, UPLO='L', name=None): x_data = np.array([[1, -2j], [2j, 5]]) x = paddle.to_tensor(x_data) - out_value, out_vector = paddle.eigh(x, UPLO='L') + out_value, out_vector = paddle.linalg.eigh(x, UPLO='L') print(out_value) #[0.17157288, 5.82842712] print(out_vector) From 8733853a6e408ca0a2757bcc48285e4d55a6c01d Mon Sep 17 00:00:00 2001 From: Zjq9409 <15205085056@163.com> Date: Thu, 23 Sep 2021 12:25:16 +0000 Subject: [PATCH 3/5] Modify format --- paddle/fluid/operators/math/eigen_values_vectors.h | 1 - paddle/fluid/platform/dynload/lapack.h | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 11dea089bf6cd..101fbb9addf3b 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -14,7 +14,6 @@ #pragma once -#include "Eigen/Core" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/math/lapack_function.h" #include "paddle/fluid/operators/svd_helper.h" diff --git a/paddle/fluid/platform/dynload/lapack.h b/paddle/fluid/platform/dynload/lapack.h index 7354220940cbd..9b4dd3d9e3ce5 100644 --- a/paddle/fluid/platform/dynload/lapack.h +++ b/paddle/fluid/platform/dynload/lapack.h @@ -29,7 +29,7 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info); -// evd +// evd extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex *a, int *lda, double *w, std::complex *work, int *lwork, double *rwork, int *lrwork, int *iwork, From 7b8aa10b18034524b11c2c3de32bcbc6748e9c38 Mon Sep 17 00:00:00 2001 From: Zjq9409 <15205085056@163.com> Date: Fri, 24 Sep 2021 10:17:43 +0000 Subject: [PATCH 4/5] Modify the incoming type parameters, add comments --- .../operators/math/eigen_values_vectors.h | 74 +++++++++++-------- .../fluid/operators/math/lapack_function.cc | 29 ++++---- paddle/fluid/operators/math/lapack_function.h | 6 +- 3 files changed, 59 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 101fbb9addf3b..223cc10f51182 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -34,6 +34,19 @@ inline int64_t GetBatchSize(framework::DDim dims) { return batch_size; } +static void CheckEighResult(const int batch, const int info) { + PADDLE_ENFORCE_LE( + info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: the [%d] off-diagonal elements of an intermediate" + "tridiagonal form did not converge to zero", + batch, info)); + PADDLE_ENFORCE_GE( + info, 0, platform::errors::PreconditionNotMet( + "For batch [%d]: the [%d] argument had an illegal value", + batch, info)); +} + template struct MatrixEighFunctor { void operator()(const framework::ExecutionContext &ctx, const Tensor &input, @@ -55,7 +68,11 @@ struct MatrixEighFunctor { auto dito = math::DeviceIndependenceTensorOperations( ctx); - *eigen_vectors = dito.Transpose(input); + if (has_vectors) { + // lapack is a column-first storage, transpose make the eigen_vectors to + // have a continuous memory layout + *eigen_vectors = dito.Transpose(input); + } auto *out_vector = eigen_vectors->mutable_data(ctx.GetPlace()); auto dims = input.dims(); @@ -68,21 +85,21 @@ struct MatrixEighFunctor { char jobz = has_vectors ? 'V' : 'N'; auto n = dims[dim_size - 1]; auto lda = std::max(1, n); - - int lwork = -1; - int lrwork = -1; - int liwork = -1; - int iwork_opt = -1; - T lwork_opt = static_cast(-1); - ValueType rwork_opt = static_cast(-1); - - Tensor info_tensor; - auto *infos_data = info_tensor.mutable_data( - framework::make_ddim({batch_size}), ctx.GetPlace()); - - math::lapackEvd(jobz, uplo, n, out_vector, lda, out_value, - &lwork_opt, lwork, &rwork_opt, lrwork, - &iwork_opt, liwork, infos_data); + // if work = -1, it means that you need to use the lapack function to query + // the optimal value + int lwork = -1; // The length of the array work + int lrwork = -1; // The dimension of the array rwork,rwork is REAL array + int liwork = -1; // The dimension of the array iwork + int iwork_opt = -1; // The optimal length of the array liwork + T lwork_opt = static_cast(-1); // The optimal length of the array work + ValueType rwork_opt = + static_cast(-1); // The optimal length of the array rwork + + int info = 0; + // Call lapackEigh to get the optimal size of work data + math::lapackEigh>(jobz, uplo, n, out_vector, lda, + out_value, &lwork_opt, lwork, &rwork_opt, + lrwork, &iwork_opt, liwork, &info); lwork = std::max(1, static_cast(lwork_opt)); liwork = std::max(1, iwork_opt); @@ -106,15 +123,10 @@ struct MatrixEighFunctor { for (auto i = 0; i < batch_size; i++) { auto *value_data = out_value + i * values_stride; auto *vector_data = out_vector + i * vector_stride; - int *info_ptr = &infos_data[i]; - math::lapackEvd(jobz, uplo, n, vector_data, lda, value_data, - work_data, lwork, rwork_data, lrwork, - iwork_data, liwork, info_ptr); - PADDLE_ENFORCE_EQ( - *info_ptr, 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: the [%d] argument had an illegal value", i, - *info_ptr)); + math::lapackEigh>(jobz, uplo, n, vector_data, lda, value_data, + work_data, lwork, rwork_data, lrwork, + iwork_data, liwork, &info); + CheckEighResult(i, info); } if (has_vectors) { *eigen_vectors = dito.Transpose(*eigen_vectors); @@ -139,7 +151,9 @@ struct MatrixEighFunctor { auto dito = math::DeviceIndependenceTensorOperations(ctx); - *eigen_vectors = dito.Transpose(input); + if (has_vectors) { + *eigen_vectors = dito.Transpose(input); + } auto *out_vector = eigen_vectors->mutable_data(ctx.GetPlace()); auto &dims = input.dims(); @@ -199,15 +213,11 @@ struct MatrixEighFunctor { Evd(handle, jobz, uplo, n, vector_data, lda, value_data, work_ptr, lwork, info_ptr); } - int error_info; + int error_info = 0; memory::Copy(platform::CPUPlace(), &error_info, BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), info_ptr, sizeof(int), dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - error_info, 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: the [%d] argument had an illegal value", i, - error_info)); + CheckEighResult(i, error_info); } if (use_syevj) { diff --git a/paddle/fluid/operators/math/lapack_function.cc b/paddle/fluid/operators/math/lapack_function.cc index f9270b9966f9a..3ce2225420e60 100644 --- a/paddle/fluid/operators/math/lapack_function.cc +++ b/paddle/fluid/operators/math/lapack_function.cc @@ -33,9 +33,9 @@ void lapackLu(int m, int n, float *a, int lda, int *ipiv, int *info) { // eigh template <> -void lapackEvd(char jobz, char uplo, int n, float *a, int lda, - float *w, float *work, int lwork, float *rwork, - int lrwork, int *iwork, int liwork, int *info) { +void lapackEigh(char jobz, char uplo, int n, float *a, int lda, float *w, + float *work, int lwork, float *rwork, int lrwork, + int *iwork, int liwork, int *info) { (void)rwork; // unused (void)lrwork; // unused platform::dynload::ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, @@ -43,10 +43,9 @@ void lapackEvd(char jobz, char uplo, int n, float *a, int lda, } template <> -void lapackEvd(char jobz, char uplo, int n, double *a, int lda, - double *w, double *work, int lwork, - double *rwork, int lrwork, int *iwork, - int liwork, int *info) { +void lapackEigh(char jobz, char uplo, int n, double *a, int lda, + double *w, double *work, int lwork, double *rwork, + int lrwork, int *iwork, int liwork, int *info) { (void)rwork; // unused (void)lrwork; // unused platform::dynload::dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, @@ -54,10 +53,10 @@ void lapackEvd(char jobz, char uplo, int n, double *a, int lda, } template <> -void lapackEvd, float>( - char jobz, char uplo, int n, paddle::platform::complex *a, int lda, - float *w, paddle::platform::complex *work, int lwork, float *rwork, - int lrwork, int *iwork, int liwork, int *info) { +void lapackEigh, float>( + char jobz, char uplo, int n, platform::complex *a, int lda, float *w, + platform::complex *work, int lwork, float *rwork, int lrwork, + int *iwork, int liwork, int *info) { platform::dynload::cheevd_(&jobz, &uplo, &n, reinterpret_cast *>(a), &lda, w, reinterpret_cast *>(work), @@ -65,10 +64,10 @@ void lapackEvd, float>( } template <> -void lapackEvd, double>( - char jobz, char uplo, int n, paddle::platform::complex *a, int lda, - double *w, paddle::platform::complex *work, int lwork, - double *rwork, int lrwork, int *iwork, int liwork, int *info) { +void lapackEigh, double>( + char jobz, char uplo, int n, platform::complex *a, int lda, + double *w, platform::complex *work, int lwork, double *rwork, + int lrwork, int *iwork, int liwork, int *info) { platform::dynload::zheevd_(&jobz, &uplo, &n, reinterpret_cast *>(a), &lda, w, reinterpret_cast *>(work), diff --git a/paddle/fluid/operators/math/lapack_function.h b/paddle/fluid/operators/math/lapack_function.h index d7e51477ba1a6..a4c2c865c859a 100644 --- a/paddle/fluid/operators/math/lapack_function.h +++ b/paddle/fluid/operators/math/lapack_function.h @@ -23,9 +23,9 @@ template void lapackLu(int m, int n, T* a, int lda, int* ipiv, int* info); template -void lapackEvd(char jobz, char uplo, int n, T* a, int lda, ValueType* w, - T* work, int lwork, ValueType* rwork, int lrwork, int* iwork, - int liwork, int* info); +void lapackEigh(char jobz, char uplo, int n, T* a, int lda, ValueType* w, + T* work, int lwork, ValueType* rwork, int lrwork, int* iwork, + int liwork, int* info); template void lapackEig(char jobvl, char jobvr, int n, T1* a, int lda, T1* w, T1* vl, From 3cf4e85884363dd42e1a811923adae0001166abf Mon Sep 17 00:00:00 2001 From: Zjq9409 <15205085056@163.com> Date: Sat, 25 Sep 2021 14:28:18 +0000 Subject: [PATCH 5/5] Support eigenvectors nullptr --- paddle/fluid/operators/eigh_op.cc | 17 ++-- paddle/fluid/operators/eigh_op.cu | 17 ++-- paddle/fluid/operators/eigh_op.h | 7 +- .../operators/math/eigen_values_vectors.h | 98 ++++++++++--------- 4 files changed, 72 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/operators/eigh_op.cc b/paddle/fluid/operators/eigh_op.cc index 5577dfb8f889b..6835951a2381f 100644 --- a/paddle/fluid/operators/eigh_op.cc +++ b/paddle/fluid/operators/eigh_op.cc @@ -147,18 +147,17 @@ REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker, REGISTER_OPERATOR(eigh_grad, ops::EighGradOp); REGISTER_OP_CPU_KERNEL( - eigh, ops::EighKernel, - ops::EighKernel, - ops::EighKernel, + ops::EighKernel, + ops::EighKernel>, - ops::EighKernel>); REGISTER_OP_CPU_KERNEL( - eigh_grad, - ops::EighGradKernel, - ops::EighGradKernel, - ops::EighGradKernel, + ops::EighGradKernel, + ops::EighGradKernel>, - ops::EighGradKernel>); diff --git a/paddle/fluid/operators/eigh_op.cu b/paddle/fluid/operators/eigh_op.cu index 61d2b66ea536d..827c551637d4d 100644 --- a/paddle/fluid/operators/eigh_op.cu +++ b/paddle/fluid/operators/eigh_op.cu @@ -16,18 +16,17 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - eigh, ops::EighKernel, - ops::EighKernel, - ops::EighKernel, + ops::EighKernel, + ops::EighKernel>, - ops::EighKernel>); REGISTER_OP_CUDA_KERNEL( - eigh_grad, - ops::EighGradKernel, - ops::EighGradKernel, - ops::EighGradKernel, + ops::EighGradKernel, + ops::EighGradKernel>, - ops::EighGradKernel>); diff --git a/paddle/fluid/operators/eigh_op.h b/paddle/fluid/operators/eigh_op.h index 085e7531dd523..ad9b0f598311b 100644 --- a/paddle/fluid/operators/eigh_op.h +++ b/paddle/fluid/operators/eigh_op.h @@ -22,7 +22,7 @@ namespace operators { using Tensor = framework::Tensor; -template +template class EighKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -31,15 +31,16 @@ class EighKernel : public framework::OpKernel { auto output_v = ctx.Output("Eigenvectors"); std::string lower = ctx.Attr("UPLO"); bool is_lower = (lower == "L"); - math::MatrixEighFunctor functor; + math::MatrixEighFunctor functor; functor(ctx, *input, output_w, output_v, is_lower, true); } }; -template +template class EighGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + using ValueType = math::Real; auto& x_grad = *ctx.Output(framework::GradVarName("X")); x_grad.mutable_data(ctx.GetPlace()); auto& output_w = *ctx.Input("Eigenvalues"); diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 0c4f0858ec1fe..01f05530e34e6 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -47,7 +47,7 @@ static void CheckEighResult(const int batch, const int info) { batch, info)); } -template +template struct MatrixEighFunctor { void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, @@ -57,22 +57,24 @@ struct MatrixEighFunctor { // Calculates the eigenvalues ​​and eigenvectors of Hermitian or real // symmetric matrices, and uses the variable has_vectors to // control whether to return the eigenvectors. -template -struct MatrixEighFunctor { +template +struct MatrixEighFunctor { public: void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, bool has_vectors) { + using ValueType = math::Real; auto *out_value = eigen_values->mutable_data(ctx.GetPlace()); auto dito = math::DeviceIndependenceTensorOperations( ctx); - // lapack is a column-major storge, transpose make the eigen_vectors to + Tensor input_trans; + // lapack is a column-major storge, transpose make the input to // have a continuous memory layout - *eigen_vectors = dito.Transpose(input); - auto *out_vector = eigen_vectors->data(); + input_trans = dito.Transpose(input); + auto *input_vector = input_trans.data(); auto dims = input.dims(); int dim_size = dims.size(); @@ -96,10 +98,9 @@ struct MatrixEighFunctor { int info = 0; // Call lapackEigh to get the optimal size of work data - math::lapackEigh>(jobz, uplo, n, out_vector, lda, - out_value, &lwork_opt, lwork, &rwork_opt, - lrwork, &iwork_opt, liwork, &info); - + math::lapackEigh(jobz, uplo, n, input_vector, lda, out_value, + &lwork_opt, lwork, &rwork_opt, lrwork, + &iwork_opt, liwork, &info); lwork = std::max(1, static_cast(lwork_opt)); liwork = std::max(1, iwork_opt); @@ -107,12 +108,11 @@ struct MatrixEighFunctor { ValueType *rwork_data = nullptr; // complex type - if (framework::IsComplexType(eigen_vectors->type())) { + if (framework::IsComplexType(input.type())) { lrwork = std::max(1, static_cast(rwork_opt)); rwork_data = rwork_tensor.mutable_data( framework::make_ddim({lrwork}), ctx.GetPlace()); } - Tensor iwork_tensor, work_tensor; auto *iwork_data = iwork_tensor.mutable_data( framework::make_ddim({liwork}), ctx.GetPlace()); @@ -121,14 +121,20 @@ struct MatrixEighFunctor { for (auto i = 0; i < batch_size; i++) { auto *value_data = out_value + i * values_stride; - auto *vector_data = out_vector + i * vector_stride; - math::lapackEigh>(jobz, uplo, n, vector_data, lda, value_data, + auto *input_data = input_vector + i * vector_stride; + math::lapackEigh>(jobz, uplo, n, input_data, lda, value_data, work_data, lwork, rwork_data, lrwork, iwork_data, liwork, &info); CheckEighResult(i, info); } if (has_vectors) { - *eigen_vectors = dito.Transpose(*eigen_vectors); + PADDLE_ENFORCE_NOT_NULL(eigen_vectors, + platform::errors::InvalidArgument( + "When has_vectors is true," + "the eigenvectors needs to be calculated, " + "so the eigenvectors must be provided.")); + input_trans = dito.Transpose(input_trans); + eigen_vectors->ShareDataWith(input_trans); } } }; @@ -138,21 +144,22 @@ struct MatrixEighFunctor { // Calculates the eigenvalues ​​and eigenvectors of Hermitian or real // symmetric matrices on GPU, and uses the variable has_vectors // to control whether to return the eigenvectors. -template -struct MatrixEighFunctor { +template +struct MatrixEighFunctor { public: void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, bool has_vectors) { + using ValueType = math::Real; auto *out_value = eigen_values->mutable_data(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); auto dito = math::DeviceIndependenceTensorOperations(ctx); - *eigen_vectors = dito.Transpose(input); - auto *out_vector = eigen_vectors->mutable_data(ctx.GetPlace()); - + Tensor input_trans; + input_trans = dito.Transpose(input); + auto *input_vector = input_trans.data(); auto &dims = input.dims(); int dim_size = dims.size(); int64_t batch_size = GetBatchSize(dims); @@ -166,7 +173,6 @@ struct MatrixEighFunctor { int lda = std::max(1, n); auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; auto values_stride = dims[dim_size - 1]; - int lwork = 0; auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_size); auto *info_ptr = reinterpret_cast(info->ptr()); @@ -174,10 +180,8 @@ struct MatrixEighFunctor { // When the input type is float32, and the feature value input dimension is // greater than or equal to [*,32,32] and less than or equal to // [*,512,512], Syevj has better performance. - bool use_syevj = - (eigen_vectors->type() == framework::proto::VarType::FP32 && - values_stride >= 32 && values_stride <= 512); - + bool use_syevj = (input.type() == framework::proto::VarType::FP32 && + values_stride >= 32 && values_stride <= 512); syevjInfo_t syevj_params; if (use_syevj) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -185,30 +189,28 @@ struct MatrixEighFunctor { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cusolverDnSsyevj_bufferSize( dev_ctx.cusolver_dn_handle(), jobz, uplo, n, - reinterpret_cast(out_vector), lda, + reinterpret_cast(input_vector), lda, reinterpret_cast(out_value), &lwork, syevj_params)); } else { - EvdBuffer(dev_ctx.cusolver_dn_handle(), jobz, uplo, n, out_vector, lda, + EvdBuffer(dev_ctx.cusolver_dn_handle(), jobz, uplo, n, input_vector, lda, out_value, &lwork); } - auto work = memory::Alloc(dev_ctx, sizeof(T) * lwork); auto *work_ptr = reinterpret_cast(work->ptr()); - for (auto i = 0; i < batch_size; i++) { - auto vector_data = out_vector + i * vector_stride; - auto value_data = out_value + i * values_stride; + auto *input_data = input_vector + i * vector_stride; + auto *value_data = out_value + i * values_stride; auto handle = dev_ctx.cusolver_dn_handle(); if (use_syevj) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSsyevj( - handle, jobz, uplo, n, reinterpret_cast(vector_data), lda, + handle, jobz, uplo, n, reinterpret_cast(input_data), lda, reinterpret_cast(value_data), reinterpret_cast(work_ptr), lwork, info_ptr, syevj_params)); } else { - Evd(handle, jobz, uplo, n, vector_data, lda, value_data, work_ptr, - lwork, info_ptr); + Evd(handle, jobz, uplo, n, input_data, lda, value_data, work_ptr, lwork, + info_ptr); } int error_info = 0; memory::Copy(platform::CPUPlace(), &error_info, @@ -221,12 +223,18 @@ struct MatrixEighFunctor { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cusolverDnDestroySyevjInfo(syevj_params)); } - if (has_vectors) { - *eigen_vectors = dito.Transpose(*eigen_vectors); + PADDLE_ENFORCE_NOT_NULL(eigen_vectors, + platform::errors::InvalidArgument( + "When has_vectors is true," + "the eigenvectors needs to be calculated," + "so the eigenvectors must be provided.")); + input_trans = dito.Transpose(input_trans); + eigen_vectors->ShareDataWith(input_trans); } } + using ValueType = math::Real; inline void EvdBuffer(cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, int *lwork) const; @@ -236,15 +244,14 @@ struct MatrixEighFunctor { T *work, int lwork, int *devInfo) const; }; -#define FUNC_WITH_TYPES(m) \ - m(float, float, Ssy, float) m(double, double, Dsy, double) \ - m(float, paddle::platform::complex, Che, cuComplex) \ - m(double, paddle::platform::complex, Zhe, cuDoubleComplex) +#define FUNC_WITH_TYPES(m) \ + m(float, Ssy, float) m(double, Dsy, double) \ + m(paddle::platform::complex, Che, cuComplex) \ + m(paddle::platform::complex, Zhe, cuDoubleComplex) -#define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \ +#define EVDBUFFER_INSTANCE(T, C, CastType) \ template <> \ - inline void \ - MatrixEighFunctor::EvdBuffer( \ + inline void MatrixEighFunctor::EvdBuffer( \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \ int *lwork) const { \ @@ -256,10 +263,9 @@ struct MatrixEighFunctor { FUNC_WITH_TYPES(EVDBUFFER_INSTANCE); -#define EVD_INSTANCE(ValueType, T, C, CastType) \ +#define EVD_INSTANCE(T, C, CastType) \ template <> \ - inline void \ - MatrixEighFunctor::Evd( \ + inline void MatrixEighFunctor::Evd( \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \ int lwork, int *devInfo) const { \