Skip to content

Commit

Permalink
[xls][mlir] Adds support for ArrayUpdateSliceOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699915048
  • Loading branch information
dimitriv authored and copybara-github committed Nov 25, 2024
1 parent 36ec15f commit fc06ae6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
10 changes: 10 additions & 0 deletions xls/contrib/mlir/testdata/array_to_bits.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,13 @@ func.func @array_concat(%arg0: !xls.array<2 x i32>, %arg1: !xls.array<2 x i32>)
%0 = "xls.array_concat"(%arg0, %arg1) : (!xls.array<2 x i32>, !xls.array<2 x i32>) -> !xls.array<4 x i32>
return %0 : !xls.array<4 x i32>
}

// CHECK-LABEL: func.func @array_update_slice(
// CHECK-SAME: %[[VAL_0:.*]]: i128,
// CHECK-SAME: %[[VAL_1:.*]]: i64,
// CHECK-SAME: %[[VAL_2:.*]]: i32) -> i128 attributes {xls = true} {
// CHECK: %[[VAL_3:.*]] = "xls.bit_slice_update"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (i128, i64, i32) -> i128
func.func @array_update_slice(%arg0: !xls.array<4 x i32>, %arg1: !xls.array<2 x i32>, %arg2: i32) -> !xls.array<4 x i32> attributes {xls = true} {
%0 = xls.array_update_slice %arg1 into %arg0[%arg2 +: 2] : !xls.array<4 x i32>
return %0 : !xls.array<4 x i32>
}
20 changes: 18 additions & 2 deletions xls/contrib/mlir/transforms/array_to_bits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,21 @@ class LegalizeArraySlicePattern : public OpConversionPattern<ArraySliceOp> {
}
};

class LegalizeArrayUpdateSlicePattern
: public OpConversionPattern<ArrayUpdateSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
ArrayUpdateSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
(void)adaptor;
Value slice = CoerceFloats({adaptor.getSlice()}, rewriter, op)[0];
rewriter.replaceOpWithNewOp<BitSliceUpdateOp>(
op, adaptor.getArray().getType(), adaptor.getArray(), slice,
adaptor.getStart());
return success();
}
};

class LegalizeArrayIndexPattern : public OpConversionPattern<ArrayIndexOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -328,8 +343,8 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
all_of(op->getResultTypes(), is_legal);
});
target.addIllegalOp<VectorizedCallOp, ArrayOp, ArrayUpdateOp, ArraySliceOp,
ArrayIndexOp, ArrayIndexStaticOp, ArrayZeroOp,
ArrayConcatOp>();
ArrayUpdateSliceOp, ArrayIndexOp, ArrayIndexStaticOp,
ArrayZeroOp, ArrayConcatOp>();
RewritePatternSet chanPatterns(&getContext());
chanPatterns.add<LegalizeChanOpPattern>(typeConverter, &getContext());
FrozenRewritePatternSet frozenChanPatterns(std::move(chanPatterns));
Expand All @@ -341,6 +356,7 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
LegalizeArrayPattern,
LegalizeArrayUpdatePattern,
LegalizeArraySlicePattern,
LegalizeArrayUpdateSlicePattern,
LegalizeArrayIndexPattern,
LegalizeArrayIndexStaticPattern,
LegalizeArrayZeroPattern,
Expand Down

0 comments on commit fc06ae6

Please sign in to comment.