Skip to content

Commit

Permalink
Add a flag to control a2a collective matmul rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
Tixxx committed Dec 19, 2024
1 parent 66dc33c commit 5b3f9b6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
12 changes: 12 additions & 0 deletions xla/service/gpu/transforms/windowed_einsum_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,12 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
// Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms
// to minimize communication overhead. To do this, the original input will
// be sliced into replica_group size and perform all-to-all+gemm.
if (!dot->GetModule()
->config()
.debug_options()
.xla_gpu_experimental_enable_alltoall_windowed_einsum()) {
return absl::OkStatus();
}
HloInstruction* lhs;
HloInstruction* rhs;
std::vector<xla::ReplicaGroup> replica_groups;
Expand Down Expand Up @@ -1185,6 +1191,12 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
absl::Status HandleAllToAll(HloInstruction* inst) override {
CHECK_EQ(inst->opcode(), HloOpcode::kAllToAll);
HloComputation* comp = inst->parent();
if (!inst->GetModule()
->config()
.debug_options()
.xla_gpu_experimental_enable_alltoall_windowed_einsum()) {
return absl::OkStatus();
}
// Rewrites a gemm+alltoall into multiple independent partial gemm+a2as
// to minimize communication overhead.
std::vector<xla::ReplicaGroup> replica_groups;
Expand Down
17 changes: 13 additions & 4 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,8 @@ TEST_F(CollectiveOpsTestE2E, NoAllToAllDecomposition) {
class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
public:
void CollectiveOpsCompareWindowedNonWindowed(
absl::string_view hlo_text, bool disable_dot_merger = false) {
absl::string_view hlo_text, bool disable_dot_merger = false,
bool enable_a2a_rewrite = false) {
const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);
Expand All @@ -788,6 +789,8 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0);
opts.set_xla_gpu_multi_streamed_windowed_einsum(true);
opts.set_xla_gpu_experimental_enable_alltoall_windowed_einsum(
enable_a2a_rewrite);
opts.set_xla_gpu_graph_min_graph_size(200);
opts.set_xla_gpu_enable_triton_gemm(false);
if (disable_dot_merger) {
Expand Down Expand Up @@ -1061,7 +1064,9 @@ ENTRY main.9_spmd {
}
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr);
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/false,
/*enable_a2a_rewrite=*/true);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
Expand All @@ -1077,7 +1082,9 @@ ENTRY main.9_spmd {
}
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr);
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/false,
/*enable_a2a_rewrite=*/true);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
Expand All @@ -1098,7 +1105,9 @@ ENTRY main.9_spmd {
}
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr);
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/false,
/*enable_a2a_rewrite=*/true);
}

TEST_F(CollectiveOpsTestE2E, CollectivePipelinerF8) {
Expand Down
6 changes: 5 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,11 @@ message DebugOptions {
// be deterministic, although with additional overhead.
bool xla_gpu_enable_scatter_determinism_expander = 345;

// Next id: 359
// Enable windowed einsum(collective matmul) rewrite for all-to-all + gemm
// This feature is still experimental.
bool xla_gpu_experimental_enable_alltoall_windowed_einsum = 359;

// Next id: 360

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 5b3f9b6

Please sign in to comment.