Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap with scatter_add extremely slow when using xla_gpu_deterministic_ops #17844

Open
Labels
NVIDIA GPU Issues specific to NVIDIA GPUs performance make things lean and fast XLA

Comments

@BrunoKM
Copy link

BrunoKM commented Sep 29, 2023

Description

The issue is 1) about a rather significant slow-down to the scatter_add operation when running jax with the xla_gpu_deterministic_ops=true flag, and 2) about a further disproportionately large slow-down when using vmap around a scatter_add operation.

Below is the code to reproduce the issue. The timings are run with and without prepending os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" at the start of the script.

Firstly, just a regular scatter_add benchmark:

import jax
import jax.numpy as jnp

def scatter_add(
    operand,  # [operand_size]
    updates,  # [updates_size]
    indices,  # [updates_size, 1]
):
    # Define dimension numbers
    update_window_dims = tuple()
    inserted_window_dims = (0,)
    scatter_dims_to_operand_dims = (0,)
    res = jax.lax.scatter_add(
        operand,
        indices,
        updates,
        dimension_numbers=jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims),
        mode="drop",
    )
    return res

operand_size = 64 * 64  # e.g. a 64x64 image

operand = jnp.zeros((operand_size,))
updates = jnp.ones((operand_size * 4))
rng = jax.random.PRNGKey(0)
indices = jax.random.randint(rng, shape=(operand_size * 4, 1), minval=0, maxval=operand_size)

scatter_add_jit = jax.jit(scatter_add)
scatter_add_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_jit(operand, updates, indices).block_until_ready()
# Without: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 25.3 µs ± 81 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 46.1 ms ± 4.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Secondly, a scatter_add benchmark with vmap:

n_batches = 100

operand = jnp.zeros((n_batches, operand_size,))
updates = jnp.ones((n_batches, operand_size * 4))
rng = jax.random.PRNGKey(0)
indices = jax.random.randint(rng, shape=(n_batches, operand_size * 4, 1), minval=0, maxval=operand_size)

scatter_add_batched = jax.vmap(scatter_add, in_axes=(0, 0, 0), out_axes=0)

scatter_add_batched_jit = jax.jit(scatter_add_batched)
scatter_add_batched_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_batched_jit(operand, updates, indices).block_until_ready()
# Without: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 79.7 µs ± 173 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 17.4 s ± 61.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

It seems pretty unexpected that, when the xla_gpu_deterministic_ops flag is set to true, calling scatter_add with vmap with a batch-size of 100 makes the runtime 377x longer, i.e. 3.7 times slower than just using a manual python for-loop.

Unrelatedly, although the slow-down of scatter_add is to be expected when enforcing determinism, it is rather severe (almost 2000x slower without vmap, and over 200000x slower with vmap).
I guess this operation doesn't come up very regularly, but it appears, for example, in the backward pass through a bilinear interpolation of an image (e.g. when using jax.scipy.ndimage.map_coordinates). Even if the vmap issue gets resolved, it would be absolutely fantastic if, in addition, there was some kind of warning about the potential impact on runtime that was shown when executing code with --xla_gpu_deterministic_ops=true.

What jax/jaxlib version are you using?

jax v0.4.16, jax v0.4.16+cuda12.cudnn89

Which accelerator(s) are you using?

GPU

Additional system info

Linux, Ubuntu 22.04.3 LTS, Python 3.11.3

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090         On | 00000000:01:00.0 Off |                  Off |
| 30%   43C    P2              139W / 500W|  20400MiB / 24564MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

Reproduced on a 3090 as well.

@BrunoKM BrunoKM added the bug Something isn't working label Sep 29, 2023
@hawkinsp hawkinsp added performance make things lean and fast XLA NVIDIA GPU Issues specific to NVIDIA GPUs and removed bug Something isn't working labels Oct 3, 2023
@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 3, 2023

At least at the moment I think this is expected: deterministic scatters are much slower on GPU because they eliminate any parallelism. XLA would need to emit different code for a faster determistic scatter.

@BrunoKM
Copy link
Author

BrunoKM commented Oct 4, 2023

@hawkinsp thanks for a response, it absolutely makes sense that deterministic scatters should be slower.

