Skip to content

Commit

Permalink
[mlir][memref] Fix emulate narrow types for strided memref offset (ll…
Browse files Browse the repository at this point in the history
…vm#68181)

This patch fixes strided memref offset calculation for emulating narrow
types.

As a side effect, this patch also adds support for a 1-D subviews with
static sizes, static offsets and strides of 1 for testing. Emulate
narrow types pass was not tested for strided memrefs before this patch.
  • Loading branch information
Groverkss authored and stellaraccident committed Oct 6, 2023
1 parent 0bc2395 commit 085c12f
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 5 deletions.
94 changes: 89 additions & 5 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
Expand Down Expand Up @@ -209,6 +210,76 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//

/// Emulating narrow ints on subview have limited support, supporting only
/// static offset and size and stride of 1. Ideally, the subview should be
/// folded away before running narrow type emulation, and this pattern would
/// never run. This pattern is mostly used for testing pruposes.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType newTy =
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}

auto convertedElementType = newTy.getElementType();
auto oldElementType = op.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}

// Only support offset for 1-D subview.
if (op.getType().getRank() != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 1 is not supported");
}

// Only support stride of 1.
if (op.getStaticStride(0) != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with stride != 1 is not supported");
}

int64_t size = op.getStaticSize(0);
int64_t offset = op.getStaticOffset(0);
// Only support static sizes and offsets.
if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with dynamic size or offset is not supported");
}

int elementsPerByte = dstBits / srcBits;
if (offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(),
"subview with offset not multiple of elementsPerByte is not "
"supported");
}

size = ceilDiv(size, elementsPerByte);
offset = offset / elementsPerByte;

rewriter.replaceOpWithNewOp<memref::SubViewOp>(
op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
op.getStaticStrides());
return success();
}
};

} // end anonymous namespace

//===----------------------------------------------------------------------===//
Expand All @@ -220,9 +291,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {

// Populate `memref.*` conversion patterns.
patterns
.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
typeConverter, patterns.getContext());
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}

Expand Down Expand Up @@ -271,9 +342,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
return std::nullopt;

StridedLayoutAttr layoutAttr;
// If the offset is 0, we do not need a strided layout as the stride is
// 1, so we only use the strided layout if the offset is not 0.
if (offset != 0) {
layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
ArrayRef<int64_t>{1});
if (offset == ShapedType::kDynamic) {
layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
ArrayRef<int64_t>{1});
} else {
// Check if the number of bytes are a multiple of the loadStoreWidth
// and if so, divide it by the loadStoreWidth to get the offset.
if ((offset * width) % loadStoreWidth != 0)
return std::nullopt;
offset = (offset * width) / loadStoreWidth;

layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
ArrayRef<int64_t>{1});
}
}

return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,22 @@ func.func @rank_zero_memref() -> i4 {
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i32>
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
// CHECK32: return %[[TRUNC]]

// -----

func.func @memref_strided_i4(%idx : index) -> i4 {
%arr = memref.alloc() : memref<128xi4>
%subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset:32>>
%1 = memref.load %subview[%idx] : memref<32xi4, strided<[1], offset:32>>
return %1 : i4
}

// CHECK-LABEL: func @memref_strided_i4
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: 16>>
// CHECK: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]

// CHECK32-LABEL: func @memref_strided_i4
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]

0 comments on commit 085c12f

Please sign in to comment.