-
Notifications
You must be signed in to change notification settings - Fork 22
fix: multi-device execution and sharding [take III] #713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
dont merge some more fixes needed |
beautiful error on TPUs
|
Just to note. This is only on TPUs, the same code works fine on GPUs |
Remaining changes are only julia side so hopefully no more JLL building needed |
965da33
to
dfb63c0
Compare
This is now ready, we just need the JLL build to go through. Once that is done I will test with TPUs |
very minor but technically if we want to be correct, hlo is the low level post xla IR, and mhlo (aka mlir hlo) and now stablehlo is what we have before re naming of code_hlo and code_mhlo lol we don't need to fix here though |
module @reactant_fn_test1 attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22x\\\22=1, \\\22y\\\22=2]>}"}, mhlo.num_partitions = 2 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<12x4x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22x\\\22}, {}, {\\\22y\\\22}]>"}, mhlo.sharding = "{devices=[1,1,2]<=[2]}"}) -> (tensor<f32>, tensor<12x4x1xf32>, tensor<12x4x16xf32> {mhlo.sharding = "{devices=[1,1,2]<=[2]}"}, tensor<12x4x16xf32>) {
%0 = mhlo.constant dense<1.000000e+00> : tensor<f32>
%1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<16x4x12xf32>
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "mhlo.transpose"(%arg0) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<12x4x16xf32>) -> tensor<16x4x12xf32>
%4 = mhlo.add %3, %3 : tensor<16x4x12xf32>
%5 = mhlo.add %3, %1 : tensor<16x4x12xf32>
%6 = mhlo.multiply %5, %4 : tensor<16x4x12xf32>
%7 = mhlo.reduce(%4 init: %2) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<16x4x12xf32>, tensor<f32>) -> tensor<f32>
%8 = mhlo.reduce(%4 init: %2) applies mhlo.add across dimensions = [0] : (tensor<16x4x12xf32>, tensor<f32>) -> tensor<4x12xf32>
%9 = "mhlo.transpose"(%8) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<4x12xf32>) -> tensor<12x4xf32>
%10 = mhlo.reshape %9 : (tensor<12x4xf32>) -> tensor<12x4x1xf32>
%11 = "mhlo.transpose"(%5) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<16x4x12xf32>) -> tensor<12x4x16xf32>
%12 = "mhlo.transpose"(%6) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<16x4x12xf32>) -> tensor<12x4x16xf32>
%13 = mhlo.custom_call @xla.sdy.FuncResultSharding(%11) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22x\\\22}, {}, {\\\22y\\\22}]>]>"}} : (tensor<12x4x16xf32>) -> tensor<12x4x16xf32>
return %7, %10, %13, %12 : tensor<f32>, tensor<12x4x1xf32>, tensor<12x4x16xf32>, tensor<12x4x16xf32>
}
} We should definitely rename code_hlo to return the actual HLO module at some point. Though I need to check how to get that without dumping it to a file. |
2e1b67a
to
0f6f397
Compare
This is now ready to go! |
@code_mhlo
that prints mlir withmhlo
dialect