Skip to content

Commit

Permalink
Fix overflow error in GPU batched linear algebra kernels.
Browse files Browse the repository at this point in the history
As reported in jax-ml#24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695694648
  • Loading branch information
dfm authored and yliu120 committed Nov 16, 2024
1 parent 8886f46 commit 796d73b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
11 changes: 7 additions & 4 deletions jaxlib/gpu/make_batch_pointers.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "jaxlib/gpu/make_batch_pointers.h"

#include <algorithm>
#include <cstdint>

#include "jaxlib/gpu/vendor.h"

Expand All @@ -24,17 +25,19 @@ namespace JAX_GPU_NAMESPACE {

namespace {
__global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out,
int batch, int batch_elem_size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
int64_t batch,
int64_t batch_elem_size) {
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
idx += blockDim.x * gridDim.x) {
buffer_out[idx] = buffer_in + idx * batch_elem_size;
}
}
} // namespace

void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
void* buffer_out, int batch, int batch_elem_size) {
const int block_dim = 128;
void* buffer_out, int64_t batch,
int64_t batch_elem_size) {
const std::size_t block_dim = 128;
const std::size_t grid_dim =
std::min<std::size_t>(1024, (batch + block_dim - 1) / block_dim);
MakeBatchPointersAsyncKernel<<<grid_dim, block_dim, 0, stream>>>(
Expand Down
5 changes: 4 additions & 1 deletion jaxlib/gpu/make_batch_pointers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ limitations under the License.
#ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
#define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_

#include <cstdint>

#include "jaxlib/gpu/vendor.h"

namespace jax {
namespace JAX_GPU_NAMESPACE {

void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
void* buffer_out, int batch, int batch_elem_size);
void* buffer_out, int64_t batch,
int64_t batch_elem_size);

} // namespace JAX_GPU_NAMESPACE
} // namespace jax
Expand Down
8 changes: 8 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,14 @@ def testLuBatching(self, shape, dtype):
self.assertAllClose(ls, actual_ls, rtol=5e-6)
self.assertAllClose(us, actual_us)

@jtu.skip_on_devices("cpu", "tpu")
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testBatchedLuOverflow(self):
# see https://github.com/jax-ml/jax/issues/24843
x = self.rng().standard_normal((1500000, 20, 20)).astype(np.float32)
lu, _, _ = lax.linalg.lu(x)
self.assertTrue(jnp.all(lu.std(axis=[1, 2]) > 0.9))

@jtu.skip_on_devices("cpu", "tpu")
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")
Expand Down

0 comments on commit 796d73b

Please sign in to comment.