Skip to content

Commit e1b8dc7

Browse files
committed
[GPU] Support int4 in cuDNN GEMM fusions.
1 parent b44eae5 commit e1b8dc7

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

xla/backends/gpu/codegen/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ xla_test(
104104
"//xla/stream_executor:dnn",
105105
"//xla/stream_executor:stream_executor_h",
106106
"//xla/stream_executor:stream_executor_memory_allocator",
107-
"//xla/tsl/lib/core:status_test_util",
108107
"@com_google_absl//absl/status:status_matchers",
109108
"@com_google_absl//absl/status:statusor",
110109
"@com_google_absl//absl/strings",

xla/backends/gpu/codegen/cudnn_test.cc

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ limitations under the License.
4646
#include "xla/stream_executor/dnn.h"
4747
#include "xla/stream_executor/stream_executor.h"
4848
#include "xla/stream_executor/stream_executor_memory_allocator.h"
49-
#include "xla/tsl/lib/core/status_test_util.h"
5049
#include "xla/xla.pb.h"
5150
#include "xla/xla_data.pb.h"
5251
#include "tsl/platform/env.h"
@@ -59,8 +58,6 @@ namespace xla {
5958
namespace gpu {
6059
namespace {
6160

62-
using ::tsl::testing::IsOkAndHolds;
63-
6461
class CuDnnFusionTest : public GpuCodegenTest {
6562
public:
6663
DebugOptions GetDebugOptionsForTest() const override {
@@ -80,12 +77,14 @@ class CuDnnFusionTest : public GpuCodegenTest {
8077
return get_cuda_cc().IsAtLeastAmpere() &&
8178
GetDnnVersionInfoOrDefault(executor).major_version() >= 9;
8279
}
83-
bool IsAtLeastCuDnn91() {
80+
bool IsAtLeastCuDnnVersion(int major, int minor) {
8481
se::StreamExecutor* executor = backend().default_stream_executor();
8582
const se::dnn::VersionInfo version = GetDnnVersionInfoOrDefault(executor);
86-
return (version.major_version() == 9 && version.minor_version() >= 1) ||
87-
version.major_version() > 9;
83+
return (version.major_version() == major &&
84+
version.minor_version() >= minor) ||
85+
version.major_version() > major;
8886
}
87+
bool IsAtLeastCuDnn91() { return IsAtLeastCuDnnVersion(9, 1); }
8988

9089
protected:
9190
void SetUp() override {
@@ -457,6 +456,29 @@ ENTRY e {
457456
ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}));
458457
}
459458

459+
TEST_F(CuDnnFusionExecutionTest, DotS4BF16ExecutesCorrectly) {
460+
if (!IsAtLeastCuDnnVersion(9, 12)) {
461+
GTEST_SKIP() << "This test case requires cuDNN 9.12+.";
462+
}
463+
EXPECT_TRUE(RunAndCompare(R"(
464+
f {
465+
a = s4[3,128,128] parameter(0)
466+
c = bf16[3,128,128] convert(a)
467+
b = bf16[3,128,128] parameter(1)
468+
d = bf16[3,128,128] dot(c, b),
469+
lhs_batch_dims={0}, rhs_batch_dims={0},
470+
lhs_contracting_dims={2}, rhs_contracting_dims={1}
471+
}
472+
473+
e {
474+
a = s4[3,128,128] parameter(0)
475+
b = bf16[3,128,128] parameter(1)
476+
f = bf16[3,128,128] fusion(a, b), kind=kCustom, calls=f,
477+
backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
478+
})",
479+
ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}));
480+
}
481+
460482
TEST_F(CuDnnFusionExecutionTest, DotF32WithOutputSubtractionExecutesCorrectly) {
461483
EXPECT_TRUE(RunAndCompare(R"(
462484
fusion1 {

xla/service/gpu/transforms/cudnn_fusion_compiler.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ inline std::optional<fe::DataType_t> ToCudnnDataType(const PrimitiveType type) {
149149
return t::BFLOAT16;
150150
case PrimitiveType::S32:
151151
return t::INT32;
152+
case PrimitiveType::S4:
153+
return t::INT4;
152154
case PrimitiveType::S8:
153155
return t::INT8;
154156
case PrimitiveType::PRED:

0 commit comments

Comments
 (0)