Skip to content

Commit b231d21

Browse files
sergachevGoogle-ML-Automation
authored andcommitted
PR #33794: [GPU] Support int4 in cuDNN GEMM fusions.
Imported from GitHub PR #33794 📝 Summary of Changes Support int4 in cuDNN GEMM fusions. 🎯 Justification Accelerates some int4 GEMM fusions (under the flag xla_gpu_cudnn_gemm_fusion_level). 🚀 Kind of Contribution ⚡️ Performance Improvement 📊 Benchmark (for Performance Improvements) > Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. These do not use int4. 🧪 Unit Tests: yes 🧪 Execution Tests: yes Copybara import of the project: -- e1b8dc7 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Support int4 in cuDNN GEMM fusions. Merging this change closes #33794 FUTURE_COPYBARA_INTEGRATE_REVIEW=#33794 from openxla:cudnn_gemm_int4 e1b8dc7 PiperOrigin-RevId: 830894321
1 parent 9da6117 commit b231d21

File tree

4 files changed

+36
-7
lines changed

4 files changed

+36
-7
lines changed

xla/backends/gpu/codegen/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ xla_test(
104104
"//xla/stream_executor:dnn",
105105
"//xla/stream_executor:stream_executor_h",
106106
"//xla/stream_executor:stream_executor_memory_allocator",
107-
"//xla/tsl/lib/core:status_test_util",
108107
"@com_google_absl//absl/status:status_matchers",
109108
"@com_google_absl//absl/status:statusor",
110109
"@com_google_absl//absl/strings",

xla/backends/gpu/codegen/cudnn_test.cc

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ limitations under the License.
4646
#include "xla/stream_executor/dnn.h"
4747
#include "xla/stream_executor/stream_executor.h"
4848
#include "xla/stream_executor/stream_executor_memory_allocator.h"
49-
#include "xla/tsl/lib/core/status_test_util.h"
5049
#include "xla/xla.pb.h"
5150
#include "xla/xla_data.pb.h"
5251
#include "tsl/platform/env.h"
@@ -59,8 +58,6 @@ namespace xla {
5958
namespace gpu {
6059
namespace {
6160

62-
using ::tsl::testing::IsOkAndHolds;
63-
6461
class CuDnnFusionTest : public GpuCodegenTest {
6562
public:
6663
DebugOptions GetDebugOptionsForTest() const override {
@@ -80,12 +77,14 @@ class CuDnnFusionTest : public GpuCodegenTest {
8077
return get_cuda_cc().IsAtLeastAmpere() &&
8178
GetDnnVersionInfoOrDefault(executor).major_version() >= 9;
8279
}
83-
bool IsAtLeastCuDnn91() {
80+
bool IsAtLeastCuDnnVersion(int major, int minor) {
8481
se::StreamExecutor* executor = backend().default_stream_executor();
8582
const se::dnn::VersionInfo version = GetDnnVersionInfoOrDefault(executor);
86-
return (version.major_version() == 9 && version.minor_version() >= 1) ||
87-
version.major_version() > 9;
83+
return (version.major_version() == major &&
84+
version.minor_version() >= minor) ||
85+
version.major_version() > major;
8886
}
87+
bool IsAtLeastCuDnn91() { return IsAtLeastCuDnnVersion(9, 1); }
8988

9089
protected:
9190
void SetUp() override {
@@ -457,6 +456,29 @@ ENTRY e {
457456
ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}));
458457
}
459458

459+
TEST_F(CuDnnFusionExecutionTest, DotS4BF16ExecutesCorrectly) {
460+
if (!IsAtLeastCuDnnVersion(9, 12)) {
461+
GTEST_SKIP() << "This test case requires cuDNN 9.12+.";
462+
}
463+
EXPECT_TRUE(RunAndCompare(R"(
464+
f {
465+
a = s4[3,128,128] parameter(0)
466+
c = bf16[3,128,128] convert(a)
467+
b = bf16[3,128,128] parameter(1)
468+
d = bf16[3,128,128] dot(c, b),
469+
lhs_batch_dims={0}, rhs_batch_dims={0},
470+
lhs_contracting_dims={2}, rhs_contracting_dims={1}
471+
}
472+
473+
e {
474+
a = s4[3,128,128] parameter(0)
475+
b = bf16[3,128,128] parameter(1)
476+
f = bf16[3,128,128] fusion(a, b), kind=kCustom, calls=f,
477+
backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
478+
})",
479+
ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}));
480+
}
481+
460482
TEST_F(CuDnnFusionExecutionTest, DotF32WithOutputSubtractionExecutesCorrectly) {
461483
EXPECT_TRUE(RunAndCompare(R"(
462484
fusion1 {

xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ mlir::stablehlo::DotAlgorithmAttr ConvertDotAlgorithm(
199199
numPrimitiveOperations = 6;
200200
break;
201201
}
202+
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9: {
203+
lhs = rhs = builder->getBF16Type();
204+
accum = builder->getF32Type();
205+
numPrimitiveOperations = 9;
206+
break;
207+
}
202208
case PrecisionConfig::ALG_DOT_TF32_TF32_F32: {
203209
lhs = rhs = builder->getTF32Type();
204210
accum = builder->getF32Type();

xla/service/gpu/transforms/cudnn_fusion_compiler.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ inline std::optional<fe::DataType_t> ToCudnnDataType(const PrimitiveType type) {
149149
return t::BFLOAT16;
150150
case PrimitiveType::S32:
151151
return t::INT32;
152+
case PrimitiveType::S4:
153+
return t::INT4;
152154
case PrimitiveType::S8:
153155
return t::INT8;
154156
case PrimitiveType::PRED:

0 commit comments

Comments
 (0)