From 2b2766535f4468a0bfdef14450d811f66681f5ae Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 20 Dec 2024 11:26:19 -0800 Subject: [PATCH] [Mosaic TPU] Validate inserted layout in relayout-insertion pass. PiperOrigin-RevId: 708378418 --- .../mosaic/dialect/tpu/transforms/relayout_insertion.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc index a4b58676a6ef..2ae63a3af63b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc @@ -57,6 +57,9 @@ FailureOr> relayout( auto src_int_vty = make_vty(src.bitwidth()); auto dst_int_vty = make_vty(dst.bitwidth()); auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling()); + // TODO(jevinjiang): Since dst_bitwidth_layout will be firstly used in the + // extSI or truncI below, we can reuse the infer functions from + // infer-vector-layout pass. auto dst_bitwidth_layout = VectorLayout( dst.bitwidth(), { @@ -66,6 +69,12 @@ FailureOr> relayout( : LayoutOffset(), }, src.tiling(), src.implicit_dim()); + if (!dst_bitwidth_layout.isValid(target_shape)) { + return emitError(v.getLoc(), + "Not implemented: failed to infer valid layout during " + "relayout, got ") + << dst_bitwidth_layout; + } auto ext_op = builder.create(v.getLoc(), src_int_vty, v); setLayout(ext_op, src, src);