Skip to content

Commit

Permalink
【complex op No.26】add complex support for inv (#63229)
Browse files Browse the repository at this point in the history
* add complex dtype for inv

* fix

* update

* fix

* fix
  • Loading branch information
zbt78 authored Jun 3, 2024
1 parent 86d347b commit 252b746
Show file tree
Hide file tree
Showing 13 changed files with 289 additions and 32 deletions.
6 changes: 6 additions & 0 deletions paddle/fluid/platform/dynload/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,14 @@ namespace dynload {
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched); \
__macro(cublasCgetrfBatched); \
__macro(cublasCgetriBatched); \
__macro(cublasZgetrfBatched); \
__macro(cublasZgetriBatched); \
__macro(cublasSmatinvBatched); \
__macro(cublasDmatinvBatched); \
__macro(cublasCmatinvBatched); \
__macro(cublasZmatinvBatched); \
__macro(cublasSgetrsBatched); \
__macro(cublasDgetrsBatched);

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/backends/dynload/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,14 @@ extern void *cublas_dso_handle;
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched); \
__macro(cublasCgetrfBatched); \
__macro(cublasCgetriBatched); \
__macro(cublasZgetrfBatched); \
__macro(cublasZgetriBatched); \
__macro(cublasSmatinvBatched); \
__macro(cublasDmatinvBatched); \
__macro(cublasCmatinvBatched); \
__macro(cublasZmatinvBatched); \
__macro(cublasSgetrsBatched); \
__macro(cublasDgetrsBatched);

Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/inverse_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@

#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(
inverse_grad, CPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {}
PD_REGISTER_KERNEL(inverse_grad,
CPU,
ALL_LAYOUT,
phi::InverseGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/inverse_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@

#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(
inverse, CPU, ALL_LAYOUT, phi::InverseKernel, float, double) {}
PD_REGISTER_KERNEL(inverse,
CPU,
ALL_LAYOUT,
phi::InverseKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
114 changes: 114 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,63 @@ struct CUBlas<phi::dtype::complex<float>> {
ldb,
batch_size));
}

static void GETRF_BATCH(cublasHandle_t handle,
int n,
phi::dtype::complex<float> **A,
int lda,
int *ipiv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetrfBatched(
handle,
n,
reinterpret_cast<cuFloatComplex **>(A),
lda,
ipiv,
info,
batch_size));
}

static void GETRI_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<float> **A,
int lda,
const int *ipiv,
phi::dtype::complex<float> **Ainv,
int ldc,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetriBatched(
handle,
n,
reinterpret_cast<const cuFloatComplex **>(A),
lda,
ipiv,
reinterpret_cast<cuFloatComplex **>(Ainv),
ldc,
info,
batch_size));
}

static void MATINV_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<float> **A,
int lda,
phi::dtype::complex<float> **Ainv,
int lda_inv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCmatinvBatched(
handle,
n,
reinterpret_cast<const cuFloatComplex **>(A),
lda,
reinterpret_cast<cuFloatComplex **>(Ainv),
lda_inv,
info,
batch_size));
}
};

template <>
Expand Down Expand Up @@ -923,6 +980,63 @@ struct CUBlas<phi::dtype::complex<double>> {
"cublasGemmEx is not supported on cuda <= 7.5"));
#endif
}

static void GETRF_BATCH(cublasHandle_t handle,
int n,
phi::dtype::complex<double> **A,
int lda,
int *ipiv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetrfBatched(
handle,
n,
reinterpret_cast<cuDoubleComplex **>(A),
lda,
ipiv,
info,
batch_size));
}

static void GETRI_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<double> **A,
int lda,
const int *ipiv,
phi::dtype::complex<double> **Ainv,
int ldc,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetriBatched(
handle,
n,
reinterpret_cast<const cuDoubleComplex **>(A),
lda,
ipiv,
reinterpret_cast<cuDoubleComplex **>(Ainv),
ldc,
info,
batch_size));
}

static void MATINV_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<double> **A,
int lda,
phi::dtype::complex<double> **Ainv,
int lda_inv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZmatinvBatched(
handle,
n,
reinterpret_cast<const cuDoubleComplex **>(A),
lda,
reinterpret_cast<cuDoubleComplex **>(Ainv),
lda_inv,
info,
batch_size));
}
};

template <>
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_inverse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,

template class MatrixInverseFunctor<CPUContext, float>;
template class MatrixInverseFunctor<CPUContext, double>;
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<float>>;
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<double>>;

} // namespace funcs
} // namespace phi
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_inverse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,

