Skip to content

Commit

Permalink
add tests to OSS
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk committed May 25, 2023
1 parent 108caef commit c7d4862
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions xla/service/gpu/tests/gemm_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,90 @@ ENTRY test {
GmockMatch(m::CustomCall({"__cublas$gemm"})));
}

TEST_P(ParameterizedGemmRewriteTest, SupportedMixTypeGemm) {
const char* hlo_text = R"(
HloModule test
ENTRY main {
param_0 = f16[240,88]{1,0} parameter(0)
param_1 = f16[88,4]{1,0} parameter(1)
dot = f16[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
constant_255 = f16[] constant(255)
broadcast = f16[240,4]{1,0} broadcast(constant_255), dimensions={}
multiply = f16[240,4]{1,0} multiply(dot, broadcast)
ROOT result = f32[240,4]{1,0} convert(multiply)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass(GetCudaComputeCapability());
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);

// input fp16 and output fp32 combination is supported by legacy cublas and
// cublasLt, expect GemmRewriter to fuse the convert into gemm.
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::CustomCall({CustomCallTarget()})));
}

TEST_P(ParameterizedGemmRewriteTest, UnsupportedMixTypeGemm) {
const char* hlo_text = R"(
HloModule test
ENTRY main {
param_0 = f32[240,88]{1,0} parameter(0)
param_1 = f32[88,4]{1,0} parameter(1)
dot = f32[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
constant_255 = f32[] constant(255)
broadcast = f32[240,4]{1,0} broadcast(constant_255), dimensions={}
multiply = f32[240,4]{1,0} multiply(dot, broadcast)
ROOT result = u8[240,4]{1,0} convert(multiply)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass(GetCudaComputeCapability());
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);

// u8 is not supported by legacy cublas and cublasLt, expect
// GemmRewriter to not fuse the convert into gemm.
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Convert(m::CustomCall({CustomCallTarget()}))));
}

TEST_P(ParameterizedGemmRewriteTest, CheckIsGemmAliasedBeforeFusion) {
const char* hlo_text = R"(
HloModule test
ENTRY main {
Arg_0.1 = f16[8,16]{1,0} parameter(0)
Arg_1.2 = f16[16,32]{1,0} parameter(1)
dot.8 = f16[8,32]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
Arg_2.3 = f16[8,32]{1,0} parameter(2)
constant.5 = f16[] constant(1)
broadcast.6 = f16[8,32]{1,0} broadcast(constant.5), dimensions={}
add.7 = f16[8,32]{1,0} add(Arg_2.3, broadcast.6)
add.9 = f16[8,32]{1,0} add(dot.8, add.7)
convert.10 = f32[8,32]{1,0} convert(add.9)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass(GetCudaComputeCapability());
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);

// input fp16 and output fp32 combination is supported by legacy cublas and
// cublasLt, but gemm output is already aliased with one of the input expect
// GemmRewriter to not fuse the convert into gemm.
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Convert(m::CustomCall({CustomCallTarget()}))));
}

INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt,
ParameterizedGemmRewriteTest, ::testing::Bool());

Expand Down

0 comments on commit c7d4862

Please sign in to comment.