@@ -82,8 +82,54 @@ void SlogDeterminantGradKernel(const Context& dev_ctx,
8282 inverse_A.Resize (x.dims ());
8383 dev_ctx.template Alloc <T>(&inverse_A);
8484
85- phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
86- mat_inv (dev_ctx, x, &inverse_A);
85+ const auto & mat_dims = x.dims ();
86+ const int rank = mat_dims.size ();
87+ int n = mat_dims[rank - 1 ];
88+ int64_t total_batch_size = rank > 2 ? x.numel () / (n * n) : 1 ;
89+
90+ // Divide the batch into chunks because of cublasMatInv limitation
91+ if (total_batch_size <= 65536 ) {
92+ phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
93+ mat_inv (dev_ctx, x, &inverse_A);
94+ } else {
95+ constexpr int64_t max_batch_size = 65536 ;
96+ int64_t processed = 0 ;
97+
98+ VLOG (3 ) << " Large batch size detected (" << total_batch_size
99+ << " ), processing in chunks of " << max_batch_size;
100+
101+ while (processed < total_batch_size) {
102+ int64_t current_batch =
103+ std::min (max_batch_size, total_batch_size - processed);
104+
105+ // Extract current batch data
106+ DenseTensor x_batch;
107+ x_batch.ShareDataWith (x);
108+ x_batch.Resize ({total_batch_size, n, n});
109+ x_batch = x_batch.Slice (processed, processed + current_batch);
110+ x_batch.Resize ({current_batch, n, n});
111+
112+ DenseTensor inverse_batch;
113+ inverse_batch.Resize ({current_batch, n, n});
114+ dev_ctx.template Alloc <T>(&inverse_batch);
115+
116+ // Compute the inverse matrix for the current batch
117+ phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
118+ mat_inv (dev_ctx, x_batch, &inverse_batch);
119+
120+ // Copy the result to the output tensor
121+ DenseTensor output_slice;
122+ output_slice.ShareDataWith (inverse_A);
123+ output_slice.Resize ({total_batch_size, n, n});
124+ output_slice = output_slice.Slice (processed, processed + current_batch);
125+ output_slice.Resize ({current_batch, n, n});
126+
127+ phi::Copy (
128+ dev_ctx, inverse_batch, dev_ctx.GetPlace (), false , &output_slice);
129+
130+ processed += current_batch;
131+ }
132+ }
87133
88134 VLOG (3 ) << " inverse(A) dims: " << inverse_A.dims ();
89135
0 commit comments