Skip to content

Commit c569f8e

Browse files
tlongerijax authors
authored andcommitted
[Mosaic] Fix check in C++ for supported (2, 128) -> (8, 128) retiling (to match Python)
PiperOrigin-RevId: 575702590
1 parent 361b430 commit c569f8e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3228,7 +3228,7 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
32283228
dst.offsets() == LayoutOffsets{0, 0} &&
32293229
src.tiling() == std::array<int64_t, 2>{2, 128} &&
32303230
dst.tiling() == std::array<int64_t, 2>{8, 128} &&
3231-
*(src_tiles.dimensions().end() - 2) <= 2) {
3231+
*(src_tiles.dimensions().end() - 2) == 1) {
32323232
xla::Array<Value> src_tiles_retiled(
32333233
dst.tileArrayShape(vty.getShape(), ctx.target_shape));
32343234
src_tiles_retiled.Each([&](const absl::Span<const int64_t> idx,

0 commit comments

Comments
 (0)