diff --git a/tests/filecheck/transforms/memref-to-dsd.mlir b/tests/filecheck/transforms/memref-to-dsd.mlir index 17048f3471..9907c976c7 100644 --- a/tests/filecheck/transforms/memref-to-dsd.mlir +++ b/tests/filecheck/transforms/memref-to-dsd.mlir @@ -98,6 +98,14 @@ builtin.module { // CHECK-NEXT: %26 = "test.op"() : () -> !csl // CHECK-NEXT: "csl.fadds"(%26, %26, %26) : (!csl, !csl, !csl) -> () +%33 = "csl.variable"() : () -> !csl.var> +%34 = "csl.load_var"(%33) : (!csl.var>) -> memref<512xf32> +"csl.store_var"(%33, %34) : (!csl.var>, memref<512xf32>) -> () + +// CHECK-NEXT: %27 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: %28 = "csl.load_var"(%27) : (!csl.var>) -> !csl +// CHECK-NEXT: "csl.store_var"(%27, %28) : (!csl.var>, !csl) -> () + }) {sym_name = "program"} : () -> () } // CHECK-NEXT: }) {"sym_name" = "program"} : () -> () diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index ced6af4ed9..ee0b1d201f 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -435,10 +435,15 @@ class LoadVarOp(IRDLOperation): var = operand_def(VarType) res = result_def() - def __init__(self, var: VariableOp): + def __init__(self, var: VariableOp | SSAValue): + if isinstance(var, SSAValue): + assert isinstance(var.type, VarType) + result_t = var.type.get_element_type() + else: + result_t = var.get_element_type() super().__init__( operands=[var], - result_types=[var.get_element_type()], + result_types=[result_t], ) def verify_(self) -> None: diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py index b4d92a164b..507e08d352 100644 --- a/xdsl/transforms/memref_to_dsd.py +++ b/xdsl/transforms/memref_to_dsd.py @@ -261,12 +261,46 @@ def match_and_rewrite(self, op: csl.AddressOfOp, rewriter: PatternRewriter, /): ) +class CslVarUpdate(RewritePattern): + """Update CSL Variable Definitions.""" + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl.VariableOp, rewriter: PatternRewriter, /): + if ( + not isinstance(op.res.type, csl.VarType) + or not isa(elem_t := op.res.type.get_element_type(), MemRefType[Attribute]) + or op.default + ): + return + dsd_t = csl.DsdType( + csl.DsdKind.mem1d_dsd if len(elem_t.shape) == 1 else csl.DsdKind.mem4d_dsd + ) + rewriter.replace_matched_op(csl.VariableOp.from_type(dsd_t)) + + +class CslVarLoad(RewritePattern): + """Update CSL Load Variables.""" + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl.LoadVarOp, rewriter: PatternRewriter, /): + if ( + not isa(op.res.type, MemRefType[Attribute]) + or not isinstance(op.var.type, csl.VarType) + or not isa(op.var.type.get_element_type(), csl.DsdType) + ): + return + rewriter.replace_matched_op(csl.LoadVarOp(op.var)) + + @dataclass(frozen=True) class MemrefToDsdPass(ModulePass): """ Lowers memref ops to CSL DSDs. - Note, that CSL uses memref types in some places + Note, that CSL uses memref types in some places. + + This performs a backwards pass translating memref-consuming ops to dsd-consuming ops when all memref type + information is known. A second forward pass translates memref-generating ops to dsd-generating ops. """ name = "memref-to-dsd" @@ -287,6 +321,8 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: forward_pass = PatternRewriteWalker( GreedyRewritePatternApplier( [ + CslVarUpdate(), + CslVarLoad(), LowerAllocOpPass(), DsdOpUpdateType(), RetainAddressOfOpPass(),