diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 2a524ceb9db887..9f58e9055acadb 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" @@ -209,6 +210,76 @@ struct ConvertMemRefLoad final : OpConversionPattern { return success(); } }; + +//===----------------------------------------------------------------------===// +// ConvertMemRefSubview +//===----------------------------------------------------------------------===// + +/// Emulating narrow ints on subview have limited support, supporting only +/// static offset and size and stride of 1. Ideally, the subview should be +/// folded away before running narrow type emulation, and this pattern would +/// never run. This pattern is mostly used for testing pruposes. +struct ConvertMemRefSubview final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefType newTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert memref type: {0}", op.getType())); + } + + auto convertedElementType = newTy.getElementType(); + auto oldElementType = op.getType().getElementType(); + int srcBits = oldElementType.getIntOrFloatBitWidth(); + int dstBits = convertedElementType.getIntOrFloatBitWidth(); + if (dstBits % srcBits != 0) { + return rewriter.notifyMatchFailure( + op, "only dstBits % srcBits == 0 supported"); + } + + // Only support offset for 1-D subview. + if (op.getType().getRank() != 1) { + return rewriter.notifyMatchFailure( + op->getLoc(), "subview with rank > 1 is not supported"); + } + + // Only support stride of 1. + if (op.getStaticStride(0) != 1) { + return rewriter.notifyMatchFailure( + op->getLoc(), "subview with stride != 1 is not supported"); + } + + int64_t size = op.getStaticSize(0); + int64_t offset = op.getStaticOffset(0); + // Only support static sizes and offsets. + if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) { + return rewriter.notifyMatchFailure( + op->getLoc(), "subview with dynamic size or offset is not supported"); + } + + int elementsPerByte = dstBits / srcBits; + if (offset % elementsPerByte != 0) { + return rewriter.notifyMatchFailure( + op->getLoc(), + "subview with offset not multiple of elementsPerByte is not " + "supported"); + } + + size = ceilDiv(size, elementsPerByte); + offset = offset / elementsPerByte; + + rewriter.replaceOpWithNewOp( + op, newTy, *adaptor.getODSOperands(0).begin(), offset, size, + op.getStaticStrides()); + return success(); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -220,9 +291,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( RewritePatternSet &patterns) { // Populate `memref.*` conversion patterns. - patterns - .add( - typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); memref::populateResolveExtractStridedMetadataPatterns(patterns); } @@ -271,9 +342,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions( return std::nullopt; StridedLayoutAttr layoutAttr; + // If the offset is 0, we do not need a strided layout as the stride is + // 1, so we only use the strided layout if the offset is not 0. if (offset != 0) { - layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, - ArrayRef{1}); + if (offset == ShapedType::kDynamic) { + layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, + ArrayRef{1}); + } else { + // Check if the number of bytes are a multiple of the loadStoreWidth + // and if so, divide it by the loadStoreWidth to get the offset. + if ((offset * width) % loadStoreWidth != 0) + return std::nullopt; + offset = (offset * width) / loadStoreWidth; + + layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, + ArrayRef{1}); + } } return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth), diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index c0050d8c510d53..6ed97f05aa7cff 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -155,3 +155,22 @@ func.func @rank_zero_memref() -> i4 { // CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref // CHECK32: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4 // CHECK32: return %[[TRUNC]] + +// ----- + +func.func @memref_strided_i4(%idx : index) -> i4 { + %arr = memref.alloc() : memref<128xi4> + %subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset:32>> + %1 = memref.load %subview[%idx] : memref<32xi4, strided<[1], offset:32>> + return %1 : i4 +} + +// CHECK-LABEL: func @memref_strided_i4 +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: 16>> +// CHECK: %[[LOAD:.+]] = memref.load %[[SUBVIEW]] + +// CHECK32-LABEL: func @memref_strided_i4 +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32> +// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>> +// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]