Skip to content

Commit

Permalink
[VTA] [Hardware] Chisel implementation (#3258)
Browse files Browse the repository at this point in the history
  • Loading branch information
vegaluisjose authored and tqchen committed Jun 5, 2019
1 parent 1f62d95 commit 32f74f3
Show file tree
Hide file tree
Showing 43 changed files with 4,784 additions and 23 deletions.
3 changes: 0 additions & 3 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@ set(USE_SORT ON)
# Build ANTLR parser for Relay text format
set(USE_ANTLR OFF)

# Build TSIM for VTA
set(USE_VTA_TSIM OFF)

# Whether use Relay debug mode
set(USE_RELAY_DEBUG OFF)

16 changes: 8 additions & 8 deletions cmake/modules/VTA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ elseif(PYTHON)
--use-cfg=${CMAKE_CURRENT_BINARY_DIR}/vta_config.json)
endif()

execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE __vta_target)
string(STRIP ${__vta_target} VTA_TARGET)
execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE VTA_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE)

message(STATUS "Build VTA runtime with target: " ${VTA_TARGET})

Expand All @@ -44,6 +43,13 @@ elseif(PYTHON)

add_library(vta SHARED ${VTA_RUNTIME_SRCS})

if(${VTA_TARGET} STREQUAL "tsim")
target_compile_definitions(vta PUBLIC USE_TSIM)
include_directories("vta/include")
file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
endif()

target_include_directories(vta PUBLIC vta/include)

foreach(__def ${VTA_DEFINITIONS})
Expand All @@ -61,12 +67,6 @@ elseif(PYTHON)
target_link_libraries(vta ${__cma_lib})
endif()

if(NOT USE_VTA_TSIM STREQUAL "OFF")
include_directories("vta/include")
file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
endif()

else()
message(STATUS "Cannot found python in env, VTA build is skipped..")
endif()
2 changes: 1 addition & 1 deletion vta/apps/tsim_example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ sudo apt install verilator sbt
## Setup in TVM

1. Install `verilator` and `sbt` as described above
2. Enable VTA TSIM by turning on the switch `USE_VTA_TSIM` in config.cmake
2. Set the VTA TARGET to `tsim` on `<tvm-root>/vta/config/vta_config.json`
3. Build tvm

## How to run VTA TSIM examples
Expand Down
2 changes: 1 addition & 1 deletion vta/apps/tsim_example/cmake/modules/hw.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ else()
file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc)
add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC})

set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
if (NOT TSIM_USE_TRACE STREQUAL "OFF")
list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd)
else()
Expand Down
76 changes: 76 additions & 0 deletions vta/hardware/chisel/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,81 @@
# specific language governing permissions and limitations
# under the License.

CONFIG = DefaultF1Config
TOP = VTA
TOP_TEST = Test
BUILD_NAME = build
USE_TRACE = 0
VTA_LIBNAME = libvta_hw

config_test = $(TOP_TEST)$(CONFIG)
vta_dir = $(abspath ../../)
tvm_dir = $(abspath ../../../)
verilator_inc_dir = /usr/local/share/verilator/include
verilator_build_dir = $(vta_dir)/$(BUILD_NAME)/verilator
chisel_build_dir = $(vta_dir)/$(BUILD_NAME)/chisel

verilator_opt = --cc
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt += +define+RANDOMIZE_REG_INIT
verilator_opt += +define+RANDOMIZE_MEM_INIT
verilator_opt += --x-assign unique
verilator_opt += --output-split 20000
verilator_opt += --output-split-cfuncs 20000
verilator_opt += --top-module ${TOP_TEST}
verilator_opt += -Mdir ${verilator_build_dir}
verilator_opt += -I$(chisel_build_dir)

cxx_flags = -O2 -Wall -fPIC -shared
cxx_flags += -fvisibility=hidden -std=c++11
cxx_flags += -DVL_TSIM_NAME=V$(TOP_TEST)
cxx_flags += -DVL_PRINTF=printf
cxx_flags += -DVL_USER_FINISH
cxx_flags += -DVM_COVERAGE=0
cxx_flags += -DVM_SC=0
cxx_flags += -Wno-sign-compare
cxx_flags += -include V$(TOP_TEST).h
cxx_flags += -I$(verilator_build_dir)
cxx_flags += -I$(verilator_inc_dir)
cxx_flags += -I$(verilator_inc_dir)/vltstd
cxx_flags += -I$(vta_dir)/include
cxx_flags += -I$(tvm_dir)/include
cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include

cxx_files = $(verilator_inc_dir)/verilated.cpp
cxx_files += $(verilator_inc_dir)/verilated_dpi.cpp
cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc

ifneq ($(USE_TRACE), 0)
verilator_opt += --trace
cxx_flags += -DVM_TRACE=1
cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).vcd
cxx_files += $(verilator_inc_dir)/verilated_vcd_c.cpp
else
cxx_flags += -DVM_TRACE=0
endif

default: lib

lib: $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so: $(verilator_build_dir)/V$(TOP_TEST).cpp
g++ $(cxx_flags) $(cxx_files) -o $@

verilator: $(verilator_build_dir)/V$(TOP_TEST).cpp
$(verilator_build_dir)/V$(TOP_TEST).cpp: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
verilator $(verilator_opt) $<

