Skip to content

Commit

Permalink
Use vmap for random_gamma implementation on CPU backend
Browse files Browse the repository at this point in the history
XLA:CPU is preparing to switch from compiling whole XLA program into a single LLVM function to a mode where each fusion/kernel will have its own entry point, and a thin runtime that will dispatch compute functions concurrently. This execution mode does not work very well with while loops with tiny computations and large number of iterations. Similar to GPU backend use vmap to avoid excessive runtime overheads.

Context: openxla/community#96
PiperOrigin-RevId: 656199716
  • Loading branch information
ezhulenev authored and jax authors committed Jul 26, 2024
1 parent 2eb1888 commit 15d4389
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
partial(_gamma_impl, use_vmap=True),
multiple_results=False))
mlir.register_lowering(random_gamma_p, mlir.lower_fun(
partial(_gamma_impl, use_vmap=False),
partial(_gamma_impl, use_vmap=True),
multiple_results=False), platform='cpu')
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule

Expand Down

0 comments on commit 15d4389

Please sign in to comment.