diff --git a/tests/dialects/test_csl.py b/tests/dialects/test_csl.py new file mode 100644 index 0000000000..5c41c3e05d --- /dev/null +++ b/tests/dialects/test_csl.py @@ -0,0 +1,43 @@ +import pytest + +from xdsl.dialects.builtin import Float32Type, IntegerType, Signedness, TensorType +from xdsl.dialects.csl import Add16Op, DsdKind, DsdType, GetMemDsdOp +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.test_value import TestSSAValue + +tensor = TestSSAValue(TensorType(Float32Type(), [4])) +size_i32 = TestSSAValue(IntegerType(32, Signedness.SIGNED)) +dest_dsd = GetMemDsdOp( + operands=[tensor, size_i32], result_types=[DsdType(DsdKind.mem1d_dsd)] +) +src_dsd1 = GetMemDsdOp( + operands=[tensor, size_i32], result_types=[DsdType(DsdKind.mem1d_dsd)] +) +src_dsd2 = GetMemDsdOp( + operands=[tensor, size_i32], result_types=[DsdType(DsdKind.mem1d_dsd)] +) +i16_value = TestSSAValue(IntegerType(16, Signedness.SIGNED)) +u16_value = TestSSAValue(IntegerType(16, Signedness.UNSIGNED)) + + +def test_verify_valid_builtin_signature(): + Add16Op(operands=[(dest_dsd, src_dsd1, src_dsd2)], result_types=[]).verify_() + Add16Op(operands=[(dest_dsd, i16_value, src_dsd1)], result_types=[]).verify_() + Add16Op(operands=[(dest_dsd, u16_value, src_dsd1)], result_types=[]).verify_() + Add16Op(operands=[(dest_dsd, src_dsd1, i16_value)], result_types=[]).verify_() + Add16Op(operands=[(dest_dsd, src_dsd1, u16_value)], result_types=[]).verify_() + + +def test_verify_invalid_builtin_signature(): + with pytest.raises(VerifyException): + Add16Op( + operands=[(dest_dsd, src_dsd1, src_dsd2, dest_dsd)], result_types=[] + ).verify_() + with pytest.raises(VerifyException): + Add16Op(operands=[(dest_dsd, src_dsd1)], result_types=[]).verify_() + with pytest.raises(VerifyException): + Add16Op(operands=[(dest_dsd, i16_value, u16_value)], result_types=[]).verify_() + with pytest.raises(VerifyException): + Add16Op(operands=[(i16_value, src_dsd1, u16_value)], result_types=[]).verify_() + with pytest.raises(VerifyException): + Add16Op(operands=[(dest_dsd, src_dsd1, size_i32)], result_types=[]).verify_() diff --git a/tests/filecheck/dialects/csl/ops.mlir b/tests/filecheck/dialects/csl/ops.mlir index 56e74205d2..dd6b24ade6 100644 --- a/tests/filecheck/dialects/csl/ops.mlir +++ b/tests/filecheck/dialects/csl/ops.mlir @@ -87,10 +87,201 @@ csl.func @initialize() { %fabin_dsd = "csl.get_fab_dsd"(%scalar) : (i32) -> !csl %fabout_dsd = "csl.get_fab_dsd"(%scalar) : (i32) -> !csl + %f16_ptr, %f16_val, %f32_ptr = "test.op"() : () -> (!csl.ptr, #csl>, f16, !csl.ptr, #csl>) + "csl.faddh"(%dsd_1d1, %dsd_1d2, %dsd_1d3) : (!csl, !csl, !csl) -> () + "csl.faddh"(%f16_ptr, %f16_val, %dsd_1d3) : (!csl.ptr, #csl>, f16, !csl) -> () + // this will fail as expected: + // "csl.faddh"(%f32_ptr, %f16_val, %dsd_1d3) : (!csl.ptr, #csl>, f16, !csl) -> () csl.return } +csl.func @builtins() { + %i16_value, %i32_value, %u16_value, %u32_value, %f16_value, %f32_value = "test.op"() : () -> (si16, si32, ui16, ui32, f16, f32) + %i16_pointer, %i32_pointer = "test.op"() : () -> (!csl.ptr, #csl>, !csl.ptr, #csl>) + %u16_pointer, %u32_pointer = "test.op"() : () -> (!csl.ptr, #csl>, !csl.ptr, #csl>) + %f16_pointer, %f32_pointer = "test.op"() : () -> (!csl.ptr, #csl>, !csl.ptr, #csl>) + %tens = "test.op"() : () -> (tensor<510xf32>) + %dest_dsd = "csl.get_mem_dsd"(%tens, %i32_value) : (tensor<510xf32>, si32) -> !csl + %src_dsd1 = "csl.get_mem_dsd"(%tens, %i32_value) : (tensor<510xf32>, si32) -> !csl + %src_dsd2 = "csl.get_mem_dsd"(%tens, %i32_value) : (tensor<510xf32>, si32) -> !csl + + "csl.add16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.add16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () + "csl.add16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () + "csl.add16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () + "csl.add16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () + + "csl.addc16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.addc16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () + "csl.addc16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () + "csl.addc16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () + "csl.addc16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () + + "csl.and16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.and16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () + "csl.and16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () + "csl.and16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () + "csl.and16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () + + "csl.clz"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.clz"(%dest_dsd, %i16_value) : (!csl, si16) -> () + "csl.clz"(%dest_dsd, %u16_value) : (!csl, ui16) -> () + + "csl.ctz"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.ctz"(%dest_dsd, %i16_value) : (!csl, si16) -> () + "csl.ctz"(%dest_dsd, %u16_value) : (!csl, ui16) -> () + + "csl.fabsh"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fabsh"(%dest_dsd, %f16_value) : (!csl, f16) -> () + + "csl.fabss"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fabss"(%dest_dsd, %f32_value) : (!csl, f32) -> () + + "csl.faddh"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.faddh"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () + "csl.faddh"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () + "csl.faddh"(%f16_pointer, %f16_value, %src_dsd1) : (!csl.ptr, #csl>, f16, !csl) -> () + + "csl.faddhs"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.faddhs"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () + "csl.faddhs"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () + "csl.faddhs"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () + + "csl.fadds"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.fadds"(%dest_dsd, %f32_value, %src_dsd1) : (!csl, f32, !csl) -> () + "csl.fadds"(%dest_dsd, %src_dsd1, %f32_value) : (!csl, !csl, f32) -> () + "csl.fadds"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () + + "csl.fh2s"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fh2s"(%dest_dsd, %f16_value) : (!csl, f16) -> () + + "csl.fh2xp16"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fh2xp16"(%dest_dsd, %f16_value) : (!csl, f16) -> () + "csl.fh2xp16"(%i16_pointer, %f16_value) : (!csl.ptr, #csl>, f16) -> () + + "csl.fmach" (%dest_dsd, %src_dsd1, %src_dsd2, %f16_value) : (!csl, !csl, !csl, f16) -> () + "csl.fmachs"(%dest_dsd, %src_dsd1, %src_dsd2, %f16_value) : (!csl, !csl, !csl, f16) -> () + "csl.fmacs" (%dest_dsd, %src_dsd1, %src_dsd2, %f32_value) : (!csl, !csl, !csl, f32) -> () + + "csl.fmaxh"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.fmaxh"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () + "csl.fmaxh"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () + "csl.fmaxh"(%f16_pointer, %f16_value, %src_dsd1) : (!csl.ptr, #csl>, f16, !csl) -> () + + "csl.fmaxs"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.fmaxs"(%dest_dsd, %f32_value, %src_dsd1) : (!csl, f32, !csl) -> () + "csl.fmaxs"(%dest_dsd, %src_dsd1, %f32_value) : (!csl, !csl, f32) -> () + "csl.fmaxs"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () + + "csl.fmovh"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fmovh"(%f16_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () + "csl.fmovh"(%dest_dsd, %f16_value) : (!csl, f16) -> () + + "csl.fmovs"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fmovs"(%f32_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () + "csl.fmovs"(%dest_dsd, %f32_value) : (!csl, f32) -> () + + "csl.fmulh"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.fmulh"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () + "csl.fmulh"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () + "csl.fmulh"(%f16_pointer, %f16_value, %src_dsd1) : (!csl.ptr, #csl>, f16, !csl) -> () + + "csl.fmuls"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.fmuls"(%dest_dsd, %f32_value, %src_dsd1) : (!csl, f32, !csl) -> () + "csl.fmuls"(%dest_dsd, %src_dsd1, %f32_value) : (!csl, !csl, f32) -> () + "csl.fmuls"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () + + "csl.fnegh"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fnegh"(%dest_dsd, %f16_value) : (!csl, f16) -> () + + "csl.fnegs"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fnegs"(%dest_dsd, %f32_value) : (!csl, f32) -> () + + "csl.fnormh"(%f16_pointer, %f16_value) : (!csl.ptr, #csl>, f16) -> () + "csl.fnorms"(%f32_pointer, %f32_value) : (!csl.ptr, #csl>, f32) -> () + + "csl.fs2h"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fs2h"(%dest_dsd, %f32_value) : (!csl, f32) -> () + + "csl.fs2xp16"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.fs2xp16"(%dest_dsd, %f32_value) : (!csl, f32) -> () + "csl.fs2xp16"(%i16_pointer, %f32_value) : (!csl.ptr, #csl>, f32) -> () + + "csl.fscaleh"(%f16_pointer, %f16_value, %i16_value) : (!csl.ptr, #csl>, f16, si16) -> () + "csl.fscales"(%f32_pointer, %f32_value, %i16_value) : (!csl.ptr, #csl>, f32, si16) -> () + + "csl.fsubh"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.fsubh"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () + "csl.fsubh"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () + "csl.fsubh"(%f16_pointer, %f16_value, %src_dsd1) : (!csl.ptr, #csl>, f16, !csl) -> () + + "csl.fsubs"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.fsubs"(%dest_dsd, %f32_value, %src_dsd1) : (!csl, f32, !csl) -> () + "csl.fsubs"(%dest_dsd, %src_dsd1, %f32_value) : (!csl, !csl, f32) -> () + "csl.fsubs"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () + + "csl.mov16"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.mov16"(%i16_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () + "csl.mov16"(%u16_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () + "csl.mov16"(%dest_dsd, %i16_value) : (!csl, si16) -> () + "csl.mov16"(%dest_dsd, %u16_value) : (!csl, ui16) -> () + + "csl.mov32"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.mov32"(%i32_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () + "csl.mov32"(%u32_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () + "csl.mov32"(%dest_dsd, %i32_value) : (!csl, si32) -> () + "csl.mov32"(%dest_dsd, %u32_value) : (!csl, ui32) -> () + + "csl.or16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.or16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () + "csl.or16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () + "csl.or16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () + "csl.or16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () + + "csl.popcnt"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.popcnt"(%dest_dsd, %i16_value) : (!csl, si16) -> () + "csl.popcnt"(%dest_dsd, %u16_value) : (!csl, ui16) -> () + + "csl.sar16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.sar16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () + "csl.sar16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () + "csl.sar16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () + "csl.sar16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () + + "csl.sll16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.sll16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () + "csl.sll16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () + "csl.sll16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () + "csl.sll16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () + + "csl.slr16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.slr16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () + "csl.slr16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () + "csl.slr16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () + "csl.slr16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () + + "csl.sub16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.sub16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () + "csl.sub16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () + + "csl.xor16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () + "csl.xor16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () + "csl.xor16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () + "csl.xor16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () + "csl.xor16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () + + "csl.xp162fh"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.xp162fh"(%dest_dsd, %i16_value) : (!csl, si16) -> () + "csl.xp162fh"(%dest_dsd, %u16_value) : (!csl, ui16) -> () + + "csl.xp162fs"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () + "csl.xp162fs"(%dest_dsd, %i16_value) : (!csl, si16) -> () + "csl.xp162fs"(%dest_dsd, %u16_value) : (!csl, ui16) -> () + + csl.return +} + %global_ptr = "test.op"() : () -> !csl.ptr, #csl> "csl.export"() <{var_name = @initialize, type = () -> ()}> : () -> () @@ -173,6 +364,155 @@ csl.func @initialize() { // CHECK-NEXT: %tensor_dsd2 = "csl.set_dsd_base_addr"(%dsd_1d, %tens) : (!csl, tensor<510xf32>) -> !csl // CHECK-NEXT: %fabin_dsd = "csl.get_fab_dsd"(%scalar) : (i32) -> !csl // CHECK-NEXT: %fabout_dsd = "csl.get_fab_dsd"(%scalar) : (i32) -> !csl +// CHECK-NEXT: %f16_ptr, %f16_val, %f32_ptr = "test.op"() : () -> (!csl.ptr, #csl>, f16, !csl.ptr, #csl>) +// CHECK-NEXT: "csl.faddh"(%dsd_1d1, %dsd_1d2, %dsd_1d3) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.faddh"(%f16_ptr, %f16_val, %dsd_1d3) : (!csl.ptr, #csl>, f16, !csl) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @builtins() { +// CHECK-NEXT: %i16_value, %i32_value, %u16_value, %u32_value, %f16_value, %f32_value = "test.op"() : () -> (si16, si32, ui16, ui32, f16, f32) +// CHECK-NEXT: %i16_pointer, %i32_pointer = "test.op"() : () -> (!csl.ptr, #csl>, !csl.ptr, #csl>) +// CHECK-NEXT: %u16_pointer, %u32_pointer = "test.op"() : () -> (!csl.ptr, #csl>, !csl.ptr, #csl>) +// CHECK-NEXT: %f16_pointer, %f32_pointer = "test.op"() : () -> (!csl.ptr, #csl>, !csl.ptr, #csl>) +// CHECK-NEXT: %tens_1 = "test.op"() : () -> tensor<510xf32> +// CHECK-NEXT: %dest_dsd = "csl.get_mem_dsd"(%tens_1, %i32_value) : (tensor<510xf32>, si32) -> !csl +// CHECK-NEXT: %src_dsd1 = "csl.get_mem_dsd"(%tens_1, %i32_value) : (tensor<510xf32>, si32) -> !csl +// CHECK-NEXT: %src_dsd2 = "csl.get_mem_dsd"(%tens_1, %i32_value) : (tensor<510xf32>, si32) -> !csl +// CHECK-NEXT: "csl.add16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.add16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () +// CHECK-NEXT: "csl.add16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () +// CHECK-NEXT: "csl.add16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () +// CHECK-NEXT: "csl.add16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () +// CHECK-NEXT: "csl.addc16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.addc16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () +// CHECK-NEXT: "csl.addc16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () +// CHECK-NEXT: "csl.addc16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () +// CHECK-NEXT: "csl.addc16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () +// CHECK-NEXT: "csl.and16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.and16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () +// CHECK-NEXT: "csl.and16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () +// CHECK-NEXT: "csl.and16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () +// CHECK-NEXT: "csl.and16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () +// CHECK-NEXT: "csl.clz"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.clz"(%dest_dsd, %i16_value) : (!csl, si16) -> () +// CHECK-NEXT: "csl.clz"(%dest_dsd, %u16_value) : (!csl, ui16) -> () +// CHECK-NEXT: "csl.ctz"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.ctz"(%dest_dsd, %i16_value) : (!csl, si16) -> () +// CHECK-NEXT: "csl.ctz"(%dest_dsd, %u16_value) : (!csl, ui16) -> () +// CHECK-NEXT: "csl.fabsh"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fabsh"(%dest_dsd, %f16_value) : (!csl, f16) -> () +// CHECK-NEXT: "csl.fabss"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fabss"(%dest_dsd, %f32_value) : (!csl, f32) -> () +// CHECK-NEXT: "csl.faddh"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.faddh"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () +// CHECK-NEXT: "csl.faddh"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () +// CHECK-NEXT: "csl.faddh"(%f16_pointer, %f16_value, %src_dsd1) : (!csl.ptr, #csl>, f16, !csl) -> () +// CHECK-NEXT: "csl.faddhs"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.faddhs"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () +// CHECK-NEXT: "csl.faddhs"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () +// CHECK-NEXT: "csl.faddhs"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () +// CHECK-NEXT: "csl.fadds"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.fadds"(%dest_dsd, %f32_value, %src_dsd1) : (!csl, f32, !csl) -> () +// CHECK-NEXT: "csl.fadds"(%dest_dsd, %src_dsd1, %f32_value) : (!csl, !csl, f32) -> () +// CHECK-NEXT: "csl.fadds"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () +// CHECK-NEXT: "csl.fh2s"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fh2s"(%dest_dsd, %f16_value) : (!csl, f16) -> () +// CHECK-NEXT: "csl.fh2xp16"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fh2xp16"(%dest_dsd, %f16_value) : (!csl, f16) -> () +// CHECK-NEXT: "csl.fh2xp16"(%i16_pointer, %f16_value) : (!csl.ptr, #csl>, f16) -> () +// CHECK-NEXT: "csl.fmach"(%dest_dsd, %src_dsd1, %src_dsd2, %f16_value) : (!csl, !csl, !csl, f16) -> () +// CHECK-NEXT: "csl.fmachs"(%dest_dsd, %src_dsd1, %src_dsd2, %f16_value) : (!csl, !csl, !csl, f16) -> () +// CHECK-NEXT: "csl.fmacs"(%dest_dsd, %src_dsd1, %src_dsd2, %f32_value) : (!csl, !csl, !csl, f32) -> () +// CHECK-NEXT: "csl.fmaxh"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.fmaxh"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () +// CHECK-NEXT: "csl.fmaxh"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () +// CHECK-NEXT: "csl.fmaxh"(%f16_pointer, %f16_value, %src_dsd1) : (!csl.ptr, #csl>, f16, !csl) -> () +// CHECK-NEXT: "csl.fmaxs"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.fmaxs"(%dest_dsd, %f32_value, %src_dsd1) : (!csl, f32, !csl) -> () +// CHECK-NEXT: "csl.fmaxs"(%dest_dsd, %src_dsd1, %f32_value) : (!csl, !csl, f32) -> () +// CHECK-NEXT: "csl.fmaxs"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () +// CHECK-NEXT: "csl.fmovh"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fmovh"(%f16_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () +// CHECK-NEXT: "csl.fmovh"(%dest_dsd, %f16_value) : (!csl, f16) -> () +// CHECK-NEXT: "csl.fmovs"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fmovs"(%f32_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () +// CHECK-NEXT: "csl.fmovs"(%dest_dsd, %f32_value) : (!csl, f32) -> () +// CHECK-NEXT: "csl.fmulh"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.fmulh"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () +// CHECK-NEXT: "csl.fmulh"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () +// CHECK-NEXT: "csl.fmulh"(%f16_pointer, %f16_value, %src_dsd1) : (!csl.ptr, #csl>, f16, !csl) -> () +// CHECK-NEXT: "csl.fmuls"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.fmuls"(%dest_dsd, %f32_value, %src_dsd1) : (!csl, f32, !csl) -> () +// CHECK-NEXT: "csl.fmuls"(%dest_dsd, %src_dsd1, %f32_value) : (!csl, !csl, f32) -> () +// CHECK-NEXT: "csl.fmuls"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () +// CHECK-NEXT: "csl.fnegh"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fnegh"(%dest_dsd, %f16_value) : (!csl, f16) -> () +// CHECK-NEXT: "csl.fnegs"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fnegs"(%dest_dsd, %f32_value) : (!csl, f32) -> () +// CHECK-NEXT: "csl.fnormh"(%f16_pointer, %f16_value) : (!csl.ptr, #csl>, f16) -> () +// CHECK-NEXT: "csl.fnorms"(%f32_pointer, %f32_value) : (!csl.ptr, #csl>, f32) -> () +// CHECK-NEXT: "csl.fs2h"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fs2h"(%dest_dsd, %f32_value) : (!csl, f32) -> () +// CHECK-NEXT: "csl.fs2xp16"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.fs2xp16"(%dest_dsd, %f32_value) : (!csl, f32) -> () +// CHECK-NEXT: "csl.fs2xp16"(%i16_pointer, %f32_value) : (!csl.ptr, #csl>, f32) -> () +// CHECK-NEXT: "csl.fscaleh"(%f16_pointer, %f16_value, %i16_value) : (!csl.ptr, #csl>, f16, si16) -> () +// CHECK-NEXT: "csl.fscales"(%f32_pointer, %f32_value, %i16_value) : (!csl.ptr, #csl>, f32, si16) -> () +// CHECK-NEXT: "csl.fsubh"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.fsubh"(%dest_dsd, %f16_value, %src_dsd1) : (!csl, f16, !csl) -> () +// CHECK-NEXT: "csl.fsubh"(%dest_dsd, %src_dsd1, %f16_value) : (!csl, !csl, f16) -> () +// CHECK-NEXT: "csl.fsubh"(%f16_pointer, %f16_value, %src_dsd1) : (!csl.ptr, #csl>, f16, !csl) -> () +// CHECK-NEXT: "csl.fsubs"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.fsubs"(%dest_dsd, %f32_value, %src_dsd1) : (!csl, f32, !csl) -> () +// CHECK-NEXT: "csl.fsubs"(%dest_dsd, %src_dsd1, %f32_value) : (!csl, !csl, f32) -> () +// CHECK-NEXT: "csl.fsubs"(%f32_pointer, %f32_value, %src_dsd1) : (!csl.ptr, #csl>, f32, !csl) -> () +// CHECK-NEXT: "csl.mov16"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.mov16"(%i16_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () +// CHECK-NEXT: "csl.mov16"(%u16_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () +// CHECK-NEXT: "csl.mov16"(%dest_dsd, %i16_value) : (!csl, si16) -> () +// CHECK-NEXT: "csl.mov16"(%dest_dsd, %u16_value) : (!csl, ui16) -> () +// CHECK-NEXT: "csl.mov32"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.mov32"(%i32_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () +// CHECK-NEXT: "csl.mov32"(%u32_pointer, %src_dsd1) : (!csl.ptr, #csl>, !csl) -> () +// CHECK-NEXT: "csl.mov32"(%dest_dsd, %i32_value) : (!csl, si32) -> () +// CHECK-NEXT: "csl.mov32"(%dest_dsd, %u32_value) : (!csl, ui32) -> () +// CHECK-NEXT: "csl.or16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.or16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () +// CHECK-NEXT: "csl.or16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () +// CHECK-NEXT: "csl.or16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () +// CHECK-NEXT: "csl.or16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () +// CHECK-NEXT: "csl.popcnt"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.popcnt"(%dest_dsd, %i16_value) : (!csl, si16) -> () +// CHECK-NEXT: "csl.popcnt"(%dest_dsd, %u16_value) : (!csl, ui16) -> () +// CHECK-NEXT: "csl.sar16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.sar16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () +// CHECK-NEXT: "csl.sar16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () +// CHECK-NEXT: "csl.sar16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () +// CHECK-NEXT: "csl.sar16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () +// CHECK-NEXT: "csl.sll16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.sll16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () +// CHECK-NEXT: "csl.sll16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () +// CHECK-NEXT: "csl.sll16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () +// CHECK-NEXT: "csl.sll16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () +// CHECK-NEXT: "csl.slr16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.slr16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () +// CHECK-NEXT: "csl.slr16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () +// CHECK-NEXT: "csl.slr16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () +// CHECK-NEXT: "csl.slr16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () +// CHECK-NEXT: "csl.sub16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.sub16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () +// CHECK-NEXT: "csl.sub16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () +// CHECK-NEXT: "csl.xor16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "csl.xor16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () +// CHECK-NEXT: "csl.xor16"(%dest_dsd, %u16_value, %src_dsd1) : (!csl, ui16, !csl) -> () +// CHECK-NEXT: "csl.xor16"(%dest_dsd, %src_dsd1, %i16_value) : (!csl, !csl, si16) -> () +// CHECK-NEXT: "csl.xor16"(%dest_dsd, %src_dsd1, %u16_value) : (!csl, !csl, ui16) -> () +// CHECK-NEXT: "csl.xp162fh"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.xp162fh"(%dest_dsd, %i16_value) : (!csl, si16) -> () +// CHECK-NEXT: "csl.xp162fh"(%dest_dsd, %u16_value) : (!csl, ui16) -> () +// CHECK-NEXT: "csl.xp162fs"(%dest_dsd, %src_dsd1) : (!csl, !csl) -> () +// CHECK-NEXT: "csl.xp162fs"(%dest_dsd, %i16_value) : (!csl, si16) -> () +// CHECK-NEXT: "csl.xp162fs"(%dest_dsd, %u16_value) : (!csl, ui16) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: %global_ptr = "test.op"() : () -> !csl.ptr, #csl> diff --git a/xdsl/dialects/csl.py b/xdsl/dialects/csl.py index c4af908396..acc98615c9 100644 --- a/xdsl/dialects/csl.py +++ b/xdsl/dialects/csl.py @@ -12,7 +12,7 @@ from abc import ABC from collections.abc import Sequence from dataclasses import dataclass -from typing import Annotated, TypeAlias +from typing import Annotated, ClassVar, TypeAlias from xdsl.dialects.builtin import ( AnyFloatAttr, @@ -288,6 +288,30 @@ def get_element_type(self) -> Attribute: ) +f16_pointer = PtrType( + [Float16Type(), PtrKindAttr(PtrKind.SINGLE), PtrConstAttr(PtrConst.VAR)] +) +f32_pointer = PtrType( + [Float32Type(), PtrKindAttr(PtrKind.SINGLE), PtrConstAttr(PtrConst.VAR)] +) +u16_value = IntegerType(16, Signedness.UNSIGNED) +i16_value = IntegerType(16, Signedness.SIGNED) +u32_value = IntegerType(32, Signedness.UNSIGNED) +i32_value = IntegerType(32, Signedness.SIGNED) +i16_pointer = PtrType( + [i16_value, PtrKindAttr(PtrKind.SINGLE), PtrConstAttr(PtrConst.VAR)] +) +u16_pointer = PtrType( + [u16_value, PtrKindAttr(PtrKind.SINGLE), PtrConstAttr(PtrConst.VAR)] +) +i32_pointer = PtrType( + [i32_value, PtrKindAttr(PtrKind.SINGLE), PtrConstAttr(PtrConst.VAR)] +) +u32_pointer = PtrType( + [u32_value, PtrKindAttr(PtrKind.SINGLE), PtrConstAttr(PtrConst.VAR)] +) + + @irdl_attr_definition class DsdType(EnumAttribute[DsdKind], TypeAttribute, SpacedOpaqueSyntaxAttribute): """ @@ -819,7 +843,7 @@ class IncrementDsdOffsetOp(IRDLOperation): name = "csl.increment_dsd_offset" op = operand_def(DsdType) - offset = operand_def(IntegerType(16, Signedness.SIGNED)) + offset = operand_def(i16_value) elem_type = prop_def(DsdElementType) result = result_def(DsdType) @@ -845,7 +869,7 @@ class SetDsdLengthOp(IRDLOperation): name = "csl.set_dsd_length" op = operand_def(DsdType) - length = operand_def(IntegerType(16, Signedness.UNSIGNED)) + length = operand_def(u16_value) result = result_def(DsdType) def verify_(self) -> None: @@ -883,6 +907,435 @@ def verify_(self) -> None: raise VerifyException(f"{self.name} can only operate on mem1d_dsd type") +FunctionSignatures = list[tuple[Attribute | type[Attribute], ...]] + + +class BuiltinDsdOp(IRDLOperation, ABC): + ops = var_operand_def() + + SIGNATURES: ClassVar[FunctionSignatures] + + def verify_(self) -> None: + def typcheck( + op_typ: Attribute, + sig_typ: Attribute | type[Attribute], + ) -> bool: + if isinstance(sig_typ, type): + return isinstance(op_typ, sig_typ) + else: + return op_typ == sig_typ + + for sig in self.SIGNATURES: + if len(self.ops) == len(sig): + if all(typcheck(op.type, sig_t) for (op, sig_t) in zip(self.ops, sig)): + return + raise VerifyException("Cannot find matching type signature") + + +class SymmetricBinary16BitOp(BuiltinDsdOp): + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, i16_value, DsdType), + (DsdType, u16_value, DsdType), + (DsdType, DsdType, i16_value), + (DsdType, DsdType, u16_value), + ] + + +class Unary16BitOp(BuiltinDsdOp): + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, i16_value), + (DsdType, u16_value), + ] + + +@irdl_op_definition +class Add16Op(SymmetricBinary16BitOp): + name = "csl.add16" + + +@irdl_op_definition +class Add16cOp(SymmetricBinary16BitOp): + name = "csl.addc16" + + +@irdl_op_definition +class And16Op(SymmetricBinary16BitOp): + name = "csl.and16" + + +@irdl_op_definition +class ClzOp(Unary16BitOp): + name = "csl.clz" + + +@irdl_op_definition +class CtzOp(Unary16BitOp): + name = "csl.ctz" + + +@irdl_op_definition +class FabshOp(BuiltinDsdOp): + name = "csl.fabsh" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, Float16Type), + ] + + +@irdl_op_definition +class FabssOp(BuiltinDsdOp): + name = "csl.fabss" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, Float32Type), + ] + + +@irdl_op_definition +class FaddhOp(BuiltinDsdOp): + name = "csl.faddh" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, Float16Type, DsdType), + (DsdType, DsdType, Float16Type), + (f16_pointer, Float16Type, DsdType), + ] + + +@irdl_op_definition +class FaddhsOp(BuiltinDsdOp): + name = "csl.faddhs" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, Float16Type, DsdType), + (DsdType, DsdType, Float16Type), + (f32_pointer, Float32Type, DsdType), + ] + + +@irdl_op_definition +class FaddsOp(BuiltinDsdOp): + name = "csl.fadds" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, Float32Type, DsdType), + (DsdType, DsdType, Float32Type), + (f32_pointer, Float32Type, DsdType), + ] + + +@irdl_op_definition +class Fh2sOp(BuiltinDsdOp): + name = "csl.fh2s" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, Float16Type), + ] + + +@irdl_op_definition +class Fh2xp16Op(BuiltinDsdOp): + name = "csl.fh2xp16" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, Float16Type), + (i16_pointer, Float16Type), + ] + + +@irdl_op_definition +class FmachOp(BuiltinDsdOp): + name = "csl.fmach" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType, Float16Type) + ] + + +@irdl_op_definition +class FmachsOp(BuiltinDsdOp): + name = "csl.fmachs" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType, Float16Type) + ] + + +@irdl_op_definition +class FmacsOp(BuiltinDsdOp): + name = "csl.fmacs" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType, Float32Type) + ] + + +@irdl_op_definition +class FmaxhOp(BuiltinDsdOp): + name = "csl.fmaxh" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, Float16Type, DsdType), + (DsdType, DsdType, Float16Type), + (f16_pointer, Float16Type, DsdType), + ] + + +@irdl_op_definition +class FmaxsOp(BuiltinDsdOp): + name = "csl.fmaxs" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, Float32Type, DsdType), + (DsdType, DsdType, Float32Type), + (f32_pointer, Float32Type, DsdType), + ] + + +@irdl_op_definition +class FmovhOp(BuiltinDsdOp): + name = "csl.fmovh" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (f16_pointer, DsdType), + (DsdType, Float16Type), + ] + + +@irdl_op_definition +class FmovsOp(BuiltinDsdOp): + name = "csl.fmovs" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (f32_pointer, DsdType), + (DsdType, Float32Type), + ] + + +@irdl_op_definition +class FmulhOp(BuiltinDsdOp): + name = "csl.fmulh" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, Float16Type, DsdType), + (DsdType, DsdType, Float16Type), + (f16_pointer, Float16Type, DsdType), + ] + + +@irdl_op_definition +class FmulsOp(BuiltinDsdOp): + name = "csl.fmuls" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, Float32Type, DsdType), + (DsdType, DsdType, Float32Type), + (f32_pointer, Float32Type, DsdType), + ] + + +@irdl_op_definition +class FneghOp(BuiltinDsdOp): + name = "csl.fnegh" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, Float16Type), + ] + + +@irdl_op_definition +class FnegsOp(BuiltinDsdOp): + name = "csl.fnegs" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, Float32Type), + ] + + +@irdl_op_definition +class FnormhOp(BuiltinDsdOp): + name = "csl.fnormh" + + SIGNATURES: ClassVar[FunctionSignatures] = [(f16_pointer, Float16Type)] + + +@irdl_op_definition +class FnormsOp(BuiltinDsdOp): + name = "csl.fnorms" + + SIGNATURES: ClassVar[FunctionSignatures] = [(f32_pointer, Float32Type)] + + +@irdl_op_definition +class Fs2hOp(BuiltinDsdOp): + name = "csl.fs2h" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, Float32Type), + ] + + +@irdl_op_definition +class Fs2xp16Op(BuiltinDsdOp): + """ + Implements @fs2xp16 + Note: this actually converts to i16, not to i32 + """ + + name = "csl.fs2xp16" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, Float32Type), + (i16_pointer, Float32Type), + ] + + +@irdl_op_definition +class FscalehOp(BuiltinDsdOp): + name = "csl.fscaleh" + + SIGNATURES: ClassVar[FunctionSignatures] = [(f16_pointer, Float16Type, i16_value)] + + +@irdl_op_definition +class FscalesOp(BuiltinDsdOp): + name = "csl.fscales" + + SIGNATURES: ClassVar[FunctionSignatures] = [(f32_pointer, Float32Type, i16_value)] + + +@irdl_op_definition +class FsubhOp(BuiltinDsdOp): + name = "csl.fsubh" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, Float16Type, DsdType), + (DsdType, DsdType, Float16Type), + (f16_pointer, Float16Type, DsdType), + ] + + +@irdl_op_definition +class FsubsOp(BuiltinDsdOp): + name = "csl.fsubs" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, Float32Type, DsdType), + (DsdType, DsdType, Float32Type), + (f32_pointer, Float32Type, DsdType), + ] + + +@irdl_op_definition +class Mov16Op(BuiltinDsdOp): + name = "csl.mov16" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (i16_pointer, DsdType), + (u16_pointer, DsdType), + (DsdType, i16_value), + (DsdType, u16_value), + ] + + +@irdl_op_definition +class Mov32Op(BuiltinDsdOp): + name = "csl.mov32" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (i32_pointer, DsdType), + (u32_pointer, DsdType), + (DsdType, i32_value), + (DsdType, u32_value), + ] + + +@irdl_op_definition +class Or16Op(SymmetricBinary16BitOp): + name = "csl.or16" + + +@irdl_op_definition +class PopcntOp(Unary16BitOp): + name = "csl.popcnt" + + +@irdl_op_definition +class Sar16Op(SymmetricBinary16BitOp): + name = "csl.sar16" + + +@irdl_op_definition +class Sll16Op(SymmetricBinary16BitOp): + name = "csl.sll16" + + +@irdl_op_definition +class Slr16Op(SymmetricBinary16BitOp): + name = "csl.slr16" + + +@irdl_op_definition +class Sub16Op(BuiltinDsdOp): + name = "csl.sub16" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType, DsdType), + (DsdType, DsdType, i16_value), + (DsdType, DsdType, u16_value), + ] + + +@irdl_op_definition +class Xor16Op(SymmetricBinary16BitOp): + name = "csl.xor16" + + +@irdl_op_definition +class Xp162fhOp(BuiltinDsdOp): + name = "csl.xp162fh" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, i16_value), + (DsdType, u16_value), + ] + + +@irdl_op_definition +class Xp162fsOp(BuiltinDsdOp): + name = "csl.xp162fs" + + SIGNATURES: ClassVar[FunctionSignatures] = [ + (DsdType, DsdType), + (DsdType, i16_value), + (DsdType, u16_value), + ] + + @irdl_op_definition class SymbolExportOp(IRDLOperation): """ @@ -1064,6 +1517,48 @@ def verify_(self) -> None: IncrementDsdOffsetOp, SetDsdLengthOp, SetDsdStrideOp, + Add16Op, + Add16cOp, + And16Op, + ClzOp, + CtzOp, + FabshOp, + FabssOp, + FaddhOp, + FaddhsOp, + FaddsOp, + Fh2sOp, + Fh2xp16Op, + FmachOp, + FmachsOp, + FmacsOp, + FmaxhOp, + FmaxsOp, + FmovhOp, + FmovsOp, + FmulhOp, + FmulsOp, + FneghOp, + FnegsOp, + FnormhOp, + FnormsOp, + Fs2hOp, + Fs2xp16Op, + FscalehOp, + FscalesOp, + FsubhOp, + FsubsOp, + Mov16Op, + Mov32Op, + Or16Op, + PopcntOp, + Sar16Op, + Sll16Op, + Slr16Op, + Sub16Op, + Xor16Op, + Xp162fhOp, + Xp162fsOp, AddressOfOp, SymbolExportOp, RpcOp,