diff --git a/xla/hlo/translate/stablehlo.cc b/xla/hlo/translate/stablehlo.cc index 2847109557731..5cc7849913b54 100644 --- a/xla/hlo/translate/stablehlo.cc +++ b/xla/hlo/translate/stablehlo.cc @@ -57,6 +57,36 @@ absl::Status MhloToStablehlo(mlir::ModuleOp module) { } return absl::OkStatus(); } + +absl::Status StablehloToMhlo(mlir::ModuleOp module, bool run_canonicalizer) { + mlir::MLIRContext* context = module->getContext(); + mlir::BaseScopedDiagnosticHandler diagnostic_handler(context); + { + mlir::PassManager pm(context); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + pm.addNestedPass( + mlir::mhlo::createChloLegalizeToHloPass()); + if (run_canonicalizer) { + pm.addNestedPass(mlir::createCanonicalizerPass()); + } + // In order to export to XLA, we must sink constants to control flow + // regions, since XLA uses functional control flow. + pm.addNestedPass( + mlir::mhlo::createSinkConstantsToControlFlowPass()); + if (failed(pm.run(module))) { + VLOG(1) << "MHLO->HLO lowering passes failed."; + module->dump(); + return diagnostic_handler.ConsumeStatus(); + } + + VLOG(5) << "MHLO module after lowering, before HLO import "; + if (VLOG_IS_ON(5)) { + module->dump(); + } + } + return absl::OkStatus(); +} + } // namespace void RegisterMlirToHloDependentDialects(mlir::DialectRegistry& registry) { @@ -113,29 +143,7 @@ absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module, xla::HloProto* hlo_proto) { if (!module) return absl::InvalidArgumentError("Module is null"); - mlir::MLIRContext* context = module->getContext(); - mlir::BaseScopedDiagnosticHandler diagnostic_handler(context); - { - mlir::PassManager pm(context); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); - pm.addNestedPass(mlir::createCanonicalizerPass()); - // In order to export to XLA, we must sink constants to control flow - // regions, since XLA uses functional control flow. - pm.addNestedPass( - mlir::mhlo::createSinkConstantsToControlFlowPass()); - if (failed(pm.run(module))) { - VLOG(1) << "MHLO->HLO lowering passes failed."; - module->dump(); - return diagnostic_handler.ConsumeStatus(); - } - - VLOG(5) << "MHLO module after lowering, before HLO import "; - if (VLOG_IS_ON(5)) { - module->dump(); - } - } + TF_RETURN_IF_ERROR(StablehloToMhlo(module, /*run_canonicalizer=*/true)); mlir::MlirToHloConversionOptions options; options.return_tuple = false; @@ -144,4 +152,22 @@ absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module, return absl::OkStatus(); } +absl::Status ConvertStablehloWithManyArgsToHloProto(mlir::ModuleOp module, + xla::HloProto* hlo_proto, + bool use_tuple_args) { + if (!module) return absl::InvalidArgumentError("Module is null"); + + TF_RETURN_IF_ERROR(StablehloToMhlo(module, /*run_canonicalizer=*/false)); + + mlir::MlirToHloConversionOptions options; + options.return_tuple = false; + options.use_tuple_args = use_tuple_args; + // Remove attributes introduced by `import_all_computation=true` at + // ConvertHloToStablehlo. + module->removeAttr("mhlo.xla_entry_computation_parameter_layouts"); + module->removeAttr("mhlo.xla_entry_computation_parameter_tiles"); + TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module, hlo_proto, options)); + return absl::OkStatus(); +} + } // namespace xla diff --git a/xla/hlo/translate/stablehlo.h b/xla/hlo/translate/stablehlo.h index 933d0c895dd53..f7c290c7aa046 100644 --- a/xla/hlo/translate/stablehlo.h +++ b/xla/hlo/translate/stablehlo.h @@ -48,6 +48,15 @@ absl::StatusOr> ConvertStablehloToHlo( absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module, xla::HloProto* hlo_proto); +// Convert StableHLO module to HloModule. +// Some platforms run out of memory when the argument list is too long. +// This API wraps the arguments in a tuple (if use_tuple_args = true) +// as a workaround. The long-term solution is to add an HLO pass to do this. +// In general, prefer the other ConvertStablehloToHloProto method. +absl::Status ConvertStablehloWithManyArgsToHloProto( + mlir::ModuleOp module, xla::HloProto* hlo_proto, + bool use_tuple_args = false); + } // namespace xla #endif // XLA_HLO_TRANSLATE_STABLEHLO_H_ diff --git a/xla/service/spmd/shardy/BUILD b/xla/service/spmd/shardy/BUILD index d9d3365c9e573..7b74f5c0a78c9 100644 --- a/xla/service/spmd/shardy/BUILD +++ b/xla/service/spmd/shardy/BUILD @@ -37,6 +37,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms:hlo_dce", "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/translate:stablehlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/hlo/utils:hlo_sharding_util", @@ -105,7 +106,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", + "//xla/tests:remote_mtest_lib_buffer_donation_test_gpu_b100", "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:statusor", diff --git a/xla/service/spmd/shardy/mhlo_round_trip/BUILD b/xla/service/spmd/shardy/mhlo_round_trip/BUILD index 1d3f0709bc166..fb62d32e756f1 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/BUILD +++ b/xla/service/spmd/shardy/mhlo_round_trip/BUILD @@ -91,7 +91,6 @@ cc_library( ":export_ops", ":export_shardings", ":shard_map_export", - "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy/round_trip_common:export_named_computations", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", diff --git a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc index 67f79119ebda6..36aee9a64f266 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" -#include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h" @@ -37,7 +36,6 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) { pm.addPass(createMhloRoundTripShardMapExportPass()); pm.addPass(createExportNamedComputationsPass()); pm.addPass(createExportMhloShardingsPass()); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); } void registerMhloExportPipeline() { diff --git a/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc index bf5c545dfa70b..c4d7a13a55bb9 100644 --- a/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc +++ b/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc @@ -31,6 +31,7 @@ using ::mlir::func::FuncOp; void addCommonPreImportPasses(mlir::OpPassManager& pm) { pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); // TODO(b/333505182): remove when partitioning is done in SDY. // We call prepare-for-export pass before SDY propagation, so that all IR // changes happen before shardings are added to operations, to ensure the diff --git a/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD b/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD index 207205bd537a1..5ed5e736198af 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD +++ b/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD @@ -20,26 +20,21 @@ cc_library( srcs = ["mhlo_to_hlo_to_mhlo.cc"], hdrs = ["mhlo_to_hlo_to_mhlo.h"], deps = [ - "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/translate:stablehlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", - "@shardy//shardy/dialect/sdy/ir:dialect", - "@stablehlo//:stablehlo_ops", - "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -49,7 +44,6 @@ cc_library( hdrs = ["testing_pipeline.h"], deps = [ ":mhlo_to_hlo_to_mhlo", - "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", "@llvm-project//mlir:Pass", ], diff --git a/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc b/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc index da7bda8f60e3b..d8da24860f80d 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc @@ -18,32 +18,30 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Quant/IR/Quant.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/TypeID.h" -#include "shardy/dialect/sdy/ir/dialect.h" -#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/hlo/translate/stablehlo.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/shape.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace sdy { @@ -55,33 +53,22 @@ using ::mlir::StringRef; // Converts an MHLO module to an HLO module. absl::StatusOr> toHlo(ModuleOp module) { - absl::StatusOr> hloModule; - xla::HloProto hloProto; - TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(module, &hloProto, - /*use_tuple_args=*/false, - /*return_tuple=*/false)); - xla::HloModuleConfig moduleConfig; - xla::ProgramShape expectedProgramShape( - hloProto.hlo_module().host_program_shape()); - moduleConfig.SetDefaultComputationLayout(expectedProgramShape); + TF_ASSIGN_OR_RETURN(std::unique_ptr hloModule, + xla::ConvertStablehloToHlo(module)); + xla::HloModuleConfig& moduleConfig = hloModule->mutable_config(); moduleConfig.set_use_spmd_partitioning(true); - return xla::HloModule::CreateFromProto(hloProto.hlo_module(), moduleConfig); + return hloModule; } -// Converts an HLO module to an MHLO module. -absl::Status toMhlo(std::unique_ptr hloModule, ModuleOp module) { - // Delete the functions, which can be more than one due to preserving - // the shmap_body functions. - mlir::SymbolTableCollection symbolTableCollection; - mlir::SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(module); - for (mlir::Operation& op : - llvm::make_early_inc_range(module.getBodyRegion().getOps())) { - symbolTable.erase(&op); - } - TF_RETURN_IF_ERROR( - xla::ConvertHloToMlirHlo(module, hloModule.get(), - /*import_all_computations=*/false, - /*flatten_computation_args_result=*/true)); +// Converts an HLO module to a StableHLO module. +absl::Status toStablehlo(std::unique_ptr hloModule, + ModuleOp& module) { + TF_ASSIGN_OR_RETURN( + mlir::OwningOpRef newModule, + xla::ConvertHloToStablehlo(*module->getContext(), hloModule.get())); + mlir::IRMapping mapping; + module.getBodyRegion().getBlocks().front().erase(); + newModule.get().getBodyRegion().cloneInto(&module.getBodyRegion(), mapping); return absl::OkStatus(); } @@ -103,7 +90,7 @@ class SdyRoundTripMhloToHloToMhloPass } // 2. HLO -> MHLO - if (absl::Status status = toMhlo(std::move(*hloModule), module); + if (absl::Status status = toStablehlo(std::move(*hloModule), module); !status.ok()) { module.emitError(absl::StrCat("Failed to convert to MHLO from HLO: ", status.message())); @@ -120,9 +107,7 @@ class SdyRoundTripMhloToHloToMhloPass } void getDependentDialects(mlir::DialectRegistry& registry) const final { - registry.insert(); + xla::RegisterMlirToHloDependentDialects(registry); } }; diff --git a/xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.cc b/xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.cc index 984186cb626c2..b4e25bafa8c87 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.cc @@ -17,7 +17,6 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" -#include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" #include "xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.h" @@ -31,7 +30,6 @@ void registerSdyRoundTripTestingPipeline() { "MHLO, then import back to Shardy", [](mlir::OpPassManager& pm) { addSdyRoundTripExportPipeline(pm); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pm.addPass(createSdyRoundTripMhloToHloToMhloPass()); addSdyRoundTripImportPipeline(pm); }); diff --git a/xla/service/spmd/shardy/shardy_xla_pass.cc b/xla/service/spmd/shardy/shardy_xla_pass.cc index d7b85bccb6074..7fded37aff136 100644 --- a/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -51,6 +51,7 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/stablehlo.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/layout.h" #include "xla/map_util.h" @@ -307,17 +308,14 @@ absl::StatusOr ShardyXLA::Run( const absl::flat_hash_set& executionThreads) { LOG(INFO) << "Using Shardy for XLA SPMD propagation."; - // HLO -> MLIR MHLO + // HLO -> StableHLO auto mlirContext = std::make_unique(); loadAllRequiredDialects(mlirContext.get()); - mlir::OwningOpRef mlirModule = - xla::llvm_ir::CreateMlirModuleOp( - mlir::UnknownLoc::get(mlirContext.get())); - TF_RETURN_IF_ERROR( - ConvertHloToMlirHlo(*mlirModule, hloModule, - /*import_all_computations=*/false, - /*flatten_computation_args_result=*/true)); - + TF_ASSIGN_OR_RETURN( + mlir::OwningOpRef mlirModule, + xla::ConvertHloToStablehlo(*mlirContext.get(), hloModule)); + LOG(INFO) << "StableHLO\n"; + mlirModule->dump(); std::string shardyDir = hloModule->config().debug_options().xla_dump_to(); if (shardyDir == "sponge") { @@ -403,10 +401,10 @@ absl::StatusOr ShardyXLA::Run( tsl::StatusScopedDiagnosticHandler diagnosticHandler(mlirContext.get()); TF_RETURN_IF_ERROR(diagnosticHandler.consumeStatus(pm.run(*mlirModule))); - // MLIR MHLO -> HLO + // StableHlo -> HLO HloProto hloProto; - TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(*mlirModule, &hloProto, useTupleArgs, - /*return_tuple=*/false)); + TF_RETURN_IF_ERROR(ConvertStablehloWithManyArgsToHloProto( + *mlirModule, &hloProto, useTupleArgs)); TF_RETURN_IF_ERROR( createFromProtoAndReplaceComputations(hloModule, hloProto.hlo_module())); diff --git a/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/xla/service/spmd/shardy/shardy_xla_pass_test.cc index 6cb846048cff7..f0a8ef7906b48 100644 --- a/xla/service/spmd/shardy/shardy_xla_pass_test.cc +++ b/xla/service/spmd/shardy/shardy_xla_pass_test.cc @@ -560,9 +560,8 @@ TEST_F(ShardyXLATest, WhileWithFreeVariables) { op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}")); // Verify the sharding of the while, and specifically that the sharding of the // result that corresponds to parameter(1) is further sharded. - EXPECT_THAT(whileInst, - op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, {replicated}, " - "{devices=[2,2]<=[4]}, {replicated}}")); + EXPECT_THAT(whileInst, op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, " + "{devices=[2,2]<=[4]}}")); } TEST_F(ShardyXLATest, ShardMap) { diff --git a/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir b/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir index 81348fb671610..d327cd439f07b 100644 --- a/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir +++ b/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir @@ -35,7 +35,7 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.shardi %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_1"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_0", "axis_1"}, {"axis_2"}]>}) { -// CHECK-NEXT: mhlo.add +// CHECK-NEXT: stablehlo.add // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[8,1,4]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"} %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> @@ -55,7 +55,7 @@ func.func @single_axis(%arg0: tensor<32x8xf32> {sdy.sharding = #sdy.sharding<@me // CHECK-LABEL: func @multi_result_op func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { %0 = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: mhlo.reduce +// CHECK: stablehlo.reduce // CHECK-SAME{LITERAL}: {mhlo.sharding = "{{devices=[1,4,8]<=[8,4]T(1,0) last_tile_dim_replicate}, {devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}}"} %1:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{}, {"y"}]>, <@mesh_2, [{"y"}, {}]>]>} : @@ -87,7 +87,7 @@ func.func @fully_replicated(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding // CHECK-SAME: -> tensor<8x16xf32> { func.func @split_axes(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {"x":(2)2}]>}, %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x":(1)2}, {"x":(2)4}]>}) -> tensor<8x16xf32> { -// CHECK-NEXT: mhlo.dot +// CHECK-NEXT: stablehlo.dot // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[4,1,8]<=[2,2,2,4]T(0,2,1,3) last_tile_dim_replicate}"} %1 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> @@ -95,8 +95,8 @@ func.func @split_axes(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh // CHECK-LABEL: func @split_constants func.func @split_constants() -> (tensor<8x8xf32>, tensor<8x8xf32>) { - // CHECK-NEXT: %[[CONST_0:.*]] = mhlo.constant {mhlo.sharding = "{devices=[8,1,4]<=[32] last_tile_dim_replicate}"} dense<1.000000e+00> - // CHECK-NEXT: %[[CONST_1:.*]] = mhlo.constant {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} dense<1.000000e+00> + // CHECK-NEXT: %[[CONST_0:.*]] = stablehlo.constant {mhlo.sharding = "{devices=[8,1,4]<=[32] last_tile_dim_replicate}"} dense<1.000000e+00> + // CHECK-NEXT: %[[CONST_1:.*]] = stablehlo.constant {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} dense<1.000000e+00> // CHECK-NEXT: return %[[CONST_0]], %[[CONST_1]] %0 = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x"}, {}]>]>} dense<1.000000e+00> : tensor<8x8xf32> %1 = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"y"}, {}]>]>} dense<1.000000e+00> : tensor<8x8xf32> @@ -130,15 +130,15 @@ func.func @reshard_fully_open_partially_open(%arg0: tensor<8x8xf32>) -> tensor<8 // CHECK-SAME: -> (tensor<8x32xf32> {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"}) { func.func @sharding_in_manual_computation_body(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_3, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_3, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_3, [{"a"}, {}]>}) { // CHECK-NEXT: %[[COPY_0:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,2,4]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x16xf32> -// CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_0]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> +// CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_0]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> // CHECK-NEXT: %[[COPY_1:.*]] = mhlo.copy %arg1 {mhlo.sharding = "{devices=[2,1,8]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<16x32xf32> -// CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_1]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> +// CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_1]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: %[[RESHARD:.*]] = mhlo.copy %[[FULL_TO_SHARD_0]] {mhlo.sharding = "{devices=[1,2,4,2]<=[8,2]T(1,0) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> -// CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[RESHARD]], %[[RESHARD]] {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> -// CHECK-NEXT: %[[DOT:.*]] = "mhlo.dot"(%[[ADD]], %[[FULL_TO_SHARD_1]]) {mhlo.sharding = "{devices=[2,2,4]<=[4,4]T(1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> -// CHECK-NEXT: %[[SINE:.*]] = mhlo.sine %[[DOT]] {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> +// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[RESHARD]], %[[RESHARD]] {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> +// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[ADD]], %[[FULL_TO_SHARD_1]] {mhlo.sharding = "{devices=[2,2,4]<=[4,4]T(1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> +// CHECK-NEXT: %[[SINE:.*]] = stablehlo.sine %[[DOT]] {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> // CHECK-NEXT: %[[COPY_2:.*]] = mhlo.copy %[[SINE]] {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> -// CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_2]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> +// CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_2]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_3, [{"b"}, {"a"}]>, <@mesh_3, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_3, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<4x8xf32>, %arg3: tensor<8x32xf32>) { %1 = sdy.reshard %arg2 <@mesh_3, [{}, {"d"}]> : tensor<4x8xf32> @@ -152,18 +152,18 @@ func.func @sharding_in_manual_computation_body(%arg0: tensor<8x16xf32> {sdy.shar // CHECK-LABEL: func @mesh_with_device_id_should_be_converted_to_maximal_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}, %arg1: tensor<8x8xf32>) func.func @mesh_with_device_id_should_be_converted_to_maximal_sharding(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, []>}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: %[[ADD:.*]] = mhlo.add %arg0, %arg1 + // CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> - // CHECK: %[[ADD_WITH_SHARDING:.*]] = mhlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{maximal device=1}"} + // CHECK: %[[ADD_WITH_SHARDING:.*]] = stablehlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{maximal device=1}"} %1 = stablehlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_1, []>]>} : tensor<8x8xf32> return %1 : tensor<8x8xf32> } // CHECK-LABEL: func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<8x8xf32>) func.func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@empty_mesh_0, [{}, {}]>}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: %[[ADD:.*]] = mhlo.add %arg0, %arg1 + // CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> - // CHECK: %[[ADD_WITH_SHARDING:.*]] = mhlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{replicated}"} + // CHECK: %[[ADD_WITH_SHARDING:.*]] = stablehlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{replicated}"} %1 = stablehlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_1, [{}, {}]>]>} : tensor<8x8xf32> return %1 : tensor<8x8xf32> } @@ -176,7 +176,7 @@ func.func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8 func.func @multiple_shardings_with_device_list(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { - // CHECK-NEXT: mhlo.add + // CHECK-NEXT: stablehlo.add // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[4,1,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate}"} %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> @@ -190,10 +190,10 @@ func.func @named_sharding_in_manual_computation( %arg0: tensor<32x2xi32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", "y"}, {}]>}) -> (tensor<32x2xi32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", "y"}, {}]>}) { // CHECK-NEXT: %[[COPY_0:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[32,1]<=[32]}"} : tensor<32x2xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : (tensor<32x2xi32>) -> tensor<4x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : (tensor<32x2xi32>) -> tensor<4x2xi32> // CHECK-NEXT: %[[FOO:.*]] = call @foo(%[[FULL_TO_SHARD]]) {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} : (tensor<4x2xi32>) -> tensor<4x2xi32> // CHECK-NEXT: %[[COPY_1:.*]] = mhlo.copy %[[FOO]] {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : tensor<4x2xi32> - // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_1]]) {mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<4x2xi32>) -> tensor<32x2xi32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_1]]) {mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<4x2xi32>) -> tensor<32x2xi32> // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<32x2xi32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_2, [{"x", "y"}, {}]>] out_shardings=[<@mesh_2, [{"x", "y"}, {}]>] manual_axes={"x"} (%arg1: tensor<4x2xi32>) { %1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh_2, [{"y"}, {}]>] out_shardings=[<@mesh_2, [{"y"}, {}]>] (%arg2: tensor<4x2xi32>) { @@ -210,11 +210,11 @@ func.func @free_axis_inside_in_out_shardings_manual_computation( %arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}, {}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i", ?}, {?}]>}) { // CHECK-NEXT: %[[COPY_OPERAND:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"} : tensor<4x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND]]) {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : (tensor<4x8xf32>) -> tensor<4x8xf32> - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %[[FULL_TO_SHARD]], %[[FULL_TO_SHARD]] {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : tensor<4x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND]]) {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : (tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD]], %[[FULL_TO_SHARD]] {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : tensor<4x8xf32> // CHECK-NEXT: %[[COPY:.*]] = mhlo.copy %[[MULT]] {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : tensor<4x8xf32> // CHECK-NEXT: %[[COPY_RESULT:.*]] = mhlo.copy %[[COPY]] {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : tensor<4x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT]]) {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"} : (tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT]]) {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"} : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_5, [{"i", ?}, {?}], replicated={"j"}>] @@ -231,11 +231,9 @@ func.func @free_axis_inside_in_out_shardings_manual_computation( func.func @custom_call_erf_topk( %arg0: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}, {}]>} ) -> (tensor<16x2xf32> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i", ?}, {?}]>}) { - // CHECK-NEXT: %[[ERF:.*]] = mhlo.erf %arg0 {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}", mhlo.version = 1 : i64} : tensor<16x8xf32> - // CHECK-NEXT: %[[VALUES:.*]], %[[IDX:.*]] = mhlo.topk(%[[ERF]], k = 2) { - // CHECK-SAME{LITERAL}: mhlo.sharding = "{{devices=[2,1,2]<=[4] last_tile_dim_replicate}, {devices=[2,1,2]<=[4] last_tile_dim_replicate}}" - // CHECK-SAME: } : tensor<16x8xf32> -> (tensor<16x2xf32>, tensor<16x2xi32>) - // CHECK-NEXT: return %[[VALUES]] : tensor<16x2xf32> + // CHECK-NEXT: %[[ERF:.*]] = stablehlo.custom_call @mhlo.erf(%arg0) {mhlo.attributes = {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}", mhlo.version = 1 : i64}} : (tensor<16x8xf32>) -> tensor<16x8xf32> + // CHECK-NEXT: stablehlo.custom_call @mhlo.topk(%[[ERF]]) + // CHECK-SAME{LITERAL}: {mhlo.attributes = {k = 2 : i64, largest = true, mhlo.sharding = "{{devices=[2,1,2]<=[4] last_tile_dim_replicate}, {devices=[2,1,2]<=[4] last_tile_dim_replicate}}"}, mhlo.version = 1 : i64} : (tensor<16x8xf32>) -> (tensor<16x2xf32>, tensor<16x2xi32>) %0 = stablehlo.custom_call @mhlo.erf(%arg0) { mhlo.attributes = {mhlo.version = 1 : i64}, sdy.sharding = #sdy.sharding_per_value<[<@mesh_5, [{"i", ?}, {?}]>]> @@ -251,5 +249,5 @@ func.func @custom_call_erf_topk( // CHECK-LABEL: func private @foo // CHECK-SAME: %arg0: tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} // CHECK-SAME: -> (tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}) { -// CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %arg0, %arg0 {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : tensor<4x2xi32> +// CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg0, %arg0 {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : tensor<4x2xi32> // CHECK-NEXT: return %[[MULT]] : tensor<4x2xi32> diff --git a/xla/service/spmd/shardy/test/round_trip_pipeline.mlir b/xla/service/spmd/shardy/test/round_trip_pipeline.mlir index d51bea212139c..cf0dc80b83006 100644 --- a/xla/service/spmd/shardy/test/round_trip_pipeline.mlir +++ b/xla/service/spmd/shardy/test/round_trip_pipeline.mlir @@ -189,10 +189,10 @@ func.func @main( %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> - // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { + // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { @@ -242,16 +242,16 @@ func.func @main(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { // CHECK-NEXT: %[[HOST:.*]] = stablehlo.custom_call @MoveToHost(%[[NC]]#0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> // CHECK-NEXT: return %[[HOST]] : tensor<8x2xi32> %0:2 = call @g.2(%arg0) {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, mhlo.sharding = "{{maximal device=0}, {replicated}}"} : (tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>) - %1 = mhlo.custom_call @MoveToHost(%0#0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = stablehlo.custom_call @MoveToHost(%0#0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> return %1 : tensor<8x2xi32> } // CHECK-NOT: g.2 func.func private @g.2(%arg0: tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>) { - %0 = mhlo.multiply %arg0, %arg0 : tensor<8x2xi32> + %0 = stablehlo.multiply %arg0, %arg0 : tensor<8x2xi32> return %0, %0 : tensor<8x2xi32>, tensor<8x2xi32> } -// TODO(b/335481977): Add more tests for MHLO ops. So far tested all SDY +// TODO(b/335481977): Add more tests for StableHLO ops. So far tested all SDY // compiler APIs other than shard as/like (doesn't exist yet). See // round_trip_pipeline_manual_computation.mlir for ManualComputationOp tests.