Do you think it's expected that you should get an additional slow-down from vmap? I.e. that this:

scatter_add_batched = jax.vmap(scatter_add, in_axes=(0, 0, 0), out_axes=0)

scatter_add_batched_jit = jax.jit(scatter_add_batched)
scatter_add_batched_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_batched_jit(operand, updates, indices).block_until_ready()
>>> 17.4 s

is 2.7x times as slow as this:

scatter_add_jit = jax.jit(scatter_add)
scatter_add_jit(operand, updates, indices).block_until_ready()
%timeit [scatter_add_jit(operand[i], updates[i], indices[i]).block_until_ready() for i in range(len(operand))]
>>> 4.6 s

with os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" enabled? I couldn't think of a reason this should be the case; both are deterministic, and I'd think compiling to xla should be at least as fast as doing the loop in python.

@j-towns
Copy link
Contributor

j-towns commented May 17, 2024

I encountered what I'm fairly confident is the same vmap-related slowdown on TPU, profiled it and discovered that while vmap of my function produces a scatter in the jaxpr, that op lowers to a while which loops over the mapped axis in XLA, with the body of the while loop containing a dynamic-update-slice, not a scatter. Presumably since XLA whiles cannot, in general, be parallelized, the compiler is unable to see this potential optimization. I don't know if this issue is common to the lowering any 'ragged' scatter.

@BrunoKM I will try your Python loop workaround and see whether that improves my use-case for now.

Setup:

In [1]: from jax import jit, lax, vmap, make_jaxpr
In [2]: import jax.numpy as jnp
In [3]: operand = jnp.ones((3, 4, 5))
In [4]: updates = jnp.ones((3, 2, 5))
In [5]: starts = jnp.ones((3,), dtype='int32')
In [6]: from functools import partial
In [7]: f = partial(lax.dynamic_update_slice_in_dim, axis=0)

Printing the jaxpr, note there is a single scatter op:

In [8]: make_jaxpr(vmap(f))(operand, updates, starts)
Out[8]: 
{ lambda ; a:f32[3,4,5] b:f32[3,2,5] c:i32[3]. let
    d:bool[3] = lt c 0
    e:i32[3] = add c 4
    f:i32[3] = select_n d c e
    g:i32[] = add 0 5
    h:i32[] = select_n False 0 g
    i:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] f
    j:i32[3,1] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 1)] h
    k:i32[3,2] = concatenate[dimension=1] i j
    l:i32[3,1] = iota[dimension=0 dtype=int32 shape=(3, 1)] 
    m:i32[3,3] = concatenate[dimension=1] l k
    n:f32[3,4,5] = scatter[
      dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2))
      indices_are_sorted=True
      mode=GatherScatterMode.CLIP
      unique_indices=True
      update_consts=()
      update_jaxpr=None
    ] a m b
  in (n,) }

Printing the HLO, because it is long and hard to read I've surrounded the relevant parts in ########:

In [9]: print(jit(vmap(f)).lower(operand, updates, starts).compile().as_text())
HloModule jit__unnamed_function_, entry_computation_layout={(f32[3,4,5]{2,1,0}, f32[3,2,5]{2,1,0}, s32[3]{0})->f32[3,4,5]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

#############################################################
%fused_computation (param_0: f32[3,4,5], param_1.3: f32[3,2,5], param_2.7: s32[], param_3.8: pred[], param_4.8: s32[3,3]) -> f32[3,4,5] {
  %param_0 = f32[3,4,5]{2,1,0} parameter(0)
  %param_3.8 = pred[] parameter(3)
  %broadcast.18 = pred[1,2,5]{2,1,0} broadcast(pred[] %param_3.8), dimensions={}
  %param_1.3 = f32[3,2,5]{2,1,0} parameter(1)
  %param_2.7 = s32[] parameter(2)
  %constant.23 = s32[] constant(0)
  %dynamic-slice.7 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,2,5]{2,1,0} %param_1.3, s32[] %param_2.7, s32[] %constant.23, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
  %param_4.8 = s32[3,3]{1,0} parameter(4)
  %dynamic-slice.8 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_4.8, s32[] %param_2.7, s32[] %constant.23), dynamic_slice_sizes={1,3}
  %slice.23 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.8), slice={[0:1], [0:1]}
  %bitcast.6 = s32[] bitcast(s32[1,1]{1,0} %slice.23)
  %bitcast.7 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.8)
  %slice.22 = s32[1]{0} slice(s32[3]{0} %bitcast.7), slice={[1:2]}
  %bitcast.5 = s32[] bitcast(s32[1]{0} %slice.22)
  %dynamic-slice.6 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,4,5]{2,1,0} %param_0, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
  %select.1 = f32[1,2,5]{2,1,0} select(pred[1,2,5]{2,1,0} %broadcast.18, f32[1,2,5]{2,1,0} %dynamic-slice.7, f32[1,2,5]{2,1,0} %dynamic-slice.6)
  ###########################################################
  # Note the shape of the update array, it is 1 in the batch dimension
  ROOT %dynamic-update-slice.2 = f32[3,4,5]{2,1,0} dynamic-update-slice(f32[3,4,5]{2,1,0} %param_0, f32[1,2,5]{2,1,0} %select.1, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23)
  ###########################################################
}
#############################################################

