Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Feb 28, 2025

Unless I am missing something obvious, XLA shouldn't be modifying input shardings unless we opt in for that:

Case I: Divisible Dimensions

As expected the inputs are of size [2, 1]

julia> @code_hlo fn_test(x_ra2)
module @reactant_fn_test attributes {mhlo.num_partitions = 8 : i64, mhlo.num_replicas = 1 : i64} {
  sdy.mesh @mesh = <["data"=2, "model"=4]>
  func.func @main(%arg0: tensor<8x2xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"model"}, {"data"}]>}) -> tensor<8x1xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.reduce(%arg0 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<8x2xf32>, tensor<f32>) -> tensor<8xf32>
    %1 = stablehlo.reshape %0 : (tensor<8xf32>) -> tensor<8x1xf32>
    return %1 : tensor<8x1xf32>
  }
}

julia> @code_xla fn_test(x_ra2)
2025-02-28 12:12:18.743942: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:304] Using Shardy for XLA SPMD propagation.
2025-02-28 12:12:18.746532: I external/xla/xla/hlo/utils/hlo_sharding_util.cc:3025] There is no registered layout_canonicalization_callback.
HloModule reactant_fn_test, is_scheduled=true, entry_computation_layout={(f32[2,1]{1,0})->f32[2,1]{1,0}}, num_partitions=8

%region_0.0.clone (Arg_0.3: f32[], Arg_1.1: f32[]) -> f32[] {
  %Arg_0.3 = f32[] parameter(0), metadata={op_name="reduce.7"}
  %Arg_1.1 = f32[] parameter(1), metadata={op_name="reduce.7"}
  ROOT %add.1 = f32[] add(f32[] %Arg_0.3, f32[] %Arg_1.1), metadata={op_name="add" source_file="/mnt/software/lux/Reactant.jl/src/Ops.jl" source_line=267}
}

ENTRY %main.11_spmd (param: f32[2,1]) -> f32[2,1] {
  %param = f32[2,1]{1,0} parameter(0), sharding={devices=[4,2]<=[2,4]T(1,0)}, metadata={op_name="Arg_0.1"}
  %bitcast = f32[2]{0} bitcast(f32[2,1]{1,0} %param), metadata={op_name="Arg_0.1"}
  %all-reduce = f32[2]{0} all-reduce(f32[2]{0} %bitcast), channel_id=1, replica_groups=[4,2]<=[2,4]T(1,0), use_global_device_ids=true, to_apply=%region_0.0.clone, metadata={op_name="reduce.7"}
  ROOT %bitcast.1 = f32[2,1]{1,0} bitcast(f32[2]{0} %all-reduce), metadata={op_name="reduce.7"}
}

Case II: Non-Divisible Dimensions

Expected input dims should be [2, 1] (with 2 replicas requiring padded inputs)

julia> @code_hlo fn_test(x_ra)
module @reactant_fn_test attributes {mhlo.num_partitions = 8 : i64, mhlo.num_replicas = 1 : i64} {
  sdy.mesh @mesh = <["data"=2, "model"=4]>
  func.func @main(%arg0: tensor<7x2xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"model"}, {"data"}]>}) -> tensor<7x1xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.reduce(%arg0 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<7x2xf32>, tensor<f32>) -> tensor<7xf32>
    %1 = stablehlo.reshape %0 : (tensor<7xf32>) -> tensor<7x1xf32>
    return %1 : tensor<7x1xf32>
  }
}

julia> @code_xla fn_test(x_ra)
2025-02-28 12:12:30.796850: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:304] Using Shardy for XLA SPMD propagation.
2025-02-28 12:12:30.799581: I external/xla/xla/hlo/utils/hlo_sharding_util.cc:3025] There is no registered layout_canonicalization_callback.
HloModule reactant_fn_test, is_scheduled=true, entry_computation_layout={(f32[7,1]{1,0})->f32[7,1]{1,0}}, num_partitions=8

%region_0.0.clone (Arg_0.3: f32[], Arg_1.1: f32[]) -> f32[] {
  %Arg_0.3 = f32[] parameter(0), metadata={op_name="reduce.7"}
  %Arg_1.1 = f32[] parameter(1), metadata={op_name="reduce.7"}
  ROOT %add.1 = f32[] add(f32[] %Arg_0.3, f32[] %Arg_1.1), metadata={op_name="add" source_file="/mnt/software/lux/Reactant.jl/src/Ops.jl" source_line=267}
}

