Skip to content

Commit

Permalink
convert-snrt-to-riscv: add lowerings for almost all info ops
Browse files Browse the repository at this point in the history
Missing ops:
- snrt.global_compute_core_idx
- snrt.global_compute_core_num
- snrt.global_dm_core_num
- snrt.cluster_compute_core_idx
- snrt.cluster_dm_core_idx

These seem to be relatively unused in practice and are skipped for now
  • Loading branch information
AntonLydike committed May 10, 2024
1 parent 81fc104 commit 930d009
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 7 deletions.
56 changes: 51 additions & 5 deletions tests/filecheck/transforms/convert-snrt-to-riscv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
// unsupported: %global_compute_core_idx = "snrt.global_compute_core_idx"() : () -> i32
// unsupported: %global_compute_core_num = "snrt.global_compute_core_num"() : () -> i32
// unsupported: %global_dm_core_num = "snrt.global_dm_core_num"() : () -> i32
// unsupported: %gcluster_core_idx = "snrt.cluster_core_idx"() : () -> i32
%gcluster_core_idx = "snrt.cluster_core_idx"() : () -> i32
%cluster_core_num = "snrt.cluster_core_num"() : () -> i32
// unsupported: %cluster_compute_core_idx = "snrt.cluster_compute_core_idx"() : () -> i32
%cluster_compute_core_num = "snrt.cluster_compute_core_num"() : () -> i32
// unsupported: %cluster_dm_core_idx = "snrt.cluster_dm_core_idx"() : () -> i32
%cluster_dm_core_num = "snrt.cluster_dm_core_num"() : () -> i32
// unsupported: %cluster_idx = "snrt.cluster_idx"() : () -> i32
%cluster_idx = "snrt.cluster_idx"() : () -> i32
%cluster_num = "snrt.cluster_num"() : () -> i32
// unsupported: %is_compute_core = "snrt.is_compute_core"() : () -> i1
// unsupported: %is_dm_core = "snrt.is_dm_core"() : () -> i1
%is_compute_core = "snrt.is_compute_core"() : () -> i1
%is_dm_core = "snrt.is_dm_core"() : () -> i1

"snrt.cluster_hw_barrier"() : () -> ()
"snrt.ssr_disable"() : () -> ()
Expand All @@ -37,17 +37,63 @@
// CHECK-NEXT: builtin.module {
// CHECK-NEXT: %global_core_base_hartid = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %global_core_base_hartid_1 = builtin.unrealized_conversion_cast %global_core_base_hartid : !riscv.reg<> to i32
// CHECK-NEXT: %global_core_idx = "snrt.global_core_idx"() : () -> i32
// CHECK-NEXT: %global_core_idx = riscv.get_register : () -> !riscv.reg<zero>
// CHECK-NEXT: %global_core_idx_1 = riscv.csrrs %global_core_idx, 3860, "r" : (!riscv.reg<zero>) -> !riscv.reg<>
// CHECK-NEXT: %global_core_idx_2 = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %global_core_idx_3 = riscv.sub %global_core_idx_1, %global_core_idx_2 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %global_core_idx_4 = builtin.unrealized_conversion_cast %global_core_idx_3 : !riscv.reg<> to i32
// CHECK-NEXT: %global_core_num = riscv.li 18 : () -> !riscv.reg<>
// CHECK-NEXT: %global_core_num_1 = builtin.unrealized_conversion_cast %global_core_num : !riscv.reg<> to i32
// CHECK-NEXT: %gcluster_core_idx = riscv.get_register : () -> !riscv.reg<zero>
// CHECK-NEXT: %gcluster_core_idx_1 = riscv.csrrs %gcluster_core_idx, 3860, "r" : (!riscv.reg<zero>) -> !riscv.reg<>
// CHECK-NEXT: %gcluster_core_idx_2 = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %gcluster_core_idx_3 = riscv.sub %gcluster_core_idx_1, %gcluster_core_idx_2 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %gcluster_core_idx_4 = builtin.unrealized_conversion_cast %gcluster_core_idx_3 : !riscv.reg<> to i32
// CHECK-NEXT: %gcluster_core_idx_5 = riscv.li 9 : () -> !riscv.reg<>
// CHECK-NEXT: %gcluster_core_idx_6 = builtin.unrealized_conversion_cast %gcluster_core_idx_5 : !riscv.reg<> to i32
// CHECK-NEXT: %gcluster_core_idx_7 = arith.remsi %gcluster_core_idx_4, %gcluster_core_idx_6 : i32
// CHECK-NEXT: %cluster_core_num = riscv.li 9 : () -> !riscv.reg<>
// CHECK-NEXT: %cluster_core_num_1 = builtin.unrealized_conversion_cast %cluster_core_num : !riscv.reg<> to i32
// CHECK-NEXT: %cluster_compute_core_num = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %cluster_compute_core_num_1 = builtin.unrealized_conversion_cast %cluster_compute_core_num : !riscv.reg<> to i32
// CHECK-NEXT: %cluster_dm_core_num = riscv.li 1 : () -> !riscv.reg<>
// CHECK-NEXT: %cluster_dm_core_num_1 = builtin.unrealized_conversion_cast %cluster_dm_core_num : !riscv.reg<> to i32
// CHECK-NEXT: %cluster_idx = riscv.li 9 : () -> !riscv.reg<>
// CHECK-NEXT: %cluster_idx_1 = builtin.unrealized_conversion_cast %cluster_idx : !riscv.reg<> to i32
// CHECK-NEXT: %cluster_idx_2 = riscv.get_register : () -> !riscv.reg<zero>
// CHECK-NEXT: %cluster_idx_3 = riscv.csrrs %cluster_idx_2, 3860, "r" : (!riscv.reg<zero>) -> !riscv.reg<>
// CHECK-NEXT: %cluster_idx_4 = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %cluster_idx_5 = riscv.sub %cluster_idx_3, %cluster_idx_4 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %cluster_idx_6 = builtin.unrealized_conversion_cast %cluster_idx_5 : !riscv.reg<> to i32
// CHECK-NEXT: %cluster_idx_7 = builtin.unrealized_conversion_cast %cluster_idx_1 : i32 to !riscv.reg<>
// CHECK-NEXT: %cluster_idx_8 = builtin.unrealized_conversion_cast %cluster_idx_6 : i32 to !riscv.reg<>
// CHECK-NEXT: %cluster_idx_9 = riscv.div %cluster_idx_8, %cluster_idx_7 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %cluster_idx_10 = builtin.unrealized_conversion_cast %cluster_idx_9 : !riscv.reg<> to i32
// CHECK-NEXT: %cluster_num = riscv.li 2 : () -> !riscv.reg<>
// CHECK-NEXT: %cluster_num_1 = builtin.unrealized_conversion_cast %cluster_num : !riscv.reg<> to i32
// CHECK-NEXT: %is_compute_core = riscv.get_register : () -> !riscv.reg<zero>
// CHECK-NEXT: %is_compute_core_1 = riscv.csrrs %is_compute_core, 3860, "r" : (!riscv.reg<zero>) -> !riscv.reg<>
// CHECK-NEXT: %is_compute_core_2 = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %is_compute_core_3 = riscv.sub %is_compute_core_1, %is_compute_core_2 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %is_compute_core_4 = builtin.unrealized_conversion_cast %is_compute_core_3 : !riscv.reg<> to i32
// CHECK-NEXT: %is_compute_core_5 = riscv.li 9 : () -> !riscv.reg<>
// CHECK-NEXT: %is_compute_core_6 = builtin.unrealized_conversion_cast %is_compute_core_5 : !riscv.reg<> to i32
// CHECK-NEXT: %is_compute_core_7 = arith.remsi %is_compute_core_4, %is_compute_core_6 : i32
// CHECK-NEXT: %is_compute_core_8 = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %is_compute_core_9 = builtin.unrealized_conversion_cast %is_compute_core_8 : !riscv.reg<> to i32
// CHECK-NEXT: %is_compute_core_10 = arith.cmpi slt, %is_compute_core_7, %is_compute_core_9 : i32
// CHECK-NEXT: %is_dm_core = riscv.get_register : () -> !riscv.reg<zero>
// CHECK-NEXT: %is_dm_core_1 = riscv.csrrs %is_dm_core, 3860, "r" : (!riscv.reg<zero>) -> !riscv.reg<>
// CHECK-NEXT: %is_dm_core_2 = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %is_dm_core_3 = riscv.sub %is_dm_core_1, %is_dm_core_2 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %is_dm_core_4 = builtin.unrealized_conversion_cast %is_dm_core_3 : !riscv.reg<> to i32
// CHECK-NEXT: %is_dm_core_5 = riscv.li 9 : () -> !riscv.reg<>
// CHECK-NEXT: %is_dm_core_6 = builtin.unrealized_conversion_cast %is_dm_core_5 : !riscv.reg<> to i32
// CHECK-NEXT: %is_dm_core_7 = arith.remsi %is_dm_core_4, %is_dm_core_6 : i32
// CHECK-NEXT: %is_dm_core_8 = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %is_dm_core_9 = builtin.unrealized_conversion_cast %is_dm_core_8 : !riscv.reg<> to i32
// CHECK-NEXT: %is_dm_core_10 = arith.cmpi sge, %is_dm_core_7, %is_dm_core_9 : i32


// Lowering of cluster_hw_barrier
// CHECK-NEXT: %0 = riscv.get_register : () -> !riscv.reg<zero>
Expand Down
138 changes: 136 additions & 2 deletions xdsl/transforms/convert_snrt_to_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Sequence
from dataclasses import dataclass

from xdsl.dialects import builtin, riscv, riscv_snitch, snitch_runtime
from xdsl.dialects import arith, builtin, riscv, riscv_snitch, snitch_runtime
from xdsl.dialects.builtin import IntegerAttr
from xdsl.ir import MLContext, Operation, SSAValue
from xdsl.passes import ModulePass
Expand All @@ -16,7 +16,7 @@


@dataclass(frozen=True)
class SnrtConstants:
class SnrtConstants(ABC):
"""
Constants used when compiling the snitch runtime, depend on the exact snitch
architecture target.
Expand Down Expand Up @@ -493,6 +493,67 @@ def match_and_rewrite(
)


class LowerIsComputeCore(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: snitch_runtime.IsComputeCoreOp, rewriter: PatternRewriter, /
):
"""
inline int __attribute__((const)) snrt_is_compute_core() {
return snrt_cluster_core_idx() < snrt_cluster_compute_core_num();
}
"""
rewriter.replace_matched_op(
[
cluster_core_idx := snitch_runtime.ClusterCoreIdxOp(),
compute_core_num := snitch_runtime.ClusterComputeCoreNumOp(),
arith.Cmpi(cluster_core_idx, compute_core_num, "slt"),
]
)


class LowerIsDmCore(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: snitch_runtime.IsDmCoreOp, rewriter: PatternRewriter, /
):
"""
inline int __attribute__((const)) snrt_is_compute_core() {
return snrt_cluster_core_idx() < snrt_cluster_compute_core_num();
}
inline int __attribute__((const)) snrt_is_dm_core() {
return !snrt_is_compute_core();
}
"""
rewriter.replace_matched_op(
[
cluster_core_idx := snitch_runtime.ClusterCoreIdxOp(),
compute_core_num := snitch_runtime.ClusterComputeCoreNumOp(),
arith.Cmpi(cluster_core_idx, compute_core_num, "sge"),
]
)


class LowerClusterCoreIdx(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: snitch_runtime.ClusterCoreIdxOp, rewriter: PatternRewriter, /
):
"""
inline uint32_t __attribute__((const)) snrt_cluster_core_idx() {
return snrt_global_core_idx() % snrt_cluster_core_num();
}
"""
rewriter.replace_matched_op(
[
global_core_idx := snitch_runtime.GlobalCoreIdxOp(),
cluster_core_num := snitch_runtime.ClusterCoreNumOp(),
arith.RemSI(global_core_idx, cluster_core_num),
]
)


@dataclass
class LowerClusterComputeCoreNum(RewritePattern):
constants: SnrtConstants
Expand Down Expand Up @@ -528,6 +589,72 @@ def match_and_rewrite(
)


@dataclass
class LowerGlobalCoreIdx(RewritePattern):
constants: SnrtConstants

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: snitch_runtime.GlobalCoreIdxOp, rewriter: PatternRewriter, /
):
"""
Implementation:
inline uint32_t __attribute__((const)) snrt_hartid() {
uint32_t hartid;
asm("csrr %0, mhartid" : "=r"(hartid));
return hartid;
}
inline uint32_t __attribute__((const)) snrt_global_core_idx() {
return snrt_hartid() - snrt_global_core_base_hartid();
}
"""
rewriter.replace_matched_op(
[
zero := riscv.GetRegisterOp(riscv.Registers.ZERO),
hartid := riscv.CsrrsOp(zero, IntegerAttr(0xF14, 12), readonly=True),
base_hartid := riscv.LiOp(self.constants.base_hartid),
core_idx := riscv.SubOp(
hartid, base_hartid, rd=riscv.IntRegisterType.unallocated()
),
builtin.UnrealizedConversionCastOp.get([core_idx], [builtin.i32]),
]
)


class LowerClusterIdx(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: snitch_runtime.ClusterIdxOp, rewriter: PatternRewriter, /
):
"""
Implementation:
inline uint32_t __attribute__((const)) snrt_cluster_idx() {
return snrt_global_core_idx() / snrt_cluster_core_num();
}
"""
rewriter.replace_matched_op(
[
cluster_core_num := snitch_runtime.ClusterCoreNumOp(),
core_idx := snitch_runtime.GlobalCoreIdxOp(),
cluster_core_num_reg := builtin.UnrealizedConversionCastOp.get(
[cluster_core_num], [riscv.IntRegisterType.unallocated()]
),
core_idx_reg := builtin.UnrealizedConversionCastOp.get(
[core_idx], [riscv.IntRegisterType.unallocated()]
),
res := riscv.DivOp(
core_idx_reg,
cluster_core_num_reg,
rd=riscv.IntRegisterType.unallocated(),
),
builtin.UnrealizedConversionCastOp.get([res], [builtin.i32]),
]
)


@dataclass(frozen=True)
class ConvertSnrtToRISCV(SnrtConstants, ModulePass):
"""
Expand All @@ -536,6 +663,8 @@ class ConvertSnrtToRISCV(SnrtConstants, ModulePass):

name = "convert-snrt-to-riscv"

cluster_num: int

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
Expand All @@ -547,12 +676,17 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
LowerDMAStart2D(),
LowerDMAStart2DWideptr(),
# information getting ops:
LowerClusterIdx(),
LowerClusterNum(self),
LowerClusterCoreIdx(),
LowerClusterCoreNum(self),
LowerClusterDmCoreNum(self),
LowerClusterComputeCoreNum(self),
LowerGlobalCoreNum(self),
LowerGlobalCoreIdx(self),
LowerGlobalCoreBaseHartid(self),
LowerIsComputeCore(),
LowerIsDmCore(),
]
)
).rewrite_module(op)

0 comments on commit 930d009

Please sign in to comment.