Skip to content

Commit

Permalink
[VTA][Chisel] TSIM VTA Source Refactor (apache#4163)
Browse files Browse the repository at this point in the history
* app init push

* fix on readme

* change name, add bit serial explanantion

* rm serialLoadMM, change doc

* syntax change for readme

* add parallel test functionality

* fix readme

* add python doc

* syntax

* init commit

* fix empty line

* fix typo
  • Loading branch information
BenjaminTu authored and jroesch committed Oct 27, 2019
1 parent 5b1350f commit 0b53071
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 130 deletions.
153 changes: 104 additions & 49 deletions apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,31 @@ package accel
import chisel3._
import chisel3.util._
import vta.dpi._
import vta.core._
import vta.util.config._
import vta.shell._

class TestConfig extends Config(new CoreConfig ++ new PynqConfig)
/** Compute
*
* Bit Slice GEMM:
*
* 1. Wait for launch to be asserted
* 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address
* 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix)
* 3. Wait for the value
* 4. Increment read-address for next value
* 5. Wait for sliced accumulator
* 6. Check if counter (cnt) is equal to length process,
otherwise goto step 2
* 7. Check if reset slice accumulator
* 8. Wait for overall accumulator
* 8. Issue a write request for 8-byte value at out_baddr address
* 5. Repeat until all inp1 data have been read
* 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector)
* 7. Wait for the value
* 8. Increment read-address for next value
* 9. Repeat until all inp2 data have been read
* 10. Wait for output to be calculated
* 11. Issue a write request for 8-byte value at out_baddr address
* 12. Increment write-address for next value to write
* 13. Check if counter (cntout) is equal to length to asser finish,
otherwise go to step 11
*/
class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
Expand All @@ -47,19 +57,24 @@ class Compute(implicit config: AccelConfig) extends Module {
val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val mem = new VTAMemDPIMaster
})
val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7)
implicit val p: Parameters = new TestConfig
val sIdle :: sReadAReq :: sReadAData :: sReadADone ::sReadBReq :: sReadBData :: sReadBDone :: sInpDone ::sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum(12)
val state = RegInit(sIdle)
val shift = io.vals(0)
val length = io.vals(1)
val rstAccum = io.vals(2)
val startDot = io.vals(3)
val cycles = RegInit(0.U(config.regBits.W))
val reg1 = Reg(chiselTypeOf(io.mem.rd.bits))
val reg2 = Reg(chiselTypeOf(io.mem.rd.bits))
val cnt = Reg(UInt(config.regBits.W))
val mvc = Module(new MatrixVectorMultiplication)
val reg1 = Reg(chiselTypeOf(mvc.io.wgt.data.bits))
val reg2 = Reg(chiselTypeOf(mvc.io.inp.data.bits))
val cntwgt = Reg(UInt(config.regBits.W))
val cntinp = Reg(UInt(config.regBits.W))
val cntout = Reg(UInt(config.regBits.W))
val raddr1 = Reg(UInt(config.ptrBits.W))
val raddr2 = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W))
val accum = Module(new Accmulator(size = p(CoreKey).blockOut, accBits = p(CoreKey).accBits))

