Skip to content

Commit

Permalink
Fix overflow error in GPU batched linear algebra kernels.
Browse files Browse the repository at this point in the history
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695490133
  • Loading branch information
dfm authored and Google-ML-Automation committed Nov 11, 2024
1 parent 6892e62 commit 6d19000
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
11 changes: 7 additions & 4 deletions jaxlib/gpu/make_batch_pointers.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "jaxlib/gpu/make_batch_pointers.h"

#include <algorithm>
#include <cstddef>

#include "jaxlib/gpu/vendor.h"

Expand All @@ -24,17 +25,19 @@ namespace JAX_GPU_NAMESPACE {

namespace {
__global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out,
int batch, int batch_elem_size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
std::size_t batch,
std::size_t batch_elem_size) {
for (std::size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
idx += blockDim.x * gridDim.x) {
buffer_out[idx] = buffer_in + idx * batch_elem_size;
}
}
} // namespace

void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
void* buffer_out, int batch, int batch_elem_size) {
const int block_dim = 128;
void* buffer_out, std::size_t batch,
std::size_t batch_elem_size) {
const std::size_t block_dim = 128;
const std::size_t grid_dim =
std::min<std::size_t>(1024, (batch + block_dim - 1) / block_dim);
MakeBatchPointersAsyncKernel<<<grid_dim, block_dim, 0, stream>>>(
Expand Down
5 changes: 4 additions & 1 deletion jaxlib/gpu/make_batch_pointers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ limitations under the License.
#ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
#define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_

#include <cstddef>

#include "jaxlib/gpu/vendor.h"

namespace jax {
namespace JAX_GPU_NAMESPACE {

void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
void* buffer_out, int batch, int batch_elem_size);
void* buffer_out, std::size_t batch,
std::size_t batch_elem_size);

} // namespace JAX_GPU_NAMESPACE
} // namespace jax
Expand Down
7 changes: 7 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,13 @@ def testLuBatching(self, shape, dtype):
self.assertAllClose(ls, actual_ls, rtol=5e-6)
self.assertAllClose(us, actual_us)

@jtu.skip_on_devices("cpu", "tpu")
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testBatchedLuOverflow(self):
x = self.rng().standard_normal((1500000, 20, 20)).astype(np.float32)
lu, _, _ = lax.linalg.lu(x)
self.assertTrue(jnp.all(lu.std(axis=[1, 2]) > 0.9))

@jtu.skip_on_devices("cpu", "tpu")
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")
Expand Down

0 comments on commit 6d19000

Please sign in to comment.