Skip to content

Commit 362cc00

Browse files
akuegeltensorflower-gardener
authored andcommitted
Relax layout change check for scatter.
It is ok that the element size layout attribute does not match. PiperOrigin-RevId: 741469387
1 parent db573bc commit 362cc00

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

third_party/xla/xla/service/hlo_verifier.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2981,7 +2981,8 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
29812981
instruction->opcode() == HloOpcode::kCompare ||
29822982
instruction->opcode() == HloOpcode::kIsFinite ||
29832983
(instruction->opcode() == HloOpcode::kSelect &&
2984-
operand_shape.element_type() == PRED)) {
2984+
operand_shape.element_type() == PRED) ||
2985+
instruction->opcode() == HloOpcode::kScatter) {
29852986
// Some instructions can change element_size_in_bits
29862987
// Select instructions ignore element_size_in_bits for predicate
29872988
equal_predicate.IgnoreElementSize();

third_party/xla/xla/service/hlo_verifier_test.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,41 @@ TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
836836
HasSubstr("Instruction shouldn't change layouts"));
837837
}
838838

839+
TEST_F(HloVerifierTestLayoutSensitive,
840+
ScatterIgnoreElementSizeForLayoutComparison) {
841+
const char* const kScatterHlo = R"(
842+
HloModule module
843+
overwrite {
844+
%p0 = s4[] parameter(0)
845+
ROOT %p1 = s4[] parameter(1)
846+
}
847+
848+
scatter {
849+
%operand = s4[2048, 2048]{1,0:E(4)} parameter(0)
850+
%update = s4[32, 16, 32]{2,1,0} parameter(1)
851+
%iota = s32[8, 4]{1,0} iota(), iota_dimension=0
852+
%indices = s32[32, 1]{1,0} reshape(%iota)
853+
854+
ROOT %scatter = s4[2048, 2048]{1,0:E(4)} scatter(
855+
s4[2048, 2048]{1,0} %operand,
856+
s32[32, 1]{1,0} %indices,
857+
s4[32, 16, 32]{2,1,0} %update
858+
),
859+
update_window_dims={1,2},
860+
inserted_window_dims={},
861+
scatter_dims_to_operand_dims={0},
862+
index_vector_dim=1,
863+
unique_indices=false,
864+
indices_are_sorted=true,
865+
to_apply=overwrite
866+
}
867+
)";
868+
TF_ASSERT_OK_AND_ASSIGN(auto module,
869+
ParseAndReturnUnverifiedModule(kScatterHlo));
870+
auto status = verifier().Run(module.get()).status();
871+
EXPECT_TRUE(status.ok());
872+
}
873+
839874
TEST_F(HloVerifierTestLayoutSensitive, BitcastNeedsSameNumberOfElements) {
840875
const char* const hlo_string = R"(
841876
HloModule Module

0 commit comments

Comments
 (0)