Skip to content

Commit

Permalink
transformations: Add memref.subview lowering to convert-memref-to-ris…
Browse files Browse the repository at this point in the history
…cv (#2847)

Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
  • Loading branch information
superlopuh and zero9178 authored Jul 8, 2024
1 parent e5f0039 commit 3ec488a
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 8 deletions.
73 changes: 73 additions & 0 deletions tests/filecheck/backend/riscv/memref_to_riscv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,79 @@ memref.store %v, %m[%d0] {"nontemporal" = false} : memref<1xi64>

// -----

// CHECK: builtin.module {

// CHECK-NEXT: %v_f64 = "test.op"() : () -> f64
// CHECK-NEXT: %i0, %i1, %offset = "test.op"() : () -> (index, index, index)
// CHECK-NEXT: %original = "test.op"() : () -> memref<4x3x2xf64>
%v_f64 = "test.op"() : () -> f64
%i0, %i1, %offset = "test.op"() : () -> (index, index, index)
%original = "test.op"() : () -> memref<4x3x2xf64>

// CHECK-NEXT: %zero_subview = builtin.unrealized_conversion_cast %original : memref<4x3x2xf64> to memref<3x2xf64>
%zero_subview = memref.subview %original[0, 0, 0][1, 3, 2][1, 1, 1] : memref<4x3x2xf64> to memref<3x2xf64>

// CHECK-NEXT: %static_subview = builtin.unrealized_conversion_cast %original : memref<4x3x2xf64> to !riscv.reg
// CHECK-NEXT: %static_subview_1 = riscv.addi %static_subview, 48 {"comment" = "subview offset"} : (!riscv.reg) -> !riscv.reg
// CHECK-NEXT: %static_subview_2 = builtin.unrealized_conversion_cast %static_subview_1 : !riscv.reg to memref<3x2xf64, strided<[2, 1], offset: 6>>
%static_subview = memref.subview %original[1, 0, 0][1, 3, 2][1, 1, 1] :
memref<4x3x2xf64> to memref<3x2xf64, strided<[2, 1], offset: 6>>

// CHECK-NEXT: %dynamic_subview = builtin.unrealized_conversion_cast %original : memref<4x3x2xf64> to !riscv.reg
// CHECK-NEXT: %subview_dim_index = builtin.unrealized_conversion_cast %offset : index to !riscv.reg
// CHECK-NEXT: %subview_dim_index_1 = riscv.li 0 : !riscv.reg
// CHECK-NEXT: %subview_dim_index_2 = riscv.li 0 : !riscv.reg
// CHECK-NEXT: %pointer_dim_stride = riscv.li 6 : !riscv.reg
// CHECK-NEXT: %pointer_dim_offset = riscv.mul %subview_dim_index, %pointer_dim_stride : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %pointer_dim_stride_1 = riscv.li 2 : !riscv.reg
// CHECK-NEXT: %pointer_dim_offset_1 = riscv.mul %subview_dim_index_1, %pointer_dim_stride_1 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %pointer_offset = riscv.add %pointer_dim_offset, %pointer_dim_offset_1 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %pointer_offset_1 = riscv.add %pointer_offset, %subview_dim_index_2 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %bytes_per_element = riscv.li 8 : !riscv.reg
// CHECK-NEXT: %scaled_pointer_offset = riscv.mul %pointer_offset_1, %bytes_per_element {"comment" = "multiply by element size"} : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %offset_pointer = riscv.add %dynamic_subview, %scaled_pointer_offset : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %dynamic_subview_1 = builtin.unrealized_conversion_cast %offset_pointer : !riscv.reg to memref<3x2xf64, strided<[2, 1], offset: ?>>
%dynamic_subview = memref.subview %original[%offset, 0, 0][1, 3, 2][1, 1, 1] :
memref<4x3x2xf64> to memref<3x2xf64, strided<[2, 1], offset: ?>>

// CHECK-NEXT: %larger_original = "test.op"() : () -> memref<5x4x3x2xf64>
%larger_original = "test.op"() : () -> memref<5x4x3x2xf64>
// CHECK-NEXT: %larger_dynamic_subview = builtin.unrealized_conversion_cast %larger_original : memref<5x4x3x2xf64> to !riscv.reg
// CHECK-NEXT: %subview_dim_index_3 = builtin.unrealized_conversion_cast %offset : index to !riscv.reg
// CHECK-NEXT: %subview_dim_index_4 = builtin.unrealized_conversion_cast %offset : index to !riscv.reg
// CHECK-NEXT: %subview_dim_index_5 = riscv.li 0 : !riscv.reg
// CHECK-NEXT: %subview_dim_index_6 = riscv.li 0 : !riscv.reg
// CHECK-NEXT: %pointer_dim_stride_2 = riscv.li 24 : !riscv.reg
// CHECK-NEXT: %pointer_dim_offset_2 = riscv.mul %subview_dim_index_3, %pointer_dim_stride_2 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %pointer_dim_stride_3 = riscv.li 6 : !riscv.reg
// CHECK-NEXT: %pointer_dim_offset_3 = riscv.mul %subview_dim_index_4, %pointer_dim_stride_3 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %pointer_offset_2 = riscv.add %pointer_dim_offset_2, %pointer_dim_offset_3 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %pointer_dim_stride_4 = riscv.li 2 : !riscv.reg
// CHECK-NEXT: %pointer_dim_offset_4 = riscv.mul %subview_dim_index_5, %pointer_dim_stride_4 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %pointer_offset_3 = riscv.add %pointer_offset_2, %pointer_dim_offset_4 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %pointer_offset_4 = riscv.add %pointer_offset_3, %subview_dim_index_6 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %bytes_per_element_1 = riscv.li 8 : !riscv.reg
// CHECK-NEXT: %scaled_pointer_offset_1 = riscv.mul %pointer_offset_4, %bytes_per_element_1 {"comment" = "multiply by element size"} : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %offset_pointer_1 = riscv.add %larger_dynamic_subview, %scaled_pointer_offset_1 : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %larger_dynamic_subview_1 = builtin.unrealized_conversion_cast %offset_pointer_1 : !riscv.reg to memref<3x2xf64, strided<[2, 1], offset: ?>>
%larger_dynamic_subview = memref.subview %larger_original[%offset, %offset, 0, 0][1, 1, 3, 2][1, 1, 1, 1] :
memref<5x4x3x2xf64> to memref<3x2xf64, strided<[2, 1], offset: ?>>

// CHECK-NEXT: }

// -----

%0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>

// Subview with constant offsets, sizes and strides.
%1 = memref.subview %0[0, 2, 0][4, 4, 4][1, 1, 1]
: memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>>

// CHECK: Only strided layout attrs implemented

// -----

%m = "test.op"() : () -> memref<2x3xf64, strided<[6, 1], offset: ?>>
%i0, %i1 = "test.op"() : () -> (index, index)

Expand Down
130 changes: 122 additions & 8 deletions xdsl/backend/riscv/lowering/convert_memref_to_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

def bitwidth_of_type(type_attribute: Attribute) -> int:
"""
Returns the width of an element type in bits, or raises ValueError for unknown inputs.
Returns the width of an element type in bits, or raises DiagnosticException for unknown inputs.
"""
if isinstance(type_attribute, AnyFloat):
return type_attribute.get_bitwidth
Expand All @@ -51,6 +51,21 @@ def bitwidth_of_type(type_attribute: Attribute) -> int:
)


