Skip to content

Commit

Permalink
transformations: Translate memref to dsd (#3092)
Browse files Browse the repository at this point in the history
Translate memref to csl dsd ops.

---------

Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Aug 28, 2024
1 parent 60c6be5 commit 6ff156c
Show file tree
Hide file tree
Showing 6 changed files with 421 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@

%4 = "csl.constants"(%inline_const, %inline_const) <{is_const}> : (i32, i32) -> memref<?xi32>

%5 = "csl.zeros"(%const27) <{is_const}> : (i16) -> memref<?xi16>

csl.return
}

Expand Down Expand Up @@ -408,6 +410,7 @@ csl.func @builtins() {
// CHECK-NEXT: const v1 : [const27]i16 = @constants([const27]i16, const27);
// CHECK-NEXT: const v2 : [const27]i32 = @constants([const27]i32, 100);
// CHECK-NEXT: const v3 : [100]i32 = @constants([100]i32, 100);
// CHECK-NEXT: const v4 : [const27]i16 = @zeros([const27]i16);
// CHECK-NEXT: return;
// CHECK-NEXT: }
// CHECK-NEXT: {{ *}}
Expand Down
97 changes: 97 additions & 0 deletions tests/filecheck/transforms/memref-to-dsd.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// RUN: xdsl-opt %s -p memref-to-dsd | filecheck %s

builtin.module {
"csl.module"() <{"kind" = #csl<module_kind program>}> ({
// CHECK-NEXT: builtin.module {
// CHECK-NEXT: "csl.module"() <{"kind" = #csl<module_kind program>}> ({

%0 = "test.op"() : () -> (index)
%a = memref.alloc() {"alignment" = 64 : i64} : memref<512xf32>
%b = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
%c = memref.alloc() {"alignment" = 64 : i64} : memref<1024xf32>
%d = memref.subview %a[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>>
%e = memref.subview %a[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>>
"csl.fadds"(%b, %d, %e) : (memref<510xf32>, memref<510xf32, strided<[1]>>, memref<510xf32, strided<[1], offset: 2>>) -> ()
%f = memref.subview %c[1] [510] [2] : memref<1024xf32> to memref<510xf32, strided<[2], offset: 1>>
"csl.fadds"(%b, %b, %f) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[2], offset: 1>>) -> ()

%1 = "csl.addressof"(%a) : (memref<512xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
%2 = "csl.addressof"(%b) : (memref<510xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
%3 = "csl.addressof"(%c) : (memref<1024xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
"csl.export"(%1) <{"var_name" = "a", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
"csl.export"(%2) <{"var_name" = "b", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
"csl.export"(%2) <{"var_name" = "c", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()

// CHECK-NEXT: %0 = "test.op"() : () -> index
// CHECK-NEXT: %a = "csl.zeros"() : () -> memref<512xf32>
// CHECK-NEXT: %a_1 = arith.constant 512 : i16
// CHECK-NEXT: %a_2 = "csl.get_mem_dsd"(%a, %a_1) : (memref<512xf32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %b = "csl.zeros"() : () -> memref<510xf32>
// CHECK-NEXT: %b_1 = arith.constant 510 : i16
// CHECK-NEXT: %b_2 = "csl.get_mem_dsd"(%b, %b_1) : (memref<510xf32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %c = "csl.zeros"() : () -> memref<1024xf32>
// CHECK-NEXT: %c_1 = arith.constant 1024 : i16
// CHECK-NEXT: %c_2 = "csl.get_mem_dsd"(%c, %c_1) : (memref<1024xf32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %d = arith.constant 510 : ui16
// CHECK-NEXT: %d_1 = "csl.set_dsd_length"(%a_2, %d) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %e = arith.constant 510 : ui16
// CHECK-NEXT: %e_1 = "csl.set_dsd_length"(%a_2, %e) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %e_2 = arith.constant 2 : si16
// CHECK-NEXT: %e_3 = "csl.increment_dsd_offset"(%e_1, %e_2) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "csl.fadds"(%b_2, %d_1, %e_3) : (!csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>) -> ()
// CHECK-NEXT: %f = arith.constant 510 : ui16
// CHECK-NEXT: %f_1 = "csl.set_dsd_length"(%c_2, %f) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %f_2 = arith.constant 2 : si8
// CHECK-NEXT: %f_3 = "csl.set_dsd_stride"(%f_1, %f_2) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %f_4 = arith.constant 1 : si16
// CHECK-NEXT: %f_5 = "csl.increment_dsd_offset"(%f_3, %f_4) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "csl.fadds"(%b_2, %b_2, %f_5) : (!csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>) -> ()
// CHECK-NEXT: %1 = "csl.addressof"(%a) : (memref<512xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: %2 = "csl.addressof"(%b) : (memref<510xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: %3 = "csl.addressof"(%c) : (memref<1024xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: "csl.export"(%1) <{"var_name" = "a", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: "csl.export"(%2) <{"var_name" = "b", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: "csl.export"(%2) <{"var_name" = "c", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()


%23 = memref.alloc() {"alignment" = 64 : i64} : memref<10xi32>
%24 = memref.alloc() {"alignment" = 64 : i64} : memref<10xi32>
"memref.copy"(%23, %24) : (memref<10xi32>, memref<10xi32>) -> ()

// CHECK: %4 = "csl.zeros"() : () -> memref<10xi32>
// CHECK-NEXT: %5 = arith.constant 10 : i16
// CHECK-NEXT: %6 = "csl.get_mem_dsd"(%4, %5) : (memref<10xi32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %7 = "csl.zeros"() : () -> memref<10xi32>
// CHECK-NEXT: %8 = arith.constant 10 : i16
// CHECK-NEXT: %9 = "csl.get_mem_dsd"(%7, %8) : (memref<10xi32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "csl.mov32"(%9, %6) : (!csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>) -> ()


%25 = arith.constant 0 : index
%26 = arith.constant 510 : index
%27 = arith.constant 1 : index
%28 = arith.constant 2 : index
%29 = memref.subview %b[%25] [%26] [%27] : memref<510xf32> to memref<510xf32, strided<[1]>>
%30 = memref.subview %c[%27] [%26] [%28] : memref<1024xf32> to memref<510xf32, strided<[2], offset: 1>>

// CHECK-NEXT: %10 = arith.constant 0 : index
// CHECK-NEXT: %11 = arith.constant 510 : index
// CHECK-NEXT: %12 = arith.constant 1 : index
// CHECK-NEXT: %13 = arith.constant 2 : index
// CHECK-NEXT: %14 = arith.index_cast %11 : index to ui16
// CHECK-NEXT: %15 = "csl.set_dsd_length"(%b_2, %14) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %16 = arith.index_cast %12 : index to si8
// CHECK-NEXT: %17 = "csl.set_dsd_stride"(%15, %16) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %18 = arith.index_cast %10 : index to si16
// CHECK-NEXT: %19 = "csl.increment_dsd_offset"(%17, %18) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %20 = arith.index_cast %11 : index to ui16
// CHECK-NEXT: %21 = "csl.set_dsd_length"(%c_2, %20) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %22 = arith.index_cast %13 : index to si8
// CHECK-NEXT: %23 = "csl.set_dsd_stride"(%21, %22) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %24 = arith.index_cast %12 : index to si16
// CHECK-NEXT: %25 = "csl.increment_dsd_offset"(%23, %24) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>

}) {sym_name = "program"} : () -> ()
}
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()
// CHECK-NEXT: }
5 changes: 5 additions & 0 deletions xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,11 @@ def print_block(self, body: Block):
self._print_or_promote_to_inline_expr(
res, f"@concat_structs({a_var}, {b_var})"
)
case csl.ZerosOp(result=res, is_const=constness):
type = self._memref_type_to_string(res)
res_name = self._get_variable_name_for(res)
kind = "const" if constness else "var"
self.print(f"{kind} {res_name} : {type} = @zeros({type});")
case csl.ConstantsOp(value=val, result=res, is_const=constness):
type = self._memref_type_to_string(res)
res_name = self._get_variable_name_for(res)
Expand Down
30 changes: 30 additions & 0 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,35 @@ def verify_(self) -> None:
)


@irdl_op_definition
class ZerosOp(IRDLOperation):
"""
Represents the @zeros operation in CSL.
"""

name = "csl.zeros"

T = Annotated[IntegerType | Float32Type | Float16Type, ConstraintVar("T")]

size = opt_operand_def(T)

result = result_def(MemRefType[T])

is_const = opt_prop_def(builtin.UnitAttr)

def __init__(
self,
memref: MemRefType[T],
dynamic_size: SSAValue | Operation | None = None,
is_const: builtin.UnitAttr | None = None,
):
super().__init__(
operands=[dynamic_size] if dynamic_size else [[]],
result_types=[memref],
properties={"is_const": is_const} if is_const else {},
)


@irdl_op_definition
class ConstantsOp(IRDLOperation):
"""
Expand Down Expand Up @@ -1749,6 +1778,7 @@ def __init__(self, struct_a: Operation, struct_b: Operation):
Xor16Op,
Xp162fhOp,
Xp162fsOp,
ZerosOp,
],
[
ColorType,
Expand Down
6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ def get_memref_stream_legalize():

return memref_stream_legalize.MemrefStreamLegalizePass

def get_memref_to_dsd():
from xdsl.transforms import memref_to_dsd

return memref_to_dsd.MemrefToDsdPass

def get_mlir_opt():
from xdsl.transforms import mlir_opt

Expand Down Expand Up @@ -435,6 +440,7 @@ def get_stencil_bufferize():
"memref-stream-interleave": get_memref_stream_interleave,
"memref-stream-tile-outer-loops": get_memref_stream_tile_outer_loops,
"memref-stream-legalize": get_memref_stream_legalize,
"memref-to-dsd": get_memref_to_dsd,
"mlir-opt": get_mlir_opt,
"printf-to-llvm": get_printf_to_llvm,
"printf-to-putchar": get_printf_to_putchar,
Expand Down
Loading

0 comments on commit 6ff156c

Please sign in to comment.