Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【complex op No.26】add complex support for inv #63229

Merged
merged 6 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/fluid/platform/dynload/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,14 @@ namespace dynload {
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched); \
__macro(cublasCgetrfBatched); \
__macro(cublasCgetriBatched); \
__macro(cublasZgetrfBatched); \
__macro(cublasZgetriBatched); \
__macro(cublasSmatinvBatched); \
__macro(cublasDmatinvBatched); \
__macro(cublasCmatinvBatched); \
__macro(cublasZmatinvBatched); \
__macro(cublasSgetrsBatched); \
__macro(cublasDgetrsBatched);

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/backends/dynload/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,14 @@ extern void *cublas_dso_handle;
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched); \
__macro(cublasCgetrfBatched); \
__macro(cublasCgetriBatched); \
__macro(cublasZgetrfBatched); \
__macro(cublasZgetriBatched); \
__macro(cublasSmatinvBatched); \
__macro(cublasDmatinvBatched); \
__macro(cublasCmatinvBatched); \
__macro(cublasZmatinvBatched); \
__macro(cublasSgetrsBatched); \
__macro(cublasDgetrsBatched);

Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/inverse_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@

#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(
inverse_grad, CPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {}
PD_REGISTER_KERNEL(inverse_grad,
CPU,
ALL_LAYOUT,
phi::InverseGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/inverse_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@

#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(
inverse, CPU, ALL_LAYOUT, phi::InverseKernel, float, double) {}
PD_REGISTER_KERNEL(inverse,
CPU,
ALL_LAYOUT,
phi::InverseKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
114 changes: 114 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,63 @@ struct CUBlas<phi::dtype::complex<float>> {
ldb,
batch_size));
}

static void GETRF_BATCH(cublasHandle_t handle,
int n,
phi::dtype::complex<float> **A,
int lda,
int *ipiv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetrfBatched(
handle,
n,
reinterpret_cast<cuFloatComplex **>(A),
lda,
ipiv,
info,
batch_size));
}

static void GETRI_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<float> **A,
int lda,
const int *ipiv,
phi::dtype::complex<float> **Ainv,
int ldc,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetriBatched(
handle,
n,
reinterpret_cast<const cuFloatComplex **>(A),
lda,
ipiv,
reinterpret_cast<cuFloatComplex **>(Ainv),
ldc,
info,
batch_size));
}

static void MATINV_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<float> **A,
int lda,
phi::dtype::complex<float> **Ainv,
int lda_inv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCmatinvBatched(
handle,
n,
reinterpret_cast<const cuFloatComplex **>(A),
lda,
reinterpret_cast<cuFloatComplex **>(Ainv),
lda_inv,
info,
batch_size));
}
};

template <>
Expand Down Expand Up @@ -923,6 +980,63 @@ struct CUBlas<phi::dtype::complex<double>> {
"cublasGemmEx is not supported on cuda <= 7.5"));
#endif
}

static void GETRF_BATCH(cublasHandle_t handle,
int n,
phi::dtype::complex<double> **A,
int lda,
int *ipiv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetrfBatched(
handle,
n,
reinterpret_cast<cuDoubleComplex **>(A),
lda,
ipiv,
info,
batch_size));
}

static void GETRI_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<double> **A,
int lda,
const int *ipiv,
phi::dtype::complex<double> **Ainv,
int ldc,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetriBatched(
handle,
n,
reinterpret_cast<const cuDoubleComplex **>(A),
lda,
ipiv,
reinterpret_cast<cuDoubleComplex **>(Ainv),
ldc,
info,
batch_size));
}

static void MATINV_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<double> **A,
int lda,
phi::dtype::complex<double> **Ainv,
int lda_inv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZmatinvBatched(
handle,
n,
reinterpret_cast<const cuDoubleComplex **>(A),
lda,
reinterpret_cast<cuDoubleComplex **>(Ainv),
lda_inv,
info,
batch_size));
}
};

template <>
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_inverse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,

template class MatrixInverseFunctor<CPUContext, float>;
template class MatrixInverseFunctor<CPUContext, double>;
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<float>>;
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<double>>;

} // namespace funcs
} // namespace phi
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_inverse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,

