Skip to content

Commit

Permalink
[Shardy] HLO ⇄ MHLO to HLO ⇄ StableHLO
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705548834
  • Loading branch information
abhigunj authored and Google-ML-Automation committed Dec 19, 2024
1 parent ca3ddd2 commit 17d8779
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 120 deletions.
72 changes: 49 additions & 23 deletions xla/hlo/translate/stablehlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,36 @@ absl::Status MhloToStablehlo(mlir::ModuleOp module) {
}
return absl::OkStatus();
}

absl::Status StablehloToMhlo(mlir::ModuleOp module, bool run_canonicalizer) {
mlir::MLIRContext* context = module->getContext();
mlir::BaseScopedDiagnosticHandler diagnostic_handler(context);
{
mlir::PassManager pm(context);
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::mhlo::createChloLegalizeToHloPass());
if (run_canonicalizer) {
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
}
// In order to export to XLA, we must sink constants to control flow
// regions, since XLA uses functional control flow.
pm.addNestedPass<mlir::func::FuncOp>(
mlir::mhlo::createSinkConstantsToControlFlowPass());
if (failed(pm.run(module))) {
VLOG(1) << "MHLO->HLO lowering passes failed.";
module->dump();
return diagnostic_handler.ConsumeStatus();
}

VLOG(5) << "MHLO module after lowering, before HLO import ";
if (VLOG_IS_ON(5)) {
module->dump();
}
}
return absl::OkStatus();
}

} // namespace

void RegisterMlirToHloDependentDialects(mlir::DialectRegistry& registry) {
Expand Down Expand Up @@ -113,29 +143,7 @@ absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module,
xla::HloProto* hlo_proto) {
if (!module) return absl::InvalidArgumentError("Module is null");

mlir::MLIRContext* context = module->getContext();
mlir::BaseScopedDiagnosticHandler diagnostic_handler(context);
{
mlir::PassManager pm(context);
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::mhlo::createChloLegalizeToHloPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
// In order to export to XLA, we must sink constants to control flow
// regions, since XLA uses functional control flow.
pm.addNestedPass<mlir::func::FuncOp>(
mlir::mhlo::createSinkConstantsToControlFlowPass());
if (failed(pm.run(module))) {
VLOG(1) << "MHLO->HLO lowering passes failed.";
module->dump();
return diagnostic_handler.ConsumeStatus();
}

VLOG(5) << "MHLO module after lowering, before HLO import ";
if (VLOG_IS_ON(5)) {
module->dump();
}
}
TF_RETURN_IF_ERROR(StablehloToMhlo(module, /*run_canonicalizer=*/true));

mlir::MlirToHloConversionOptions options;
options.return_tuple = false;
Expand All @@ -144,4 +152,22 @@ absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module,
return absl::OkStatus();
}

absl::Status ConvertStablehloWithManyArgsToHloProto(mlir::ModuleOp module,
xla::HloProto* hlo_proto,
bool use_tuple_args) {
if (!module) return absl::InvalidArgumentError("Module is null");

TF_RETURN_IF_ERROR(StablehloToMhlo(module, /*run_canonicalizer=*/false));

mlir::MlirToHloConversionOptions options;
options.return_tuple = false;
options.use_tuple_args = use_tuple_args;
// Remove attributes introduced by `import_all_computation=true` at
// ConvertHloToStablehlo.
module->removeAttr("mhlo.xla_entry_computation_parameter_layouts");
module->removeAttr("mhlo.xla_entry_computation_parameter_tiles");
TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module, hlo_proto, options));
return absl::OkStatus();
}

} // namespace xla
9 changes: 9 additions & 0 deletions xla/hlo/translate/stablehlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ absl::StatusOr<std::unique_ptr<xla::HloModule>> ConvertStablehloToHlo(
absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module,
xla::HloProto* hlo_proto);

// Convert StableHLO module to HloModule.
// Some platforms run out of memory when the argument list is too long.
// This API wraps the arguments in a tuple (if use_tuple_args = true)
// as a workaround. The long-term solution is to add an HLO pass to do this.
// In general, prefer the other ConvertStablehloToHloProto method.
absl::Status ConvertStablehloWithManyArgsToHloProto(
mlir::ModuleOp module, xla::HloProto* hlo_proto,
bool use_tuple_args = false);

} // namespace xla

