Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/phi/backends/dynload/cusolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ extern void *cusolver_dso_handle;
__macro(cusolverDnSetStream); \
__macro(cusolverDnSpotrf_bufferSize); \
__macro(cusolverDnDpotrf_bufferSize); \
__macro(cusolverDnXpotrf_bufferSize); \
__macro(cusolverDnSpotrf); \
__macro(cusolverDnDpotrf); \
__macro(cusolverDnXpotrf); \
__macro(cusolverDnSpotrs); \
__macro(cusolverDnDpotrs); \
__macro(cusolverDnCpotrs); \
Expand Down Expand Up @@ -120,6 +122,8 @@ CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
#if CUDA_VERSION >= 9020
#define CUSOLVER_ROUTINE_EACH_R2(__macro) \
__macro(cusolverDnCreateSyevjInfo); \
__macro(cusolverDnCreateParams); \
__macro(cusolverDnDestroyParams); \
__macro(cusolverDnSsyevj_bufferSize); \
__macro(cusolverDnDsyevj_bufferSize); \
__macro(cusolverDnCheevj_bufferSize); \
Expand Down
136 changes: 104 additions & 32 deletions paddle/phi/kernels/gpu/cholesky_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,60 @@ struct MatrixBandPartFunctor {

FUNC_WITH_TYPES(POTRF_INSTANCE);

#if CUDA_VERSION >= 11040
#define POTRF64_INSTANCE(T, C) \
void Potrf64(const GPUContext& dev_ctx, \
cublasFillMode_t uplo, \
int64_t n, \
T* A, \
int64_t lda, \
int* info) { \
auto handle = dev_ctx.cusolver_dn_handle(); \
cusolverDnParams_t params; \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnCreateParams(&params)); \
size_t workspace_device_size = 0; \
size_t workspace_host_size = 0; \
cudaDataType_t data_type = \
std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_64F; \
PADDLE_ENFORCE_GPU_SUCCESS( \
dynload::cusolverDnXpotrf_bufferSize(handle, \
params, \
uplo, \
n, \
data_type, \
A, \
lda, \
data_type, \
&workspace_device_size, \
&workspace_host_size)); \
auto workspace_device = phi::memory_utils::Alloc( \
dev_ctx.GetPlace(), \
workspace_device_size, \
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); \
auto workspace_host = phi::memory_utils::Alloc( \
phi::CPUPlace(), \
workspace_host_size, \
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); \
PADDLE_ENFORCE_GPU_SUCCESS( \
dynload::cusolverDnXpotrf(handle, \
params, \
uplo, \
n, \
data_type, \
A, \
lda, \
data_type, \
workspace_device->ptr(), \
workspace_device_size, \
workspace_host->ptr(), \
workspace_host_size, \
info)); \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDestroyParams(params)); \
}

FUNC_WITH_TYPES(POTRF64_INSTANCE);
#endif

