Skip to content

Commit

Permalink
Merge branch 'main' into nicolai/csl-param-to-const
Browse files Browse the repository at this point in the history
  • Loading branch information
n-io authored Oct 25, 2024
2 parents 252bec3 + a3e2e7c commit 40df9be
Show file tree
Hide file tree
Showing 23 changed files with 775 additions and 319 deletions.
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ dev = [
"nbval<0.12",
"filecheck==1.0.1",
"lit<19.0.0",
"marimo==0.9.11",
"marimo==0.9.14",
"pre-commit==4.0.1",
"ruff==0.7.0",
"ruff==0.7.1",
"asv<0.7",
"nbconvert>=7.7.2,<8.0.0",
"textual-dev==1.6.1",
"pytest-asyncio==0.24.0",
"pyright==1.1.385",
"pyright==1.1.386",
]
gui = ["textual==0.83.0", "pyclip==0.7"]
jax = ["jax==0.4.34", "numpy==2.1.2"]
gui = ["textual==0.84.0", "pyclip==0.7"]
jax = ["jax==0.4.35", "numpy==2.1.2"]
onnx = ["onnx==1.17.0", "numpy==2.1.2"]
riscv = ["riscemu==2.2.7"]
wgpu = ["wgpu==0.18.1"]
Expand Down
16 changes: 0 additions & 16 deletions tests/dialects/test_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,6 @@ def test_create_index_attr_from_list_edge_case2():
assert exc_info.value.args[0] == "Expected 1 to 3 indexes for stencil.index, got 4."


@pytest.mark.parametrize(
"indices1, indices2",
(([1], [4]), ([1, 2], [4, 5]), ([1, 2, 3], [5, 6, 7])),
)
def test_index_attr_size_from_bounds(indices1: list[int], indices2: list[int]):
stencil_index_attr1 = IndexAttr.get(*indices1)
stencil_index_attr2 = IndexAttr.get(*indices2)

size_from_bounds = IndexAttr.size_from_bounds(
stencil_index_attr1, stencil_index_attr2
)
expected_list = [abs(idx1 - idx2) for idx1, idx2 in zip(indices1, indices2)]

assert size_from_bounds == expected_list


