Skip to content

Commit

Permalink
[Mosaic TPU] Validate inserted layout in relayout-insertion pass.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708378418
  • Loading branch information
bythew3i authored and Google-ML-Automation committed Dec 20, 2024
1 parent aa386f8 commit 2b27665
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ FailureOr<TypedValue<VectorType>> relayout(
auto src_int_vty = make_vty(src.bitwidth());
auto dst_int_vty = make_vty(dst.bitwidth());
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
// TODO(jevinjiang): Since dst_bitwidth_layout will be firstly used in the
// extSI or truncI below, we can reuse the infer functions from
// infer-vector-layout pass.
auto dst_bitwidth_layout = VectorLayout(
dst.bitwidth(),
{
Expand All @@ -66,6 +69,12 @@ FailureOr<TypedValue<VectorType>> relayout(
: LayoutOffset(),
},
src.tiling(), src.implicit_dim());
if (!dst_bitwidth_layout.isValid(target_shape)) {
return emitError(v.getLoc(),
"Not implemented: failed to infer valid layout during "
"relayout, got ")
<< dst_bitwidth_layout;
}
auto ext_op = builder.create<arith::ExtUIOp>(v.getLoc(), src_int_vty, v);
setLayout(ext_op, src, src);

Expand Down

0 comments on commit 2b27665

Please sign in to comment.