Skip to content

Commit f191703

Browse files
allanrenuccitensorflower-gardener
authored andcommitted
[XLA:GPU] Annotate combinable sync collectives.
Instead of computing the set of synchronous collectives once per combiner pass, we compute it once for all combiner passes. We add a new pass prior to combiner passes where we annotate synchronous collectives. PiperOrigin-RevId: 743207737
1 parent 60f3ace commit f191703

15 files changed

+443
-253
lines changed

third_party/xla/xla/service/gpu/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,7 @@ cc_library(
14141414
"@llvm-project//llvm:BitWriter",
14151415
"@llvm-project//llvm:Core",
14161416
"@llvm-project//llvm:Support",
1417+
"@llvm-project//llvm:TargetParser",
14171418
"@llvm-project//llvm:TransformUtils",
14181419
"@llvm-project//mlir:FuncDialect",
14191420
"@llvm-project//mlir:IR",
@@ -1493,6 +1494,7 @@ cc_library(
14931494
"//xla/service/gpu/model:gpu_hlo_cost_analysis",
14941495
"//xla/service/gpu/model:matmul_ptable_stats_collection",
14951496
"//xla/service/gpu/model:sol_gpu_cost_model_stats_collection",
1497+
"//xla/service/gpu/transforms/collectives:collective_combiner_annotator",
14961498
"//xla/service/gpu/transforms/collectives:convert_async_collectives_to_sync",
14971499
"//xla/service/gpu/transforms/collectives:gpu_all_gather_combiner",
14981500
"//xla/service/gpu/transforms/collectives:gpu_all_reduce_combiner",
@@ -1641,7 +1643,7 @@ cc_library(
16411643
]) + xla_internal(["service:export_hlo"]) + if_google([
16421644
"//xla/hlo/experimental/auto_sharding",
16431645
"//xla/hlo/experimental/auto_sharding:auto_sharding_option",
1644-
]) + ["@llvm-project//llvm:TargetParser"],
1646+
]),
16451647
)
16461648

16471649
xla_test(

third_party/xla/xla/service/gpu/gpu_compiler.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ limitations under the License.
3737
#include "absl/strings/string_view.h"
3838
#include "absl/synchronization/blocking_counter.h"
3939
#include "absl/types/span.h"
40-
#include "absl/types/variant.h"
4140
#include "llvm/ADT/DenseMap.h"
4241
#include "llvm/ADT/SmallString.h"
4342
#include "llvm/ADT/StringRef.h"
@@ -194,6 +193,7 @@ limitations under the License.
194193
#include "xla/service/gpu/transforms/collective_select_folder.h"
195194
#include "xla/service/gpu/transforms/collectives/all_gather_combiner.h"
196195
#include "xla/service/gpu/transforms/collectives/all_reduce_combiner.h"
196+
#include "xla/service/gpu/transforms/collectives/collective_combiner_annotator.h"
197197
#include "xla/service/gpu/transforms/collectives/convert_async_collectives_to_sync.h"
198198
#include "xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h"
199199
#include "xla/service/gpu/transforms/collectives/reduce_scatter_combiner.h"
@@ -1150,6 +1150,12 @@ absl::Status RunPostFusionPasses(
11501150

11511151
HloPassPipeline pipeline("post-fusion optimization");
11521152
pipeline.AddPass<RenameFusions>();
1153+
if (hlo_module->config()
1154+
.debug_options()
1155+
.xla_gpu_experimental_enable_sync_collective_combining()) {
1156+
pipeline.AddPass<CollectiveCombinerAnnotator>(device_description,
1157+
pointer_size);
1158+
}
11531159
pipeline.AddPass<GpuAllGatherCombiner>(
11541160
device_description,
11551161
/*default_combine_threshold_in_bytes=*/kDefaultAllGatherCombineThreshold,

third_party/xla/xla/service/gpu/transforms/collectives/BUILD

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,23 +67,15 @@ cc_library(
6767
srcs = ["gpu_collective_combiner_utils.cc"],
6868
hdrs = ["gpu_collective_combiner_utils.h"],
6969
deps = [
70-
":collective_ops_utils",
71-
":convert_async_collectives_to_sync",
7270
"//xla:util",
7371
"//xla/hlo/ir:hlo",
74-
"//xla/hlo/pass:hlo_pass_pipeline",
75-
"//xla/hlo/utils:hlo_query",
7672
"//xla/service:collective_ops_utils",
7773
"//xla/service:collective_utils",
7874
"//xla/service/gpu:backend_configs_cc",
7975
"//xla/service/gpu:gpu_hlo_schedule",
8076
"//xla/stream_executor:device_description",
81-
"//xla/tsl/platform:errors",
82-
"@com_google_absl//absl/container:flat_hash_set",
8377
"@com_google_absl//absl/log",
8478
"@com_google_absl//absl/status",
85-
"@com_google_absl//absl/status:statusor",
86-
"@com_google_absl//absl/strings",
8779
],
8880
)
8981

@@ -103,9 +95,7 @@ xla_cc_test(
10395
"//xla/service/gpu:backend_configs_cc",
10496
"//xla/stream_executor:device_description",
10597
"//xla/tests:hlo_test_base",
106-
"//xla/tsl/platform:status_matchers",
10798
"//xla/tsl/platform:statusor",
108-
"@com_google_absl//absl/container:flat_hash_set",
10999
"@com_google_absl//absl/status:statusor",
110100
"@com_google_absl//absl/strings:string_view",
111101
"@com_google_googletest//:gtest_main",
@@ -117,6 +107,7 @@ cc_library(
117107
srcs = ["all_gather_combiner.cc"],
118108
hdrs = ["all_gather_combiner.h"],
119109
deps = [
110+
":collective_combiner_annotator",
120111
":gpu_collective_combiner_utils",
121112
"//xla/hlo/ir:hlo",
122113
"//xla/hlo/pass:hlo_pass",
@@ -126,7 +117,6 @@ cc_library(
126117
"//xla/stream_executor:device_description",
127118
"//xla/tsl/platform:statusor",
128119
"@com_google_absl//absl/container:flat_hash_set",
129-
"@com_google_absl//absl/functional:bind_front",
130120
"@com_google_absl//absl/status:statusor",
131121
"@com_google_absl//absl/strings:string_view",
132122
],
@@ -157,6 +147,7 @@ cc_library(
157147
srcs = ["reduce_scatter_combiner.cc"],
158148
hdrs = ["reduce_scatter_combiner.h"],
159149
deps = [
150+
":collective_combiner_annotator",
160151
":gpu_collective_combiner_utils",
161152
"//xla/hlo/ir:hlo",
162153
"//xla/hlo/pass:hlo_pass",
@@ -166,7 +157,6 @@ cc_library(
166157
"//xla/stream_executor:device_description",
167158
"//xla/tsl/platform:statusor",
168159
"@com_google_absl//absl/container:flat_hash_set",
169-
"@com_google_absl//absl/functional:bind_front",
170160
"@com_google_absl//absl/status:statusor",
171161
"@com_google_absl//absl/strings:string_view",
172162
],
@@ -196,6 +186,7 @@ cc_library(
196186
srcs = ["all_reduce_combiner.cc"],
197187
hdrs = ["all_reduce_combiner.h"],
198188
deps = [
189+
":collective_combiner_annotator",
199190
":gpu_collective_combiner_utils",
200191
"//xla/hlo/ir:hlo",
201192
"//xla/hlo/pass:hlo_pass",
@@ -205,7 +196,6 @@ cc_library(
205196
"//xla/stream_executor:device_description",
206197
"//xla/tsl/platform:statusor",
207198
"@com_google_absl//absl/container:flat_hash_set",
208-
"@com_google_absl//absl/functional:bind_front",
209199
"@com_google_absl//absl/status:statusor",
210200
"@com_google_absl//absl/strings:string_view",
211201
],
@@ -229,3 +219,44 @@ xla_cc_test(
229219
"@com_google_googletest//:gtest_main",
230220
],
231221
)
222+
223+
cc_library(
224+
name = "collective_combiner_annotator",
225+
srcs = ["collective_combiner_annotator.cc"],
226+
hdrs = ["collective_combiner_annotator.h"],
227+
deps = [
228+
":collective_ops_utils",
229+
":convert_async_collectives_to_sync",
230+
"//xla:util",
231+
"//xla/hlo/ir:hlo",
232+
"//xla/hlo/pass:hlo_pass",
233+
"//xla/hlo/pass:hlo_pass_pipeline",
234+
"//xla/hlo/utils:hlo_query",
235+
"//xla/service/gpu:gpu_hlo_schedule",
236+
"//xla/stream_executor:device_description",
237+
"//xla/tsl/platform:errors",
238+
"//xla/tsl/platform:statusor",
239+
"@com_google_absl//absl/container:flat_hash_set",
240+
"@com_google_absl//absl/log",
241+
"@com_google_absl//absl/status",
242+
"@com_google_absl//absl/status:statusor",
243+
"@com_google_absl//absl/strings",
244+
"@com_google_absl//absl/strings:string_view",
245+
],
246+
)
247+
248+
xla_cc_test(
249+
name = "collective_combiner_annotator_test",
250+
srcs = ["collective_combiner_annotator_test.cc"],
251+
deps = [
252+
":collective_combiner_annotator",
253+
"//xla/hlo/ir:hlo",
254+
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
255+
"//xla/stream_executor:device_description",
256+
"//xla/tsl/platform:status_matchers",
257+
"//xla/tsl/platform:statusor",
258+
"@com_google_absl//absl/status:statusor",
259+
"@com_google_absl//absl/strings:string_view",
260+
"@com_google_googletest//:gtest_main",
261+
],
262+
)

third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner.cc

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ limitations under the License.
1818
#include <optional>
1919

2020
#include "absl/container/flat_hash_set.h"
21-
#include "absl/functional/bind_front.h"
2221
#include "absl/status/statusor.h"
2322
#include "absl/strings/string_view.h"
2423
#include "xla/hlo/ir/hlo_instruction.h"
2524
#include "xla/hlo/ir/hlo_module.h"
2625
#include "xla/hlo/ir/hlo_opcode.h"
2726
#include "xla/hlo/transforms/collectives/all_gather_combiner.h"
2827
#include "xla/service/gpu/backend_configs.pb.h"
28+
#include "xla/service/gpu/transforms/collectives/collective_combiner_annotator.h"
2929
#include "xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h"
3030
#include "xla/service/hlo_domain_map.h"
3131
#include "xla/tsl/platform/statusor.h"
@@ -49,10 +49,9 @@ std::optional<AllGatherCombiner::GroupKey> PipelinedCombinerKey(
4949
}
5050

5151
std::optional<AllGatherCombiner::GroupKey> SynchronousCombinerKey(
52-
const absl::flat_hash_set<HloInstruction*>& sync_collectives,
5352
const HloInstruction* instruction, const HloDomainMap& domain_map,
5453
bool combine_by_dim, bool combine_different_dtypes) {
55-
if (!sync_collectives.contains(instruction)) {
54+
if (!IsCombinableSyncCollective(*instruction)) {
5655
return std::nullopt;
5756
}
5857
return AllGatherCombiner::CombineKey(instruction, domain_map, combine_by_dim,
@@ -77,21 +76,11 @@ absl::StatusOr<bool> GpuAllGatherCombiner::Run(
7776
bool changed = false;
7877

7978
// Combine as much as possible for synchronous collectives.
80-
absl::flat_hash_set<HloInstruction*> sync_collectives;
81-
if (module->config()
82-
.debug_options()
83-
.xla_gpu_experimental_enable_sync_collective_combining()) {
84-
TF_ASSIGN_OR_RETURN(
85-
sync_collectives,
86-
SynchronousCollectives(*module, pointer_size_, device_info_));
87-
}
88-
if (!sync_collectives.empty()) {
79+
if (ContainsCombinableSyncCollective(*module)) {
8980
combine_threshold_in_bytes_ = MaxAvailableMemory(*module, device_info_);
9081
TF_ASSIGN_OR_RETURN(
9182
bool combined,
92-
RunWithKeyCombiner(
93-
module, execution_threads,
94-
absl::bind_front(SynchronousCombinerKey, sync_collectives)));
83+
RunWithKeyCombiner(module, execution_threads, SynchronousCombinerKey));
9584
changed |= combined;
9685
}
9786

third_party/xla/xla/service/gpu/transforms/collectives/all_gather_combiner_test.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,10 @@ TEST_F(GpuAllGatherCombinerTest, CombinesSynchronousCollectivesMaximally) {
356356
p1 = f16[5000000]{0} parameter(1)
357357
358358
// 20MB combinable all-gather collectives. Default combiner threshold is 30MB.
359-
ag0 = f16[10000000]{0} all-gather(p0), replica_groups={}, dimensions={0}
360-
ag1 = f16[10000000]{0} all-gather(p1), replica_groups={}, dimensions={0}
359+
ag0 = f16[10000000]{0} all-gather(p0), replica_groups={}, dimensions={0},
360+
frontend_attributes={sync_collective="true"}
361+
ag1 = f16[10000000]{0} all-gather(p1), replica_groups={}, dimensions={0},
362+
frontend_attributes={sync_collective="true"}
361363
ROOT result = tuple(ag0, ag1)
362364
}
363365
)";
@@ -373,13 +375,7 @@ TEST_F(GpuAllGatherCombinerTest, CombinesSynchronousCollectivesMaximally) {
373375
/*combine_by_dim=*/false,
374376
/*combine_different_dtypes=*/true, /*pointer_size=*/4);
375377

376-
EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(false));
377-
378-
module->mutable_config()
379-
.mutable_debug_options()
380-
.set_xla_gpu_experimental_enable_sync_collective_combining(true);
381378
EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(true));
382-
383379
Matcher<const HloInstruction*> combined_all_gather =
384380
op::AllGather(op::Parameter(0), op::Parameter(1));
385381
EXPECT_THAT(module->entry_computation()->root_instruction(),

third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner.cc

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ limitations under the License.
1818
#include <optional>
1919

2020
#include "absl/container/flat_hash_set.h"
21-
#include "absl/functional/bind_front.h"
2221
#include "absl/status/statusor.h"
2322
#include "absl/strings/string_view.h"
2423
#include "xla/hlo/ir/hlo_instruction.h"
2524
#include "xla/hlo/ir/hlo_module.h"
2625
#include "xla/hlo/ir/hlo_opcode.h"
2726
#include "xla/hlo/transforms/collectives/all_reduce_combiner.h"
2827
#include "xla/service/gpu/backend_configs.pb.h"
28+
#include "xla/service/gpu/transforms/collectives/collective_combiner_annotator.h"
2929
#include "xla/service/gpu/transforms/collectives/gpu_collective_combiner_utils.h"
3030
#include "xla/service/hlo_domain_map.h"
3131
#include "xla/tsl/platform/statusor.h"
@@ -47,9 +47,8 @@ std::optional<AllReduceCombiner::GroupKey> PipelinedCombinerKey(
4747
}
4848

4949
std::optional<AllReduceCombiner::GroupKey> SynchronousCombinerKey(
50-
const absl::flat_hash_set<HloInstruction*>& sync_collectives,
5150
const HloInstruction* instruction, const HloDomainMap& domain_map) {
52-
if (!sync_collectives.contains(instruction)) {
51+
if (!IsCombinableSyncCollective(*instruction)) {
5352
return std::nullopt;
5453
}
5554
return AllReduceCombiner::CombineKey(instruction, domain_map);
@@ -73,21 +72,11 @@ absl::StatusOr<bool> GpuAllReduceCombiner::Run(
7372
bool changed = false;
7473

7574
// Combine as much as possible for synchronous collectives.
76-
absl::flat_hash_set<HloInstruction*> sync_collectives;
77-
if (module->config()
78-
.debug_options()
79-
.xla_gpu_experimental_enable_sync_collective_combining()) {
80-
TF_ASSIGN_OR_RETURN(
81-
sync_collectives,
82-
SynchronousCollectives(*module, pointer_size_, device_info_));
83-
}
84-
if (!sync_collectives.empty()) {
75+
if (ContainsCombinableSyncCollective(*module)) {
8576
combine_threshold_in_bytes_ = MaxAvailableMemory(*module, device_info_);
8677
TF_ASSIGN_OR_RETURN(
8778
bool combined,
88-
RunWithKeyCombiner(
89-
module, execution_threads,
90-
absl::bind_front(SynchronousCombinerKey, sync_collectives)));
79+
RunWithKeyCombiner(module, execution_threads, SynchronousCombinerKey));
9180
changed |= combined;
9281
}
9382

third_party/xla/xla/service/gpu/transforms/collectives/all_reduce_combiner_test.cc

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,10 @@ TEST_F(GpuAllReduceCombinerTest, CombinesSynchronousCollectivesMaximally) {
359359
p1 = f16[10000000]{0} parameter(1)
360360
361361
// 20MB combinable all-reduce collectives. Default combiner threshold is 30MB.
362-
ar0 = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add
363-
ar1 = f16[10000000]{0} all-reduce(p1), replica_groups={}, to_apply=add
362+
ar0 = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add,
363+
frontend_attributes={sync_collective="true"}
364+
ar1 = f16[10000000]{0} all-reduce(p1), replica_groups={}, to_apply=add,
365+
frontend_attributes={sync_collective="true"}
364366
ROOT result = tuple(ar0, ar1)
365367
}
366368
)";
@@ -374,13 +376,7 @@ TEST_F(GpuAllReduceCombinerTest, CombinesSynchronousCollectivesMaximally) {
374376
/*combine_threshold_in_bytes=*/kDefaultAllReduceCombineThreshold,
375377
/*combine_threshold_count=*/256, /*pointer_size=*/4);
376378

377-
EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(false));
378-
379-
module->mutable_config()
380-
.mutable_debug_options()
381-
.set_xla_gpu_experimental_enable_sync_collective_combining(true);
382379
EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(true));
383-
384380
Matcher<const HloInstruction*> combined_all_reduce =
385381
op::AllReduce(op::Parameter(0), op::Parameter(1));
386382
EXPECT_THAT(module->entry_computation()->root_instruction(),
@@ -400,17 +396,17 @@ TEST_F(GpuAllReduceCombinerTest,
400396
}
401397
402398
ENTRY main {
403-
p0 = f16[10000000]{0} parameter(0)
404-
p1 = f16[10000000]{0} parameter(1)
399+
p0 = f16[10000]{0} parameter(0)
400+
p1 = f16[10000]{0} parameter(1)
405401
406402
// This all-reduce must happen first, which is enforced by the control
407403
// dependency and must be respected.
408-
lead_ar = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add
404+
lead_ar = f16[10000]{0} all-reduce(p0), replica_groups={}, to_apply=add
409405
410406
// These all-reduce have control dependencies and must not be combined.
411-
ar0 = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add,
407+
ar0 = f16[10000]{0} all-reduce(p0), replica_groups={}, to_apply=add,
412408
control-predecessors={lead_ar}
413-
ar1 = f16[10000000]{0} all-reduce(p1), replica_groups={}, to_apply=add,
409+
ar1 = f16[10000]{0} all-reduce(p1), replica_groups={}, to_apply=add,
414410
control-predecessors={lead_ar}
415411
ROOT result = tuple(ar0, ar1)
416412
}
@@ -424,10 +420,6 @@ TEST_F(GpuAllReduceCombinerTest,
424420
kDefaultAllReduceCombineThreshold,
425421
/*combine_threshold_in_bytes=*/kDefaultAllReduceCombineThreshold,
426422
/*combine_threshold_count=*/256, /*pointer_size=*/4);
427-
428-
module->mutable_config()
429-
.mutable_debug_options()
430-
.set_xla_gpu_experimental_enable_sync_collective_combining(true);
431423
EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(false));
432424
}
433425

0 commit comments

Comments
 (0)