@pytest.mark.parametrize(
"indices",
(([1]), ([1, 2]), ([1, 2, 3])),
Expand Down
2 changes: 2 additions & 0 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
%const_array = memref.get_global @const_array : memref<10xi32>

%literal_array = arith.constant dense<[1.200000e+00, 2.300000e+00, 3.400000e+00]> : memref<3xf32>
%literal_array_w_zeros = arith.constant dense<[1.200000e+00, 0, 3.400000e+00, 0]> : memref<4xf32>

%uninit_ptr = "csl.addressof"(%uninit_array) : (memref<10xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
%global_ptr = "csl.addressof"(%global_array) : (memref<10xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
Expand Down Expand Up @@ -675,6 +676,7 @@ csl.func @builtins() {
// CHECK-NEXT: var global_array : [10]f32 = @constants([10]f32, 4.2);
// CHECK-NEXT: const const_array : [10]i32 = @constants([10]i32, 10);
// CHECK-NEXT: const literal_array : [3]f32 = [3]f32 { 1.2, 2.3, 3.4 };
// CHECK-NEXT: const literal_array_w_zeros : [4]f32 = [4]f32 { 1.2, 0.0, 3.4, 0.0 };
// CHECK-NEXT: var uninit_ptr : [*]f32 = &uninit_array;
// CHECK-NEXT: var global_ptr : [*]f32 = &global_array;
// CHECK-NEXT: const const_ptr : [*]const i32 = &const_array;
Expand Down
45 changes: 45 additions & 0 deletions tests/filecheck/backend/riscv/convert_arith_to_riscv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,51 @@ builtin.module {
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf14_2, 1 : (!riscv.reg) -> !riscv.reg
%cmpf15 = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 15 : i32} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.li 1 : !riscv.reg

// tests with fastmath flags when set to "fast"
%cmpf1_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 1 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf2_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 2 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf3_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 3 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.fle.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf4_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 4 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf5_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 5 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.fle.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf6_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 6 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.or %cmpf6_fm_1, %cmpf6_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
%cmpf7_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 7 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.feq.s %rhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.and %cmpf7_fm_1, %cmpf7_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
%cmpf8_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 8 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.or %cmpf8_fm_1, %cmpf8_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf8_fm_2, 1 : (!riscv.reg) -> !riscv.reg
%cmpf9_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 9 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.fle.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf9_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf10_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 10 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf10_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf11_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 11 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.fle.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf11_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf12_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 12 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf12_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf13_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 13 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf13_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf14_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 14 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.feq.s %rhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.and %cmpf14_fm_1, %cmpf14_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf14_fm_2, 1 : (!riscv.reg) -> !riscv.reg
%index_cast = "arith.index_cast"(%lhsindex) : (index) -> i32
// CHECK-NEXT: }
}
32 changes: 32 additions & 0 deletions tests/filecheck/dialects/llvm/arith_invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: xdsl-opt %s --split-input-file --verify-diagnostics | filecheck %s

"builtin.module"() ({

%arg0 = "test.op"() : () -> (i32)

%trunc = llvm.trunc %arg0 : i32 to i64
// CHECK: invalid cast opcode for cast from i32 to i64

}) : () -> ()

// -----

"builtin.module"() ({

%arg0 = "test.op"() : () -> (i32)

%zext = llvm.zext %arg0 : i32 to i16
// CHECK: invalid cast opcode for cast from i32 to i16

}) : () -> ()

// -----

"builtin.module"() ({

%arg0 = "test.op"() : () -> (i32)

%sext = llvm.sext %arg0 : i32 to i16
// CHECK: invalid cast opcode for cast from i32 to i16

}) : () -> ()
9 changes: 9 additions & 0 deletions tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,12 @@

%ashr = llvm.ashr %arg0, %arg1 : i32
// CHECK: %ashr = llvm.ashr %arg0, %arg1 : i32

%trunc = llvm.trunc %arg0 : i32 to i16
// CHECK: %trunc = llvm.trunc %arg0 : i32 to i16

%sext = llvm.sext %arg0 : i32 to i64
// CHECK: %sext = llvm.sext %arg0 : i32 to i64

%zext = llvm.zext %arg0 : i32 to i64
// CHECK: %zext = llvm.zext %arg0 : i32 to i64
46 changes: 43 additions & 3 deletions tests/filecheck/transforms/convert-arith-to-varith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func.func @test_addi() {
// CHECK-NEXT: %a, %b, %c = "test.op"() : () -> (i32, i32, i32)
// CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (i32, i32, i32)
// CHECK-NEXT: %x2 = arith.addi %0, %1 : i32
// CHECK-NEXT: %r = varith.add %c, %a, %b, %2, %0, %1 : i32
// CHECK-NEXT: %r = varith.add %a, %b, %c, %0, %1, %2 : i32
// CHECK-NEXT: "test.op"(%r, %x2) : (i32, i32) -> ()
}

Expand All @@ -45,7 +45,7 @@ func.func @test_addf() {
// CHECK-NEXT: %a, %b, %c = "test.op"() : () -> (f32, f32, f32)
// CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (f32, f32, f32)
// CHECK-NEXT: %x2 = arith.addf %0, %1 : f32
// CHECK-NEXT: %r = varith.add %c, %a, %b, %2, %0, %1 : f32
// CHECK-NEXT: %r = varith.add %a, %b, %c, %0, %1, %2 : f32
// CHECK-NEXT: "test.op"(%r, %x2) : (f32, f32) -> ()
}

Expand All @@ -69,6 +69,46 @@ func.func @test_mulf() {
// CHECK-NEXT: %a, %b, %c = "test.op"() : () -> (f32, f32, f32)
// CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (f32, f32, f32)
// CHECK-NEXT: %x2 = arith.mulf %0, %1 : f32
// CHECK-NEXT: %r = varith.mul %c, %a, %b, %2, %0, %1 : f32
// CHECK-NEXT: %r = varith.mul %a, %b, %c, %0, %1, %2 : f32
// CHECK-NEXT: "test.op"(%r, %x2) : (f32, f32) -> ()
}

func.func @test() {
%0, %1, %2, %3, %4, %5 = "test.op"() : () -> (f32, f32, f32, f32, f32, f32)
%6 = arith.constant 1.234500e-01 : f32
%a = arith.addf %5, %4 : f32
%b = arith.addf %a, %3 : f32
%c = arith.addf %b, %2 : f32
%d = arith.addf %c, %1 : f32
%e = arith.addf %d, %0 : f32
%12 = arith.mulf %e, %6 : f32
"test.op"(%12) : (f32) -> ()
func.return

// CHECK-LABEL: @test
// CHECK-NEXT: %0, %1, %2, %3, %4, %5 = "test.op"() : () -> (f32, f32, f32, f32, f32, f32)
// CHECK-NEXT: %6 = arith.constant 1.234500e-01 : f32
// CHECK-NEXT: %e = varith.add %5, %4, %3, %2, %1, %0 : f32
// CHECK-NEXT: %7 = arith.mulf %e, %6 : f32
// CHECK-NEXT: "test.op"(%7) : (f32) -> ()
}

func.func @test2() {
%0, %1, %2, %3, %4, %5 = "test.op"() : () -> (f32, f32, f32, f32, f32, f32)
%6 = arith.constant 1.234500e-01 : f32
%a = arith.addf %5, %4 : f32
%b = arith.addf %3, %a : f32
%c = arith.addf %2, %b : f32
%d = arith.addf %1, %c : f32
%e = arith.addf %0, %d : f32
%12 = arith.mulf %e, %6 : f32
"test.op"(%12) : (f32) -> ()
func.return

// CHECK-LABEL: @test
// CHECK-NEXT: %0, %1, %2, %3, %4, %5 = "test.op"() : () -> (f32, f32, f32, f32, f32, f32)
// CHECK-NEXT: %6 = arith.constant 1.234500e-01 : f32
// CHECK-NEXT: %e = varith.add %1, %2, %3, %5, %4, %0 : f32
// CHECK-NEXT: %7 = arith.mulf %e, %6 : f32
// CHECK-NEXT: "test.op"(%7) : (f32) -> ()
}
Loading

0 comments on commit 40df9be

Please sign in to comment.