Skip to content

Commit

Permalink
PR tensorflow#16938: Add NANOO FP8 support for collaborative communic…
Browse files Browse the repository at this point in the history
…ation unit tests

Imported from GitHub PR openxla/xla#16938

This PR adds support for NANOO FP8 data format in the collaborative communication unit tests.
- For the context on OCP FP8 and NANOO FP8, please refer to this comment:
google/flax#3993 (comment)
- The unit tests in this PR are similar to GEMM unit test introduced in the following PR to be able to deal with both OCP and NANOO fp8 formats:
openxla/xla#10488
Copybara import of the project:

--
0fc74ccae6cfcaf4e8627ea338ee03783af0626b by Wen Chen <Wen.Chen@amd.com>:

[AMD] Added NCCL support for fp8e4m3fnuz and fp8e5m2fnuz.

--
d247af5cd33fe42698bb55ef1c18f32df8a02a21 by scxfjiang <sc.xfjiang@gmail.com>:

refactor tests for collective comm ops

--
6f8c418b3052f7c531896bd5f8cbbc7a766ef7fc by scxfjiang <sc.xfjiang@gmail.com>:

rafactor collective comm e2e tests

--
8ecb6ecf08a1536c5b3f8ba87e0e9f8813b1b359 by scxfjiang <sc.xfjiang@gmail.com>:

update: replace str

--
338d3af2ca1a32302fdfe9d7abee335d24539ee9 by scxfjiang <sc.xfjiang@gmail.com>:

get rid of macros

Merging this change closes tensorflow#16938

PiperOrigin-RevId: 676615012
  • Loading branch information
ScXfjiang committed Sep 20, 2024
1 parent 64ba48a commit e31d7b6
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 71 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ static absl::StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType dtype,
case S8:
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return ncclInt8;
case PRED:
case U8:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type,
// they involve actual computation and not just data movement.
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return !IsReductionCollective(reduction_op);
default:
return false;
Expand Down
176 changes: 105 additions & 71 deletions third_party/xla/xla/tests/collective_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1702,77 +1702,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) {
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[1,2] constant({{1,2}})
allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0}
p = f8e4m3fn[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[2] constant({1,2})
a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e5m2[2] constant({1,2})
a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) {
const char* const kModuleStr = R"(
HloModule test
Expand Down Expand Up @@ -2174,5 +2103,110 @@ body {
results[1]));
}

class Fp8CollectiveOpsTest : public CollectiveOpsTest {
public:
Fp8CollectiveOpsTest() {
replacements_[kF8E4M3DatatypePlaceholder] =
IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}

protected:
bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}

const se::GpuComputeCapability& Capability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[1,2] constant({{1,2}})
allgather = <<F8E4M3>>[2, 2] all-gather(a0), dimensions={0}
p = <<F8E4M3>>[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[2] constant({1,2})
a2a = <<F8E4M3>>[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E5M2>>[2] constant({1,2})
a1 = <<F8E5M2>>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

} // namespace
} // namespace xla

0 comments on commit e31d7b6

Please sign in to comment.