template class MatrixInverseFunctor<GPUContext, float>;
template class MatrixInverseFunctor<GPUContext, double>;
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<float>>;
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<double>>;

} // namespace funcs
} // namespace phi
79 changes: 65 additions & 14 deletions paddle/phi/kernels/funcs/matrix_inverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,69 @@ limitations under the License. */
namespace phi {
namespace funcs {

template <typename Context, typename T>
struct MapMatrixInverseFunctor {
void operator()(
const Context& dev_ctx, const T* a_ptr, T* a_inv_ptr, int offset, int n) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;

ConstEigenMatrixMap mat(a_ptr + offset, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + offset, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(min_abs_pivot,
static_cast<T>(0),
errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
}
};

template <typename Context, typename T>
struct MapMatrixInverseFunctor<Context, phi::dtype::complex<T>> {
void operator()(const Context& dev_ctx,
const phi::dtype::complex<T>* a_ptr,
phi::dtype::complex<T>* a_inv_ptr,
int offset,
int n) {
using Matrix = Eigen::Matrix<std::complex<T>,
Eigen::Dynamic,
Eigen::Dynamic,
Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
std::complex<T>* std_ptr = new std::complex<T>[n * n];
std::complex<T>* std_inv_ptr = new std::complex<T>[n * n];
for (int i = 0; i < n * n; i++) {
*(std_ptr + i) = static_cast<std::complex<T>>(*(a_ptr + offset + i));
}
ConstEigenMatrixMap mat(std_ptr, n, n);
EigenMatrixMap mat_inv(std_inv_ptr, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_NE(min_abs_pivot,
static_cast<std::complex<T>>(0),
errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
for (int i = 0; i < n * n; i++) {
*(a_inv_ptr + offset + i) =
static_cast<phi::dtype::complex<T>>(*(std_inv_ptr + i));
}
delete[] std_ptr;
delete[] std_inv_ptr;
}
};

template <typename Context, typename T>
void ComputeInverseEigen(const Context& dev_ctx,
const DenseTensor& a,
DenseTensor* a_inv) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
Expand All @@ -41,17 +96,13 @@ void ComputeInverseEigen(const Context& dev_ctx,
const T* a_ptr = a.data<T>();
T* a_inv_ptr = dev_ctx.template Alloc<T>(a_inv);

// Putting phi::dtype::complex into eigen::matrix has a problem,
// it's not going to get the right result,
// so we're going to convert it to std::complex and
// then we're going to put it into eigen::matrix.
for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(min_abs_pivot,
static_cast<T>(0),
errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
MapMatrixInverseFunctor<Context, T> functor;
functor(dev_ctx, a_ptr, a_inv_ptr, i * n * n, n);
}
}

Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/inverse_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/inverse_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
inverse_grad, GPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {}
PD_REGISTER_KERNEL(inverse_grad,
GPU,
ALL_LAYOUT,
phi::InverseGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/inverse_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/inverse_kernel_impl.h"

PD_REGISTER_KERNEL(
inverse, GPU, ALL_LAYOUT, phi::InverseKernel, float, double) {}
PD_REGISTER_KERNEL(inverse,
GPU,
ALL_LAYOUT,
phi::InverseKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
37 changes: 29 additions & 8 deletions paddle/phi/kernels/impl/inverse_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"

Expand All @@ -37,15 +38,35 @@ void InverseGradKernel(const Context& dev_ctx,
tmp_out.Resize(out.dims());
dev_ctx.template Alloc<T>(&tmp_out);

auto mat_dim_a0 =
phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false);
auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
blas.MatMul(out_grad, mat_dim_a0, out, mat_dim_b0, T(1), &tmp_out, T(0));
if (IsComplexType(out.dtype())) {
DenseTensor out_conj;
out_conj.Resize(out.dims());
dev_ctx.template Alloc<T>(&out_conj);

auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
auto mat_dim_b1 =
phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false);
blas.MatMul(out, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0));
phi::ConjKernel<T, Context>(dev_ctx, out, &out_conj);

auto mat_dim_a0 =
phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false);
auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
blas.MatMul(
out_grad, mat_dim_a0, out_conj, mat_dim_b0, T(1), &tmp_out, T(0));

auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
auto mat_dim_b1 =
phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false);
blas.MatMul(
out_conj, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0));
} else {
auto mat_dim_a0 =
phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false);
auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
blas.MatMul(out_grad, mat_dim_a0, out, mat_dim_b0, T(1), &tmp_out, T(0));

auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
auto mat_dim_b1 =
phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false);
blas.MatMul(out, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0));
}
}
}

Expand Down
Loading