diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 18dc1fe9ea0a..cde0b4be8657 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -847,7 +847,7 @@ class VectorLayoutInferer { tile_offsets.push_back(cast(cst_op.getValue()).getInt() % tiling[i]); } else { - if (failed(verifyDivisibleIndex(tiled_index, tiling[0], dim, op))) { + if (failed(verifyDivisibleIndex(tiled_index, tiling[i], dim, op))) { return failure(); } tile_offsets.push_back(0); @@ -1188,7 +1188,7 @@ class VectorLayoutInferer { tile_offsets.push_back(cast(cst_op.getValue()).getInt() % tiling[i]); } else { - if (failed(verifyDivisibleIndex(tiled_index, tiling[0], dim, op))) { + if (failed(verifyDivisibleIndex(tiled_index, tiling[i], dim, op))) { return failure(); } tile_offsets.push_back(0);