Skip to content

Commit e0b7945

Browse files
authored
transformations: (memref-to-dsd) Support memref.load ops (#3338)
This pass translates memref allocs to csl array allocs with a get_dsd on top of it, which will be used by various compute ops. This PR add support for memref.load ops, which should continue to load from the memref, not from a get_dsd op. --------- Co-authored-by: n-io <n-io@users.noreply.github.com>
1 parent 61bf632 commit e0b7945

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

tests/filecheck/transforms/memref-to-dsd.mlir

+6
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ builtin.module {
114114
// CHECK-NEXT: %29 = arith.constant 510 : i16
115115
// CHECK-NEXT: %30 = "csl.get_mem_dsd"(%b, %29) : (memref<510xf32>, i16) -> !csl<dsd mem1d_dsd>
116116

117+
%38 = memref.load %b[%28] : memref<510xf32>
118+
"test.op"(%38) : (f32) -> ()
119+
120+
// CHECK-NEXT: %31 = memref.load %b[%13] : memref<510xf32>
121+
// CHECK-NEXT: "test.op"(%31) : (f32) -> ()
122+
117123
}) {sym_name = "program"} : () -> ()
118124
}
119125
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()

xdsl/transforms/memref_to_dsd.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,24 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter, /):
8989
raise ValueError("Failed to resolve GetMemDsdOp called on dsd type")
9090

9191

92+
class FixMemrefLoadOnGetDsd(RewritePattern):
93+
"""
94+
Memref load ops should load from the underlying memref, not from the dsd.
95+
"""
96+
97+
@op_type_rewrite_pattern
98+
def match_and_rewrite(self, op: memref.Load, rewriter: PatternRewriter, /):
99+
if isinstance(op.memref.type, csl.DsdType):
100+
if isinstance(op.memref, OpResult) and isinstance(
101+
op.memref.op, csl.GetMemDsdOp
102+
):
103+
rewriter.replace_matched_op(
104+
memref.Load.get(op.memref.op.base_addr, op.indices)
105+
)
106+
else:
107+
raise ValueError("Failed to resolve memref.load called on dsd type")
108+
109+
92110
class LowerSubviewOpPass(RewritePattern):
93111
"""Lowers memref.subview to dsd ops"""
94112

@@ -355,12 +373,10 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
355373
LowerAllocOpPass(),
356374
DsdOpUpdateType(),
357375
RetainAddressOfOpPass(),
376+
FixMemrefLoadOnGetDsd(),
377+
FixGetDsdOnGetDsd(),
358378
]
359379
),
360380
apply_recursively=False,
361381
)
362382
forward_pass.rewrite_module(op)
363-
cleanup_pass = PatternRewriteWalker(
364-
FixGetDsdOnGetDsd(),
365-
)
366-
cleanup_pass.rewrite_module(op)

0 commit comments

Comments
 (0)