%and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] {
  %lhs = pred[] parameter(0)
  %rhs = pred[] parameter(1)
  ROOT %and = pred[] and(pred[] %lhs, pred[] %rhs)
}

%fused_computation.1 (param_0.4: s32[3]) -> pred[] {
  %constant.26 = s32[] constant(0)
  %broadcast.19 = s32[3]{0} broadcast(s32[] %constant.26), dimensions={}
  %param_0.4 = s32[3]{0} parameter(0)
  %compare.4 = pred[3]{0} compare(s32[3]{0} %broadcast.19, s32[3]{0} %param_0.4), direction=LE
  %constant.25 = s32[3]{0} constant({2, 2, 0})
  %compare.3 = pred[3]{0} compare(s32[3]{0} %constant.25, s32[3]{0} %param_0.4), direction=GE
  %and.2 = pred[3]{0} and(pred[3]{0} %compare.4, pred[3]{0} %compare.3)
  %constant.24 = pred[] constant(true)
  ROOT %reduce.1 = pred[] reduce(pred[3]{0} %and.2, pred[] %constant.24), dimensions={0}, to_apply=%and.reduce_sub_computation
}

%fused_computation.2 (param_0.7: s32[3,3], param_1.15: s32[]) -> s32[3] {
  %param_0.7 = s32[3,3]{1,0} parameter(0)
  %param_1.15 = s32[] parameter(1)
  %constant.27 = s32[] constant(0)
  %dynamic-slice.9 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_0.7, s32[] %param_1.15, s32[] %constant.27), dynamic_slice_sizes={1,3}
  %slice.25 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.9), slice={[0:1], [0:1]}
  %bitcast.9 = s32[1]{0} bitcast(s32[1,1]{1,0} %slice.25)
  %bitcast.8 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.9)
  %slice.24 = s32[2]{0} slice(s32[3]{0} %bitcast.8), slice={[1:3]}
  ROOT %concatenate.2 = s32[3]{0} concatenate(s32[1]{0} %bitcast.9, s32[2]{0} %slice.24), dimensions={0}
}

