Skip to content

Commit

Permalink
rafactor collective comm e2e tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Sep 9, 2024
1 parent d247af5 commit 6f8c418
Showing 1 changed file with 33 additions and 152 deletions.
185 changes: 33 additions & 152 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <vector>

#include "absl/strings/string_view.h"
#include "absl/strings/str_replace.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand Down Expand Up @@ -54,6 +55,21 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) {

class CollectiveOpsTestE2E : public HloTestBase {
public:
CollectiveOpsTestE2E() {
replacements_[kF8E4M3DatatypePlaceholder] =
#if GOOGLE_CUDA
"f8e4m3fn";
#else
"f8e4m3fnuz";
#endif
replacements_[kF8E5M2DatatypePlaceholder] =
#if GOOGLE_CUDA
"f8e5m2";
#else
"f8e5m2fnuz";
#endif
}

bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}
Expand Down Expand Up @@ -108,6 +124,13 @@ class CollectiveOpsTestE2E : public HloTestBase {
/*argument_provider*/ [](int64_t, int64_t) { return nullptr; },
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
}

protected:
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

// E2E tests for collective ops. These will generally verify some HLO transform
Expand Down Expand Up @@ -811,11 +834,11 @@ ENTRY main.12 {
TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherAndReduceScatterF8) {
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<<F8E4M3>>[2,16,48]{2,1,0}, <<F8E4M3>>[48,192]{1,0}, <<F8E4M3>>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
ENTRY main.12 {
Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_0.1 = <<F8E4M3>>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = <<F8E4M3>>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_2.3 = bf16[] parameter(3)
Arg_3.4 = bf16[] parameter(4)
broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={}
Expand All @@ -834,12 +857,12 @@ ENTRY main.12 {
constant.1 = bf16[] constant(448.)
broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={}
clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4)
convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp)
convert.2 = <<F8E4M3>>[2,16,192]{2,1,0} convert(clamp)
Arg_5.6 = bf16[] parameter(6)
broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={}
convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2)
multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5)
Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_6.7 = <<F8E4M3>>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_7.8 = bf16[] parameter(7)
broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={}
convert.4 = bf16[192,48]{1,0} convert(Arg_6.7)
Expand All @@ -852,8 +875,9 @@ ENTRY main.12 {

// Disable the dot merger pass which can prevent the creation of FP8 GEMM
// Custom Calls.
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/true);
CollectiveOpsCompareWindowedNonWindowed(
absl::StrReplaceAll(kModuleReplicatedStr, replacements_),
/*disable_dot_merger=*/true);

// Verify the creation of FP8 GEMM Custom Calls on Hopper and newer
// architectures.
Expand All @@ -866,54 +890,6 @@ ENTRY main.12 {
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
}

// TODO: Refactor the test to reduce the duplicate code for OCP fp8 and Nanoo fp8
TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherAndReduceScatterF8Nanoo) {
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fnuz[2,16,48]{2,1,0}, f8e4m3fnuz[48,192]{1,0}, f8e4m3fnuz[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
ENTRY main.12 {
Arg_0.1 = f8e4m3fnuz[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = f8e4m3fnuz[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_2.3 = bf16[] parameter(3)
Arg_3.4 = bf16[] parameter(4)
broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={}
broadcast.1 = bf16[48,192]{1,0} broadcast(Arg_3.4), dimensions={}
convert = bf16[2,16,48]{2,1,0} convert(Arg_0.1)
convert.1 = bf16[48,192]{1,0} convert(Arg_1.2)
multiply = bf16[2,16,48]{2,1,0} multiply(broadcast, convert)
multiply.1 = bf16[48,192]{1,0} multiply(broadcast.1, convert.1)
dot.5 = bf16[2,16,192]{2,1,0} dot(multiply, multiply.1), lhs_contracting_dims={2}, rhs_contracting_dims={0}
custom-call.7 = bf16[2,16,192]{2,1,0} custom-call(dot.5), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]}
Arg_4.5 = bf16[] parameter(5)
broadcast.2 = bf16[2,16,192]{2,1,0} broadcast(Arg_4.5), dimensions={}
divide = bf16[2,16,192]{2,1,0} divide(custom-call.7, broadcast.2)
constant = bf16[] constant(-448.)
broadcast.3 = bf16[2,16,192]{2,1,0} broadcast(constant), dimensions={}
constant.1 = bf16[] constant(448.)
broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={}
clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4)
convert.2 = f8e4m3fnuz[2,16,192]{2,1,0} convert(clamp)
Arg_5.6 = bf16[] parameter(6)
broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={}
convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2)
multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5)
Arg_6.7 = f8e4m3fnuz[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_7.8 = bf16[] parameter(7)
broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={}
convert.4 = bf16[192,48]{1,0} convert(Arg_6.7)
multiply.3 = bf16[192,48]{1,0} multiply(convert.4, broadcast.6)
dot.6 = bf16[2,16,48]{2,1,0} dot(multiply.2, multiply.3), lhs_contracting_dims={2}, rhs_contracting_dims={0}
tuple.10 = (bf16[2,16,48]{2,1,0}) tuple(dot.6)
ROOT get-tuple-element.11 = bf16[2,16,48]{2,1,0} get-tuple-element(tuple.10), index=0, sharding={devices=[1,4,1]<=[4]}
} // main.12
)";

