Skip to content

Commit 337bb42

Browse files
authored
Add check with cublas call gemmex (#74306)
* add check with cublas call gemmex * fix typo
1 parent 6827b0d commit 337bb42

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

paddle/phi/kernels/funcs/blas/blas_impl.cu.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,14 @@ struct CUBlas<phi::dtype::complex<double>> {
12601260
}
12611261
};
12621262

1263+
inline void CheckGEMMNSize(int64_t N) {
1264+
constexpr int64_t kMaxN = 1073741823;
1265+
if (N > kMaxN) {
1266+
PADDLE_THROW(common::errors::Unimplemented(
1267+
"cublas GEMM does not support N > %ld. Got N = %ld. ", kMaxN, N));
1268+
}
1269+
}
1270+
12631271
template <>
12641272
template <typename T>
12651273
void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
@@ -1307,6 +1315,7 @@ void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
13071315
"GEMM_EX_64 is not supported on cuda < 12.3"));
13081316
#endif
13091317
} else {
1318+
CheckGEMMNSize(N);
13101319
CUBlas<T>::GEMM_EX(&cuda_ctx,
13111320
cuTransB,
13121321
cuTransA,
@@ -1418,6 +1427,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
14181427
"GEMM_EX_64 is not supported on cuda < 12.3"));
14191428
#endif // CUDA_VERSION >= 12030
14201429
} else {
1430+
CheckGEMMNSize(N);
14211431
CUBlas<phi::dtype::float16>::GEMM_EX(&cuda_ctx,
14221432
cuTransB,
14231433
cuTransA,
@@ -1514,6 +1524,7 @@ void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
15141524
"GEMM_EX_64 is not supported on cuda < 12.3"));
15151525
#endif
15161526
} else {
1527+
CheckGEMMNSize(N);
15171528
CUBlas<T>::GEMM_EX(&cuda_ctx,
15181529
cuTransB,
15191530
cuTransA,
@@ -1627,6 +1638,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
16271638
#endif // CUDA_VERSION >= 12030
16281639
} else {
16291640
#if CUDA_VERSION >= 8000
1641+
CheckGEMMNSize(N);
16301642
CUBlas<phi::dtype::float16>::GEMM_EX(&cuda_ctx,
16311643
cuTransB,
16321644
cuTransA,
@@ -1736,6 +1748,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
17361748
"cublasGemmEx_64 is not supported on cuda < 12.3"));
17371749
#endif // CUDA_VERSION >= 12030
17381750
} else {
1751+
CheckGEMMNSize(N);
17391752
dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
17401753
PADDLE_ENFORCE_GPU_SUCCESS(
17411754
phi::dynload::cublasGemmEx(handle,
@@ -1836,6 +1849,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
18361849
"cublasGemmEx_64 is not supported on cuda < 12.3"));
18371850
#endif // CUDA_VERSION >= 12030
18381851
} else {
1852+
CheckGEMMNSize(N);
18391853
dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
18401854
PADDLE_ENFORCE_GPU_SUCCESS(
18411855
phi::dynload::cublasGemmEx(handle,
@@ -1931,6 +1945,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
19311945
#endif // CUDA_VERSION >= 12030
19321946
} else {
19331947
#if CUDA_VERSION >= 8000
1948+
CheckGEMMNSize(N);
19341949
CUBlas<phi::dtype::complex<float>>::GEMM_EX(&cuda_ctx,
19351950
cuTransB,
19361951
cuTransA,
@@ -2040,6 +2055,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
20402055
#endif // CUDA_VERSION >= 12030
20412056
} else {
20422057
#if CUDA_VERSION >= 8000
2058+
CheckGEMMNSize(N);
20432059
CUBlas<phi::dtype::complex<double>>::GEMM_EX(&cuda_ctx,
20442060
cuTransB,
20452061
cuTransA,
@@ -2101,6 +2117,7 @@ void Blas<phi::GPUContext>::GEMM(bool transA,
21012117
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
21022118

21032119
#if CUDA_VERSION >= 8000
2120+
CheckGEMMNSize(N);
21042121
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
21052122
auto &cuda_ctx = const_cast<phi::GPUContext &>(dev_ctx_);
21062123
CUBlas<T>::GEMM_EX(&cuda_ctx,
@@ -2173,6 +2190,7 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
21732190
if (use_tensor_op_math) {
21742191
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
21752192
}
2193+
CheckGEMMNSize(N);
21762194
dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
21772195
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
21782196
cuTransB,
@@ -2234,6 +2252,7 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
22342252
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
22352253
}
22362254

2255+
CheckGEMMNSize(N);
22372256
dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
22382257
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
22392258
cuTransB,

0 commit comments

Comments
 (0)