#endif // XLA_HLO_TRANSLATE_STABLEHLO_H_
3 changes: 2 additions & 1 deletion xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ cc_library(
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/transforms:hlo_dce",
"//xla/hlo/transforms:tuple_simplifier",
"//xla/hlo/translate:stablehlo",
"//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
"//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo",
"//xla/hlo/utils:hlo_sharding_util",
Expand Down Expand Up @@ -105,7 +106,7 @@ xla_cc_test(
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
"//xla/tests:remote_mtest_lib_buffer_donation_test_gpu_b100",
"@com_google_absl//absl/log",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
Expand Down
1 change: 0 additions & 1 deletion xla/service/spmd/shardy/mhlo_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ cc_library(
":export_ops",
":export_shardings",
":shard_map_export",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy/round_trip_common:export_named_computations",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
Expand Down
2 changes: 0 additions & 2 deletions xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h"
Expand All @@ -37,7 +36,6 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) {
pm.addPass(createMhloRoundTripShardMapExportPass());
pm.addPass(createExportNamedComputationsPass());
pm.addPass(createExportMhloShardingsPass());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
}

void registerMhloExportPipeline() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using ::mlir::func::FuncOp;

void addCommonPreImportPasses(mlir::OpPassManager& pm) {
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
// TODO(b/333505182): remove when partitioning is done in SDY.
// We call prepare-for-export pass before SDY propagation, so that all IR
// changes happen before shardings are added to operations, to ensure the
Expand Down
12 changes: 3 additions & 9 deletions xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,21 @@ cc_library(
srcs = ["mhlo_to_hlo_to_mhlo.cc"],
hdrs = ["mhlo_to_hlo_to_mhlo.h"],
deps = [
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/translate:stablehlo",
"//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
"//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo",
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SparseTensorDialect",
"@llvm-project//mlir:Support",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -49,7 +44,6 @@ cc_library(
hdrs = ["testing_pipeline.h"],
deps = [
":mhlo_to_hlo_to_mhlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy/sdy_round_trip:pipelines",
"@llvm-project//mlir:Pass",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,30 @@ limitations under the License.
#include <memory>
#include <utility>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/TypeID.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/hlo/translate/stablehlo.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace sdy {
Expand All @@ -55,33 +53,22 @@ using ::mlir::StringRef;

// Converts an MHLO module to an HLO module.
absl::StatusOr<std::unique_ptr<HloModule>> toHlo(ModuleOp module) {
absl::StatusOr<std::unique_ptr<HloModule>> hloModule;
xla::HloProto hloProto;
TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(module, &hloProto,
/*use_tuple_args=*/false,
/*return_tuple=*/false));
xla::HloModuleConfig moduleConfig;
xla::ProgramShape expectedProgramShape(
hloProto.hlo_module().host_program_shape());
moduleConfig.SetDefaultComputationLayout(expectedProgramShape);
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hloModule,
xla::ConvertStablehloToHlo(module));
xla::HloModuleConfig& moduleConfig = hloModule->mutable_config();
moduleConfig.set_use_spmd_partitioning(true);
return xla::HloModule::CreateFromProto(hloProto.hlo_module(), moduleConfig);
return hloModule;
}

// Converts an HLO module to an MHLO module.
absl::Status toMhlo(std::unique_ptr<HloModule> hloModule, ModuleOp module) {
// Delete the functions, which can be more than one due to preserving
// the shmap_body functions.
mlir::SymbolTableCollection symbolTableCollection;
mlir::SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(module);
for (mlir::Operation& op :
llvm::make_early_inc_range(module.getBodyRegion().getOps())) {
symbolTable.erase(&op);
}
TF_RETURN_IF_ERROR(
xla::ConvertHloToMlirHlo(module, hloModule.get(),
/*import_all_computations=*/false,
/*flatten_computation_args_result=*/true));
// Converts an HLO module to a StableHLO module.
absl::Status toStablehlo(std::unique_ptr<HloModule> hloModule,
ModuleOp& module) {
TF_ASSIGN_OR_RETURN(
mlir::OwningOpRef<mlir::ModuleOp> newModule,
xla::ConvertHloToStablehlo(*module->getContext(), hloModule.get()));
mlir::IRMapping mapping;
module.getBodyRegion().getBlocks().front().erase();
newModule.get().getBodyRegion().cloneInto(&module.getBodyRegion(), mapping);
return absl::OkStatus();
}

Expand All @@ -103,7 +90,7 @@ class SdyRoundTripMhloToHloToMhloPass
}

// 2. HLO -> MHLO
if (absl::Status status = toMhlo(std::move(*hloModule), module);
if (absl::Status status = toStablehlo(std::move(*hloModule), module);
!status.ok()) {
module.emitError(absl::StrCat("Failed to convert to MHLO from HLO: ",
status.message()));
Expand All @@ -120,9 +107,7 @@ class SdyRoundTripMhloToHloToMhloPass
}

void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<mlir::sdy::SdyDialect, mlir::stablehlo::StablehloDialect,
mlir::mhlo::MhloDialect, mlir::quant::QuantDialect,
mlir::sparse_tensor::SparseTensorDialect>();
xla::RegisterMlirToHloDependentDialects(registry);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.

#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
#include "xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.h"

Expand All @@ -31,7 +30,6 @@ void registerSdyRoundTripTestingPipeline() {
"MHLO, then import back to Shardy",
[](mlir::OpPassManager& pm) {
addSdyRoundTripExportPipeline(pm);
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
pm.addPass(createSdyRoundTripMhloToHloToMhloPass());
addSdyRoundTripImportPipeline(pm);
});
Expand Down
22 changes: 10 additions & 12 deletions xla/service/spmd/shardy/shardy_xla_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ limitations under the License.
#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h"
#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
#include "xla/hlo/translate/stablehlo.h"
#include "xla/hlo/utils/hlo_sharding_util.h"
#include "xla/layout.h"
#include "xla/map_util.h"
Expand Down Expand Up @@ -307,17 +308,14 @@ absl::StatusOr<bool> ShardyXLA::Run(
const absl::flat_hash_set<absl::string_view>& executionThreads) {
LOG(INFO) << "Using Shardy for XLA SPMD propagation.";

// HLO -> MLIR MHLO
// HLO -> StableHLO
auto mlirContext = std::make_unique<mlir::MLIRContext>();
loadAllRequiredDialects(mlirContext.get());
mlir::OwningOpRef<mlir::ModuleOp> mlirModule =
xla::llvm_ir::CreateMlirModuleOp(
mlir::UnknownLoc::get(mlirContext.get()));
TF_RETURN_IF_ERROR(
ConvertHloToMlirHlo(*mlirModule, hloModule,
/*import_all_computations=*/false,
/*flatten_computation_args_result=*/true));

TF_ASSIGN_OR_RETURN(
mlir::OwningOpRef<mlir::ModuleOp> mlirModule,
xla::ConvertHloToStablehlo(*mlirContext.get(), hloModule));
LOG(INFO) << "StableHLO\n";
mlirModule->dump();
std::string shardyDir = hloModule->config().debug_options().xla_dump_to();

if (shardyDir == "sponge") {
Expand Down Expand Up @@ -403,10 +401,10 @@ absl::StatusOr<bool> ShardyXLA::Run(
tsl::StatusScopedDiagnosticHandler diagnosticHandler(mlirContext.get());
TF_RETURN_IF_ERROR(diagnosticHandler.consumeStatus(pm.run(*mlirModule)));

// MLIR MHLO -> HLO
// StableHlo -> HLO
HloProto hloProto;
TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(*mlirModule, &hloProto, useTupleArgs,
/*return_tuple=*/false));
TF_RETURN_IF_ERROR(ConvertStablehloWithManyArgsToHloProto(
*mlirModule, &hloProto, useTupleArgs));
TF_RETURN_IF_ERROR(
createFromProtoAndReplaceComputations(hloModule, hloProto.hlo_module()));

Expand Down
5 changes: 2 additions & 3 deletions xla/service/spmd/shardy/shardy_xla_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,8 @@ TEST_F(ShardyXLATest, WhileWithFreeVariables) {
op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}"));
// Verify the sharding of the while, and specifically that the sharding of the
// result that corresponds to parameter(1) is further sharded.
EXPECT_THAT(whileInst,
op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, {replicated}, "
"{devices=[2,2]<=[4]}, {replicated}}"));
EXPECT_THAT(whileInst, op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, "
"{devices=[2,2]<=[4]}}"));
}

TEST_F(ShardyXLATest, ShardMap) {
Expand Down
Loading

0 comments on commit 17d8779

Please sign in to comment.