Skip to content

Commit

Permalink
PR #16938: Add NANOO FP8 support for collaborative communication unit…
Browse files Browse the repository at this point in the history
… tests

Imported from GitHub PR #16938

This PR adds support for NANOO FP8 data format in the collaborative communication unit tests.
- For the context on OCP FP8 and NANOO FP8, please refer to this comment:
google/flax#3993 (comment)
- The unit tests in this PR are similar to GEMM unit test introduced in the following PR to be able to deal with both OCP and NANOO fp8 formats:
#10488
Copybara import of the project:

--
0fc74cc by Wen Chen <Wen.Chen@amd.com>:

[AMD] Added NCCL support for fp8e4m3fnuz and fp8e5m2fnuz.

--
d247af5 by scxfjiang <sc.xfjiang@gmail.com>:

refactor tests for collective comm ops

--
6f8c418 by scxfjiang <sc.xfjiang@gmail.com>:

rafactor collective comm e2e tests

--
8ecb6ec by scxfjiang <sc.xfjiang@gmail.com>:

update: replace str

--
338d3af by scxfjiang <sc.xfjiang@gmail.com>:

get rid of macros

Merging this change closes #16938

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16938 from ROCm:ci_dev_rccl_nanoo_fp8 338d3af
PiperOrigin-RevId: 675635116
  • Loading branch information
ScXfjiang authored and Google-ML-Automation committed Sep 18, 2024
1 parent 768b667 commit 00f8f62
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 85 deletions.
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ static absl::StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType dtype,
case S8:
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return ncclInt8;
case PRED:
case U8:
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/nccl_collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type,
// they involve actual computation and not just data movement.
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return !IsReductionCollective(reduction_op);
default:
return false;
Expand Down
40 changes: 29 additions & 11 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <variant>
#include <vector>

#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -54,6 +55,13 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) {

class CollectiveOpsTestE2E : public HloTestBase {
public:
CollectiveOpsTestE2E() {
replacements_[kF8E4M3DatatypePlaceholder] =
IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}

bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}
Expand Down Expand Up @@ -108,6 +116,13 @@ class CollectiveOpsTestE2E : public HloTestBase {
/*argument_provider*/ [](int64_t, int64_t) { return nullptr; },
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
}

protected:
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

// E2E tests for collective ops. These will generally verify some HLO transform
Expand Down Expand Up @@ -811,11 +826,11 @@ ENTRY main.12 {
TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherAndReduceScatterF8) {
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<<F8E4M3>>[2,16,48]{2,1,0}, <<F8E4M3>>[48,192]{1,0}, <<F8E4M3>>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
ENTRY main.12 {
Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_0.1 = <<F8E4M3>>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = <<F8E4M3>>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_2.3 = bf16[] parameter(3)
Arg_3.4 = bf16[] parameter(4)
broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={}
Expand All @@ -834,12 +849,12 @@ ENTRY main.12 {
constant.1 = bf16[] constant(448.)
broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={}
clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4)
convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp)
convert.2 = <<F8E4M3>>[2,16,192]{2,1,0} convert(clamp)
Arg_5.6 = bf16[] parameter(6)
broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={}
convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2)
multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5)
Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_6.7 = <<F8E4M3>>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_7.8 = bf16[] parameter(7)
broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={}
convert.4 = bf16[192,48]{1,0} convert(Arg_6.7)
Expand All @@ -852,8 +867,9 @@ ENTRY main.12 {

// Disable the dot merger pass which can prevent the creation of FP8 GEMM
// Custom Calls.
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/true);
CollectiveOpsCompareWindowedNonWindowed(
absl::StrReplaceAll(kModuleReplicatedStr, replacements_),
/*disable_dot_merger=*/true);

// Verify the creation of FP8 GEMM Custom Calls on Hopper and newer
// architectures.
Expand All @@ -863,7 +879,8 @@ ENTRY main.12 {
opts.set_xla_gpu_graph_min_graph_size(200);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.add_xla_disable_hlo_passes("dot-merger");
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
CollectiveOpsVerifyF8Matmul(
absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
Expand Down Expand Up @@ -1023,15 +1040,15 @@ while_body {
r = bf16[32,128] bitcast(dynamic-slice.k)
a = bf16[32,128] add(r, r), control-predecessors={constant.2559}
// A fp8 pattern of quant-dequant before the collective AG.
qa = f8e4m3fn[32,128] convert(a)
qa = <<F8E4M3>>[32,128] convert(a)
dqa = bf16[32,128] convert(qa)
a_scale = bf16[] get-tuple-element(param), index=3
a_scales = bf16[32,128] broadcast(a_scale), dimensions={}
dqa_unscaled = bf16[32,128] multiply(dqa, a_scales)
mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}}
ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128}
qma = f8e4m3fn[128,128] convert(ma)
qma = <<F8E4M3>>[128,128] convert(ma)
dqma = bf16[128,128] convert(qma)
ma_scale = bf16[] get-tuple-element(param), index=4
ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={}
Expand Down Expand Up @@ -1061,7 +1078,8 @@ ENTRY entry {
opts.set_xla_gpu_run_post_layout_collective_pipeliner(true);
opts.set_xla_gpu_enable_pipelined_collectives(true);
opts.set_xla_gpu_enable_triton_gemm(false);
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
CollectiveOpsVerifyF8Matmul(
absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts);
}

TEST_F(CollectiveOpsTestE2E,
Expand Down
179 changes: 105 additions & 74 deletions xla/tests/collective_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1753,80 +1753,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) {
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[1,2] constant({{1,2}})
allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0}
p = f8e4m3fn[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[2] constant({1,2})
a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e5m2[2] constant({1,2})
a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) {
const char* const kModuleStr = R"(
HloModule test
Expand Down Expand Up @@ -2273,5 +2199,110 @@ body {
results[1]));
}

class Fp8CollectiveOpsTest : public CollectiveOpsTest {
public:
Fp8CollectiveOpsTest() {
replacements_[kF8E4M3DatatypePlaceholder] =
IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}

protected:
bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}

const se::GpuComputeCapability& Capability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[1,2] constant({{1,2}})
allgather = <<F8E4M3>>[2, 2] all-gather(a0), dimensions={0}
p = <<F8E4M3>>[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[2] constant({1,2})
a2a = <<F8E4M3>>[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E5M2>>[2] constant({1,2})
a1 = <<F8E5M2>>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

} // namespace
} // namespace xla

0 comments on commit 00f8f62

Please sign in to comment.