Skip to content

Commit

Permalink
[Mosaic:TPU] Fix fully replicated relayout
Browse files Browse the repository at this point in the history
It was incorrect since batch dims are not replicated

PiperOrigin-RevId: 703189919
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 5, 2024
1 parent 2a4a0e8 commit 23d5c10
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6588,15 +6588,20 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
/*use_implicit_shape=*/true);
}
if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() &&
!src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) {
!src.offsets()[1].has_value()) {
// A fully replicated value is always easy to relayout
// It would be nice to be able to assert this here, but given replicated
// values our rules can introduce equivalent expressions.
// assert all(t is src_tiles_list[0] for t in src_tiles_list)
xla::Array<Value> dst_tiles(
/*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape),
/*value=*/src_tiles.data()[0]);
return assemble_with_mask_check(dst_tiles);
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
SmallVector<int64_t> idxs;
dst_tiles.Each([&](const absl::Span<const int64_t> src_idx, Value *vreg) {
idxs.assign(src_idx.begin(), src_idx.end());
dst.eraseImplicit(idxs);
src.insertImplicit<int64_t>(idxs, 0);
*(idxs.end() - 2) = 0;
*(idxs.end() - 1) = 0;
*vreg = src_tiles(idxs);
});
return assemble_with_mask_check(dst_tiles, /*use_implicit_shape=*/true);
}

// Consider (1,128),-2 -> (8,128). In this case we can change the implicit
Expand Down

0 comments on commit 23d5c10

Please sign in to comment.