Skip to content

Commit

Permalink
Merge pull request #45 from ROCm/rocm-jaxlib-v0.4.30-qa_collective_fp8
Browse files Browse the repository at this point in the history
Add NANOO FP8 support for collaborative communication unit tests
  • Loading branch information
ScXfjiang authored Sep 24, 2024
2 parents ed82401 + 8353dcd commit a42d9cd
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 76 deletions.
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ static absl::StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType dtype,
case S8:
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return ncclInt8;
case PRED:
case U8:
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/nccl_collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type,
// they involve actual computation and not just data movement.
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return !IsReductionCollective(reduction_op);
default:
return false;
Expand Down
2 changes: 2 additions & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,8 @@ xla_test(
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/service/gpu:backend_configs_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
],
)

Expand Down
176 changes: 105 additions & 71 deletions xla/tests/collective_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1722,77 +1722,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) {
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[1,2] constant({{1,2}})
allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0}
p = f8e4m3fn[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[2] constant({1,2})
a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e5m2[2] constant({1,2})
a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) {
const char* const kModuleStr = R"(
HloModule test
Expand Down Expand Up @@ -2230,5 +2159,110 @@ body {
results[1]));
}

class Fp8CollectiveOpsTest : public CollectiveOpsTest {
public:
Fp8CollectiveOpsTest() {
replacements_[kF8E4M3DatatypePlaceholder] =
IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}

protected:
bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}

const se::GpuComputeCapability& Capability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[1,2] constant({{1,2}})
allgather = <<F8E4M3>>[2, 2] all-gather(a0), dimensions={0}
p = <<F8E4M3>>[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[2] constant({1,2})
a2a = <<F8E4M3>>[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E5M2>>[2] constant({1,2})
a1 = <<F8E5M2>>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

} // namespace
} // namespace xla
37 changes: 32 additions & 5 deletions xla/tests/collective_ops_test_e2e.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand Down Expand Up @@ -47,6 +49,24 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) {

class CollectiveOpsTestE2E : public HloTestBase {
public:
CollectiveOpsTestE2E() {
replacements_[kF8E4M3DatatypePlaceholder] =
IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}

bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}

const se::GpuComputeCapability& Capability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

absl::StatusOr<std::vector<Literal>> ExecuteReplicated(Executable* executable,
int64_t num_replicas) {
DeviceAssignment device_assignment = MakeDeviceAssn(num_replicas);
Expand All @@ -56,6 +76,13 @@ class CollectiveOpsTestE2E : public HloTestBase {
/*argument_provider*/ [](int64_t, int64_t) { return nullptr; },
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
}

protected:
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

// E2E tests for collective ops. These will generally verify some HLO transform
Expand Down Expand Up @@ -740,11 +767,11 @@ ENTRY main.12 {
TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherAndReduceScatterF8) {
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<<F8E4M3>>[2,16,48]{2,1,0}, <<F8E4M3>>[48,192]{1,0}, <<F8E4M3>>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
ENTRY main.12 {
Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_0.1 = <<F8E4M3>>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = <<F8E4M3>>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_2.3 = bf16[] parameter(3)
Arg_3.4 = bf16[] parameter(4)
broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={}
Expand All @@ -763,12 +790,12 @@ ENTRY main.12 {
constant.1 = bf16[] constant(448.)
broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={}
clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4)
convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp)
convert.2 = <<F8E4M3>>[2,16,192]{2,1,0} convert(clamp)
Arg_5.6 = bf16[] parameter(6)
broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={}
convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2)
multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5)
Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_6.7 = <<F8E4M3>>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_7.8 = bf16[] parameter(7)
broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={}
convert.4 = bf16[192,48]{1,0} convert(Arg_6.7)
Expand Down

0 comments on commit a42d9cd

Please sign in to comment.