diff --git a/xla/service/gpu/transforms/windowed_einsum_handler.cc b/xla/service/gpu/transforms/windowed_einsum_handler.cc index 2ffec420c30ae3..d4ffe20e9a953f 100644 --- a/xla/service/gpu/transforms/windowed_einsum_handler.cc +++ b/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -961,6 +961,12 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms // to minimize communication overhead. To do this, the original input will // be sliced into replica_group size and perform all-to-all+gemm. + if (!dot->GetModule() + ->config() + .debug_options() + .xla_gpu_experimental_enable_alltoall_windowed_einsum()) { + return absl::OkStatus(); + } HloInstruction* lhs; HloInstruction* rhs; std::vector replica_groups; @@ -1185,6 +1191,12 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { absl::Status HandleAllToAll(HloInstruction* inst) override { CHECK_EQ(inst->opcode(), HloOpcode::kAllToAll); HloComputation* comp = inst->parent(); + if (!inst->GetModule() + ->config() + .debug_options() + .xla_gpu_experimental_enable_alltoall_windowed_einsum()) { + return absl::OkStatus(); + } // Rewrites a gemm+alltoall into multiple independent partial gemm+a2as // to minimize communication overhead. std::vector replica_groups; diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index 2eb8f2ed2d4b91..c60f9ec54f82f1 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -778,7 +778,8 @@ TEST_F(CollectiveOpsTestE2E, NoAllToAllDecomposition) { class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { public: void CollectiveOpsCompareWindowedNonWindowed( - absl::string_view hlo_text, bool disable_dot_merger = false) { + absl::string_view hlo_text, bool disable_dot_merger = false, + bool enable_a2a_rewrite = false) { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 4; SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); @@ -788,6 +789,8 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { auto opts = GetDebugOptionsForTest(); opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_experimental_enable_alltoall_windowed_einsum( + enable_a2a_rewrite); opts.set_xla_gpu_graph_min_graph_size(200); opts.set_xla_gpu_enable_triton_gemm(false); if (disable_dot_merger) { @@ -1061,7 +1064,9 @@ ENTRY main.9_spmd { } )"; - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/false, + /*enable_a2a_rewrite=*/true); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, @@ -1077,7 +1082,9 @@ ENTRY main.9_spmd { } )"; - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/false, + /*enable_a2a_rewrite=*/true); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, @@ -1098,7 +1105,9 @@ ENTRY main.9_spmd { } )"; - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/false, + /*enable_a2a_rewrite=*/true); } TEST_F(CollectiveOpsTestE2E, CollectivePipelinerF8) { diff --git a/xla/xla.proto b/xla/xla.proto index 1382558c1f7a4d..368dc0f2c7233e 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -1098,7 +1098,11 @@ message DebugOptions { // be deterministic, although with additional overhead. bool xla_gpu_enable_scatter_determinism_expander = 345; - // Next id: 359 + // Enable windowed einsum(collective matmul) rewrite for all-to-all + gemm + // This feature is still experimental. + bool xla_gpu_experimental_enable_alltoall_windowed_einsum = 359; + + // Next id: 360 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.