Skip to content

Commit

Permalink
Improve based on review #2
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Nov 1, 2023
1 parent 9684568 commit e906d76
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
33 changes: 33 additions & 0 deletions xla/service/gpu/tests/gemm_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6862,6 +6862,39 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
#if CUDA_VERSION < 12000
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
#endif // CUDA_VERSION < 12000
const char* hlo_template = R"(
HloModule test
ENTRY test {
x = f8e4m3fn[1600,3200] parameter(0)
y = f8e4m3fn[3200,1600] parameter(1)
x_f32 = f32[1600,3200] convert(x)
y_f32 = f32[3200,1600] convert(y)
x_scale = f32[] parameter(2)
y_scale = f32[] parameter(3)
x_scale_bcast = f32[1600,3200] broadcast(x_scale), dimensions={}
y_scale_bcast = f32[3200,1600] broadcast(y_scale), dimensions={}
x_unscaled = f32[1600,3200] multiply(x_f32, x_scale_bcast)
y_unscaled = f32[3200,1600] multiply(y_f32, y_scale_bcast)
ROOT out = f32[1600,1600] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={<<precision>>,<<precision>>}
}
)";

absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
replacements["<<precision>>"] = "default";
const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements);
EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3}));
EXPECT_FALSE(RunAndCompare(hlo_text_default, ErrorSpec{1e-4, 1e-4}));

replacements["<<precision>>"] = "highest";
const auto hlo_text_highest = absl::StrReplaceAll(hlo_template, replacements);
EXPECT_TRUE(RunAndCompare(hlo_text_highest, ErrorSpec{1e-4, 1e-4}));
}

TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
#if CUDA_VERSION < 12000
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
Expand Down
6 changes: 3 additions & 3 deletions xla/stream_executor/cuda/cuda_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg,
cfg.compute_precision));
}

// For FP8 matmuls, there are two options available: fast
// accumulation(PrecisionConfig.Precision.DEFAULT) and
// higher precision accumulation (PrecisionConfig.Precision.HIGHEST).
// FP8 matmuls have a fast accumulation mode that is less precise than the
// default accumulation mode. Use the fast accumulation mode if the compute
// precision is DEFAULT.
bool enable_fast_accum = (xla::primitive_util::IsF8Type(lhs_layout.dtype) ||
xla::primitive_util::IsF8Type(rhs_layout.dtype)) &&
cfg.compute_precision == 0;
Expand Down

0 comments on commit e906d76

Please sign in to comment.