#if CUDA_VERSION >= 9020 && !defined(_WIN32)
#define POTRF_BATCH_INSTANCE(T, C) \
void PotrfBatched(const GPUContext& dev_ctx, \
Expand All @@ -114,7 +168,7 @@ void CholeskyKernel(const Context& dev_ctx,
const DenseTensor& x,
bool upper,
DenseTensor* out) {
if (out->numel() == 0) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
Expand All @@ -125,7 +179,7 @@ void CholeskyKernel(const Context& dev_ctx,
batch_count *= dims[i];
}
int m = dims[dims.size() - 1];
int tensor_size = batch_count * m * m;
int64_t tensor_size = batch_count * static_cast<int64_t>(m) * m;

const auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
Expand All @@ -135,22 +189,16 @@ void CholeskyKernel(const Context& dev_ctx,
upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
// portf is inplace, thus copy the triangular part of the input matrices to
// the output and set the other triangular part to 0 firstly

phi::funcs::ForRange<GPUContext> for_range(dev_ctx, tensor_size);
// Pre-processing
if (upper) {
MatrixBandPartFunctor<T> matrix_band_part_functor(m,
m,
/* num_lower_diags */ 0,
/* num_upper_diags */ m,
x_data,
out_data);
MatrixBandPartFunctor<T> matrix_band_part_functor(
m, m, 0, -1, x_data, out_data);
for_range(matrix_band_part_functor);
} else {
MatrixBandPartFunctor<T> matrix_band_part_functor(m,
m,
/* num_lower_diags */ m,
/* num_upper_diags */ 0,
x_data,
out_data);
MatrixBandPartFunctor<T> matrix_band_part_functor(
m, m, -1, 0, x_data, out_data);
for_range(matrix_band_part_functor);
}

Expand All @@ -164,7 +212,7 @@ void CholeskyKernel(const Context& dev_ctx,
if (batch_count > 1) {
std::vector<T*> output_ptrs;
for (int i = 0; i < batch_count; i++) {
output_ptrs.emplace_back(out_data + i * m * m);
output_ptrs.emplace_back(out_data + static_cast<int64_t>(i) * m * m);
}
thrust::device_vector<T*> dev_output_ptrs(output_ptrs.begin(),
output_ptrs.end());
Expand All @@ -178,28 +226,28 @@ void CholeskyKernel(const Context& dev_ctx,
// TODO(guosheng): There seems to a bug in cusolver potrfBatched and need
// to clear the upper triangle of the output. Remove this workaround once
// the bug is fixed.

if (!upper) {
MatrixBandPartFunctor<T> matrix_band_part_functor(m,
m,
/* num_lower_diags */ m,
/* num_upper_diags */ 0,
out_data,
out_data);
MatrixBandPartFunctor<T> matrix_band_part_functor(
m, m, -1, 0, out_data, out_data);
for_range(matrix_band_part_functor);
}
} else {
#endif
for (int i = 0; i < batch_count; i++) {
Potrf(dev_ctx, uplo, m, out_data + i * m * m, m, info_ptr + i);
int64_t offset = static_cast<int64_t>(i) * m * m;
#if CUDA_VERSION >= 11040
Potrf64(dev_ctx, uplo, m, out_data + offset, m, info_ptr + i);
#else
Potrf(dev_ctx, uplo, m, out_data + offset, m, info_ptr + i);
#endif
}

#if CUDA_VERSION >= 9020 && !defined(_WIN32)
}
#endif
// check the info
std::vector<int> error_info; // only for checking positive matrix
std::vector<int> error_info;
error_info.resize(batch_count);

memory_utils::Copy(CPUPlace(),
error_info.data(),
dev_ctx.GetPlace(),
Expand All @@ -208,13 +256,37 @@ void CholeskyKernel(const Context& dev_ctx,
dev_ctx.stream());

for (int i = 0; i < batch_count; ++i) {
PADDLE_ENFORCE_EQ(error_info[i],
0,
errors::PreconditionNotMet(
"For batch [%d]: U(%d, %d) is zero, singular U.",
i,
error_info[i],
error_info[i]));
const int info = error_info[i];
if (info == 0) {
continue;
}
if (info < 0) {
PADDLE_ENFORCE_EQ(
info,
0,
errors::InvalidArgument("Cholesky kernel failed for batch %d: "
"The %d-th argument was invalid, please "
"check the kernel implementation.",
i,
-info));
}
PADDLE_ENFORCE_EQ(
info,
0,
errors::PreconditionNotMet(
"Cholesky decomposition failed for batch %d: "
"The leading minor of order %d is not positive definite.",
i,
info));
}

// Post-processing to clear the other triangle
if (upper) {
MatrixBandPartFunctor<T> band_part_post(m, m, 0, -1, out_data, out_data);
for_range(band_part_post);
} else {
MatrixBandPartFunctor<T> band_part_post(m, m, -1, 0, out_data, out_data);
for_range(band_part_post);
}
}

Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/impl/cholesky_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ void CholeskyGradKernel(const Context& dev_ctx,
batch_count *= dims[i];
}
auto m = dims[dims.size() - 1];
int tensor_size = batch_count * m * m;
int64_t tensor_size = static_cast<int64_t>(batch_count) * m * m;

std::vector<int> axis(dims.size() - 2);
std::iota(axis.begin(), axis.end(), 0);
Expand Down Expand Up @@ -304,16 +304,17 @@ void CholeskyGradKernel(const Context& dev_ctx,
for_range(eye_functor);
// TODO(guosheng): use trsmBatched for GPU
for (int i = 0; i < batch_count; i++) {
int64_t offset = static_cast<int64_t>(i) * m * m;
blas.TRSM(/*side*/ CblasLeft,
/*uplo*/ CblasLower,
/*trans*/ CblasNoTrans,
/*diag*/ CblasNonUnit,
/*m*/ m,
/*n*/ m,
/*alpha*/ T(1),
l_data + i * m * m,
l_data + offset,
/*lda*/ m,
identity_data + i * m * m,
identity_data + offset,
/*ldb*/ m);
}
DenseTensor& l_inverse = identity;
Expand Down