From 6d19000ab82c0ab8c0261a107e2ba82791823664 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 11 Nov 2024 14:51:05 -0800 Subject: [PATCH] Fix overflow error in GPU batched linear algebra kernels. As reported in https://github.com/jax-ml/jax/issues/24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS. PiperOrigin-RevId: 695490133 --- jaxlib/gpu/make_batch_pointers.cu.cc | 11 +++++++---- jaxlib/gpu/make_batch_pointers.h | 5 ++++- tests/linalg_test.py | 7 +++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc index b10655645924..e4c84aeb7d1d 100644 --- a/jaxlib/gpu/make_batch_pointers.cu.cc +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/gpu/make_batch_pointers.h" #include +#include #include "jaxlib/gpu/vendor.h" @@ -24,8 +25,9 @@ namespace JAX_GPU_NAMESPACE { namespace { __global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out, - int batch, int batch_elem_size) { - for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch; + std::size_t batch, + std::size_t batch_elem_size) { + for (std::size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch; idx += blockDim.x * gridDim.x) { buffer_out[idx] = buffer_in + idx * batch_elem_size; } @@ -33,8 +35,9 @@ __global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out, } // namespace void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, - void* buffer_out, int batch, int batch_elem_size) { - const int block_dim = 128; + void* buffer_out, std::size_t batch, + std::size_t batch_elem_size) { + const std::size_t block_dim = 128; const std::size_t grid_dim = std::min(1024, (batch + block_dim - 1) / block_dim); MakeBatchPointersAsyncKernel<<>>( diff --git a/jaxlib/gpu/make_batch_pointers.h b/jaxlib/gpu/make_batch_pointers.h index f2fd064961e8..1637ad77b5ae 100644 --- a/jaxlib/gpu/make_batch_pointers.h +++ b/jaxlib/gpu/make_batch_pointers.h @@ -16,13 +16,16 @@ limitations under the License. #ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ #define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ +#include + #include "jaxlib/gpu/vendor.h" namespace jax { namespace JAX_GPU_NAMESPACE { void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, - void* buffer_out, int batch, int batch_elem_size); + void* buffer_out, std::size_t batch, + std::size_t batch_elem_size); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5ace4b5ecf18..7c6fade7f2ab 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1450,6 +1450,13 @@ def testLuBatching(self, shape, dtype): self.assertAllClose(ls, actual_ls, rtol=5e-6) self.assertAllClose(us, actual_us) + @jtu.skip_on_devices("cpu", "tpu") + @jtu.skip_on_flag("jax_skip_slow_tests", True) + def testBatchedLuOverflow(self): + x = self.rng().standard_normal((1500000, 20, 20)).astype(np.float32) + lu, _, _ = lax.linalg.lu(x) + self.assertTrue(jnp.all(lu.std(axis=[1, 2]) > 0.9)) + @jtu.skip_on_devices("cpu", "tpu") @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument")