From 6f8c418b3052f7c531896bd5f8cbbc7a766ef7fc Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Mon, 9 Sep 2024 06:57:14 -0500 Subject: [PATCH] rafactor collective comm e2e tests --- xla/tests/collective_ops_e2e_test.cc | 185 +++++---------------------- 1 file changed, 33 insertions(+), 152 deletions(-) diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index b9bc60e6ba484..32af4ef5ee4c7 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/strings/str_replace.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -54,6 +55,21 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) { class CollectiveOpsTestE2E : public HloTestBase { public: + CollectiveOpsTestE2E() { + replacements_[kF8E4M3DatatypePlaceholder] = +#if GOOGLE_CUDA + "f8e4m3fn"; +#else + "f8e4m3fnuz"; +#endif + replacements_[kF8E5M2DatatypePlaceholder] = +#if GOOGLE_CUDA + "f8e5m2"; +#else + "f8e5m2fnuz"; +#endif + } + bool IsCuda() { return std::holds_alternative(Capability()); } @@ -108,6 +124,13 @@ class CollectiveOpsTestE2E : public HloTestBase { /*argument_provider*/ [](int64_t, int64_t) { return nullptr; }, num_replicas, /*run_hlo_passes=*/false, &device_assignment); } + + protected: + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; }; // E2E tests for collective ops. These will generally verify some HLO transform @@ -811,11 +834,11 @@ ENTRY main.12 { TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherAndReduceScatterF8) { absl::string_view kModuleReplicatedStr = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<>[2,16,48]{2,1,0}, <>[48,192]{1,0}, <>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 ENTRY main.12 { - Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + Arg_0.1 = <>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + Arg_1.2 = <>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} Arg_2.3 = bf16[] parameter(3) Arg_3.4 = bf16[] parameter(4) broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={} @@ -834,12 +857,12 @@ ENTRY main.12 { constant.1 = bf16[] constant(448.) broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={} clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4) - convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp) + convert.2 = <>[2,16,192]{2,1,0} convert(clamp) Arg_5.6 = bf16[] parameter(6) broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={} convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2) multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5) - Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} + Arg_6.7 = <>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} Arg_7.8 = bf16[] parameter(7) broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={} convert.4 = bf16[192,48]{1,0} convert(Arg_6.7) @@ -852,8 +875,9 @@ ENTRY main.12 { // Disable the dot merger pass which can prevent the creation of FP8 GEMM // Custom Calls. - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, - /*disable_dot_merger=*/true); + CollectiveOpsCompareWindowedNonWindowed( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), + /*disable_dot_merger=*/true); // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer // architectures. @@ -866,54 +890,6 @@ ENTRY main.12 { CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); } -// TODO: Refactor the test to reduce the duplicate code for OCP fp8 and Nanoo fp8 -TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, - WindowedEinsumE2EAllGatherAndReduceScatterF8Nanoo) { - absl::string_view kModuleReplicatedStr = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fnuz[2,16,48]{2,1,0}, f8e4m3fnuz[48,192]{1,0}, f8e4m3fnuz[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 -ENTRY main.12 { - Arg_0.1 = f8e4m3fnuz[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - Arg_1.2 = f8e4m3fnuz[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} - Arg_2.3 = bf16[] parameter(3) - Arg_3.4 = bf16[] parameter(4) - broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={} - broadcast.1 = bf16[48,192]{1,0} broadcast(Arg_3.4), dimensions={} - convert = bf16[2,16,48]{2,1,0} convert(Arg_0.1) - convert.1 = bf16[48,192]{1,0} convert(Arg_1.2) - multiply = bf16[2,16,48]{2,1,0} multiply(broadcast, convert) - multiply.1 = bf16[48,192]{1,0} multiply(broadcast.1, convert.1) - dot.5 = bf16[2,16,192]{2,1,0} dot(multiply, multiply.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} - custom-call.7 = bf16[2,16,192]{2,1,0} custom-call(dot.5), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} - Arg_4.5 = bf16[] parameter(5) - broadcast.2 = bf16[2,16,192]{2,1,0} broadcast(Arg_4.5), dimensions={} - divide = bf16[2,16,192]{2,1,0} divide(custom-call.7, broadcast.2) - constant = bf16[] constant(-448.) - broadcast.3 = bf16[2,16,192]{2,1,0} broadcast(constant), dimensions={} - constant.1 = bf16[] constant(448.) - broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={} - clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4) - convert.2 = f8e4m3fnuz[2,16,192]{2,1,0} convert(clamp) - Arg_5.6 = bf16[] parameter(6) - broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={} - convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2) - multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5) - Arg_6.7 = f8e4m3fnuz[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} - Arg_7.8 = bf16[] parameter(7) - broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={} - convert.4 = bf16[192,48]{1,0} convert(Arg_6.7) - multiply.3 = bf16[192,48]{1,0} multiply(convert.4, broadcast.6) - dot.6 = bf16[2,16,48]{2,1,0} dot(multiply.2, multiply.3), lhs_contracting_dims={2}, rhs_contracting_dims={0} - tuple.10 = (bf16[2,16,48]{2,1,0}) tuple(dot.6) - ROOT get-tuple-element.11 = bf16[2,16,48]{2,1,0} get-tuple-element(tuple.10), index=0, sharding={devices=[1,4,1]<=[4]} -} // main.12 -)"; - - // Disable the dot merger pass which can prevent the creation of FP8 GEMM - // Custom Calls. - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, - /*disable_dot_merger=*/true); -} - TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherMultiConsumerF8) { absl::string_view kModuleReplicatedStr = R"( @@ -1071,7 +1047,7 @@ while_body { r = bf16[32,128] bitcast(dynamic-slice.k) a = bf16[32,128] add(r, r), control-predecessors={constant.2559} // A fp8 pattern of quant-dequant before the collective AG. - qa = f8e4m3fn[32,128] convert(a) + qa = <>[32,128] convert(a) dqa = bf16[32,128] convert(qa) a_scale = bf16[] get-tuple-element(param), index=3 a_scales = bf16[32,128] broadcast(a_scale), dimensions={} @@ -1079,7 +1055,7 @@ while_body { mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}} ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128} - qma = f8e4m3fn[128,128] convert(ma) + qma = <>[128,128] convert(ma) dqma = bf16[128,128] convert(qma) ma_scale = bf16[] get-tuple-element(param), index=4 ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={} @@ -1112,101 +1088,6 @@ ENTRY entry { CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); } -// TODO: Refactor the test to reduce the duplicate code for OCP fp8 and Nanoo fp8 -TEST_F(CollectiveOpsTestE2E, PostLayoutCollectivePipelinerNanoo) { - // We need fp8 support to test the post-layout collective pipeliner. This will - // preserve the desired fp8 patterns and so the gemm rewriter can correctly - // recognize them and rewrite to custom fp8 gemm calls. - if (!HasFp8Support()) { - GTEST_SKIP() << "Test requires a post-Ada GPU."; - } - - absl::string_view kModuleReplicatedStr = R"( -HloModule module, entry_computation_layout={(bf16[384,128], bf16[96,128], bf16[], bf16[])->bf16[384,128]}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 -add { - lhs = bf16[] parameter(0) - rhs = bf16[] parameter(1) - ROOT add = bf16[] add(lhs, rhs) -} -while_cond { - param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0) - gte = s32[] get-tuple-element(param), index=0 - constant.1 = s32[] constant(3) - ROOT cmp = pred[] compare(gte, constant.1), direction=LT -} -while_body { - param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0) - get-tuple-element.394 = s32[] get-tuple-element(param), index=0 - get-tuple-element.395 = bf16[384,128] get-tuple-element(param), index=1 - get-tuple-element.k = bf16[96,128] get-tuple-element(param), index=2 - constant.2561 = s32[] constant(0) - constant.2557 = s32[] constant(1) - add.230 = s32[] add(get-tuple-element.394, constant.2557) - constant.2559 = s32[] constant(3) - subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) - constant.2560 = s32[] constant(-1) - add.231 = s32[] add(subtract.139, constant.2560) - compare.747 = pred[] compare(add.231, constant.2561), direction=LT - constant.2562 = s32[] constant(2) - add.232 = s32[] add(subtract.139, constant.2562) - select.1348 = s32[] select(compare.747, add.232, add.231) - dynamic-slice.k = bf16[32,128] dynamic-slice(get-tuple-element.k, select.1348, constant.2561), dynamic_slice_sizes={32,128} - r = bf16[32,128] bitcast(dynamic-slice.k) - a = bf16[32,128] add(r, r), control-predecessors={constant.2559} - // A fp8 pattern of quant-dequant before the collective AG. - qa = f8e4m3fnuz[32,128] convert(a) - dqa = bf16[32,128] convert(qa) - a_scale = bf16[] get-tuple-element(param), index=3 - a_scales = bf16[32,128] broadcast(a_scale), dimensions={} - dqa_unscaled = bf16[32,128] multiply(dqa, a_scales) - mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}} - ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128} - - qma = f8e4m3fnuz[128,128] convert(ma) - dqma = bf16[128,128] convert(qma) - ma_scale = bf16[] get-tuple-element(param), index=4 - ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={} - dqma_unscaled = bf16[128,128] multiply(dqma, ma_scales) - mc = bf16[128,128] dot(dqma_unscaled, mb), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dynamic-update-slice.35 = bf16[384,128] dynamic-update-slice(get-tuple-element.395, mc, select.1348, constant.2561) - ROOT tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k, a_scale, ma_scale), control-predecessors={a} -} -ENTRY entry { - c0 = s32[] constant(0) - p0 = bf16[384,128] parameter(0) - p1 = bf16[96,128] parameter(1) - s0 = bf16[] parameter(2) - s1 = bf16[] parameter(3) - tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(c0, p0, p1, s0, s1) - while = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) while(tuple), condition=while_cond, body=while_body - ROOT gte1 = bf16[384,128] get-tuple-element(while), index=1 -} -)"; - - const int64_t kNumReplicas = 1; - const int64_t kNumPartitions = 4; - - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - auto opts = GetDebugOptionsForTest(); - opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); - opts.set_xla_gpu_enable_pipelined_collectives(true); - opts.set_xla_gpu_enable_triton_gemm(false); - config.set_debug_options(opts); - config.set_num_partitions(kNumPartitions); - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); - - TF_ASSERT_OK_AND_ASSIGN(auto executable, - CreateExecutable(std::move(module), - /*run_hlo_passes=*/true)); - EXPECT_TRUE(executable->has_module()); - HloInstruction* gemm_op = - FindInstruction(&executable->module(), HloOpcode::kCustomCall); - EXPECT_THAT(gemm_op, NotNull()); - EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); -} - TEST_F(CollectiveOpsTestE2E, PostLayoutCollectivePipelinerShouldFlattenCallGraph) { // The allgather in the loop has a nested while loop as its operand,