Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Feb 13, 2025

strangely enough dot_general is giving me incorrect results when using sharding. Every other operator I tested with the exact same sharding setup gives correct result.

Once the JLL builds I will test it out on a tpu pod to verify this isn't some weird behavior originating from --xla_force_host_platform_device_count=8

This should also unblock PRONTOLab/GB-25#8 (comment)

@avik-pal avik-pal mentioned this pull request Feb 14, 2025
@avik-pal
Copy link
Collaborator Author

forgot to export some names for mac. Will fix in next JLL.

@avik-pal
Copy link
Collaborator Author

An even simpler sharding case that gives incorrect results

# Currently an extremely simple test
using Reactant, Test

const addressable_devices = Reactant.addressable_devices()

mesh = Sharding.Mesh(reshape(collect(Int64, 0:3), (2, 2)), ("data", "model"))

# samples_sharding = Sharding.NamedSharding(mesh, (nothing, "data"))
w1_sharding = Sharding.NamedSharding(mesh, ("model", nothing))
# w2_sharding = Sharding.NamedSharding(mesh, ("data", nothing))

# samples = reshape(collect(Float32, 1:84), 7, 12)
w1 = reshape(collect(Float32, 1:4), 2, 2)
w2 = reshape(collect(Float32, 1:4), 2, 2)

w1_ra = Reactant.to_rarray(w1; sharding=w1_sharding)
w2_ra = Reactant.to_rarray(w2; sharding=w1_sharding)

@code_xla *(w2_ra, w1_ra)

# @jit *(w2_ra, w1_ra)

@avik-pal
Copy link
Collaborator Author

julia> @jit fn_test2(x_ra)
2025-02-14 10:32:46.155582: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:306] Using Shardy for XLA SPMD propagation.
2025-02-14 10:32:46.229661: I external/xla/xla/hlo/utils/hlo_sharding_util.cc:3063] There is no registered layout_canonicalization_callback.
4×4 ConcreteRArray{Float32, 2, 8, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, 8, Tuple{Nothing, Nothing}, 2}, NTuple{8, Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}:
  2.0   8.0  11.0  17.0
  8.0  14.0  17.0  23.0
 11.0  17.0  20.0  26.0
 17.0  23.0  26.0  32.0

julia> fn_test2(x)
4×4 Matrix{Float32}:
  2.0   7.0  12.0  17.0
  7.0  12.0  17.0  22.0
 12.0  17.0  22.0  27.0
 17.0  22.0  27.0  32.0

julia> @code_xla fn_test2(x_ra)
2025-02-14 10:33:00.715410: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:306] Using Shardy for XLA SPMD propagation.
2025-02-14 10:33:00.788563: I external/xla/xla/hlo/utils/hlo_sharding_util.cc:3063] There is no registered layout_canonicalization_callback.
HloModule reactant_fn_test2, is_scheduled=true, entry_computation_layout={(f32[2,1]{1,0})->f32[4,4]{1,0}}, num_partitions=8

%fused_computation (param_0.2: f32[4,4], param_1: f32[4,4]) -> f32[4,4] {
  %param_0.2 = f32[4,4]{1,0} parameter(0)
  %param_1 = f32[4,4]{1,0} parameter(1)
  %add.2 = f32[4,4]{1,0} add(f32[4,4]{1,0} %param_0.2, f32[4,4]{1,0} %param_1), metadata={op_name="add" source_file="/home/avik-pal/reactant/Reactant.jl/src/Ops.jl" source_line=266}
  %transpose.6 = f32[4,4]{0,1} transpose(f32[4,4]{1,0} %add.2), dimensions={1,0}, metadata={op_name="transpose.4"}
  ROOT %copy.4 = f32[4,4]{1,0} copy(f32[4,4]{0,1} %transpose.6), metadata={op_name="transpose.4"}
}

