Skip to content

Commit c466f94

Browse files
authored
[Fix Big Tensor] Fix backward accuracy diff of paddle.linalg.slogdet (#74537)
1 parent d4f1984 commit c466f94

File tree

1 file changed

+48
-2
lines changed

1 file changed

+48
-2
lines changed

paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,54 @@ void SlogDeterminantGradKernel(const Context& dev_ctx,
8282
inverse_A.Resize(x.dims());
8383
dev_ctx.template Alloc<T>(&inverse_A);
8484

85-
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
86-
mat_inv(dev_ctx, x, &inverse_A);
85+
const auto& mat_dims = x.dims();
86+
const int rank = mat_dims.size();
87+
int n = mat_dims[rank - 1];
88+
int64_t total_batch_size = rank > 2 ? x.numel() / (n * n) : 1;
89+
90+
// Divide the batch into chunks because of cublasMatInv limitation
91+
if (total_batch_size <= 65536) {
92+
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
93+
mat_inv(dev_ctx, x, &inverse_A);
94+
} else {
95+
constexpr int64_t max_batch_size = 65536;
96+
int64_t processed = 0;
97+
98+
VLOG(3) << "Large batch size detected (" << total_batch_size
99+
<< "), processing in chunks of " << max_batch_size;
100+
101+
while (processed < total_batch_size) {
102+
int64_t current_batch =
103+
std::min(max_batch_size, total_batch_size - processed);
104+
105+
// Extract current batch data
106+
DenseTensor x_batch;
107+
x_batch.ShareDataWith(x);
108+
x_batch.Resize({total_batch_size, n, n});
109+
x_batch = x_batch.Slice(processed, processed + current_batch);
110+
x_batch.Resize({current_batch, n, n});
111+
112+
DenseTensor inverse_batch;
113+
inverse_batch.Resize({current_batch, n, n});
114+
dev_ctx.template Alloc<T>(&inverse_batch);
115+
116+
// Compute the inverse matrix for the current batch
117+
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
118+
mat_inv(dev_ctx, x_batch, &inverse_batch);
119+
120+
// Copy the result to the output tensor
121+
DenseTensor output_slice;
122+
output_slice.ShareDataWith(inverse_A);
123+
output_slice.Resize({total_batch_size, n, n});
124+
output_slice = output_slice.Slice(processed, processed + current_batch);
125+
output_slice.Resize({current_batch, n, n});
126+
127+
phi::Copy(
128+
dev_ctx, inverse_batch, dev_ctx.GetPlace(), false, &output_slice);
129+
130+
processed += current_batch;
131+
}
132+
}
87133

88134
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
89135

0 commit comments

Comments
 (0)