From f9a639c3a9a7f459cdc4914fb6ed8ec2c74fb37b Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sat, 20 Jul 2024 20:22:28 +0100 Subject: [PATCH 1/3] testing: add ssum testcase to bottom-up tests --- .../riscv-backend-paper/bottom_up.mlir | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir index 660b986488..fc43db743d 100644 --- a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir @@ -1,5 +1,55 @@ // RUN: xdsl-opt -p convert-linalg-to-memref-stream,test-optimise-memref-stream,test-lower-memref-stream-to-snitch-stream,test-lower-snitch-stream-to-asm -t riscv-asm %s | filecheck %s + +func.func public @ssum( + %X: memref<8x16xf32>, + %Y: memref<8x16xf32>, + %Z: memref<8x16xf32> +) { + %X_1 = builtin.unrealized_conversion_cast %X : memref<8x16xf32> to !riscv.reg + %Y_1 = builtin.unrealized_conversion_cast %Y : memref<8x16xf32> to !riscv.reg + %Z_1 = builtin.unrealized_conversion_cast %Z : memref<8x16xf32> to !riscv.reg + snitch_stream.streaming_region { + patterns = [ + #snitch_stream.stride_pattern + ] + } ins(%X_1, %Y_1 : !riscv.reg, !riscv.reg) outs(%Z_1 : !riscv.reg) { + ^0(%x : !stream.readable, %y : !stream.readable, %0 : !stream.writable): + %1 = riscv.li 8 : !riscv.reg + %2 = riscv.li 0 : !riscv.reg + %3 = riscv.li 1 : !riscv.reg + %4 = riscv.li 64 : !riscv.reg + riscv_scf.for %5 : !riscv.reg = %2 to %4 step %3 { + %x_1 = riscv_snitch.read from %x : !riscv.freg + %y_1 = riscv_snitch.read from %y : !riscv.freg + %z = riscv.vfadd.s %x_1, %y_1 : (!riscv.freg, !riscv.freg) -> !riscv.freg + riscv_snitch.write %z to %0 : !riscv.freg + } + } + func.return +} + +// CHECK: .text +// CHECK-NEXT: .globl ssum +// CHECK-NEXT: .p2align 2 +// CHECK-NEXT: ssum: +// CHECK-NEXT: mv t2, a0 +// CHECK-NEXT: mv t1, a1 +// CHECK-NEXT: mv t0, a2 +// CHECK-NEXT: li t3, 63 +// CHECK-NEXT: scfgwi t3, 95 +// CHECK-NEXT: li t3, 8 +// CHECK-NEXT: scfgwi t3, 223 +// CHECK-NEXT: scfgwi t2, 768 +// CHECK-NEXT: scfgwi t1, 769 +// CHECK-NEXT: scfgwi t0, 898 +// CHECK-NEXT: csrrsi zero, 1984, 1 +// CHECK-NEXT: li t0, 63 +// CHECK-NEXT: frep.o t0, 1, 0, 0 +// CHECK-NEXT: vfadd.s ft2, ft0, ft1 +// CHECK-NEXT: csrrci zero, 1984, 1 +// CHECK-NEXT: ret + func.func public @conv_2d_nchw_fchw_d1_s1_3x3( %X: memref<1x1x10x10xf64>, %Y: memref<1x1x3x3xf64>, From cb8c277d74e91e393d30edd4dbba93c0c2b37ba2 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sun, 21 Jul 2024 10:10:21 +0200 Subject: [PATCH 2/3] dialects: (riscv_snitch) add fastmath flags in vector ops --- tests/filecheck/dialects/riscv_snitch/ops.mlir | 4 ++-- xdsl/dialects/riscv_snitch.py | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/filecheck/dialects/riscv_snitch/ops.mlir b/tests/filecheck/dialects/riscv_snitch/ops.mlir index 81b679c2c9..107282367e 100644 --- a/tests/filecheck/dialects/riscv_snitch/ops.mlir +++ b/tests/filecheck/dialects/riscv_snitch/ops.mlir @@ -147,8 +147,8 @@ riscv_func.func @simd() { // CHECK-GENERIC-NEXT: }) {"sym_name" = "xdma", "function_type" = () -> ()} : () -> () // CHECK-GENERIC-NEXT: "riscv_func.func"() ({ // CHECK-GENERIC-NEXT: %v = "riscv.get_float_register"() : () -> !riscv.freg -// CHECK-GENERIC-NEXT: %0 = "riscv_snitch.vfmul.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg -// CHECK-GENERIC-NEXT: %1 = "riscv_snitch.vfadd.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg +// CHECK-GENERIC-NEXT: %0 = "riscv_snitch.vfmul.s"(%v, %v) {"fastmath" = #riscv.fastmath} : (!riscv.freg, !riscv.freg) -> !riscv.freg +// CHECK-GENERIC-NEXT: %1 = "riscv_snitch.vfadd.s"(%v, %v) {"fastmath" = #riscv.fastmath} : (!riscv.freg, !riscv.freg) -> !riscv.freg // CHECK-GENERIC-NEXT: %2 = "riscv_snitch.vfcpka.s.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg // CHECK-GENERIC-NEXT: "riscv_func.return"() : () -> () // CHECK-GENERIC-NEXT: }) {"sym_name" = "simd", "function_type" = () -> ()} : () -> () diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 57a1528780..86858f5125 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -734,9 +734,7 @@ def assembly_instruction_name(self) -> str: @irdl_op_definition -class VFMulSOp( - RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType] -): +class VFMulSOp(riscv.RdRsRsFloatOperationWithFastMath): """ Performs vectorial multiplication of corresponding f32 values from rs1 and rs2 and stores the results in the corresponding f32 lanes @@ -755,9 +753,7 @@ def assembly_instruction_name(self) -> str: @irdl_op_definition -class VFAddSOp( - RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType] -): +class VFAddSOp(riscv.RdRsRsFloatOperationWithFastMath): """ Performs vectorial addition of corresponding f32 values from rs1 and rs2 and stores the results in the corresponding f32 lanes From 76bb14901a5694b966a79be6d73c8a9285c88a8b Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sat, 20 Jul 2024 23:26:10 +0200 Subject: [PATCH 3/3] transformations: add convert-arith-to-riscv-snitch pass --- .../riscv/convert_arith_to_riscv_snitch.mlir | 29 +++++++ .../lowering/convert_arith_to_riscv_snitch.py | 87 +++++++++++++++++++ xdsl/tools/command_line_tool.py | 6 ++ ...st_lower_memref_stream_to_snitch_stream.py | 4 + 4 files changed, 126 insertions(+) create mode 100644 tests/filecheck/backend/riscv/convert_arith_to_riscv_snitch.mlir create mode 100644 xdsl/backend/riscv/lowering/convert_arith_to_riscv_snitch.py diff --git a/tests/filecheck/backend/riscv/convert_arith_to_riscv_snitch.mlir b/tests/filecheck/backend/riscv/convert_arith_to_riscv_snitch.mlir new file mode 100644 index 0000000000..e840ec83a0 --- /dev/null +++ b/tests/filecheck/backend/riscv/convert_arith_to_riscv_snitch.mlir @@ -0,0 +1,29 @@ +// RUN: xdsl-opt -p convert-arith-to-riscv-snitch,reconcile-unrealized-casts %s | filecheck %s + +// CHECK: builtin.module +// CHECK-NEXT: %l, %r = "test.op"() : () -> (!riscv.freg, !riscv.freg) +%l, %r = "test.op"() : () -> (!riscv.freg, !riscv.freg) +%l32 = builtin.unrealized_conversion_cast %l : !riscv.freg to vector<2xf32> +%r32 = builtin.unrealized_conversion_cast %r : !riscv.freg to vector<2xf32> +%lhsvf64 = builtin.unrealized_conversion_cast %l : !riscv.freg to vector<1xf64> +%rhsvf64 = builtin.unrealized_conversion_cast %r : !riscv.freg to vector<1xf64> + +// CHECK-NEXT: %addf32 = riscv_snitch.vfadd.s %l, %r : (!riscv.freg, !riscv.freg) -> !riscv.freg +%addf32 = arith.addf %l32, %r32 : vector<2xf32> + +// tests with fastmath flags when set to "fast" +// CHECK-NEXT: %addf32_fm = riscv_snitch.vfadd.s %l, %r fastmath : (!riscv.freg, !riscv.freg) -> !riscv.freg +%addf32_fm = arith.addf %l32, %r32 fastmath : vector<2xf32> + + +// CHECK-NEXT: %addf64 = riscv.fadd.d %l, %r : (!riscv.freg, !riscv.freg) -> !riscv.freg +%addf64 = arith.addf %lhsvf64, %rhsvf64 : vector<1xf64> + +// tests with fastmath flags when set to "fast" +// CHECK-NEXT: %addf64_fm = riscv.fadd.d %l, %r fastmath : (!riscv.freg, !riscv.freg) -> !riscv.freg +%addf64_fm = arith.addf %lhsvf64, %rhsvf64 fastmath : vector<1xf64> + +// tests with fastmath flags when set to "contract" +// CHECK-NEXT: %addf64_fm_contract = riscv.fadd.d %l, %r fastmath : (!riscv.freg, !riscv.freg) -> !riscv.freg +%addf64_fm_contract = arith.addf %lhsvf64, %rhsvf64 fastmath : vector<1xf64> + diff --git a/xdsl/backend/riscv/lowering/convert_arith_to_riscv_snitch.py b/xdsl/backend/riscv/lowering/convert_arith_to_riscv_snitch.py new file mode 100644 index 0000000000..71acd590c4 --- /dev/null +++ b/xdsl/backend/riscv/lowering/convert_arith_to_riscv_snitch.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass +from math import prod +from typing import Any, cast + +from xdsl.context import MLContext +from xdsl.dialects import arith, riscv, riscv_snitch +from xdsl.dialects.builtin import ( + Float32Type, + Float64Type, + ModuleOp, + UnrealizedConversionCastOp, + VectorType, +) +from xdsl.ir import Operation +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriter, + PatternRewriteWalker, + RewritePattern, +) + +_FLOAT_REGISTER_TYPE = riscv.FloatRegisterType.unallocated() + + +@dataclass +class LowerBinaryFloatVectorOp(RewritePattern): + arith_op_cls: type[arith.FloatingPointLikeBinaryOp] + riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath] + riscv_snitch_v_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath] + + def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None: + if not isinstance(op, self.arith_op_cls): + return + + operand_type = op.result.type + if not isinstance(operand_type, VectorType): + return + shape = operand_type.shape + count = prod(dim.data for dim in shape.data) + + operand_type = cast(VectorType[Any], operand_type) + scalar_type = operand_type.element_type + + lhs = UnrealizedConversionCastOp.get((op.lhs,), (_FLOAT_REGISTER_TYPE,)) + rhs = UnrealizedConversionCastOp.get((op.rhs,), (_FLOAT_REGISTER_TYPE,)) + + match scalar_type: + case Float64Type(): + if count != 1: + return + cls = self.riscv_d_op_cls + case Float32Type(): + if count != 2: + return + cls = self.riscv_snitch_v_f_op_cls + case _: + assert False, f"Unexpected float type {op.lhs.type}" + + rv_flags = riscv.FastMathFlagsAttr("none") + if op.fastmath is not None: + rv_flags = riscv.FastMathFlagsAttr(op.fastmath.data) + + new_op = cls(lhs, rhs, rd=_FLOAT_REGISTER_TYPE, fastmath=rv_flags) + cast_op = UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)) + + rewriter.replace_matched_op((lhs, rhs, new_op, cast_op)) + + +lower_arith_addf = LowerBinaryFloatVectorOp( + arith.Addf, riscv.FAddDOp, riscv_snitch.VFAddSOp +) + + +class ConvertArithToRiscvSnitchPass(ModulePass): + name = "convert-arith-to-riscv-snitch" + + def apply(self, ctx: MLContext, op: ModuleOp) -> None: + walker = PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + lower_arith_addf, + ] + ), + apply_recursively=False, + ) + walker.rewrite_module(op) diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index 07cb6cb89e..b77a9e67cf 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -231,6 +231,11 @@ def get_convert_arith_to_riscv(): return convert_arith_to_riscv.ConvertArithToRiscvPass + def get_convert_arith_to_riscv_snitch(): + from xdsl.backend.riscv.lowering import convert_arith_to_riscv_snitch + + return convert_arith_to_riscv_snitch.ConvertArithToRiscvSnitchPass + def get_convert_func_to_riscv_func(): from xdsl.backend.riscv.lowering import convert_func_to_riscv_func @@ -364,6 +369,7 @@ def get_test_optimise_memref_stream(): "canonicalize": get_canonicalize, "constant-fold-interp": get_constant_fold_interp, "convert-arith-to-riscv": get_convert_arith_to_riscv, + "convert-arith-to-riscv-snitch": get_convert_arith_to_riscv_snitch, "convert-func-to-riscv-func": get_convert_func_to_riscv_func, "convert-linalg-to-memref-stream": get_convert_linalg_to_memref_stream, "convert-linalg-to-loops": get_convert_linalg_to_loops, diff --git a/xdsl/transforms/test_lower_memref_stream_to_snitch_stream.py b/xdsl/transforms/test_lower_memref_stream_to_snitch_stream.py index 1cf46cf822..7113ea9433 100644 --- a/xdsl/transforms/test_lower_memref_stream_to_snitch_stream.py +++ b/xdsl/transforms/test_lower_memref_stream_to_snitch_stream.py @@ -1,6 +1,9 @@ from dataclasses import dataclass from xdsl.backend.riscv.lowering.convert_arith_to_riscv import ConvertArithToRiscvPass +from xdsl.backend.riscv.lowering.convert_arith_to_riscv_snitch import ( + ConvertArithToRiscvSnitchPass, +) from xdsl.backend.riscv.lowering.convert_func_to_riscv_func import ( ConvertFuncToRiscvFuncPass, ) @@ -19,6 +22,7 @@ ConvertMemrefToRiscvPass(), LowerAffinePass(), ConvertScfToRiscvPass(), + ConvertArithToRiscvSnitchPass(), ConvertArithToRiscvPass(), ConvertFuncToRiscvFuncPass(), ConvertMemrefStreamToSnitchStreamPass(),