def element_size_for_type(type_attribute: Attribute) -> int:
"""
Returns the width of an element type in bytes, or raises DiagnosticException for
unknown inputs, or sizes not divisible by 8.
"""
bitwidth = bitwidth_of_type(type_attribute)
if bitwidth % 8:
raise DiagnosticException(
f"Cannot determine size for element type {type_attribute}"
f" with bitwidth {bitwidth}"
)
bytes_per_element = bitwidth // 8
return bytes_per_element


class ConvertMemrefAllocOp(RewritePattern):

@op_type_rewrite_pattern
Expand Down Expand Up @@ -102,13 +117,7 @@ def get_strided_pointer(
a new pointer to the element being accessed by the 'indices'.
"""

bitwidth = bitwidth_of_type(memref_type.element_type)
if bitwidth % 8:
raise DiagnosticException(
f"Cannot create offset for element type {memref_type.element_type}"
f" with bitwidth {bitwidth}"
)
bytes_per_element = bitwidth // 8
bytes_per_element = element_size_for_type(memref_type.element_type)

match memref_type.layout:
case NoneAttr():
Expand Down Expand Up @@ -338,6 +347,110 @@ def match_and_rewrite(self, op: memref.GetGlobal, rewriter: PatternRewriter):
)


class ConvertMemrefSubviewOp(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.Subview, rewriter: PatternRewriter):
# Assumes that the operation is valid, meaning that the subview is indeed a
# subview, and that if the offset is stated in the layout attribute, then it's
# correct.

# From MLIR docs:
# https://github.com/llvm/llvm-project/blob/4a9aef683df895934c26591404692d41a687b005/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp#L173-L186
# Replace `dst = subview(memref, sub_offset, sub_sizes, sub_strides))`
# With
#
# \verbatim
# source_buffer, source_offset, source_sizes, source_strides =
# extract_strided_metadata(memref)
# offset = source_offset + sum(sub_offset#i * source_strides#i)
# sizes = sub_sizes
# strides#i = base_strides#i * sub_sizes#i
# dst = reinterpret_cast baseBuffer, offset, sizes, strides
# \endverbatim

# This lowering does not preserve offset, sizes, and strides at runtime, instead
# representing the memref as the base + offset directly, and relying on users of
# the memref to use the information in the type to scale accesses.

source = op.source
result = op.result
source_type = source.type
assert isinstance(source_type, MemRefType)
source_type = cast(MemRefType[Attribute], source_type)
result_type = cast(MemRefType[Attribute], result.type)

result_layout_attr = result_type.layout
if isinstance(result_layout_attr, NoneAttr):
# When a subview has no layout attr, the result is a perfect subview at offset
# 0.
rewriter.replace_matched_op(
UnrealizedConversionCastOp.get((source,), (result_type,))
)
return

if not isinstance(result_layout_attr, StridedLayoutAttr):
raise DiagnosticException("Only strided layout attrs implemented")

offset = result_layout_attr.get_offset()

factor = element_size_for_type(result_type.element_type)

if offset == 0:
rewriter.replace_matched_op(
UnrealizedConversionCastOp.get((source,), (result_type,))
)
return

src = UnrealizedConversionCastOp.get(
(source,), (riscv.IntRegisterType.unallocated(),)
)
src_rd = src.results[0]

if offset is None:
indices: list[SSAValue] = []
index_ops: list[Operation] = []

dynamic_offset_index = 0
for static_offset_attr in op.static_offsets.data:
static_offset = static_offset_attr.data
assert isinstance(static_offset, int)
if static_offset == memref.Subview.DYNAMIC_INDEX:
index_ops.append(
cast_index_op := UnrealizedConversionCastOp.get(
(op.offsets[dynamic_offset_index],),
(riscv.IntRegisterType.unallocated(),),
)
)
index_val = cast_index_op.results[0]
dynamic_offset_index += 1
else:
# No need to insert arithmetic ops that will be multiplied by zero
index_ops.append(offset_op := riscv.LiOp(static_offset))
index_val = offset_op.rd
index_val.name_hint = "subview_dim_index"
indices.append(index_val)
offset_ops, offset_rd = get_strided_pointer(src_rd, indices, source_type)
else:
factor_op = riscv.AddiOp(
src_rd,
offset * factor,
comment="subview offset",
)
index_ops = []
offset_ops = (factor_op,)
offset_rd = factor_op.rd

rewriter.replace_matched_op(
(
src,
*index_ops,
*offset_ops,
UnrealizedConversionCastOp.get((offset_rd,), (result_type,)),
)
)


class ConvertMemrefToRiscvPass(ModulePass):
name = "convert-memref-to-riscv"

Expand All @@ -356,6 +469,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
ConvertMemrefLoadOp(),
ConvertMemrefGlobalOp(),
ConvertMemrefGetGlobalOp(),
ConvertMemrefSubviewOp(),
]
)
).rewrite_module(op)
Expand Down
6 changes: 6 additions & 0 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,12 @@ def get_strides(self) -> Iterable[int | None]:
else:
yield stride.data

def get_offset(self) -> int | None:
if isinstance(self.offset, NoneAttr):
return None
else:
return self.offset.data


@irdl_attr_definition
class AffineMapAttr(MemrefLayoutAttr, Data[AffineMap]):
Expand Down

0 comments on commit 3ec488a

Please sign in to comment.