verilog: $(chisel_build_dir)/$(TOP).$(CONFIG).v
$(chisel_build_dir)/$(TOP).$(CONFIG).v:
sbt 'runMain vta.$(CONFIG) --target-dir $(chisel_build_dir) --top-name $(TOP).$(CONFIG)'

verilog_test: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v:
sbt 'runMain vta.$(config_test) --target-dir $(chisel_build_dir) --top-name $(TOP_TEST).$(CONFIG)'

clean:
-rm -rf target project/target project/project

cleanall:
-rm -rf $(vta_dir)/$(BUILD_NAME)
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ module VTAHostDPI #

always_ff @(posedge clock) begin
if (__exit == 'd1) begin
$display("[DONE] at cycle:%016d", cycles);
$display("[TSIM] Verilog $finish called at cycle:%016d", cycles);
$finish;
end
end
Expand Down
201 changes: 201 additions & 0 deletions vta/hardware/chisel/src/main/scala/core/Compute.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package vta.core

import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._

/** Compute.
*
* The compute unit is in charge of the following:
* - Loading micro-ops from memory (loadUop module)
* - Loading biases (acc) from memory (tensorAcc module)
* - Compute ALU instructions (tensorAlu module)
* - Compute GEMM instructions (tensorGemm module)
*/
class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val i_post = Vec(2, Input(Bool()))
val o_post = Vec(2, Output(Bool()))
val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
val uop_baddr = Input(UInt(mp.addrBits.W))
val acc_baddr = Input(UInt(mp.addrBits.W))
val vme_rd = Vec(2, new VMEReadMaster)
val inp = new TensorMaster(tensorType = "inp")
val wgt = new TensorMaster(tensorType = "wgt")
val out = new TensorMaster(tensorType = "out")
val finish = Output(Bool())
})
val sIdle :: sSync :: sExe :: Nil = Enum(3)
val state = RegInit(sIdle)

val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0)))

val loadUop = Module(new LoadUop)
val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
val tensorGemm = Module(new TensorGemm)
val tensorAlu = Module(new TensorAlu)

val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))

// decode
val dec = Module(new ComputeDecode)
dec.io.inst := inst_q.io.deq.bits

val inst_type = Cat(dec.io.isFinish,
dec.io.isAlu,
dec.io.isGemm,
dec.io.isLoadAcc,
dec.io.isLoadUop).asUInt

val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
val start = snext & sprev
val done =
MuxLookup(inst_type,
false.B, // default
Array(
"h_01".U -> loadUop.io.done,
"h_02".U -> tensorAcc.io.done,
"h_04".U -> tensorGemm.io.done,
"h_08".U -> tensorAlu.io.done,
"h_10".U -> true.B // Finish
)
)

// control
switch (state) {
is (sIdle) {
when (start) {
when (dec.io.isSync) {
state := sSync
} .elsewhen (inst_type.orR) {
state := sExe
}
}
}
is (sSync) {
state := sIdle
}
is (sExe) {
when (done) {
state := sIdle
}
}
}

// instructions
inst_q.io.enq <> io.inst
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)

// uop
loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
loadUop.io.inst := inst_q.io.deq.bits
loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd
loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)

// acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr
tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
io.vme_rd(1) <> tensorAcc.io.vme_rd

// gemm
tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
tensorGemm.io.inst := inst_q.io.deq.bits
tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorGemm.io.inp <> io.inp
tensorGemm.io.wgt <> io.wgt
tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm
tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits

// alu
tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
tensorAlu.io.inst := inst_q.io.deq.bits
tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu
tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits

// out
io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx)
io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)

// semaphore
s(0).io.spost := io.i_post(0)
s(1).io.spost := io.i_post(1)
s(0).io.swait := dec.io.pop_prev & (state === sIdle & start)
s(1).io.swait := dec.io.pop_next & (state === sIdle & start)
io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync))

// finish
io.finish := state === sExe & done & dec.io.isFinish

// debug
if (debug) {
// start
when (state === sIdle && start) {
when (dec.io.isSync) {
printf("[Compute] start sync\n")
} .elsewhen (dec.io.isLoadUop) {
printf("[Compute] start load uop\n")
} .elsewhen (dec.io.isLoadAcc) {
printf("[Compute] start load acc\n")
} .elsewhen (dec.io.isGemm) {
printf("[Compute] start gemm\n")
} .elsewhen (dec.io.isAlu) {
printf("[Compute] start alu\n")
} .elsewhen (dec.io.isFinish) {
printf("[Compute] start finish\n")
}
}
// done
when (state === sSync) {
printf("[Compute] done sync\n")
}
when (state === sExe) {
when (done) {
when (dec.io.isLoadUop) {
printf("[Compute] done load uop\n")
} .elsewhen (dec.io.isLoadAcc) {
printf("[Compute] done load acc\n")
} .elsewhen (dec.io.isGemm) {
printf("[Compute] done gemm\n")
} .elsewhen (dec.io.isAlu) {
printf("[Compute] done alu\n")
} .elsewhen (dec.io.isFinish) {
printf("[Compute] done finish\n")
}
}
}
}
}
Loading

0 comments on commit 32f74f3

Please sign in to comment.