ENTRY %main.0_spmd (param: f32[2,1]) -> f32[4,4] {
  %param = f32[2,1]{1,0} parameter(0), sharding={devices=[2,4]<=[4,2]T(1,0)}, metadata={op_name="Arg_0.1"}
  %bitcast = f32[1,2]{0,1} bitcast(f32[2,1]{1,0} %param), metadata={op_name="Arg_0.1"}
  %bitcast.2 = f32[2,1]{0,1} bitcast(f32[2,1]{1,0} %param), sharding={devices=[2,4]<=[4,2]T(1,0)}, metadata={op_name="Arg_0.1"}
  %all-gather = f32[1,4]{0,1} all-gather(f32[1,2]{0,1} %bitcast), channel_id=1, replica_groups=[4,2]<=[8], dimensions={1}, use_global_device_ids=true, metadata={op_name="add" source_file="/home/avik-pal/reactant/Reactant.jl/src/Ops.jl" source_line=266}
  %all-gather.2 = f32[2,4]{0,1} all-gather(f32[2,1]{0,1} %bitcast.2), channel_id=3, replica_groups=[2,4]<=[4,2]T(1,0), dimensions={1}, use_global_device_ids=true, metadata={op_name="add" source_file="/home/avik-pal/reactant/Reactant.jl/src/Ops.jl" source_line=266}
  %bitcast.1 = f32[1,4]{1,0} bitcast(f32[1,4]{0,1} %all-gather), metadata={op_name="add" source_file="/home/avik-pal/reactant/Reactant.jl/src/Ops.jl" source_line=266}
  %copy.2 = f32[2,4]{1,0} copy(f32[2,4]{0,1} %all-gather.2), metadata={op_name="add" source_file="/home/avik-pal/reactant/Reactant.jl/src/Ops.jl" source_line=266}
  %all-gather.1 = f32[4,4]{1,0} all-gather(f32[1,4]{1,0} %bitcast.1), channel_id=2, replica_groups=[2,4]<=[4,2]T(1,0), dimensions={0}, use_global_device_ids=true, metadata={op_name="add" source_file="/home/avik-pal/reactant/Reactant.jl/src/Ops.jl" source_line=266}
  %all-gather.3 = f32[4,4]{1,0} all-gather(f32[2,4]{1,0} %copy.2), channel_id=4, replica_groups=[4,2]<=[8], dimensions={0}, use_global_device_ids=true, metadata={op_name="add" source_file="/home/avik-pal/reactant/Reactant.jl/src/Ops.jl" source_line=266}
  ROOT %transpose_copy_fusion = f32[4,4]{1,0} fusion(f32[4,4]{1,0} %all-gather.1, f32[4,4]{1,0} %all-gather.3), kind=kLoop, calls=%fused_computation, metadata={op_name="transpose.4"}
}

@avik-pal
Copy link
Collaborator Author

foo(x) = x .+ x'

x = reshape(collect(Float32, 1:4), 2, 2)

x_ra = Reactant.to_rarray(
    x;
    sharding=Sharding.NamedSharding(
        Sharding.Mesh(reshape(collect(Int64, 0:3), (2, 2)), ("data", "model")),
        ("data", nothing),
    ),
)

@code_xla foo(x_ra)

@jit foo(x_ra)

Comment on lines +168 to +172
tmp = Reactant.ConcreteRArray(
ones(sharding_and_shape.shape); sharding=LazySharding(sharding_and_shape.sharding)
)
_, exec, _, _, _ = Reactant.Compiler.compile_xla(internal_simple_op, (tmp,))
return XLA.CondensedOpSharding(only(XLA.get_parameter_shardings(exec)))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the most ideal solution, but is guaranteed to be correct. After GB I will see if there is a nicer way to do this

@avik-pal
Copy link
Collaborator Author

Locally tests pass. We need a new JLL before CI is green

@avik-pal avik-pal marked this pull request as ready for review February 16, 2025 17:42
@avik-pal avik-pal requested a review from wsmoses February 16, 2025 20:49
@wsmoses wsmoses merged commit 20f7a3c into main Feb 16, 2025
35 of 39 checks passed
@wsmoses wsmoses deleted the ap/wider_support_sharding branch February 16, 2025 22:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants