Skip to content

Commit

Permalink
Support RoCC accels which define CSRs (#3358)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryz123 authored May 16, 2023
1 parent 1272bd5 commit 7ddf02a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 13 deletions.
22 changes: 19 additions & 3 deletions src/main/scala/rocket/CSR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,13 @@ class VType(implicit p: Parameters) extends CoreBundle {

class CSRFile(
perfEventSets: EventSets = new EventSets(Seq()),
customCSRs: Seq[CustomCSR] = Nil)(implicit p: Parameters)
customCSRs: Seq[CustomCSR] = Nil,
roccCSRs: Seq[CustomCSR] = Nil)(implicit p: Parameters)
extends CoreModule()(p)
with HasCoreParameters {
val io = IO(new CSRFileIO {
val customCSRs = Output(Vec(CSRFile.this.customCSRs.size, new CustomCSRIO))
val roccCSRs = Output(Vec(CSRFile.this.roccCSRs.size, new CustomCSRIO))
})

val reset_mstatus = WireDefault(0.U.asTypeOf(new MStatus()))
Expand Down Expand Up @@ -745,13 +747,15 @@ class CSRFile(
}

// implementation-defined CSRs
val reg_custom = customCSRs.map { csr =>
def generateCustomCSR(csr: CustomCSR) = {
require(csr.mask >= 0 && csr.mask.bitLength <= xLen)
require(!read_mapping.contains(csr.id))
val reg = csr.init.map(init => RegInit(init.U(xLen.W))).getOrElse(Reg(UInt(xLen.W)))
read_mapping += csr.id -> reg
reg
}
val reg_custom = customCSRs.map(generateCustomCSR(_))
val reg_rocc = roccCSRs.map(generateCustomCSR(_))

if (usingHypervisor) {
read_mapping += CSRs.mtinst -> 0.U
Expand Down Expand Up @@ -1102,6 +1106,12 @@ class CSRFile(
io.value := reg
}

for ((io, reg) <- io.roccCSRs zip reg_rocc) {
io.wen := false.B
io.wdata := wdata
io.value := reg
}

io.rw.rdata := Mux1H(for ((k, v) <- read_mapping) yield decoded_addr(k) -> v)

// cover access to register
Expand Down Expand Up @@ -1421,13 +1431,19 @@ class CSRFile(
pmp.addr := wdata
}
}
for ((io, csr, reg) <- (io.customCSRs, customCSRs, reg_custom).zipped) {
def writeCustomCSR(io: CustomCSRIO, csr: CustomCSR, reg: UInt) = {
val mask = csr.mask.U(xLen.W)
when (decoded_addr(csr.id)) {
reg := (wdata & mask) | (reg & ~mask)
io.wen := true.B
}
}
for ((io, csr, reg) <- (io.customCSRs, customCSRs, reg_custom).zipped) {
writeCustomCSR(io, csr, reg)
}
for ((io, csr, reg) <- (io.roccCSRs, roccCSRs, reg_rocc).zipped) {
writeCustomCSR(io, csr, reg)
}
if (usingVector) {
when (decoded_addr(CSRs.vstart)) { set_vs_dirty := true.B; reg_vstart.get := wdata }
when (decoded_addr(CSRs.vxrm)) { set_vs_dirty := true.B; reg_vxrm.get := wdata }
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/rocket/RocketCore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class RocketCustomCSRs(implicit p: Parameters) extends CustomCSRs with HasRocket
class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p)
with HasRocketCoreParameters
with HasCoreIO {
def nTotalRoCCCSRs = tile.roccCSRs.flatten.size

val clock_en_reg = RegInit(true.B)
val long_latency_stall = Reg(Bool())
Expand Down Expand Up @@ -310,7 +311,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p)
val ctrl_killd = Wire(Bool())
val id_npc = (ibuf.io.pc.asSInt + ImmGen(IMM_UJ, id_inst(0))).asUInt

val csr = Module(new CSRFile(perfEvents, coreParams.customCSRs.decls))
val csr = Module(new CSRFile(perfEvents, coreParams.customCSRs.decls, tile.roccCSRs.flatten))
val id_csr_en = id_ctrl.csr.isOneOf(CSR.S, CSR.C, CSR.W)
val id_system_insn = id_ctrl.csr === CSR.I
val id_csr_ren = id_ctrl.csr.isOneOf(CSR.S, CSR.C) && id_expanded_inst(0).rs1 === 0.U
Expand Down Expand Up @@ -814,6 +815,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p)
csr.io.rw.wdata := wb_reg_wdata
io.trace.insns := csr.io.trace
io.trace.time := csr.io.time
io.rocc.csrs := csr.io.roccCSRs
for (((iobpw, wphit), bp) <- io.bpwatch zip wb_reg_wphit zip csr.io.bp) {
iobpw.valid(0) := wphit
iobpw.action := bp.control.action
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/tile/Core.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class TraceBundle(implicit val p: Parameters) extends Bundle with HasCoreParamet

trait HasCoreIO extends HasTileParameters {
implicit val p: Parameters
def nTotalRoCCCSRs: Int
val io = new CoreBundle()(p) {
val hartid = UInt(hartIdLen.W).asInput
val reset_vector = UInt(resetVectorLen.W).asInput
Expand All @@ -168,7 +169,7 @@ trait HasCoreIO extends HasTileParameters {
val dmem = new HellaCacheIO
val ptw = new DatapathPTWIO().flip
val fpu = new FPUCoreIO().flip
val rocc = new RoCCCoreIO().flip
val rocc = new RoCCCoreIO(nTotalRoCCCSRs).flip
val trace = Output(new TraceBundle)
val bpwatch = Vec(coreParams.nBreakpoints, new BPWatch(coreParams.retireWidth)).asOutput
val cease = Bool().asOutput
Expand Down
22 changes: 14 additions & 8 deletions src/main/scala/tile/LazyRoCC.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,46 @@ class RoCCResponse(implicit p: Parameters) extends CoreBundle()(p) {
val data = Bits(xLen.W)
}

class RoCCCoreIO(implicit p: Parameters) extends CoreBundle()(p) {
class RoCCCoreIO(val nRoCCCSRs: Int = 0)(implicit p: Parameters) extends CoreBundle()(p) {
val cmd = Flipped(Decoupled(new RoCCCommand))
val resp = Decoupled(new RoCCResponse)
val mem = new HellaCacheIO
val busy = Output(Bool())
val interrupt = Output(Bool())
val exception = Input(Bool())
val csrs = Input(Vec(nRoCCCSRs, new CustomCSRIO))
}

class RoCCIO(val nPTWPorts: Int)(implicit p: Parameters) extends RoCCCoreIO()(p) {
class RoCCIO(val nPTWPorts: Int, nRoCCCSRs: Int)(implicit p: Parameters) extends RoCCCoreIO(nRoCCCSRs)(p) {
val ptw = Vec(nPTWPorts, new TLBPTWIO)
val fpu_req = Decoupled(new FPInput)
val fpu_resp = Flipped(Decoupled(new FPResult))
}

/** Base classes for Diplomatic TL2 RoCC units **/
abstract class LazyRoCC(
val opcodes: OpcodeSet,
val nPTWPorts: Int = 0,
val usesFPU: Boolean = false
)(implicit p: Parameters) extends LazyModule {
val opcodes: OpcodeSet,
val nPTWPorts: Int = 0,
val usesFPU: Boolean = false,
val roccCSRs: Seq[CustomCSR] = Nil
)(implicit p: Parameters) extends LazyModule {
val module: LazyRoCCModuleImp
require(roccCSRs.map(_.id).toSet.size == roccCSRs.size)
val atlNode: TLNode = TLIdentityNode()
val tlNode: TLNode = TLIdentityNode()
}

class LazyRoCCModuleImp(outer: LazyRoCC) extends LazyModuleImp(outer) {
val io = IO(new RoCCIO(outer.nPTWPorts))
val io = IO(new RoCCIO(outer.nPTWPorts, outer.roccCSRs.size))
}

/** Mixins for including RoCC **/

trait HasLazyRoCC extends CanHavePTW { this: BaseTile =>
val roccs = p(BuildRoCC).map(_(p))

val roccCSRs = roccs.map(_.roccCSRs) // the set of custom CSRs requested by all roccs
require(roccCSRs.flatten.map(_.id).toSet.size == roccCSRs.flatten.size,
"LazyRoCC instantiations require overlapping CSRs")
roccs.map(_.atlNode).foreach { atl => tlMasterXbar.node :=* atl }
roccs.map(_.tlNode).foreach { tl => tlOtherMastersNode :=* tl }

Expand Down Expand Up @@ -115,6 +120,7 @@ trait HasLazyRoCCModule extends CanHavePTWModule
} else {
(None, None)
}
val roccCSRIOs = outer.roccs.map(_.module.io.csrs)
}

class AccumulatorExample(opcodes: OpcodeSet, val n: Int = 4)(implicit p: Parameters) extends LazyRoCC(opcodes) {
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/tile/RocketTile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class RocketTileModuleImp(outer: RocketTile) extends BaseTileModuleImp(outer)
core.io.rocc.resp <> respArb.get.io.out
core.io.rocc.busy <> (cmdRouter.get.io.busy || outer.roccs.map(_.module.io.busy).reduce(_ || _))
core.io.rocc.interrupt := outer.roccs.map(_.module.io.interrupt).reduce(_ || _)
(core.io.rocc.csrs zip roccCSRIOs.flatten).foreach { t => t._2 := t._1 }
}

// Rocket has higher priority to DTIM than other TileLink clients
Expand Down

0 comments on commit 7ddf02a

Please sign in to comment.