From a5a36bf7f83d54033b81f4e6af3a52af3791a5a8 Mon Sep 17 00:00:00 2001 From: dk949 <56653556+dk949@users.noreply.github.com> Date: Wed, 18 Sep 2024 15:51:14 +0100 Subject: [PATCH] backend: (csl) added `arith.select` to csl backend `arith.select` can be translated to a CSL ternary operator `if (cond) lhs else rhs`. --- tests/filecheck/backend/csl/print_csl.mlir | 21 +++++++++++++++++++++ xdsl/backend/csl/print_csl.py | 6 ++++++ 2 files changed, 27 insertions(+) diff --git a/tests/filecheck/backend/csl/print_csl.mlir b/tests/filecheck/backend/csl/print_csl.mlir index 333b42722f..dfd60be868 100644 --- a/tests/filecheck/backend/csl/print_csl.mlir +++ b/tests/filecheck/backend/csl/print_csl.mlir @@ -139,6 +139,17 @@ csl.return %and_ : i1 } + csl.func @select() -> !csl { + %value1 = arith.constant 100 : si32 + %value2 = arith.constant 200 : si32 + %toggle = arith.constant 1 : i1 + %A = memref.get_global @A : memref<24xf32> + %dsd1 = "csl.get_mem_dsd"(%A, %value1) : (memref<24xf32>, si32) -> !csl + %dsd2 = "csl.get_mem_dsd"(%A, %value2) : (memref<24xf32>, si32) -> !csl + %selected_dsd = arith.select %toggle, %dsd1, %dsd2 : !csl + csl.return %selected_dsd : !csl + } + csl.func @constants() { %inline_const = arith.constant 100 : i32 @@ -570,6 +581,16 @@ csl.func @builtins() { // CHECK-NEXT: return (((0 <= 1) or (0.0 > 1.1)) and (0.0 >= 1.1)); // CHECK-NEXT: } // CHECK-NEXT: {{ *}} +// CHECK-NEXT: fn select() mem1d_dsd { +// CHECK-NEXT: const dsd1 : mem1d_dsd = @get_dsd( mem1d_dsd, .{ +// CHECK-NEXT: .tensor_access = | d0 | { 100 } -> A[ d0 ] +// CHECK-NEXT: }); +// CHECK-NEXT: const dsd2 : mem1d_dsd = @get_dsd( mem1d_dsd, .{ +// CHECK-NEXT: .tensor_access = | d0 | { 200 } -> A[ d0 ] +// CHECK-NEXT: }); +// CHECK-NEXT: return (if (true) dsd1 else dsd2); +// CHECK-NEXT: } +// CHECK-NEXT: {{ *}} // CHECK-NEXT: fn constants() void { // CHECK-NEXT: var v0 : [const27]i16 = @constants([const27]i16, const27); // CHECK-NEXT: const v1 : [const27]i16 = @constants([const27]i16, const27); diff --git a/xdsl/backend/csl/print_csl.py b/xdsl/backend/csl/print_csl.py index 8be8050f02..5e3ed79de7 100644 --- a/xdsl/backend/csl/print_csl.py +++ b/xdsl/backend/csl/print_csl.py @@ -625,6 +625,12 @@ def print_block(self, body: Block): self._print_or_promote_to_inline_expr( res, self._cmp_value_expr(op), brackets=True ) + case arith.Select(cond=cond, lhs=lhs, rhs=rhs, result=res): + cond = self._get_variable_name_for(cond) + lhs = self._get_variable_name_for(lhs) + rhs = self._get_variable_name_for(rhs) + if_str = f"if ({cond}) {lhs} else {rhs}" + self._print_or_promote_to_inline_expr(res, if_str, brackets=True) case csl.ConcatStructOp(this_struct=a, another_struct=b, result=res): a_var = self._get_variable_name_for(a) b_var = self._get_variable_name_for(b)