Skip to content

fix: multi-device execution and sharding [take III] #713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 11, 2025
Merged

fix: multi-device execution and sharding [take III] #713

merged 17 commits into from
Feb 11, 2025

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Feb 8, 2025

  • fixes for MultiGPU runs
  • Tests
  • API to load output shardings from XLA
  • JLL Changes
  • Adds @code_mhlo that prints mlir with mhlo dialect
  • Remove forced replication of outputs
    • remove linear results sharding info
  • use OpSharding to regenerate the outputs
  • ConcreteRNumber can now be replicated
  • JLL bump

@avik-pal
Copy link
Collaborator Author

avik-pal commented Feb 8, 2025

dont merge some more fixes needed

@avik-pal
Copy link
Collaborator Author

avik-pal commented Feb 8, 2025

beautiful error on TPUs

ERROR: INTERNAL: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc:5245) !HasReplicatedSharding(hlo->sharding()) || CanSideEffectingHaveReplicatedSharding(hlo) side-effect HLO cannot have a replicated sharding: %cust
om-call.8 = f32[12,2]{1,0} custom-call(f32[12,2]{1,0} %sine.1), custom_call_target="xla.sdy.FuncResultSharding", custom_call_has_side_effect=true, sharding={replicated}, frontend_attributes={xla.sdy.sharding="#sdy.sharding_per_value<[<@mesh, [{}, {}], re
plicated={\"data\", \"model\"}>]>"}, metadata={op_name="custom-call.9"}

@avik-pal avik-pal changed the title fix: use different API fix: mutli-device execution and sharding [take III] Feb 8, 2025
@avik-pal avik-pal marked this pull request as draft February 8, 2025 22:15
@avik-pal avik-pal changed the title fix: mutli-device execution and sharding [take III] fix: multi-device execution and sharding [take III] Feb 8, 2025
@avik-pal avik-pal linked an issue Feb 9, 2025 that may be closed by this pull request
@avik-pal
Copy link
Collaborator Author

avik-pal commented Feb 9, 2025

beautiful error on TPUs

ERROR: INTERNAL: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc:5245) !HasReplicatedSharding(hlo->sharding()) || CanSideEffectingHaveReplicatedSharding(hlo) side-effect HLO cannot have a replicated sharding: %cust
om-call.8 = f32[12,2]{1,0} custom-call(f32[12,2]{1,0} %sine.1), custom_call_target="xla.sdy.FuncResultSharding", custom_call_has_side_effect=true, sharding={replicated}, frontend_attributes={xla.sdy.sharding="#sdy.sharding_per_value<[<@mesh, [{}, {}], re
plicated={\"data\", \"model\"}>]>"}, metadata={op_name="custom-call.9"}

Just to note. This is only on TPUs, the same code works fine on GPUs

@avik-pal avik-pal changed the base branch from main to ap/jll February 10, 2025 03:51
@avik-pal
Copy link
Collaborator Author

Remaining changes are only julia side so hopefully no more JLL building needed

Base automatically changed from ap/jll to main February 10, 2025 04:06
@avik-pal avik-pal force-pushed the ap/fixes branch 2 times, most recently from 965da33 to dfb63c0 Compare February 10, 2025 16:54
@avik-pal avik-pal marked this pull request as ready for review February 11, 2025 02:29
@avik-pal
Copy link
Collaborator Author

This is now ready, we just need the JLL build to go through. Once that is done I will test with TPUs

@wsmoses
Copy link
Member

wsmoses commented Feb 11, 2025

very minor but technically if we want to be correct, hlo is the low level post xla IR, and mhlo (aka mlir hlo) and now stablehlo is what we have before

re naming of code_hlo and code_mhlo lol

we don't need to fix here though

@avik-pal
Copy link
Collaborator Author

avik-pal commented Feb 11, 2025

very minor but technically if we want to be correct, hlo is the low level post xla IR, and mhlo (aka mlir hlo) and now stablehlo is what we have before

@code_mhlo actually still prints the mhlo not the xla IR. But it is useful since all the shardy ops are expanded into custom calls.

module @reactant_fn_test1 attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22x\\\22=1, \\\22y\\\22=2]>}"}, mhlo.num_partitions = 2 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<12x4x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22x\\\22}, {}, {\\\22y\\\22}]>"}, mhlo.sharding = "{devices=[1,1,2]<=[2]}"}) -> (tensor<f32>, tensor<12x4x1xf32>, tensor<12x4x16xf32> {mhlo.sharding = "{devices=[1,1,2]<=[2]}"}, tensor<12x4x16xf32>) {
    %0 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<16x4x12xf32>
    %2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %3 = "mhlo.transpose"(%arg0) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<12x4x16xf32>) -> tensor<16x4x12xf32>
    %4 = mhlo.add %3, %3 : tensor<16x4x12xf32>
    %5 = mhlo.add %3, %1 : tensor<16x4x12xf32>
    %6 = mhlo.multiply %5, %4 : tensor<16x4x12xf32>
    %7 = mhlo.reduce(%4 init: %2) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<16x4x12xf32>, tensor<f32>) -> tensor<f32>
    %8 = mhlo.reduce(%4 init: %2) applies mhlo.add across dimensions = [0] : (tensor<16x4x12xf32>, tensor<f32>) -> tensor<4x12xf32>
    %9 = "mhlo.transpose"(%8) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<4x12xf32>) -> tensor<12x4xf32>
    %10 = mhlo.reshape %9 : (tensor<12x4xf32>) -> tensor<12x4x1xf32>
    %11 = "mhlo.transpose"(%5) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<16x4x12xf32>) -> tensor<12x4x16xf32>
    %12 = "mhlo.transpose"(%6) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<16x4x12xf32>) -> tensor<12x4x16xf32>
    %13 = mhlo.custom_call @xla.sdy.FuncResultSharding(%11) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22x\\\22}, {}, {\\\22y\\\22}]>]>"}} : (tensor<12x4x16xf32>) -> tensor<12x4x16xf32>
    return %7, %10, %13, %12 : tensor<f32>, tensor<12x4x1xf32>, tensor<12x4x16xf32>, tensor<12x4x16xf32>
  }
}

We should definitely rename code_hlo to return the actual HLO module at some point. Though I need to check how to get that without dumping it to a file.

@avik-pal
Copy link
Collaborator Author

This is now ready to go!

@wsmoses wsmoses merged commit 90b0d1d into main Feb 11, 2025
36 of 39 checks passed
@wsmoses wsmoses deleted the ap/fixes branch February 11, 2025 20:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

shardy functions not visible on macos
2 participants