#############################################################
%while_body (param.1: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> (s32[], f32[3,4,5], s32[3,3], f32[3,2,5]) {
  %param.1 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
  %get-tuple-element.12 = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=0
  %copy.3 = s32[] copy(s32[] %get-tuple-element.12)
  %constant.10 = s32[] constant(1)
  %add = s32[] add(s32[] %copy.3, s32[] %constant.10)
  %get-tuple-element.13 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=1
  %get-tuple-element.19 = f32[3,2,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=3
  %get-tuple-element.18 = s32[3,3]{1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=2
  %fusion.2 = s32[3]{0} fusion(s32[3,3]{1,0} %get-tuple-element.18, s32[] %copy.3), kind=kLoop, calls=%fused_computation.2
  %fusion.1 = pred[] fusion(s32[3]{0} %fusion.2), kind=kLoop, calls=%fused_computation.1
  ###########################################################
  %fusion = f32[3,4,5]{2,1,0} fusion(f32[3,4,5]{2,1,0} %get-tuple-element.13, f32[3,2,5]{2,1,0} %get-tuple-element.19, s32[] %copy.3, pred[] %fusion.1, s32[3,3]{1,0} %get-tuple-element.18), kind=kLoop, calls=%fused_computation
  ###########################################################
  ROOT %tuple.5 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %add, f32[3,4,5]{2,1,0} %fusion, s32[3,3]{1,0} %get-tuple-element.18, f32[3,2,5]{2,1,0} %get-tuple-element.19)
}
#############################################################

%while_cond (param.0: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> pred[] {
  %param.0 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
  %get-tuple-element = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.0), index=0
  %constant.1 = s32[] constant(3)
  ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.1), direction=LT
}

%fused_computation.3 (param_0.10: s32[3]) -> s32[3,3] {
  %constant.30 = s32[] constant(0)
  %broadcast.25 = s32[3,3]{1,0} broadcast(s32[] %constant.30), dimensions={}
  %iota.1 = s32[3,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(<unnamed function>)/jit(main)/iota[dtype=int32 shape=(3, 1) dimension=0]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %param_0.10 = s32[3]{0} parameter(0)
  %broadcast.24 = s32[3]{0} broadcast(s32[] %constant.30), dimensions={}
  %compare.5 = pred[3]{0} compare(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.24), direction=LT, metadata={op_name="jit(<unnamed function>)/jit(main)/lt" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %constant.29 = s32[] constant(4)
  %broadcast.23 = s32[3]{0} broadcast(s32[] %constant.29), dimensions={}
  %add.1 = s32[3]{0} add(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.23), metadata={op_name="jit(<unnamed function>)/jit(main)/add" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %select.2 = s32[3]{0} select(pred[3]{0} %compare.5, s32[3]{0} %add.1, s32[3]{0} %param_0.10), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" sou
rce_file="<ipython-input-9-61375df8be79>" source_line=1}
  %bitcast.10 = s32[3,1]{1,0} bitcast(s32[3]{0} %select.2), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %broadcast.22 = s32[3,1]{1,0} broadcast(s32[] %constant.30), dimensions={}
  %concatenate.3 = s32[3,3]{1,0} concatenate(s32[3,1]{1,0} %iota.1, s32[3,1]{1,0} %bitcast.10, s32[3,1]{1,0} %broadcast.22), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/concatenate[dimension=1]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %constant.28 = s32[3]{0} constant({2, 2, 0})
  %broadcast.21 = s32[3,3]{1,0} broadcast(s32[3]{0} %constant.28), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=(1,)]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  ROOT %clamp.0 = s32[3,3]{1,0} clamp(s32[3,3]{1,0} %broadcast.25, s32[3,3]{1,0} %concatenate.3, s32[3,3]{1,0} %broadcast.21), metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}

#############################################################
ENTRY %main.26 (Arg_0.1: f32[3,4,5], Arg_1.2: f32[3,2,5], Arg_2.3: s32[3]) -> f32[3,4,5] {
  %constant.4 = s32[] constant(0)
  %copy.8 = s32[] copy(s32[] %constant.4)
  %Arg_0.1 = f32[3,4,5]{2,1,0} parameter(0), sharding={replicated}
  %copy.7 = f32[3,4,5]{2,1,0} copy(f32[3,4,5]{2,1,0} %Arg_0.1)
  %Arg_2.3 = s32[3]{0} parameter(2), sharding={replicated}
  %fusion.3 = s32[3,3]{1,0} fusion(s32[3]{0} %Arg_2.3), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  %Arg_1.2 = f32[3,2,5]{2,1,0} parameter(1), sharding={replicated}
  %tuple.3 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %copy.8, f32[3,4,5]{2,1,0} %copy.7, s32[3,3]{1,0} %fusion.3, f32[3,2,5]{2,1,0} %Arg_1.2)
  ###########################################################
  %while = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) while((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %tuple.3), condition=%while_cond, body=%while_body, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
  ###########################################################
  ROOT %get-tuple-element.5 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %while), index=1, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}
#############################################################

@j-towns
Copy link
Contributor

j-towns commented May 17, 2024

Here's a more minimal, self-contained snippet which reproduces the vmap slowdown. On TPU the vmap version is around 2x slower than using a Python loop and jnp.stacking the result.

from timeit import timeit

from jax import jit, lax, vmap
import jax.numpy as jnp


# For f which outputs a single array, this simulates vmap using Python map
pymap = lambda f: lambda *args: jnp.stack(list(map(f, *args)))

operands = jnp.ones((100, 32))
updates = jnp.ones((100, 2))
starts = jnp.ones((100, 1), dtype='int32')

f = lax.dynamic_update_slice

f_vmapped = jit(vmap(f))
f_pymapped = jit(pymap(f))

# Ensure compiled
f_vmapped(operands, updates, starts)
f_pymapped(operands, updates, starts)

t_vmapped = timeit(
    lambda: f_vmapped(operands, updates, starts).block_until_ready(), number=100
) / 100

t_pymapped = timeit(
    lambda: f_pymapped(operands, updates, starts).block_until_ready(), number=100
) / 100

print(f"Time vmap(f): {t_vmapped:.2}s")
print(f"Time pymap(f): {t_pymapped:.2}s")

On a TPU v4-8 VM I get:

Time vmap(f): 0.00088s
Time pymap(f): 0.00036s

Running the script on CPU on my laptop, the Python loop version is slower than the vmap version

Time vmap(f): 1.3e-05s
Time pymap(f): 3.3e-05s

@jaro-sevcik
Copy link
Contributor

jaro-sevcik commented Jul 20, 2024

I realize this is an older issue, but one option is to roll your own deterministic scatter_add (using prefix sums):

def add_segment(iv, jt):
  i, v = iv
  j, t = jt
  return j, v * jp.equal(i, j) + t
  
@jax.jit
def scatter_add_det(operand, updates, indices):
  indices = jp.reshape(indices, updates.shape)
  # Sort the indices and the values by the indices.
  indices, sorted = jax.lax.sort_key_val(indices, updates, dimension=-1)
  # Sum up runs of the same index - the sum for each index will be at the end of each run.
  _, sums = jax.lax.associative_scan(add_segment, (indices, sorted))
  # Produce an array of bools - if an element is set then the position
  # is the end of run of the same index.
  end_of_run = jp.concatenate([jp.not_equal(indices[1:], indices[:-1]), jp.array([True])])
  # Set all position that are not at end of run to an out-of-bound index.
  indices = jp.where(end_of_run, indices, operand.shape[-1])
  # Now do scatter-add where we know the (in-bounds) indices are unique.
  # That is still fast on GPUs (no non-determinism from atomics).
  return operand.at[indices].add(sums, mode='drop', unique_indices=True)

This is 5-15x slower than the non-deterministic one (depending on shape of things), but at least it's not multiple orders of magnitude. It would be nice if XLA could lower to something like this automatically.

copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 17, 2024
Imported from GitHub PR openxla/xla#17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at openxla/xla#18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 17, 2024
Imported from GitHub PR openxla/xla#17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at openxla/xla#18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 17, 2024
Imported from GitHub PR openxla/xla#17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at openxla/xla#18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686871951
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 17, 2024
Imported from GitHub PR openxla/xla#17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at openxla/xla#18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

PiperOrigin-RevId: 686871951
copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 28, 2024
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 28, 2024
…r operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 28, 2024
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 28, 2024
…r operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 29, 2024
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 29, 2024
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 29, 2024
…r operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328
serach24 added a commit to serach24/xla that referenced this issue Nov 9, 2024
… scatter operations

Imported from GitHub PR openxla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes openxla#18326

COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
serach24 added a commit to serach24/xla that referenced this issue Nov 12, 2024
… scatter operations

Imported from GitHub PR openxla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes openxla#18326

COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696790875
serach24 added a commit to serach24/xla that referenced this issue Nov 15, 2024
… scatter operations

Imported from GitHub PR openxla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes openxla#18326

COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
b016044 by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
fbdb066 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
d36c8ac by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
678886f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
b01604490908fbe43685aed7178d0a66602b7a8c by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
fbdb066fd38a2fadb4322caaabe8c8d1a9fa77e3 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
d36c8ac7260c241c4ca6ed7dc16018f8030c0b80 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
678886f97bd133c4ffa2fbf0365e15c808383a6f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696956113
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

PiperOrigin-RevId: 696956113
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment