diff --git a/tests/filecheck/projects/riscv-backend-paper/nsnet.mlir b/tests/filecheck/projects/riscv-backend-paper/nsnet.mlir index b3926ce6bb..6524c240a5 100644 --- a/tests/filecheck/projects/riscv-backend-paper/nsnet.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/nsnet.mlir @@ -1,6 +1,4 @@ // RUN: xdsl-opt -p arith-add-fastmath,test-lower-linalg-to-snitch -t riscv-asm %s | filecheck %s -// RUN: xdsl-opt -p arith-add-fastmath,test-lower-linalg-to-snitch{optimization-level=4} -t riscv-asm %s | filecheck %s -// RUN: xdsl-opt -p arith-add-fastmath,test-lower-linalg-to-snitch{optimization-level=0} -t riscv-asm %s | filecheck %s --check-prefix=CHECK-OPT func.func @main$async_dispatch_0_matmul_transpose_b_1x400x161_f64$xdsl_kernel1(%arg0: memref<1x161xf64>, %arg1: memref<5x161xf64, strided<[161, 1]>>, %arg2: memref<1x5xf64, strided<[40, 1]>>) { linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<1x161xf64>, memref<5x161xf64, strided<[161, 1]>>) outs(%arg2 : memref<1x5xf64, strided<[40, 1]>>) { @@ -59,50 +57,3 @@ func.func @main$async_dispatch_0_matmul_transpose_b_1x400x161_f64$xdsl_kernel1(% // CHECK-NEXT: fsd ft3, 32(t0) # store double value to memref of shape (1, 5) // CHECK-NEXT: csrrci zero, 1984, 1 // CHECK-NEXT: ret - -// CHECK-OPT: .text -// CHECK-OPT-NEXT: .globl main$async_dispatch_0_matmul_transpose_b_1x400x161_f64$xdsl_kernel1 -// CHECK-OPT-NEXT: .p2align 2 -// CHECK-OPT-NEXT: # Regalloc stats: {"preallocated_float": [], "preallocated_int": ["a0", "a1", "a2", "zero"], "allocated_float": ["ft0", "ft1", "ft2"], "allocated_int": ["a0", "a1", "a2", "a3", "a4", "a5", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "zero"]} -// CHECK-OPT-NEXT: main$async_dispatch_0_matmul_transpose_b_1x400x161_f64$xdsl_kernel1: -// CHECK-OPT-NEXT: mv t4, a0 -// CHECK-OPT-NEXT: mv t3, a1 -// CHECK-OPT-NEXT: mv t2, a2 -// CHECK-OPT-NEXT: li t6, 5 -// CHECK-OPT-NEXT: li t0, 161 -// CHECK-OPT-NEXT: mv t5, zero -// CHECK-OPT-NEXT: # Constant folded riscv_cf.bge -// CHECK-OPT-NEXT: scf_body_1_for: -// CHECK-OPT-NEXT: mv a3, zero -// CHECK-OPT-NEXT: # Constant folded riscv_cf.bge -// CHECK-OPT-NEXT: scf_body_0_for: -// CHECK-OPT-NEXT: mv a4, a3 -// CHECK-OPT-NEXT: li a5, 8 -// CHECK-OPT-NEXT: mul a4, a4, a5 # multiply by element size -// CHECK-OPT-NEXT: add a4, t4, a4 -// CHECK-OPT-NEXT: fld ft0, 0(a4) # load double from memref of shape (1, 161) -// CHECK-OPT-NEXT: li a4, 161 -// CHECK-OPT-NEXT: mul a4, t5, a4 -// CHECK-OPT-NEXT: add a4, a4, a3 -// CHECK-OPT-NEXT: li a5, 8 -// CHECK-OPT-NEXT: mul a4, a4, a5 # multiply by element size -// CHECK-OPT-NEXT: add a4, t3, a4 -// CHECK-OPT-NEXT: fld ft1, 0(a4) # load double from memref of shape (5, 161) -// CHECK-OPT-NEXT: mv a4, t5 -// CHECK-OPT-NEXT: li a5, 8 -// CHECK-OPT-NEXT: mul a4, a4, a5 # multiply by element size -// CHECK-OPT-NEXT: add a4, t2, a4 -// CHECK-OPT-NEXT: fld ft2, 0(a4) # load double from memref of shape (1, 5) -// CHECK-OPT-NEXT: fmadd.d ft0, ft0, ft1, ft2 -// CHECK-OPT-NEXT: mv a4, t5 -// CHECK-OPT-NEXT: li a5, 8 -// CHECK-OPT-NEXT: mul a4, a4, a5 # multiply by element size -// CHECK-OPT-NEXT: add a4, t2, a4 -// CHECK-OPT-NEXT: fsd ft0, 0(a4) # store double value to memref of shape (1, 5) -// CHECK-OPT-NEXT: addi a3, a3, 1 -// CHECK-OPT-NEXT: blt a3, t0, scf_body_0_for -// CHECK-OPT-NEXT: scf_body_end_0_for: -// CHECK-OPT-NEXT: addi t5, t5, 1 -// CHECK-OPT-NEXT: blt t5, t6, scf_body_1_for -// CHECK-OPT-NEXT: scf_body_end_1_for: -// CHECK-OPT-NEXT: ret diff --git a/tests/transforms/test_test_lower_linalg_to_snitch.py b/tests/transforms/test_test_lower_linalg_to_snitch.py deleted file mode 100644 index 0194ff086f..0000000000 --- a/tests/transforms/test_test_lower_linalg_to_snitch.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest - -from xdsl.passes import ModulePass -from xdsl.transforms import ( - convert_riscv_scf_for_to_frep, - memref_stream_interleave, - memref_stream_unnest_out_parameters, - memref_streamify, -) -from xdsl.transforms.test_lower_linalg_to_snitch import get_excluded_passes - - -@pytest.mark.parametrize( - "optimization_level,expected", - [ - ( - 0, - ( - memref_stream_interleave.MemrefStreamInterleavePass(), - convert_riscv_scf_for_to_frep.ConvertRiscvScfForToFrepPass(), - memref_stream_unnest_out_parameters.MemrefStreamUnnestOutParametersPass(), - memref_streamify.MemrefStreamifyPass(), - ), - ), - ( - 1, - ( - memref_stream_interleave.MemrefStreamInterleavePass(), - convert_riscv_scf_for_to_frep.ConvertRiscvScfForToFrepPass(), - memref_stream_unnest_out_parameters.MemrefStreamUnnestOutParametersPass(), - ), - ), - ( - 2, - ( - memref_stream_interleave.MemrefStreamInterleavePass(), - convert_riscv_scf_for_to_frep.ConvertRiscvScfForToFrepPass(), - ), - ), - ( - 3, - (memref_stream_interleave.MemrefStreamInterleavePass(),), - ), - (4, ()), - ], -) -def test_get_excluded_passes(optimization_level: int, expected: tuple[ModulePass, ...]): - assert get_excluded_passes(optimization_level) == expected diff --git a/xdsl/transforms/test_lower_linalg_to_snitch.py b/xdsl/transforms/test_lower_linalg_to_snitch.py index 94e9f9d5af..a3b0101ec2 100644 --- a/xdsl/transforms/test_lower_linalg_to_snitch.py +++ b/xdsl/transforms/test_lower_linalg_to_snitch.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from xdsl.backend.riscv.lowering import ( convert_arith_to_riscv, @@ -96,73 +96,16 @@ memref_streamify.MemrefStreamifyPass(), ) -MAX_OPT_LEVEL = len(LINALG_SNITCH_OPTIMIZATION_PASSES) - - -def get_excluded_passes( - optimization_level: int = MAX_OPT_LEVEL, -) -> tuple[ModulePass, ...]: - """ - This function determines which optimization passes should be excluded from the - lowering pipeline based on the specified optimization level. A higher optimization - level includes more passes. - - Args: - optimization_level (int): The desired optimization level, ranging from 0 to - 4 (inclusive). Defaults to 4. - - Returns: - tuple[ModulePass, ...]: A tuple containing the ModulePass objects to be excluded - from the lowering pipeline. - """ - - if optimization_level == MAX_OPT_LEVEL: - return () - - return ( - LINALG_SNITCH_OPTIMIZATION_PASSES[:-optimization_level] - if optimization_level - else LINALG_SNITCH_OPTIMIZATION_PASSES - ) - - -def get_passes(optimization_level: int = MAX_OPT_LEVEL) -> tuple[ModulePass, ...]: - """ - This function returns a tuple of ModulePass objects to be applied in the lowering - pipeline, based on the specified optimization level. - - Args: - optimization_level (int): The desired optimization level, ranging from 0 to - 4 (inclusive). Defaults to 4. - - Returns: - tuple[ModulePass, ...]: A tuple containing the ModulePass objects to be applied - in the lowering pipeline. - """ - - excluded_passes = get_excluded_passes(optimization_level) - return tuple( - p for p in TEST_LOWER_LINALG_TO_SNITCH_PASSES if p not in excluded_passes - ) - @dataclass(frozen=True) class TestLowerLinalgToSnitchPass(ModulePass): """ A compiler pass used for testing lowering microkernels from linalg generic to snitch assembly. - - Args: - optimization_level (int): The desired optimization level, ranging from 0 to - 4 (inclusive). Defaults to 4. """ name = "test-lower-linalg-to-snitch" - optimization_level: int = field(default=MAX_OPT_LEVEL) - def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - passes = get_passes(self.optimization_level) - - for p in passes: + for p in TEST_LOWER_LINALG_TO_SNITCH_PASSES: p.apply(ctx, op)