Skip to content

Commit b085142

Browse files
wenscarltensorflower-gardener
authored andcommitted
PR tensorflow#6599: Fp8 Fast Accumulation support for cublasLt
Imported from GitHub PR openxla/xla#6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue#openxla/xla#6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da8ca08cd2d4796a7b8f032827867a361bc by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 96845683cc4b1e7b947bc919fbf97d8865abeac9 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d7620780d2cf1fe8433c933648dcb98dc61d by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes tensorflow#6599 PiperOrigin-RevId: 578948593
1 parent f2bed49 commit b085142

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6940,6 +6940,38 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
69406940
)");
69416941
}
69426942

6943+
TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
6944+
#if CUDA_VERSION < 12000
6945+
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
6946+
#endif // CUDA_VERSION < 12000
6947+
const char* hlo_template = R"(
6948+
HloModule test
6949+
6950+
ENTRY test {
6951+
x = f8e4m3fn[1600,3200] parameter(0)
6952+
y = f8e4m3fn[3200,1600] parameter(1)
6953+
x_f32 = f32[1600,3200] convert(x)
6954+
y_f32 = f32[3200,1600] convert(y)
6955+
x_scale = f32[] parameter(2)
6956+
y_scale = f32[] parameter(3)
6957+
x_scale_bcast = f32[1600,3200] broadcast(x_scale), dimensions={}
6958+
y_scale_bcast = f32[3200,1600] broadcast(y_scale), dimensions={}
6959+
x_unscaled = f32[1600,3200] multiply(x_f32, x_scale_bcast)
6960+
y_unscaled = f32[3200,1600] multiply(y_f32, y_scale_bcast)
6961+
ROOT out = f32[1600,1600] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={<<precision>>,<<precision>>}
6962+
}
6963+
)";
6964+
6965+
absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
6966+
replacements["<<precision>>"] = "default";
6967+
const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements);
6968+
EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3}));
6969+
6970+
replacements["<<precision>>"] = "highest";
6971+
const auto hlo_text_highest = absl::StrReplaceAll(hlo_template, replacements);
6972+
EXPECT_TRUE(RunAndCompare(hlo_text_highest, ErrorSpec{1e-4, 1e-4}));
6973+
}
6974+
69436975
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
69446976
#if CUDA_VERSION < 12000
69456977
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";

third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ cudaDataType_t BlasLt::MatrixLayout::type() const {
191191
/*static*/ tsl::StatusOr<BlasLt::MatmulDesc> BlasLt::MatmulDesc::Create(
192192
blas::ComputationType compute_type, blas::DataType scale_type,
193193
blas::Transpose trans_a, blas::Transpose trans_b,
194-
gpu::BlasLt::Epilogue epilogue, PointerMode pointer_mode) {
194+
gpu::BlasLt::Epilogue epilogue, bool enable_fast_accum,
195+
PointerMode pointer_mode) {
195196
VLOG(2) << "MatmulDesc::Create: compute_type: " << (int)compute_type
196197
<< " scale:" << (int)scale_type << " trans a/b: " << (int)trans_a
197198
<< "," << (int)trans_b << " epilogue:" << (int)epilogue
@@ -210,6 +211,8 @@ cudaDataType_t BlasLt::MatrixLayout::type() const {
210211
AsCublasOperation(trans_b)));
211212
TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue));
212213
TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi));
214+
TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
215+
static_cast<int8_t>(enable_fast_accum)));
213216
return std::move(desc);
214217
}
215218

@@ -315,11 +318,17 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg,
315318
cfg.compute_precision));
316319
}
317320

321+
// FP8 matmuls have a fast accumulation mode that is less precise than the
322+
// default accumulation mode. Use the fast accumulation mode if the compute
323+
// precision is DEFAULT.
324+
bool enable_fast_accum = (xla::primitive_util::IsF8Type(lhs_layout.dtype) ||
325+
xla::primitive_util::IsF8Type(rhs_layout.dtype)) &&
326+
cfg.compute_precision == 0;
318327
TF_ASSIGN_OR_RETURN(
319328
auto op_desc,
320329
MatmulDesc::Create(*compute_type,
321330
gpu::GetScaleType(output_dtype, *compute_type),
322-
trans_a, trans_b, epilogue));
331+
trans_a, trans_b, epilogue, enable_fast_accum));
323332

324333
TF_ASSIGN_OR_RETURN(auto a_desc, MatrixLayout::Create(lhs_layout));
325334
TF_ASSIGN_OR_RETURN(auto b_desc, MatrixLayout::Create(rhs_layout));

third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class BlasLt : public gpu::BlasLt {
7070
blas::ComputationType compute_type, blas::DataType scale_type,
7171
blas::Transpose trans_a = blas::Transpose::kNoTranspose,
7272
blas::Transpose trans_b = blas::Transpose::kNoTranspose,
73-
Epilogue epilogue = Epilogue::kDefault,
73+
Epilogue epilogue = Epilogue::kDefault, bool enable_fast_accum = false,
7474
PointerMode pointer_mode = PointerMode::kHost);
7575

7676
cublasComputeType_t compute_type() const;

0 commit comments

Comments
 (0)