@@ -1260,6 +1260,14 @@ struct CUBlas<phi::dtype::complex<double>> {
12601260 }
12611261};
12621262
1263+ inline void CheckGEMMNSize (int64_t N) {
1264+ constexpr int64_t kMaxN = 1073741823 ;
1265+ if (N > kMaxN ) {
1266+ PADDLE_THROW (common::errors::Unimplemented (
1267+ " cublas GEMM does not support N > %ld. Got N = %ld. " , kMaxN , N));
1268+ }
1269+ }
1270+
12631271template <>
12641272template <typename T>
12651273void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
@@ -1307,6 +1315,7 @@ void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
13071315 " GEMM_EX_64 is not supported on cuda < 12.3" ));
13081316#endif
13091317 } else {
1318+ CheckGEMMNSize (N);
13101319 CUBlas<T>::GEMM_EX (&cuda_ctx,
13111320 cuTransB,
13121321 cuTransA,
@@ -1418,6 +1427,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
14181427 " GEMM_EX_64 is not supported on cuda < 12.3" ));
14191428#endif // CUDA_VERSION >= 12030
14201429 } else {
1430+ CheckGEMMNSize (N);
14211431 CUBlas<phi::dtype::float16>::GEMM_EX (&cuda_ctx,
14221432 cuTransB,
14231433 cuTransA,
@@ -1514,6 +1524,7 @@ void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
15141524 " GEMM_EX_64 is not supported on cuda < 12.3" ));
15151525#endif
15161526 } else {
1527+ CheckGEMMNSize (N);
15171528 CUBlas<T>::GEMM_EX (&cuda_ctx,
15181529 cuTransB,
15191530 cuTransA,
@@ -1627,6 +1638,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
16271638#endif // CUDA_VERSION >= 12030
16281639 } else {
16291640#if CUDA_VERSION >= 8000
1641+ CheckGEMMNSize (N);
16301642 CUBlas<phi::dtype::float16>::GEMM_EX (&cuda_ctx,
16311643 cuTransB,
16321644 cuTransA,
@@ -1736,6 +1748,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
17361748 " cublasGemmEx_64 is not supported on cuda < 12.3" ));
17371749#endif // CUDA_VERSION >= 12030
17381750 } else {
1751+ CheckGEMMNSize (N);
17391752 dev_ctx_.TensorCoreCublasCallIfAvailable ([&](cublasHandle_t handle) {
17401753 PADDLE_ENFORCE_GPU_SUCCESS (
17411754 phi::dynload::cublasGemmEx (handle,
@@ -1836,6 +1849,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
18361849 " cublasGemmEx_64 is not supported on cuda < 12.3" ));
18371850#endif // CUDA_VERSION >= 12030
18381851 } else {
1852+ CheckGEMMNSize (N);
18391853 dev_ctx_.TensorCoreCublasCallIfAvailable ([&](cublasHandle_t handle) {
18401854 PADDLE_ENFORCE_GPU_SUCCESS (
18411855 phi::dynload::cublasGemmEx (handle,
@@ -1931,6 +1945,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
19311945#endif // CUDA_VERSION >= 12030
19321946 } else {
19331947#if CUDA_VERSION >= 8000
1948+ CheckGEMMNSize (N);
19341949 CUBlas<phi::dtype::complex <float >>::GEMM_EX (&cuda_ctx,
19351950 cuTransB,
19361951 cuTransA,
@@ -2040,6 +2055,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
20402055#endif // CUDA_VERSION >= 12030
20412056 } else {
20422057#if CUDA_VERSION >= 8000
2058+ CheckGEMMNSize (N);
20432059 CUBlas<phi::dtype::complex <double >>::GEMM_EX (&cuda_ctx,
20442060 cuTransB,
20452061 cuTransA,
@@ -2101,6 +2117,7 @@ void Blas<phi::GPUContext>::GEMM(bool transA,
21012117 cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
21022118
21032119#if CUDA_VERSION >= 8000
2120+ CheckGEMMNSize (N);
21042121 if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float >::value) {
21052122 auto &cuda_ctx = const_cast <phi::GPUContext &>(dev_ctx_);
21062123 CUBlas<T>::GEMM_EX (&cuda_ctx,
@@ -2173,6 +2190,7 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
21732190 if (use_tensor_op_math) {
21742191 algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
21752192 }
2193+ CheckGEMMNSize (N);
21762194 dev_ctx_.TensorCoreCublasCallIfAvailable ([&](cublasHandle_t handle) {
21772195 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasGemmEx (handle,
21782196 cuTransB,
@@ -2234,6 +2252,7 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
22342252 algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
22352253 }
22362254
2255+ CheckGEMMNSize (N);
22372256 dev_ctx_.TensorCoreCublasCallIfAvailable ([&](cublasHandle_t handle) {
22382257 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasGemmEx (handle,
22392258 cuTransB,
0 commit comments