Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: add convert-arith-to-riscv-snitch pass #2914

Merged
merged 4 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/filecheck/backend/riscv/convert_arith_to_riscv_snitch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: xdsl-opt -p convert-arith-to-riscv-snitch,reconcile-unrealized-casts %s | filecheck %s

// CHECK: builtin.module
// CHECK-NEXT: %l, %r = "test.op"() : () -> (!riscv.freg, !riscv.freg)
Comment on lines +3 to +4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// CHECK: builtin.module
// CHECK-NEXT: %l, %r = "test.op"() : () -> (!riscv.freg, !riscv.freg)
// CHECK: builtin.module
// CHECK-NEXT: %[[L.*]], %[[R:.*]] = "test.op"() : () -> (!riscv.freg, !riscv.freg)

Since 9b43500 landed, does this mean we can use named SSA variables now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do it in this case if the variable names are stable and more readable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll merge for now since you've approved but happy to refactor tests together tomorrow, there's more to come

%l, %r = "test.op"() : () -> (!riscv.freg, !riscv.freg)
%l32 = builtin.unrealized_conversion_cast %l : !riscv.freg to vector<2xf32>
%r32 = builtin.unrealized_conversion_cast %r : !riscv.freg to vector<2xf32>
%lhsvf64 = builtin.unrealized_conversion_cast %l : !riscv.freg to vector<1xf64>
%rhsvf64 = builtin.unrealized_conversion_cast %r : !riscv.freg to vector<1xf64>

// CHECK-NEXT: %addf32 = riscv_snitch.vfadd.s %l, %r : (!riscv.freg, !riscv.freg) -> !riscv.freg
%addf32 = arith.addf %l32, %r32 : vector<2xf32>

// tests with fastmath flags when set to "fast"
// CHECK-NEXT: %addf32_fm = riscv_snitch.vfadd.s %l, %r fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.freg
%addf32_fm = arith.addf %l32, %r32 fastmath<fast> : vector<2xf32>


// CHECK-NEXT: %addf64 = riscv.fadd.d %l, %r : (!riscv.freg, !riscv.freg) -> !riscv.freg
%addf64 = arith.addf %lhsvf64, %rhsvf64 : vector<1xf64>

// tests with fastmath flags when set to "fast"
// CHECK-NEXT: %addf64_fm = riscv.fadd.d %l, %r fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.freg
%addf64_fm = arith.addf %lhsvf64, %rhsvf64 fastmath<fast> : vector<1xf64>

// tests with fastmath flags when set to "contract"
// CHECK-NEXT: %addf64_fm_contract = riscv.fadd.d %l, %r fastmath<contract> : (!riscv.freg, !riscv.freg) -> !riscv.freg
%addf64_fm_contract = arith.addf %lhsvf64, %rhsvf64 fastmath<contract> : vector<1xf64>

4 changes: 2 additions & 2 deletions tests/filecheck/dialects/riscv_snitch/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ riscv_func.func @simd() {
// CHECK-GENERIC-NEXT: }) {"sym_name" = "xdma", "function_type" = () -> ()} : () -> ()
// CHECK-GENERIC-NEXT: "riscv_func.func"() ({
// CHECK-GENERIC-NEXT: %v = "riscv.get_float_register"() : () -> !riscv.freg
// CHECK-GENERIC-NEXT: %0 = "riscv_snitch.vfmul.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-GENERIC-NEXT: %1 = "riscv_snitch.vfadd.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-GENERIC-NEXT: %0 = "riscv_snitch.vfmul.s"(%v, %v) {"fastmath" = #riscv.fastmath<none>} : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-GENERIC-NEXT: %1 = "riscv_snitch.vfadd.s"(%v, %v) {"fastmath" = #riscv.fastmath<none>} : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-GENERIC-NEXT: %2 = "riscv_snitch.vfcpka.s.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-GENERIC-NEXT: "riscv_func.return"() : () -> ()
// CHECK-GENERIC-NEXT: }) {"sym_name" = "simd", "function_type" = () -> ()} : () -> ()
Expand Down
50 changes: 50 additions & 0 deletions tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,55 @@
// RUN: xdsl-opt -p convert-linalg-to-memref-stream,test-optimise-memref-stream,test-lower-memref-stream-to-snitch-stream,test-lower-snitch-stream-to-asm -t riscv-asm %s | filecheck %s


func.func public @ssum(
%X: memref<8x16xf32>,
%Y: memref<8x16xf32>,
%Z: memref<8x16xf32>
) {
%X_1 = builtin.unrealized_conversion_cast %X : memref<8x16xf32> to !riscv.reg
%Y_1 = builtin.unrealized_conversion_cast %Y : memref<8x16xf32> to !riscv.reg
%Z_1 = builtin.unrealized_conversion_cast %Z : memref<8x16xf32> to !riscv.reg
snitch_stream.streaming_region {
patterns = [
#snitch_stream.stride_pattern<ub = [64], strides = [8]>
]
} ins(%X_1, %Y_1 : !riscv.reg, !riscv.reg) outs(%Z_1 : !riscv.reg) {
^0(%x : !stream.readable<!riscv.freg>, %y : !stream.readable<!riscv.freg>, %0 : !stream.writable<!riscv.freg>):
%1 = riscv.li 8 : !riscv.reg
%2 = riscv.li 0 : !riscv.reg
%3 = riscv.li 1 : !riscv.reg
%4 = riscv.li 64 : !riscv.reg
riscv_scf.for %5 : !riscv.reg = %2 to %4 step %3 {
%x_1 = riscv_snitch.read from %x : !riscv.freg
%y_1 = riscv_snitch.read from %y : !riscv.freg
%z = riscv.vfadd.s %x_1, %y_1 : (!riscv.freg, !riscv.freg) -> !riscv.freg
riscv_snitch.write %z to %0 : !riscv.freg
}
}
func.return
}

// CHECK: .text
// CHECK-NEXT: .globl ssum
// CHECK-NEXT: .p2align 2
// CHECK-NEXT: ssum:
// CHECK-NEXT: mv t2, a0
// CHECK-NEXT: mv t1, a1
// CHECK-NEXT: mv t0, a2
// CHECK-NEXT: li t3, 63
// CHECK-NEXT: scfgwi t3, 95
// CHECK-NEXT: li t3, 8
// CHECK-NEXT: scfgwi t3, 223
// CHECK-NEXT: scfgwi t2, 768
// CHECK-NEXT: scfgwi t1, 769
// CHECK-NEXT: scfgwi t0, 898
// CHECK-NEXT: csrrsi zero, 1984, 1
// CHECK-NEXT: li t0, 63
// CHECK-NEXT: frep.o t0, 1, 0, 0
// CHECK-NEXT: vfadd.s ft2, ft0, ft1
// CHECK-NEXT: csrrci zero, 1984, 1
// CHECK-NEXT: ret

func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
%X: memref<1x1x10x10xf64>,
%Y: memref<1x1x3x3xf64>,
Expand Down
87 changes: 87 additions & 0 deletions xdsl/backend/riscv/lowering/convert_arith_to_riscv_snitch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from dataclasses import dataclass
from math import prod
from typing import Any, cast

from xdsl.context import MLContext
from xdsl.dialects import arith, riscv, riscv_snitch
from xdsl.dialects.builtin import (
Float32Type,
Float64Type,
ModuleOp,
UnrealizedConversionCastOp,
VectorType,
)
from xdsl.ir import Operation
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
)

_FLOAT_REGISTER_TYPE = riscv.FloatRegisterType.unallocated()


@dataclass
class LowerBinaryFloatVectorOp(RewritePattern):
arith_op_cls: type[arith.FloatingPointLikeBinaryOp]
riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
riscv_snitch_v_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]

def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
if not isinstance(op, self.arith_op_cls):
return

operand_type = op.result.type
if not isinstance(operand_type, VectorType):
return
shape = operand_type.shape
count = prod(dim.data for dim in shape.data)

operand_type = cast(VectorType[Any], operand_type)
scalar_type = operand_type.element_type

lhs = UnrealizedConversionCastOp.get((op.lhs,), (_FLOAT_REGISTER_TYPE,))
rhs = UnrealizedConversionCastOp.get((op.rhs,), (_FLOAT_REGISTER_TYPE,))

match scalar_type:
case Float64Type():
if count != 1:
return
cls = self.riscv_d_op_cls
case Float32Type():
if count != 2:
return
cls = self.riscv_snitch_v_f_op_cls
case _:
assert False, f"Unexpected float type {op.lhs.type}"

rv_flags = riscv.FastMathFlagsAttr("none")
if op.fastmath is not None:
rv_flags = riscv.FastMathFlagsAttr(op.fastmath.data)

new_op = cls(lhs, rhs, rd=_FLOAT_REGISTER_TYPE, fastmath=rv_flags)
cast_op = UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,))

rewriter.replace_matched_op((lhs, rhs, new_op, cast_op))


lower_arith_addf = LowerBinaryFloatVectorOp(
arith.Addf, riscv.FAddDOp, riscv_snitch.VFAddSOp
)


class ConvertArithToRiscvSnitchPass(ModulePass):
name = "convert-arith-to-riscv-snitch"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
walker = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
lower_arith_addf,
]
),
apply_recursively=False,
)
walker.rewrite_module(op)
8 changes: 2 additions & 6 deletions xdsl/dialects/riscv_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,9 +734,7 @@ def assembly_instruction_name(self) -> str:


@irdl_op_definition
class VFMulSOp(
RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType]
):
class VFMulSOp(riscv.RdRsRsFloatOperationWithFastMath):
"""
Performs vectorial multiplication of corresponding f32 values from
rs1 and rs2 and stores the results in the corresponding f32 lanes
Expand All @@ -755,9 +753,7 @@ def assembly_instruction_name(self) -> str:


@irdl_op_definition
class VFAddSOp(
RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType]
):
class VFAddSOp(riscv.RdRsRsFloatOperationWithFastMath):
"""
Performs vectorial addition of corresponding f32 values from
rs1 and rs2 and stores the results in the corresponding f32 lanes
Expand Down
6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ def get_convert_arith_to_riscv():

return convert_arith_to_riscv.ConvertArithToRiscvPass

def get_convert_arith_to_riscv_snitch():
from xdsl.backend.riscv.lowering import convert_arith_to_riscv_snitch

return convert_arith_to_riscv_snitch.ConvertArithToRiscvSnitchPass

def get_convert_func_to_riscv_func():
from xdsl.backend.riscv.lowering import convert_func_to_riscv_func

Expand Down Expand Up @@ -364,6 +369,7 @@ def get_test_optimise_memref_stream():
"canonicalize": get_canonicalize,
"constant-fold-interp": get_constant_fold_interp,
"convert-arith-to-riscv": get_convert_arith_to_riscv,
"convert-arith-to-riscv-snitch": get_convert_arith_to_riscv_snitch,
"convert-func-to-riscv-func": get_convert_func_to_riscv_func,
"convert-linalg-to-memref-stream": get_convert_linalg_to_memref_stream,
"convert-linalg-to-loops": get_convert_linalg_to_loops,
Expand Down
4 changes: 4 additions & 0 deletions xdsl/transforms/test_lower_memref_stream_to_snitch_stream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass

from xdsl.backend.riscv.lowering.convert_arith_to_riscv import ConvertArithToRiscvPass
from xdsl.backend.riscv.lowering.convert_arith_to_riscv_snitch import (
ConvertArithToRiscvSnitchPass,
)
from xdsl.backend.riscv.lowering.convert_func_to_riscv_func import (
ConvertFuncToRiscvFuncPass,
)
Expand All @@ -19,6 +22,7 @@
ConvertMemrefToRiscvPass(),
LowerAffinePass(),
ConvertScfToRiscvPass(),
ConvertArithToRiscvSnitchPass(),
ConvertArithToRiscvPass(),
ConvertFuncToRiscvFuncPass(),
ConvertMemrefStreamToSnitchStreamPass(),
Expand Down
Loading