// Disable the dot merger pass which can prevent the creation of FP8 GEMM
// Custom Calls.
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/true);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherMultiConsumerF8) {
absl::string_view kModuleReplicatedStr = R"(
Expand Down Expand Up @@ -1071,15 +1047,15 @@ while_body {
r = bf16[32,128] bitcast(dynamic-slice.k)
a = bf16[32,128] add(r, r), control-predecessors={constant.2559}
// A fp8 pattern of quant-dequant before the collective AG.
qa = f8e4m3fn[32,128] convert(a)
qa = <<F8E4M3>>[32,128] convert(a)
dqa = bf16[32,128] convert(qa)
a_scale = bf16[] get-tuple-element(param), index=3
a_scales = bf16[32,128] broadcast(a_scale), dimensions={}
dqa_unscaled = bf16[32,128] multiply(dqa, a_scales)
mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}}
ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128}
qma = f8e4m3fn[128,128] convert(ma)
qma = <<F8E4M3>>[128,128] convert(ma)
dqma = bf16[128,128] convert(qma)
ma_scale = bf16[] get-tuple-element(param), index=4
ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={}
Expand Down Expand Up @@ -1112,101 +1088,6 @@ ENTRY entry {
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
}

// TODO: Refactor the test to reduce the duplicate code for OCP fp8 and Nanoo fp8
TEST_F(CollectiveOpsTestE2E, PostLayoutCollectivePipelinerNanoo) {
// We need fp8 support to test the post-layout collective pipeliner. This will
// preserve the desired fp8 patterns and so the gemm rewriter can correctly
// recognize them and rewrite to custom fp8 gemm calls.
if (!HasFp8Support()) {
GTEST_SKIP() << "Test requires a post-Ada GPU.";
}

absl::string_view kModuleReplicatedStr = R"(
HloModule module, entry_computation_layout={(bf16[384,128], bf16[96,128], bf16[], bf16[])->bf16[384,128]}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[384,128] get-tuple-element(param), index=1
get-tuple-element.k = bf16[96,128] get-tuple-element(param), index=2
constant.2561 = s32[] constant(0)
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.k = bf16[32,128] dynamic-slice(get-tuple-element.k, select.1348, constant.2561), dynamic_slice_sizes={32,128}
r = bf16[32,128] bitcast(dynamic-slice.k)
a = bf16[32,128] add(r, r), control-predecessors={constant.2559}
// A fp8 pattern of quant-dequant before the collective AG.
qa = f8e4m3fnuz[32,128] convert(a)
dqa = bf16[32,128] convert(qa)
a_scale = bf16[] get-tuple-element(param), index=3
a_scales = bf16[32,128] broadcast(a_scale), dimensions={}
dqa_unscaled = bf16[32,128] multiply(dqa, a_scales)
mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}}
ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128}
qma = f8e4m3fnuz[128,128] convert(ma)
dqma = bf16[128,128] convert(qma)
ma_scale = bf16[] get-tuple-element(param), index=4
ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={}
dqma_unscaled = bf16[128,128] multiply(dqma, ma_scales)
mc = bf16[128,128] dot(dqma_unscaled, mb), lhs_contracting_dims={1}, rhs_contracting_dims={0}
dynamic-update-slice.35 = bf16[384,128] dynamic-update-slice(get-tuple-element.395, mc, select.1348, constant.2561)
ROOT tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k, a_scale, ma_scale), control-predecessors={a}
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[384,128] parameter(0)
p1 = bf16[96,128] parameter(1)
s0 = bf16[] parameter(2)
s1 = bf16[] parameter(3)
tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(c0, p0, p1, s0, s1)
while = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[384,128] get-tuple-element(while), index=1
}
)";

const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;

HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_run_post_layout_collective_pipeliner(true);
opts.set_xla_gpu_enable_pipelined_collectives(true);
opts.set_xla_gpu_enable_triton_gemm(false);
config.set_debug_options(opts);
config.set_num_partitions(kNumPartitions);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config));

TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(std::move(module),
/*run_hlo_passes=*/true));
EXPECT_TRUE(executable->has_module());
HloInstruction* gemm_op =
FindInstruction(&executable->module(), HloOpcode::kCustomCall);
EXPECT_THAT(gemm_op, NotNull());
EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8");
}

TEST_F(CollectiveOpsTestE2E,
PostLayoutCollectivePipelinerShouldFlattenCallGraph) {
// The allgather in the loop has a nested while loop as its operand,
Expand Down

0 comments on commit 6f8c418

Please sign in to comment.