Skip to content

Commit

Permalink
[XLA:GPU] Enable symbolic tile analysis for bitcast and reshape.
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=#13301 from Intel-tensorflow:amin/bug-fix-jax 47d5bde
PiperOrigin-RevId: 638276915
  • Loading branch information
olegshyshkov authored and copybara-github committed Jun 4, 2024
1 parent febe8fc commit f237330
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 28 deletions.
8 changes: 8 additions & 0 deletions xla/service/cpu/onednn_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ std::unique_ptr<matmul::primitive_desc> CreateMatMulPrimDesc(
break;
case OneDnnMatMulConfig::BINARY_ADD: {
auto binary_md = fused_mds.at(fused_operand_idx);
// Extend addend rank to match result rank.
auto missed_rank = output_md.get_ndims() - binary_md.get_ndims();
XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0);
if (missed_rank > 0) {
auto binary_dims = binary_md.get_dims();
binary_dims.insert(binary_dims.begin(), missed_rank, 1);
binary_md = binary_md.reshape(binary_dims);
}
if (fused_operands_ref) {
auto arg_idx =
DNNL_ARG_ATTR_MULTIPLE_POST_OP(post_ops.len()) | DNNL_ARG_SRC_1;
Expand Down
4 changes: 0 additions & 4 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ cc_library(
":symbolic_tile",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
Expand Down Expand Up @@ -588,9 +587,6 @@ cc_library(
":indexing_analysis",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/service:name_uniquer",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
Expand Down
2 changes: 0 additions & 2 deletions xla/service/gpu/model/symbolic_tile_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ absl::StatusOr<IndexingMap> ComputeBlockIdToTileOffsetIndexing(
// line. This is not an inherent limitation of the approach, but simply
// issues to be resolved in the current implementation.
if (hlo->opcode() == HloOpcode::kDot ||
hlo->opcode() == HloOpcode::kReshape ||
hlo->opcode() == HloOpcode::kBitcast ||
hlo->opcode() == HloOpcode::kConcatenate) {
return FusionDecision{} << "Bailing out on " << hlo->ToString();
}
Expand Down
127 changes: 105 additions & 22 deletions xla/service/gpu/model/symbolic_tile_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ namespace xla {
namespace gpu {
namespace {

using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::ExplainMatchResult;
using ::testing::Matcher;
using ::testing::Pointee;

MATCHER_P3(MatchTiledHloInstructionImpl, tile_sizes, tile_strides,
block_id_to_tile_offsets_indexing, "") {
Expand All @@ -54,6 +56,19 @@ MATCHER_P3(MatchTiledHloInstructionImpl, tile_sizes, tile_strides,
result_listener);
}

std::vector<const TiledHloInstruction*> GetInstructionsWithName(
const TiledHloComputation& tiled_hlo_computation,
absl::string_view instruction_name) {
std::vector<const TiledHloInstruction*> result;
for (const TiledHloInstruction* instruction :
tiled_hlo_computation.instructions()) {
if (instruction->hlo()->name() == instruction_name) {
result.push_back(instruction);
}
}
return result;
}

Matcher<const TiledHloInstruction> MatchTiledHloInstruction(
absl::Span<const int64_t> tile_sizes,
absl::Span<const int64_t> tile_strides,
Expand Down Expand Up @@ -133,6 +148,96 @@ ENTRY main {
)"));
}

TEST_F(SymbolicTileAnalysisTest,
NormalizationDiamondWithBroadcastAndReshapeIsSupported) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
max {
p1.1 = f32[] parameter(1)
p0.1 = f32[] parameter(0)
ROOT m = f32[] maximum(p0.1, p1.1)
}
ENTRY main {
p0 = f32[48,97] parameter(0)
bitcast = f32[4,12,97] bitcast(p0)
p1 = pred[4,97] parameter(1)
broadcast.1 = pred[4,12,97]{2,1,0} broadcast(p1), dimensions={0,2}
constant.3 = f32[] constant(-2.38197633e+38)
broadcast.5 = f32[4,12,97] broadcast(constant.3), dimensions={}
select = f32[4,12,97] select(broadcast.1, bitcast, broadcast.5)
constant = f32[] constant(-inf)
reduce = f32[4,12] reduce(select, constant), dimensions={2}, to_apply=max
broadcast = f32[4,12,97] broadcast(reduce), dimensions={0,1}
ROOT subtract = f32[4,12,97] subtract(select, broadcast)
})"));

EXPECT_TRUE(SetAnalysis(module.get()));

{
TF_ASSERT_OK_AND_ASSIGN(
TiledHloComputation tiled_hlo_computation,
analysis_->ComputeTiledHloInstructions(/*tile_parameters=*/{1, 1, 97}));

EXPECT_THAT(GetInstructionsWithName(tiled_hlo_computation, "p0"),
ElementsAre(Pointee(MatchTiledHloInstruction(
/*tile_sizes=*/{1, 97}, /*tile_strides=*/{0, 1},
/*block_id_to_tile_offsets_indexing=*/R"(
(d0) -> (d0, 0)
domain:
d0 in [0, 47]
)"))));

EXPECT_THAT(GetInstructionsWithName(tiled_hlo_computation, "p1"),
ElementsAre(Pointee(MatchTiledHloInstruction(
/*tile_sizes=*/{1, 97}, /*tile_strides=*/{1, 1},
/*block_id_to_tile_offsets_indexing=*/R"(
(d0) -> (d0 floordiv 12, 0)
domain:
d0 in [0, 47]
)"))));
}

{
TF_ASSERT_OK_AND_ASSIGN(
TiledHloComputation tiled_hlo_computation,
analysis_->ComputeTiledHloInstructions(/*tile_parameters=*/{1, 2, 10}));

EXPECT_THAT(GetInstructionsWithName(tiled_hlo_computation, "p0"),
ElementsAre(Pointee(MatchTiledHloInstruction(
/*tile_sizes=*/{2, 10}, /*tile_strides=*/{1, 1},
/*block_id_to_tile_offsets_indexing=*/R"(
(d0) -> (((d0 floordiv 10) mod 6) * 2 + (d0 floordiv 60) * 12, (d0 mod 10) * 10)
domain:
d0 in [0, 239]
)")),
Pointee(MatchTiledHloInstruction(
/*tile_sizes=*/{2, 97}, /*tile_strides=*/{1, 1},
/*block_id_to_tile_offsets_indexing=*/R"(
(d0) -> (((d0 floordiv 10) mod 6) * 2 + (d0 floordiv 60) * 12, 0)
domain:
d0 in [0, 239]
)"))));

EXPECT_THAT(GetInstructionsWithName(tiled_hlo_computation, "p1"),
ElementsAre(Pointee(MatchTiledHloInstruction(
/*tile_sizes=*/{1, 10}, /*tile_strides=*/{1, 1},
/*block_id_to_tile_offsets_indexing=*/R"(
(d0) -> (d0 floordiv 60, (d0 mod 10) * 10)
domain: d0 in [0, 239]
)")),
Pointee(MatchTiledHloInstruction(
/*tile_sizes=*/{1, 97}, /*tile_strides=*/{1, 1},
/*block_id_to_tile_offsets_indexing=*/R"(
(d0) -> (d0 floordiv 60, 0)
domain: d0 in [0, 239]
)"))));
}
}

TEST_F(SymbolicTileAnalysisTest, ElementwiseDiamondCSEIsSupported) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
Expand Down Expand Up @@ -252,28 +357,6 @@ ENTRY main {
EXPECT_FALSE(SetAnalysis(module.get()));
}

TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedReshape) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
ENTRY main {
p0 = f32[1,2]{1,0} parameter(0)
ROOT reshape = f32[2] reshape(p0)
})"));

EXPECT_FALSE(SetAnalysis(module.get()));
}

TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedBitcast) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
ENTRY main {
p0 = f32[1,2]{1,0} parameter(0)
ROOT bitcast = f32[2] bitcast(p0)
})"));

EXPECT_FALSE(SetAnalysis(module.get()));
}

TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedConcatenate) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
Expand Down
46 changes: 46 additions & 0 deletions xla/tests/onednn_matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,52 @@ TEST_F(MatmulTest, WeightsPrepackAndScratch) {
)");
}

TEST_F(MatmulTest, ConsecutiveBinaryAdd) {
const char* matmul_module_str = R"(
HloModule matmul.test.f32
region_0.22 {
Arg_0.23 = f32[] parameter(0)
Arg_1.24 = f32[] parameter(1)
ROOT add.25 = f32[] add(Arg_0.23, Arg_1.24)
}
region_1.29 {
Arg_0.30 = f32[] parameter(0)
Arg_1.31 = f32[] parameter(1)
ROOT add.32 = f32[] add(Arg_0.30, Arg_1.31)
}
ENTRY main {
constant.2 = f32[] constant(1e-06)
broadcast.3 = f32[1000000] broadcast(constant.2), dimensions={}
constant.7 = f32[] constant(1)
broadcast.8 = f32[1000000,3] broadcast(constant.7), dimensions={}
Arg_0.1 = f32[3] parameter(0)
reshape.10 = f32[1,3] reshape(Arg_0.1)
broadcast.11 = f32[1,3] broadcast(reshape.10), dimensions={0,1}
reshape.12 = f32[3] reshape(broadcast.11)
broadcast.13 = f32[1000000,3] broadcast(reshape.12), dimensions={1}
subtract.14 = f32[1000000,3] subtract(broadcast.8, broadcast.13)
constant.4 = f32[] constant(0)
broadcast.5 = f32[3,3] broadcast(constant.4), dimensions={}
dot.15 = f32[1000000,3] dot(subtract.14, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
dot.16 = f32[1000000,3] dot(broadcast.3, dot.15), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}
dot.17 = f32[1000000,3] dot(broadcast.3, subtract.14), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}
dot.18 = f32[1000000,3] dot(dot.17, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={1}
add.19 = f32[1000000,3] add(dot.16, dot.18)
constant.9 = f32[3] constant({1, 2, 3})
dot.20 = f32[1000000,3] dot(broadcast.3, constant.9), lhs_contracting_dims={}, rhs_contracting_dims={}
add.21 = f32[1000000,3] add(add.19, dot.20)
constant.6 = f32[] constant(0)
reduce.26 = f32[3] reduce(add.21, constant.6), dimensions={0}, to_apply=region_0.22
reshape.27 = f32[1,3] reshape(reduce.26)
negate.28 = f32[1,3] negate(reshape.27)
ROOT reduce.33 = f32[3] reduce(negate.28, constant.6), dimensions={0}, to_apply=region_1.29
})";

EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4}));
}

} // namespace cpu
} // namespace xla

Expand Down

0 comments on commit f237330

Please sign in to comment.