template class MatrixInverseFunctor<GPUContext, float>;
template class MatrixInverseFunctor<GPUContext, double>;
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<float>>;
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<double>>;

} // namespace funcs
} // namespace phi
79 changes: 65 additions & 14 deletions paddle/phi/kernels/funcs/matrix_inverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,69 @@ limitations under the License. */
namespace phi {
namespace funcs {

template <typename Context, typename T>
struct MapMatrixInverseFunctor {
void operator()(
const Context& dev_ctx, const T* a_ptr, T* a_inv_ptr, int offset, int n) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;

ConstEigenMatrixMap mat(a_ptr + offset, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + offset, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(min_abs_pivot,
static_cast<T>(0),
errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
}
};

template <typename Context, typename T>
struct MapMatrixInverseFunctor<Context, phi::dtype::complex<T>> {
void operator()(const Context& dev_ctx,
const phi::dtype::complex<T>* a_ptr,
phi::dtype::complex<T>* a_inv_ptr,
int offset,
int n) {
using Matrix = Eigen::Matrix<std::complex<T>,
Eigen::Dynamic,
Eigen::Dynamic,
Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
std::complex<T>* std_ptr = new std::complex<T>[n * n];
std::complex<T>* std_inv_ptr = new std::complex<T>[n * n];
for (int i = 0; i < n * n; i++) {
*(std_ptr + i) = static_cast<std::complex<T>>(*(a_ptr + offset + i));
}
ConstEigenMatrixMap mat(std_ptr, n, n);
EigenMatrixMap mat_inv(std_inv_ptr, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_NE(min_abs_pivot,
static_cast<std::complex<T>>(0),
errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
for (int i = 0; i < n * n; i++) {
*(a_inv_ptr + offset + i) =
static_cast<phi::dtype::complex<T>>(*(std_inv_ptr + i));
}
delete[] std_ptr;
delete[] std_inv_ptr;
}
};

template <typename Context, typename T>
void ComputeInverseEigen(const Context& dev_ctx,
const DenseTensor& a,
DenseTensor* a_inv) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
Expand All @@ -41,17 +96,13 @@ void ComputeInverseEigen(const Context& dev_ctx,
const T* a_ptr = a.data<T>();
T* a_inv_ptr = dev_ctx.template Alloc<T>(a_inv);

// Putting phi::dtype::complex into eigen::matrix has a problem,
// it's not going to get the right result,
// so we're going to convert it to std::complex and
// then we're going to put it into eigen::matrix.
for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(min_abs_pivot,
static_cast<T>(0),
errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
MapMatrixInverseFunctor<Context, T> functor;
functor(dev_ctx, a_ptr, a_inv_ptr, i * n * n, n);
}
}

Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/inverse_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/inverse_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
inverse_grad, GPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {}
PD_REGISTER_KERNEL(inverse_grad,
GPU,
ALL_LAYOUT,
phi::InverseGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/inverse_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/inverse_kernel_impl.h"

PD_REGISTER_KERNEL(
inverse, GPU, ALL_LAYOUT, phi::InverseKernel, float, double) {}
PD_REGISTER_KERNEL(inverse,
GPU,
ALL_LAYOUT,
phi::InverseKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
37 changes: 29 additions & 8 deletions paddle/phi/kernels/impl/inverse_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"

Expand All @@ -37,15 +38,35 @@ void InverseGradKernel(const Context& dev_ctx,
tmp_out.Resize(out.dims());
dev_ctx.template Alloc<T>(&tmp_out);

auto mat_dim_a0 =
phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false);
auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
blas.MatMul(out_grad, mat_dim_a0, out, mat_dim_b0, T(1), &tmp_out, T(0));
if (IsComplexType(out.dtype())) {
DenseTensor out_conj;
out_conj.Resize(out.dims());
dev_ctx.template Alloc<T>(&out_conj);

auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
auto mat_dim_b1 =
phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false);
blas.MatMul(out, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0));
phi::ConjKernel<T, Context>(dev_ctx, out, &out_conj);

auto mat_dim_a0 =
phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false);
auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
blas.MatMul(
out_grad, mat_dim_a0, out_conj, mat_dim_b0, T(1), &tmp_out, T(0));

auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
auto mat_dim_b1 =
phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false);
blas.MatMul(
out_conj, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0));
} else {
auto mat_dim_a0 =
phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false);
auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
blas.MatMul(out_grad, mat_dim_a0, out, mat_dim_b0, T(1), &tmp_out, T(0));

auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
auto mat_dim_b1 =
phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false);
blas.MatMul(out, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0));
}
}
}

Expand Down
Loading

0 comments on commit 252b746

Please sign in to comment.