Skip to content

Commit 4bf1df1

Browse files
committed
fix__linalg_solve
1 parent 8ad991d commit 4bf1df1

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

paddle/phi/kernels/funcs/matrix_solve.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include "paddle/phi/kernels/funcs/blas/blas.h"
1919
#include "paddle/phi/kernels/funcs/math_function.h"
2020
#include "paddle/phi/kernels/funcs/scatter.cu.h"
21+
#include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h"
2122

2223
namespace phi {
2324
namespace funcs {
@@ -161,11 +162,14 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& dev_ctx,
161162
int n = a_dims[a_rank - 1];
162163
int lda = n;
163164
int64_t batch_size = a_rank > 2 ? a.numel() / (n * n) : 1;
165+
CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(a);
166+
164167

165168
const auto& b_dims = b.dims();
166169
const int b_rank = b_dims.size();
167170
int nrhs = b_dims[b_rank - 1];
168171
int ldb = n;
172+
CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(b);
169173

170174
// 1. Copy input A to a temporary tensor tmp_a for LU factorization.
171175
DenseTensor tmp_a(a.dtype());

0 commit comments

Comments
 (0)