switch (state) {
is (sIdle) {
Expand All @@ -73,14 +88,38 @@ class Compute(implicit config: AccelConfig) extends Module {
}
is (sReadAData) {
when (io.mem.rd.valid) {
state := sReadADone
}
}
is (sReadADone) {
when (cntwgt === (length * length) - 1.U) {
state := sReadBReq
} .otherwise {
state := sReadAReq
}
}
is (sReadBReq) {
state := sReadBData
}
is (sReadBData) {
when (io.mem.rd.valid) {
state := sReadBDone
}
}
is (sReadBDone) {
when (cntinp === length-1.U) {
state := sInpDone
} .otherwise {
state := sReadBReq
}
}
// Both input is processed
is (sInpDone) {
state := sWait
}
// Wait for computation
is (sWait) {
when (accum.io.ready) {
state := sWriteReq
}
}
Expand All @@ -89,15 +128,18 @@ class Compute(implicit config: AccelConfig) extends Module {
state := sWriteData
}
is (sWriteData) {
when (cnt === (length - 1.U)) {
state := sWriteDone
}
is (sWriteDone) {
when (cntout === (length - 1.U)) {
state := sIdle
} .otherwise {
state := sReadAReq
state := sWriteReq
}
}
}

val last = state === sWriteData && cnt === (length - 1.U)
val last = state === sWriteDone && cntout === (length - 1.U)

// cycle counter
when (state === sIdle) {
Expand All @@ -114,10 +156,12 @@ class Compute(implicit config: AccelConfig) extends Module {
raddr1 := io.ptrs(0)
raddr2 := io.ptrs(1)
waddr := io.ptrs(2)
} .elsewhen (state === sWriteData) { // increment input array by 1-byte
} .elsewhen (state === sReadADone) { // increment input array by 1-byte
raddr1 := raddr1 + 1.U
} .elsewhen (state === sReadBDone) { // increment input array by 1-byte
raddr2 := raddr2 + 1.U
waddr := waddr
} .elsewhen (state === sWriteDone) {
waddr := waddr + 4.U // writing 4 bytes
}

// create request
Expand All @@ -128,59 +172,70 @@ class Compute(implicit config: AccelConfig) extends Module {

// read
when (state === sReadAData && io.mem.rd.valid) {
reg1 := io.mem.rd.bits(7, 0)
reg1(cntwgt/length)(cntwgt%length) := io.mem.rd.bits(7, 0)
}

when (state === sReadBData && io.mem.rd.valid) {
reg2 := io.mem.rd.bits(7, 0)
reg2(0)(cntinp) := io.mem.rd.bits(7, 0)
}

io.mem.rd.ready := state === sReadAData | state === sReadBData
mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed
mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed

mvc.io.wgt.data.bits <> reg1
mvc.io.inp.data.bits <> reg2
// Modify when shift operation is supported
mvc.io.reset := false.B
mvc.io.acc_i.data.valid := true.B
for (i <- 0 until p(CoreKey).blockOut) {
mvc.io.acc_i.data.bits(0)(i) := 0.U
}


val sliceAccum = Module(new Accumulator(63))
val overallAccum = Module(new Accumulator(64))

sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed
sliceAccum.io.in := reg1 * reg2
sliceAccum.io.clear := startDot
overallAccum.io.clear := rstAccum
overallAccum.io.valid := last // last element has been processed
overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits
accum.io.in := mvc.io.acc_o.data.bits
accum.io.shift := shift
accum.io.clear := rstAccum
accum.io.valid := mvc.io.acc_o.data.valid

// write
io.mem.wr.valid := overallAccum.io.ready
io.mem.wr.bits := overallAccum.io.sum

io.mem.wr.valid := state === sWriteData
io.mem.wr.bits := accum.io.sum(cntout)

// count read/write
when (state === sIdle) {
cnt := 0.U
} .elsewhen (state === sWriteData) {
cnt := cnt + 1.U
cntwgt := 0.U
cntinp := 0.U
cntout := 0.U
} .elsewhen (state === sReadADone) {
cntwgt := cntwgt + 1.U
} .elsewhen (state === sReadBDone) {
cntinp := cntinp + 1.U
} .elsewhen (state === sWriteDone) {
cntout := cntout + 1.U
}

io.finish := overallAccum.io.ready // data has been added
io.finish := last // data has been added
}


class Accumulator(dataBits: Int = 8) extends Module {
// Shift operation until supported in MVM
class Accmulator(size: Int = 16, accBits: Int = 32) extends Module {
val io = IO(new Bundle {
val clear = Input(Bool())
val valid = Input(Bool())
val ready = Output(Bool())
val in = Input(UInt(dataBits.W))
val sum = Output(UInt((dataBits).W))
val in = Input(Vec(1, Vec(size, (UInt(accBits.W)))))
val shift = Input(UInt(8.W))
val sum = Output(Vec(size, (UInt(accBits.W))))
})
val reg = RegInit(VecInit(Seq.fill(size)(0.U(accBits.W))))

val reg = RegInit(0.U((dataBits).W))
val ready = RegNext(io.valid)
when (io.clear) {
reg := 0.U
} .elsewhen (io.valid) {
reg := reg + io.in
}
io.ready := ready
io.sum := reg
for (i <- 0 until size) {
when (io.clear) {
reg(i) := 0.U
} .elsewhen(io.valid) {
reg(i) := reg(i) + (io.in(0)(i) << io.shift)
}
}
io.ready := RegNext(io.valid)
io.sum := reg
}

10 changes: 3 additions & 7 deletions apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,9 @@ import vta.dpi._
* Shift value | 0x08
* Vector length | 0x0c
* Reset Accumulator | 0x10
* Reset Dot Module | 0x14
* Input1 pointer lsb | 0x18
* Input1 pointer msb | 0x1c
* Input2 pointer lsb | 0x20
* Input2 pointer msb | 0x24
* Output pointer lsb | 0x28
* Output pointer msb | 0x2c
* Input1 pointer | 0x18
* Input2 pointer | 0x20
* Output pointer | 0x28
* -------------------------------
* ------------------------------
Expand Down
18 changes: 9 additions & 9 deletions apps/gemm/src/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ class Device {

uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
uint32_t cycles;
uint32_t length = inp1->shape[0];
size_t size1 = (inp1->dtype.bits >> 3) * length;
uint32_t length = inp2->shape[0];
// 1 matrix 1 vector input
size_t size1 = (inp1->dtype.bits >> 3) * length * length;
size_t size2 = (inp2->dtype.bits >> 3) * length;
size_t size3 = (64 >> 3);
// 1 vector output
size_t size3 = (32 >> 3) * length;
inp1_ = this->MemAlloc(size1);
inp2_ = this->MemAlloc(size2);
out_ = this->MemAlloc(size3);
Expand Down Expand Up @@ -115,19 +117,17 @@ class Device {

void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
dpi_->WriteReg(0x08, shiftVal);
dpi_->WriteReg(0x0c, length); // vector length
dpi_->WriteReg(0x0c, length); // tensor size
dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
dpi_->WriteReg(0x00, 0x1); // launch
dpi_->WriteReg(0x00, 0x0); // launch
dpi_->WriteReg(0x00, 0x0);

if (reset == 1) {
dpi_->WriteReg(0x10, 0x1); // reset accum
dpi_->WriteReg(0x10, 0x0); // stop reset accum
dpi_->WriteReg(0x10, 0x1); // reset accumulator
dpi_->WriteReg(0x10, 0x0);
}
dpi_->WriteReg(0x14, 0x1); // reset dot
dpi_->WriteReg(0x14, 0x0); // stop reset dot
}

uint32_t WaitForCompletion() {
Expand Down
Loading

0 comments on commit 0b53071

Please sign in to comment.