Skip to content

Commit

Permalink
dialects: (csl) added csl.addressof_fn (#3135)
Browse files Browse the repository at this point in the history
Functions like `csl.addressof`, but takes a prop `fn_name` instead of an
SSAValue. Result type is a single const pointer to a function (this is a
limitation imposed by the CSL language.
  • Loading branch information
dk949 authored Sep 2, 2024
1 parent dd00ea8 commit 8a70781
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@
%ptr_to_arr = "csl.addressof"(%uninit_array) : (memref<10xf32>) -> !csl.ptr<memref<10xf32>, #csl<ptr_kind single>, #csl<ptr_const var>>
%ptr_to_val = "csl.addressof"(%const27) : (i16) -> !csl.ptr<i16, #csl<ptr_kind single>, #csl<ptr_const const>>

%ptr_1_fn = "csl.addressof_fn"() <{fn_name = @args_no_return}> : () -> !csl.ptr<(i32, i32) -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
%ptr_2_fn = "csl.addressof_fn"() <{fn_name = @no_args_return}> : () -> !csl.ptr<() -> (f32), #csl<ptr_kind single>, #csl<ptr_const const>>



"csl.export"(%global_ptr) <{
type = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>,
Expand Down Expand Up @@ -453,6 +457,8 @@ csl.func @builtins() {
// CHECK-NEXT: const const_ptr : [*]const i32 = &const_array;
// CHECK-NEXT: var ptr_to_arr : *[10]f32 = &uninit_array;
// CHECK-NEXT: const ptr_to_val : *const i16 = &const27;
// CHECK-NEXT: const ptr_1_fn : *const fn(i32, i32) void = &args_no_return;
// CHECK-NEXT: const ptr_2_fn : *const fn() f32 = &no_args_return;
// CHECK-NEXT: comptime {
// CHECK-NEXT: @export_symbol(global_ptr, "ptr_name");
// CHECK-NEXT: }
Expand Down
4 changes: 4 additions & 0 deletions tests/filecheck/dialects/csl/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ csl.func @initialize() {
%many_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const const>>
%single_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr<memref<10xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>

%function_ptr = "csl.addressof_fn"() <{fn_name = @initialize}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>

%dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl<dsd mem1d_dsd>
%dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3, 4], "offsets" = [1, 2]}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
%dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl<dsd mem4d_dsd>
Expand Down Expand Up @@ -370,6 +372,7 @@ csl.func @builtins() {
// CHECK-NEXT: %scalar_ptr = "csl.addressof"(%scalar) : (i32) -> !csl.ptr<i32, #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %many_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const const>>
// CHECK-NEXT: %single_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr<memref<10xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl<dsd mem4d_dsd>
Expand Down Expand Up @@ -605,6 +608,7 @@ csl.func @builtins() {
// CHECK-GENERIC-NEXT: %scalar_ptr = "csl.addressof"(%scalar) : (i32) -> !csl.ptr<i32, #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-GENERIC-NEXT: %many_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const const>>
// CHECK-GENERIC-NEXT: %single_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr<memref<10xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-GENERIC-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-GENERIC-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl<dsd mem1d_dsd>
// CHECK-GENERIC-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-GENERIC-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl<dsd mem4d_dsd>
Expand Down
5 changes: 5 additions & 0 deletions xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,11 @@ def print_block(self, body: Block):
ty = cast(csl.PtrType, res.type)
use = self._var_use(res, ty.constness.data.value)
self.print(f"{use} = &{val_name};")

case csl.AddressOfFnOp(fn_name=name, res=res):
ty = cast(csl.PtrType, res.type)
use = self._var_use(res, ty.constness.data.value)
self.print(f"{use} = &{name.string_value()};")
case csl.SymbolExportOp(value=val, type=ty) as exp:
name = exp.get_name()
q_name = f'"{name}"'
Expand Down
34 changes: 34 additions & 0 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,39 @@ def verify_(self) -> None:
return super().verify_()


@irdl_op_definition
class AddressOfFnOp(IRDLOperation):
"""
Takes the address of a function from symbol ref.
Result has to have kind SINGLE and constness CONST
"""

name = "csl.addressof_fn"
fn_name = prop_def(SymbolRefAttr)

res = result_def(PtrType)

def __init__(self, fn_name: str | SymbolRefAttr):
if isinstance(fn_name, str):
fn_name = SymbolRefAttr(fn_name)

super().__init__(properties={"fn_name": fn_name})

def verify_(self) -> None:
ty = self.res.type
assert isa(ty, PtrType)
if not isa(ty.type, FunctionType):
raise VerifyException("Pointed to type must be a function type")
if ty.kind.data != PtrKind.SINGLE:
raise VerifyException("Pointer kind must be 'single'")

if ty.constness.data != PtrConst.CONST:
raise VerifyException("Function pointers must be const")

return super().verify_()


@irdl_op_definition
class AddressOfOp(IRDLOperation):
"""
Expand Down Expand Up @@ -1733,6 +1766,7 @@ def __init__(self, struct_a: Operation, struct_b: Operation):
[
Add16Op,
Add16cOp,
AddressOfFnOp,
AddressOfOp,
And16Op,
CallOp,
Expand Down

0 comments on commit 8a70781

Please sign in to comment.