%fused_computation (param_0.2: u32[], param_1.6: f32[7]) -> f32[2,1] {
  %param_1.6 = f32[7]{0} parameter(1)
  %constant.16 = f32[] constant(0)
  %pad.1 = f32[8]{0} pad(f32[7]{0} %param_1.6, f32[] %constant.16), padding=0_1, metadata={op_name="reduce.7"}
  %constant.15 = u32[1]{0} constant({0}), metadata={op_name="reduce.7"}
  %constant.14 = u32[8]{0} constant({0, 1, 2, 3, 0, 1, 2, 3}), metadata={op_name="reduce.7"}
  %param_0.2 = u32[] parameter(0)
  %dynamic-slice.4 = u32[1]{0} dynamic-slice(u32[8]{0} %constant.14, u32[] %param_0.2), dynamic_slice_sizes={1}, metadata={op_name="reduce.7"}
  %constant.13 = u32[1]{0} constant({3}), metadata={op_name="reduce.7"}
  %clamp.2 = u32[1]{0} clamp(u32[1]{0} %constant.15, u32[1]{0} %dynamic-slice.4, u32[1]{0} %constant.13), metadata={op_name="reduce.7"}
  %convert.2 = s32[1]{0} convert(u32[1]{0} %clamp.2), metadata={op_name="reduce.7"}
  %constant.12 = s32[1]{0} constant({2}), metadata={op_name="reduce.7"}
  %multiply.3 = s32[1]{0} multiply(s32[1]{0} %convert.2, s32[1]{0} %constant.12), metadata={op_name="reduce.7"}
  %bitcast.4 = s32[] bitcast(s32[1]{0} %multiply.3), metadata={op_name="reduce.7"}
  %dynamic-slice.3 = f32[2]{0} dynamic-slice(f32[8]{0} %pad.1, s32[] %bitcast.4), dynamic_slice_sizes={2}, metadata={op_name="reduce.7"}
  ROOT %bitcast.3 = f32[2,1]{1,0} bitcast(f32[2]{0} %dynamic-slice.3), metadata={op_name="reduce.7"}
}

ENTRY %main.11_spmd (param: f32[7,1]) -> f32[7,1] {
  %partition-id = u32[] partition-id()
  %param = f32[7,1]{1,0} parameter(0), sharding={devices=[1,2,4]<=[8] last_tile_dim_replicate}, metadata={op_name="Arg_0.1"}
  %bitcast = f32[7]{0} bitcast(f32[7,1]{1,0} %param), metadata={op_name="Arg_0.1"}
  %all-reduce = f32[7]{0} all-reduce(f32[7]{0} %bitcast), channel_id=1, replica_groups=[4,2]<=[2,4]T(1,0), use_global_device_ids=true, to_apply=%region_0.0.clone, metadata={op_name="reduce.7"}
  %dynamic-slice_bitcast_fusion = f32[2,1]{1,0} fusion(u32[] %partition-id, f32[7]{0} %all-reduce), kind=kLoop, calls=%fused_computation, metadata={op_name="reduce.7"}
  %all-gather = f32[8,1]{1,0} all-gather(f32[2,1]{1,0} %dynamic-slice_bitcast_fusion), channel_id=2, replica_groups=[2,4]<=[8], dimensions={0}, use_global_device_ids=true
  ROOT %slice = f32[7,1]{1,0} slice(f32[8,1]{1,0} %all-gather), slice={[0:7], [0:1]}
}

@avik-pal avik-pal changed the title feat: sharding with non-divisible axes [alternatve approach] feat: sharding with non-divisible axes [alternate approach] Feb 28, 2025
@avik-pal avik-pal force-pushed the ap/implicit_padding branch from 3e10c82 to 11a38ed Compare February 28, 2025 17:21
@avik-pal

This comment was marked as outdated.

@avik-pal

This comment was marked as outdated.

@avik-pal
Copy link
Collaborator Author

https://github.com/openxla/xla/blob/a83a2d3f1f977e4825cb210320597c5825c25ead/xla/client/executable_build_options.h#L327-L330 are false by default. so this is a bit unexpected.

module @reactant_fn_test attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22data\\\22=2, \\\22model\\\22=4]>}"}, mhlo.num_partitions = 8 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<7x2xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22model\\\22}, {\\\22data\\\22}]>"}, mhlo.sharding = "{devices=[4,2]<=[2,4]T(1,0)}"}) -> tensor<7x1xf32> {
    %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = mhlo.reduce(%arg0 init: %0) applies mhlo.add across dimensions = [1] : (tensor<7x2xf32>, tensor<f32>) -> tensor<7xf32>
    %2 = mhlo.reshape %1 : (tensor<7xf32>) -> tensor<7x1xf32>
    return %2 : tensor<7x1xf32>
  }
}

MHLO sharding: {devices=[4,2]<=[2,4]T(1,0)}
But after XLA passes: sharding={devices=[1,2,4]<=[8] last_tile_dim_replicate}

@avik-pal avik-pal changed the title feat: sharding with non-divisible axes [alternate approach] feat: sharding with non-divisible dimensions [alternate approach] Feb 28, 2025
@avik-pal avik-pal requested a review from wsmoses February 28, 2025 22:15
@avik-pal avik-pal force-pushed the ap/implicit_padding branch from 81f1477 to e70f4c2 Compare March 1, 2025 16:04
@avik-pal
Copy link
Collaborator Author

avik-pal commented Mar 4, 2025

@wsmoses is this good to go from your end?

@avik-pal avik-pal merged commit d1be533 into main Mar 4, 2025
37 of 39 checks passed
@avik-pal avik-pal deleted the ap/implicit_padding branch March 4, 2025 02:38
avik-pal added a commit that referenced this pull request Mar 4, 2025
* feat: support implicit padding from XLA

* feat: use XLA for shard-info if we need padding

* test: padding for sharding

* fix: return type
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants