Skip to content

Commit

Permalink
Add Zbb support
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryz123 committed Aug 12, 2024
1 parent 90da5bc commit 83e8c6b
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 14 deletions.
71 changes: 64 additions & 7 deletions src/main/scala/rocket/ALU.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
package freechips.rocketchip.rocket

import chisel3._
import chisel3.util.{BitPat, Fill, Cat, Reverse}
import chisel3.util.{BitPat, Fill, Cat, Reverse, PriorityEncoderOH, PopCount, MuxLookup}
import org.chipsalliance.cde.config.Parameters
import freechips.rocketchip.tile.CoreModule
import freechips.rocketchip.util._

object ALU {
val SZ_ALU_FN = 4
def FN_X = BitPat("b????")
val SZ_ALU_FN = 5
def FN_X = BitPat("b?????")
def FN_ADD = 0.U
def FN_SL = 1.U
def FN_SEQ = 2.U
Expand All @@ -27,6 +28,15 @@ object ALU {
def FN_SGE = 13.U
def FN_SLTU = 14.U
def FN_SGEU = 15.U
def FN_UNARY = 16.U
def FN_ROL = 17.U
def FN_ROR = 18.U

def FN_MAX = 28.U
def FN_MIN = 29.U
def FN_MAXU = 30.U
def FN_MINU = 31.U
def FN_MAXMIN = BitPat("b111??")

// Mul/div reuse some integer FNs
def FN_DIV = FN_XOR
Expand All @@ -41,10 +51,12 @@ object ALU {

def isMulFN(fn: UInt, cmp: UInt) = fn(1,0) === cmp(1,0)
def isSub(cmd: UInt) = cmd(3)
def isCmp(cmd: UInt) = cmd >= FN_SLT
def isCmp(cmd: UInt) = (cmd >= FN_SLT && cmd <= FN_SGEU)
def isMaxMin(cmd: UInt) = (cmd >= FN_MAX && cmd <= FN_MINU)
def cmpUnsigned(cmd: UInt) = cmd(1)
def cmpInverted(cmd: UInt) = cmd(0)
def cmpEq(cmd: UInt) = !cmd(3)
def shiftReverse(cmd: UInt) = !cmd.isOneOf(FN_SR, FN_SRA, FN_ROR)
}

import ALU._
Expand Down Expand Up @@ -84,11 +96,11 @@ class ALU(implicit p: Parameters) extends AbstractALU()(p) {
val shamt = Cat(io.in2(5) & (io.dw === DW_64), io.in2(4,0))
(shamt, Cat(shin_hi, io.in1(31,0)))
}
val shin = Mux(io.fn === FN_SR || io.fn === FN_SRA, shin_r, Reverse(shin_r))
val shin = Mux(shiftReverse(io.fn), Reverse(shin_r), shin_r)
val shout_r = (Cat(isSub(io.fn) & shin(xLen-1), shin).asSInt >> shamt)(xLen-1,0)
val shout_l = Reverse(shout_r)
val shout = Mux(io.fn === FN_SR || io.fn === FN_SRA, shout_r, 0.U) |
Mux(io.fn === FN_SL, shout_l, 0.U)
Mux(io.fn === FN_SL, shout_l, 0.U)

// CZEQZ, CZNEZ
val in2_not_zero = io.in2.orR
Expand All @@ -105,7 +117,52 @@ class ALU(implicit p: Parameters) extends AbstractALU()(p) {
case Some(co) => shift_logic | co
case _ => shift_logic
}
val out = Mux(io.fn === FN_ADD || io.fn === FN_SUB, io.adder_out, shift_logic_cond)

// CLZ, CTZ, CPOP
val tz_in = MuxLookup((io.dw === DW_32) ## !io.in2(0), 0.U)(Seq(
0.U -> io.in1,
1.U -> Reverse(io.in1),
2.U -> 1.U ## io.in1(31,0),
3.U -> 1.U ## Reverse(io.in1(31,0))
))
val popc_in = Mux(io.in2(1),
Mux(io.dw === DW_32, io.in1(31,0), io.in1),
PriorityEncoderOH(1.U ## tz_in) - 1.U)
val count = PopCount(popc_in)
val in1_bytes = io.in1.asTypeOf(Vec(xLen / 8, UInt(8.W)))
val orcb = VecInit(in1_bytes.map(b => Fill(8, b =/= 0.U))).asUInt
val rev8 = VecInit(in1_bytes.reverse).asUInt
val unary = MuxLookup(io.in2(11,0), count)(Seq(
0x287.U -> orcb,
(if (xLen == 32) 0x698 else 0x6b8).U -> rev8,
0x080.U -> io.in1(15,0),
0x604.U -> Fill(xLen-8, io.in1(7)) ## io.in1(7,0),
0x605.U -> Fill(xLen-16, io.in1(15)) ## io.in1(15,0)
))

// MAX, MIN, MAXU, MINU
val maxmin_out = Mux(io.cmp_out, io.in2, io.in1)

// ROL, ROR
val rot_shamt = Mux(io.dw === DW_32, 32.U, 64.U) - shamt
val rotin = Mux(io.fn(0), shin_r, Reverse(shin_r))
val rotout_r = (rotin >> rot_shamt)(xLen-1,0)
val rotout_l = Reverse(rotout_r)
val rotout = Mux(io.fn(0), rotout_r, rotout_l) | Mux(io.fn(0), shout_l, shout_r)

val out = MuxLookup(io.fn, shift_logic_cond)(Seq(
FN_ADD -> io.adder_out,
FN_SUB -> io.adder_out
) ++ (if (coreParams.useZbb) Seq(
FN_UNARY -> unary,
FN_MAX -> maxmin_out,
FN_MIN -> maxmin_out,
FN_MAXU -> maxmin_out,
FN_MINU -> maxmin_out,
FN_ROL -> rotout,
FN_ROR -> rotout,
) else Nil))


io.out := out
if (xLen > 32) {
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/rocket/Configs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class WithCease(enable: Boolean = true) extends RocketCoreConf
class WithCoreClockGatingEnabled extends RocketCoreConfig(_.copy(clockGate = true))
class WithPgLevels(n: Int) extends RocketCoreConfig(_.copy(pgLevels = n))
class WithZba extends RocketCoreConfig(_.copy(useZba = true))
class WithZbb extends RocketCoreConfig(_.copy(useZbb = true))
class WithSV48 extends WithPgLevels(4)
class WithSV39 extends WithPgLevels(3)

Expand Down
11 changes: 6 additions & 5 deletions src/main/scala/rocket/Consts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ trait ScalarOpConstants {
def IMM_I = 4.U(3.W)
def IMM_Z = 5.U(3.W)

def A2_X = BitPat("b??")
def A2_ZERO = 0.U(2.W)
def A2_SIZE = 1.U(2.W)
def A2_RS2 = 2.U(2.W)
def A2_IMM = 3.U(2.W)
def A2_X = BitPat("b???")
def A2_ZERO = 0.U(3.W)
def A2_SIZE = 1.U(3.W)
def A2_RS2 = 2.U(3.W)
def A2_IMM = 3.U(3.W)
def A2_RS2INV = 4.U(3.W)

def X = BitPat("b?")
def N = BitPat("b0")
Expand Down
46 changes: 46 additions & 0 deletions src/main/scala/rocket/IDecode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,52 @@ class Zba64Decode(implicit val p: Parameters) extends DecodeConstants
)
}

class ZbbDecode(implicit val p: Parameters) extends DecodeConstants
{
val table: Array[(BitPat, List[BitPat])] = Array(
ANDN -> List(Y,N,N,N,N,N,Y,Y,A2_RS2INV,A1_RS1,IMM_X,DW_XPR,FN_AND, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
ORN -> List(Y,N,N,N,N,N,Y,Y,A2_RS2INV,A1_RS1,IMM_X,DW_XPR,FN_OR , N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
XNOR -> List(Y,N,N,N,N,N,Y,Y,A2_RS2INV,A1_RS1,IMM_X,DW_XPR,FN_XOR, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
CLZ -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
CPOP -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
CTZ -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
MAX -> List(Y,N,N,N,N,N,Y,Y,A2_RS2 ,A1_RS1,IMM_X,DW_XPR,FN_MAX, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
MAXU -> List(Y,N,N,N,N,N,Y,Y,A2_RS2 ,A1_RS1,IMM_X,DW_XPR,FN_MAXU, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
MIN -> List(Y,N,N,N,N,N,Y,Y,A2_RS2 ,A1_RS1,IMM_X,DW_XPR,FN_MIN, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
MINU -> List(Y,N,N,N,N,N,Y,Y,A2_RS2 ,A1_RS1,IMM_X,DW_XPR,FN_MINU, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
ORC_B -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
ROL -> List(Y,N,N,N,N,N,Y,Y,A2_RS2 ,A1_RS1,IMM_X,DW_XPR,FN_ROL, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
ROR -> List(Y,N,N,N,N,N,Y,Y,A2_RS2 ,A1_RS1,IMM_X,DW_XPR,FN_ROR, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
RORI -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_ROR, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
ZEXT_H -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
SEXT_B -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
SEXT_H -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
)
}

class Zbb64Decode(implicit val p: Parameters) extends DecodeConstants
{
val table: Array[(BitPat, List[BitPat])] = Array(
REV8 -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
CTZW -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_32 ,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
CLZW -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_32 ,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
CPOPW -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_32 ,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
ROLW -> List(Y,N,N,N,N,N,Y,Y,A2_RS2 ,A1_RS1,IMM_X,DW_32 ,FN_ROL, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
RORW -> List(Y,N,N,N,N,N,Y,Y,A2_RS2 ,A1_RS1,IMM_X,DW_32 ,FN_ROR, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
RORIW -> List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_32 ,FN_ROR, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
)
}


class Zbb32Decode(implicit val p: Parameters) extends DecodeConstants
{
val table: Array[(BitPat, List[BitPat])] = Array(
Instructions32.REV8 ->
List(Y,N,N,N,N,N,N,Y,A2_IMM ,A1_RS1,IMM_I,DW_XPR,FN_UNARY, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N),
)
}


class RoCCDecode(implicit val p: Parameters) extends DecodeConstants
{
val table: Array[(BitPat, List[BitPat])] = Array(
Expand Down
7 changes: 5 additions & 2 deletions src/main/scala/rocket/RocketCore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ case class RocketCoreParams(
useRVE: Boolean = false,
useConditionalZero: Boolean = false,
useZba: Boolean = false,
useZbb: Boolean = false,
nLocalInterrupts: Int = 0,
useNMI: Boolean = false,
nBreakpoints: Int = 1,
Expand Down Expand Up @@ -66,7 +67,6 @@ case class RocketCoreParams(
val instBits: Int = if (useCompressed) 16 else 32
val lrscCycles: Int = 80 // worst case is 14 mispredicted branches + slop
val traceHasWdata: Boolean = debugROB.isDefined // ooo wb, so no wdata in trace
val useZbb: Boolean = false
val useZbs: Boolean = false
override val useVector = vector.isDefined
override val vectorUseDCache = vector.map(_.useDCache).getOrElse(false)
Expand Down Expand Up @@ -235,6 +235,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p)
rocketParams.haveCease.option(new CeaseDecode) ++:
usingVector.option(new VCFGDecode) ++:
(if (coreParams.useZba) new ZbaDecode +: (xLen > 32).option(new Zba64Decode).toSeq else Nil) ++:
(if (coreParams.useZbb) Seq(new ZbbDecode, if (xLen == 32) new Zbb32Decode else new Zbb64Decode) else Nil) ++:
Seq(new IDecode)
} flatMap(_.table)

Expand Down Expand Up @@ -475,7 +476,9 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p)
val ex_op2 = MuxLookup(ex_ctrl.sel_alu2, 0.S)(Seq(
A2_RS2 -> ex_rs(1).asSInt,
A2_IMM -> ex_imm,
A2_SIZE -> Mux(ex_reg_rvc, 2.S, 4.S)))
A2_SIZE -> Mux(ex_reg_rvc, 2.S, 4.S),
A2_RS2INV -> (if (rocketParams.useZbb) (~ex_rs(1)).asSInt else 0.S)
))

val (ex_new_vl, ex_new_vconfig) = if (usingVector) {
val ex_new_vtype = VType.fromUInt(MuxCase(ex_rs(1), Seq(
Expand Down

0 comments on commit 83e8c6b

Please sign in to comment.