-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vmap
with scatter_add
extremely slow when using xla_gpu_deterministic_ops
#17844
vmap
with scatter_add
extremely slow when using xla_gpu_deterministic_ops
#17844
Comments
At least at the moment I think this is expected: deterministic scatters are much slower on GPU because they eliminate any parallelism. XLA would need to emit different code for a faster determistic scatter. |
@hawkinsp thanks for a response, it absolutely makes sense that deterministic scatters should be slower. Do you think it's expected that you should get an additional slow-down from scatter_add_batched = jax.vmap(scatter_add, in_axes=(0, 0, 0), out_axes=0)
scatter_add_batched_jit = jax.jit(scatter_add_batched)
scatter_add_batched_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_batched_jit(operand, updates, indices).block_until_ready()
>>> 17.4 s is 2.7x times as slow as this: scatter_add_jit = jax.jit(scatter_add)
scatter_add_jit(operand, updates, indices).block_until_ready()
%timeit [scatter_add_jit(operand[i], updates[i], indices[i]).block_until_ready() for i in range(len(operand))]
>>> 4.6 s with |
I encountered what I'm fairly confident is the same vmap-related slowdown on TPU, profiled it and discovered that while vmap of my function produces a @BrunoKM I will try your Python loop workaround and see whether that improves my use-case for now. Setup: In [1]: from jax import jit, lax, vmap, make_jaxpr
In [2]: import jax.numpy as jnp
In [3]: operand = jnp.ones((3, 4, 5))
In [4]: updates = jnp.ones((3, 2, 5))
In [5]: starts = jnp.ones((3,), dtype='int32')
In [6]: from functools import partial
In [7]: f = partial(lax.dynamic_update_slice_in_dim, axis=0) Printing the jaxpr, note there is a single scatter op: In [8]: make_jaxpr(vmap(f))(operand, updates, starts)
Out[8]:
{ lambda ; a:f32[3,4,5] b:f32[3,2,5] c:i32[3]. let
d:bool[3] = lt c 0
e:i32[3] = add c 4
f:i32[3] = select_n d c e
g:i32[] = add 0 5
h:i32[] = select_n False 0 g
i:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] f
j:i32[3,1] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 1)] h
k:i32[3,2] = concatenate[dimension=1] i j
l:i32[3,1] = iota[dimension=0 dtype=int32 shape=(3, 1)]
m:i32[3,3] = concatenate[dimension=1] l k
n:f32[3,4,5] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2))
indices_are_sorted=True
mode=GatherScatterMode.CLIP
unique_indices=True
update_consts=()
update_jaxpr=None
] a m b
in (n,) } Printing the HLO, because it is long and hard to read I've surrounded the relevant parts in ########: In [9]: print(jit(vmap(f)).lower(operand, updates, starts).compile().as_text())
HloModule jit__unnamed_function_, entry_computation_layout={(f32[3,4,5]{2,1,0}, f32[3,2,5]{2,1,0}, s32[3]{0})->f32[3,4,5]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
#############################################################
%fused_computation (param_0: f32[3,4,5], param_1.3: f32[3,2,5], param_2.7: s32[], param_3.8: pred[], param_4.8: s32[3,3]) -> f32[3,4,5] {
%param_0 = f32[3,4,5]{2,1,0} parameter(0)
%param_3.8 = pred[] parameter(3)
%broadcast.18 = pred[1,2,5]{2,1,0} broadcast(pred[] %param_3.8), dimensions={}
%param_1.3 = f32[3,2,5]{2,1,0} parameter(1)
%param_2.7 = s32[] parameter(2)
%constant.23 = s32[] constant(0)
%dynamic-slice.7 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,2,5]{2,1,0} %param_1.3, s32[] %param_2.7, s32[] %constant.23, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
%param_4.8 = s32[3,3]{1,0} parameter(4)
%dynamic-slice.8 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_4.8, s32[] %param_2.7, s32[] %constant.23), dynamic_slice_sizes={1,3}
%slice.23 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.8), slice={[0:1], [0:1]}
%bitcast.6 = s32[] bitcast(s32[1,1]{1,0} %slice.23)
%bitcast.7 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.8)
%slice.22 = s32[1]{0} slice(s32[3]{0} %bitcast.7), slice={[1:2]}
%bitcast.5 = s32[] bitcast(s32[1]{0} %slice.22)
%dynamic-slice.6 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,4,5]{2,1,0} %param_0, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
%select.1 = f32[1,2,5]{2,1,0} select(pred[1,2,5]{2,1,0} %broadcast.18, f32[1,2,5]{2,1,0} %dynamic-slice.7, f32[1,2,5]{2,1,0} %dynamic-slice.6)
###########################################################
# Note the shape of the update array, it is 1 in the batch dimension
ROOT %dynamic-update-slice.2 = f32[3,4,5]{2,1,0} dynamic-update-slice(f32[3,4,5]{2,1,0} %param_0, f32[1,2,5]{2,1,0} %select.1, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23)
###########################################################
}
#############################################################
%and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] {
%lhs = pred[] parameter(0)
%rhs = pred[] parameter(1)
ROOT %and = pred[] and(pred[] %lhs, pred[] %rhs)
}
%fused_computation.1 (param_0.4: s32[3]) -> pred[] {
%constant.26 = s32[] constant(0)
%broadcast.19 = s32[3]{0} broadcast(s32[] %constant.26), dimensions={}
%param_0.4 = s32[3]{0} parameter(0)
%compare.4 = pred[3]{0} compare(s32[3]{0} %broadcast.19, s32[3]{0} %param_0.4), direction=LE
%constant.25 = s32[3]{0} constant({2, 2, 0})
%compare.3 = pred[3]{0} compare(s32[3]{0} %constant.25, s32[3]{0} %param_0.4), direction=GE
%and.2 = pred[3]{0} and(pred[3]{0} %compare.4, pred[3]{0} %compare.3)
%constant.24 = pred[] constant(true)
ROOT %reduce.1 = pred[] reduce(pred[3]{0} %and.2, pred[] %constant.24), dimensions={0}, to_apply=%and.reduce_sub_computation
}
%fused_computation.2 (param_0.7: s32[3,3], param_1.15: s32[]) -> s32[3] {
%param_0.7 = s32[3,3]{1,0} parameter(0)
%param_1.15 = s32[] parameter(1)
%constant.27 = s32[] constant(0)
%dynamic-slice.9 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_0.7, s32[] %param_1.15, s32[] %constant.27), dynamic_slice_sizes={1,3}
%slice.25 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.9), slice={[0:1], [0:1]}
%bitcast.9 = s32[1]{0} bitcast(s32[1,1]{1,0} %slice.25)
%bitcast.8 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.9)
%slice.24 = s32[2]{0} slice(s32[3]{0} %bitcast.8), slice={[1:3]}
ROOT %concatenate.2 = s32[3]{0} concatenate(s32[1]{0} %bitcast.9, s32[2]{0} %slice.24), dimensions={0}
}
#############################################################
%while_body (param.1: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> (s32[], f32[3,4,5], s32[3,3], f32[3,2,5]) {
%param.1 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
%get-tuple-element.12 = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=0
%copy.3 = s32[] copy(s32[] %get-tuple-element.12)
%constant.10 = s32[] constant(1)
%add = s32[] add(s32[] %copy.3, s32[] %constant.10)
%get-tuple-element.13 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=1
%get-tuple-element.19 = f32[3,2,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=3
%get-tuple-element.18 = s32[3,3]{1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=2
%fusion.2 = s32[3]{0} fusion(s32[3,3]{1,0} %get-tuple-element.18, s32[] %copy.3), kind=kLoop, calls=%fused_computation.2
%fusion.1 = pred[] fusion(s32[3]{0} %fusion.2), kind=kLoop, calls=%fused_computation.1
###########################################################
%fusion = f32[3,4,5]{2,1,0} fusion(f32[3,4,5]{2,1,0} %get-tuple-element.13, f32[3,2,5]{2,1,0} %get-tuple-element.19, s32[] %copy.3, pred[] %fusion.1, s32[3,3]{1,0} %get-tuple-element.18), kind=kLoop, calls=%fused_computation
###########################################################
ROOT %tuple.5 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %add, f32[3,4,5]{2,1,0} %fusion, s32[3,3]{1,0} %get-tuple-element.18, f32[3,2,5]{2,1,0} %get-tuple-element.19)
}
#############################################################
%while_cond (param.0: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> pred[] {
%param.0 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
%get-tuple-element = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.0), index=0
%constant.1 = s32[] constant(3)
ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.1), direction=LT
}
%fused_computation.3 (param_0.10: s32[3]) -> s32[3,3] {
%constant.30 = s32[] constant(0)
%broadcast.25 = s32[3,3]{1,0} broadcast(s32[] %constant.30), dimensions={}
%iota.1 = s32[3,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(<unnamed function>)/jit(main)/iota[dtype=int32 shape=(3, 1) dimension=0]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%param_0.10 = s32[3]{0} parameter(0)
%broadcast.24 = s32[3]{0} broadcast(s32[] %constant.30), dimensions={}
%compare.5 = pred[3]{0} compare(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.24), direction=LT, metadata={op_name="jit(<unnamed function>)/jit(main)/lt" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%constant.29 = s32[] constant(4)
%broadcast.23 = s32[3]{0} broadcast(s32[] %constant.29), dimensions={}
%add.1 = s32[3]{0} add(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.23), metadata={op_name="jit(<unnamed function>)/jit(main)/add" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%select.2 = s32[3]{0} select(pred[3]{0} %compare.5, s32[3]{0} %add.1, s32[3]{0} %param_0.10), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" sou
rce_file="<ipython-input-9-61375df8be79>" source_line=1}
%bitcast.10 = s32[3,1]{1,0} bitcast(s32[3]{0} %select.2), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%broadcast.22 = s32[3,1]{1,0} broadcast(s32[] %constant.30), dimensions={}
%concatenate.3 = s32[3,3]{1,0} concatenate(s32[3,1]{1,0} %iota.1, s32[3,1]{1,0} %bitcast.10, s32[3,1]{1,0} %broadcast.22), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/concatenate[dimension=1]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%constant.28 = s32[3]{0} constant({2, 2, 0})
%broadcast.21 = s32[3,3]{1,0} broadcast(s32[3]{0} %constant.28), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=(1,)]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
ROOT %clamp.0 = s32[3,3]{1,0} clamp(s32[3,3]{1,0} %broadcast.25, s32[3,3]{1,0} %concatenate.3, s32[3,3]{1,0} %broadcast.21), metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}
#############################################################
ENTRY %main.26 (Arg_0.1: f32[3,4,5], Arg_1.2: f32[3,2,5], Arg_2.3: s32[3]) -> f32[3,4,5] {
%constant.4 = s32[] constant(0)
%copy.8 = s32[] copy(s32[] %constant.4)
%Arg_0.1 = f32[3,4,5]{2,1,0} parameter(0), sharding={replicated}
%copy.7 = f32[3,4,5]{2,1,0} copy(f32[3,4,5]{2,1,0} %Arg_0.1)
%Arg_2.3 = s32[3]{0} parameter(2), sharding={replicated}
%fusion.3 = s32[3,3]{1,0} fusion(s32[3]{0} %Arg_2.3), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%Arg_1.2 = f32[3,2,5]{2,1,0} parameter(1), sharding={replicated}
%tuple.3 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %copy.8, f32[3,4,5]{2,1,0} %copy.7, s32[3,3]{1,0} %fusion.3, f32[3,2,5]{2,1,0} %Arg_1.2)
###########################################################
%while = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) while((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %tuple.3), condition=%while_cond, body=%while_body, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
###########################################################
ROOT %get-tuple-element.5 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %while), index=1, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}
############################################################# |
Here's a more minimal, self-contained snippet which reproduces the vmap slowdown. On TPU the vmap version is around 2x slower than using a Python loop and from timeit import timeit
from jax import jit, lax, vmap
import jax.numpy as jnp
# For f which outputs a single array, this simulates vmap using Python map
pymap = lambda f: lambda *args: jnp.stack(list(map(f, *args)))
operands = jnp.ones((100, 32))
updates = jnp.ones((100, 2))
starts = jnp.ones((100, 1), dtype='int32')
f = lax.dynamic_update_slice
f_vmapped = jit(vmap(f))
f_pymapped = jit(pymap(f))
# Ensure compiled
f_vmapped(operands, updates, starts)
f_pymapped(operands, updates, starts)
t_vmapped = timeit(
lambda: f_vmapped(operands, updates, starts).block_until_ready(), number=100
) / 100
t_pymapped = timeit(
lambda: f_pymapped(operands, updates, starts).block_until_ready(), number=100
) / 100
print(f"Time vmap(f): {t_vmapped:.2}s")
print(f"Time pymap(f): {t_pymapped:.2}s") On a TPU v4-8 VM I get:
Running the script on CPU on my laptop, the Python loop version is slower than the vmap version
|
I realize this is an older issue, but one option is to roll your own deterministic scatter_add (using prefix sums): def add_segment(iv, jt):
i, v = iv
j, t = jt
return j, v * jp.equal(i, j) + t
@jax.jit
def scatter_add_det(operand, updates, indices):
indices = jp.reshape(indices, updates.shape)
# Sort the indices and the values by the indices.
indices, sorted = jax.lax.sort_key_val(indices, updates, dimension=-1)
# Sum up runs of the same index - the sum for each index will be at the end of each run.
_, sums = jax.lax.associative_scan(add_segment, (indices, sorted))
# Produce an array of bools - if an element is set then the position
# is the end of run of the same index.
end_of_run = jp.concatenate([jp.not_equal(indices[1:], indices[:-1]), jp.array([True])])
# Set all position that are not at end of run to an out-of-bound index.
indices = jp.where(end_of_run, indices, operand.shape[-1])
# Now do scatter-add where we know the (in-bounds) indices are unique.
# That is still fast on GPUs (no non-determinism from atomics).
return operand.at[indices].add(sums, mode='drop', unique_indices=True) This is 5-15x slower than the non-deterministic one (depending on shape of things), but at least it's not multiple orders of magnitude. It would be nice if XLA could lower to something like this automatically. |
Imported from GitHub PR #17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at #18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615 PiperOrigin-RevId: 686779279
Imported from GitHub PR openxla/xla#17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at openxla/xla#18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a PiperOrigin-RevId: 686779279
Imported from GitHub PR #17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at #18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615 PiperOrigin-RevId: 686779279
Imported from GitHub PR #17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at #18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615 PiperOrigin-RevId: 686779279
Imported from GitHub PR openxla/xla#17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at openxla/xla#18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a PiperOrigin-RevId: 686779279
Imported from GitHub PR #17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at #18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615 PiperOrigin-RevId: 686779279
Imported from GitHub PR openxla/xla#17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at openxla/xla#18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a PiperOrigin-RevId: 686779279
Imported from GitHub PR #17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at #18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615 PiperOrigin-RevId: 686871951
Imported from GitHub PR openxla/xla#17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at openxla/xla#18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 PiperOrigin-RevId: 686871951
…r operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 690490783
…r operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 690490783
…r operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 690490783
…r operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 690490783
…r operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 690490783
…r operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 691023328
…r operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328
… scatter operations Imported from GitHub PR openxla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes openxla#18326 COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 691023328
… scatter operations Imported from GitHub PR openxla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes openxla#18326 COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 691023328
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696790875
… scatter operations Imported from GitHub PR openxla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes openxla#18326 COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 691023328
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- b016044 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- fbdb066 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- d36c8ac by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 678886f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- b01604490908fbe43685aed7178d0a66602b7a8c by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- fbdb066fd38a2fadb4322caaabe8c8d1a9fa77e3 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- d36c8ac7260c241c4ca6ed7dc16018f8030c0b80 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 678886f97bd133c4ffa2fbf0365e15c808383a6f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696956113
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 PiperOrigin-RevId: 696956113
Description
The issue is 1) about a rather significant slow-down to the
scatter_add
operation when running jax with thexla_gpu_deterministic_ops=true
flag, and 2) about a further disproportionately large slow-down when usingvmap
around ascatter_add
operation.Below is the code to reproduce the issue. The timings are run with and without prepending
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
at the start of the script.Firstly, just a regular
scatter_add
benchmark:Secondly, a
scatter_add
benchmark withvmap
:It seems pretty unexpected that, when the
xla_gpu_deterministic_ops
flag is set to true, callingscatter_add
withvmap
with a batch-size of 100 makes the runtime 377x longer, i.e. 3.7 times slower than just using a manual pythonfor
-loop.Unrelatedly, although the slow-down of
scatter_add
is to be expected when enforcing determinism, it is rather severe (almost 2000x slower withoutvmap
, and over 200000x slower withvmap
).I guess this operation doesn't come up very regularly, but it appears, for example, in the backward pass through a bilinear interpolation of an image (e.g. when using
jax.scipy.ndimage.map_coordinates
). Even if thevmap
issue gets resolved, it would be absolutely fantastic if, in addition, there was some kind of warning about the potential impact on runtime that was shown when executing code with--xla_gpu_deterministic_ops=true
.What jax/jaxlib version are you using?
jax v0.4.16, jax v0.4.16+cuda12.cudnn89
Which accelerator(s) are you using?
GPU
Additional system info
Linux, Ubuntu 22.04.3 LTS, Python 3.11.3
NVIDIA GPU info
Reproduced on a 3090 as well.
The text was updated successfully, but these errors were encountered: