@@ -46,7 +46,6 @@ limitations under the License.
4646#include " xla/stream_executor/dnn.h"
4747#include " xla/stream_executor/stream_executor.h"
4848#include " xla/stream_executor/stream_executor_memory_allocator.h"
49- #include " xla/tsl/lib/core/status_test_util.h"
5049#include " xla/xla.pb.h"
5150#include " xla/xla_data.pb.h"
5251#include " tsl/platform/env.h"
@@ -59,8 +58,6 @@ namespace xla {
5958namespace gpu {
6059namespace {
6160
62- using ::tsl::testing::IsOkAndHolds;
63-
6461class CuDnnFusionTest : public GpuCodegenTest {
6562 public:
6663 DebugOptions GetDebugOptionsForTest () const override {
@@ -80,12 +77,14 @@ class CuDnnFusionTest : public GpuCodegenTest {
8077 return get_cuda_cc ().IsAtLeastAmpere () &&
8178 GetDnnVersionInfoOrDefault (executor).major_version () >= 9 ;
8279 }
83- bool IsAtLeastCuDnn91 ( ) {
80+ bool IsAtLeastCuDnnVersion ( int major, int minor ) {
8481 se::StreamExecutor* executor = backend ().default_stream_executor ();
8582 const se::dnn::VersionInfo version = GetDnnVersionInfoOrDefault (executor);
86- return (version.major_version () == 9 && version.minor_version () >= 1 ) ||
87- version.major_version () > 9 ;
83+ return (version.major_version () == major &&
84+ version.minor_version () >= minor) ||
85+ version.major_version () > major;
8886 }
87+ bool IsAtLeastCuDnn91 () { return IsAtLeastCuDnnVersion (9 , 1 ); }
8988
9089 protected:
9190 void SetUp () override {
@@ -457,6 +456,29 @@ ENTRY e {
457456 ErrorSpec{/* aabs=*/ 1e-6 , /* arel=*/ 1e-6 }));
458457}
459458
459+ TEST_F (CuDnnFusionExecutionTest, DotS4BF16ExecutesCorrectly) {
460+ if (!IsAtLeastCuDnnVersion (9 , 12 )) {
461+ GTEST_SKIP () << " This test case requires cuDNN 9.12+." ;
462+ }
463+ EXPECT_TRUE (RunAndCompare (R"(
464+ f {
465+ a = s4[3,128,128] parameter(0)
466+ c = bf16[3,128,128] convert(a)
467+ b = bf16[3,128,128] parameter(1)
468+ d = bf16[3,128,128] dot(c, b),
469+ lhs_batch_dims={0}, rhs_batch_dims={0},
470+ lhs_contracting_dims={2}, rhs_contracting_dims={1}
471+ }
472+
473+ e {
474+ a = s4[3,128,128] parameter(0)
475+ b = bf16[3,128,128] parameter(1)
476+ f = bf16[3,128,128] fusion(a, b), kind=kCustom, calls=f,
477+ backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
478+ })" ,
479+ ErrorSpec{/* aabs=*/ 1e-6 , /* arel=*/ 1e-6 }));
480+ }
481+
460482TEST_F (CuDnnFusionExecutionTest, DotF32WithOutputSubtractionExecutesCorrectly) {
461483 EXPECT_TRUE (RunAndCompare (R"(
462484fusion1 {
0 commit comments