From 32f74f31c8100dd562621c2c8005b87db2d3e6f0 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Wed, 5 Jun 2019 10:17:11 -0700 Subject: [PATCH] [VTA] [Hardware] Chisel implementation (#3258) --- cmake/config.cmake | 3 - cmake/modules/VTA.cmake | 16 +- vta/apps/tsim_example/README.md | 2 +- vta/apps/tsim_example/cmake/modules/hw.cmake | 2 +- vta/hardware/chisel/Makefile | 76 ++++ .../src/main/resources/verilog/VTAHostDPI.v | 2 +- .../chisel/src/main/scala/core/Compute.scala | 201 ++++++++++ .../chisel/src/main/scala/core/Configs.scala | 46 +++ .../chisel/src/main/scala/core/Core.scala | 109 ++++++ .../chisel/src/main/scala/core/Decode.scala | 229 +++++++++++ .../chisel/src/main/scala/core/Fetch.scala | 197 ++++++++++ .../chisel/src/main/scala/core/ISA.scala | 93 +++++ .../chisel/src/main/scala/core/Load.scala | 131 +++++++ .../chisel/src/main/scala/core/LoadUop.scala | 214 ++++++++++ .../src/main/scala/core/Semaphore.scala | 42 ++ .../chisel/src/main/scala/core/Store.scala | 114 ++++++ .../src/main/scala/core/TensorAlu.scala | 295 ++++++++++++++ .../src/main/scala/core/TensorGemm.scala | 364 ++++++++++++++++++ .../src/main/scala/core/TensorLoad.scala | 278 +++++++++++++ .../src/main/scala/core/TensorStore.scala | 224 +++++++++++ .../src/main/scala/core/TensorUtil.scala | 304 +++++++++++++++ .../chisel/src/main/scala/core/package.scala | 23 ++ .../src/main/scala/dpi/VTAHostDPI.scala | 83 ++++ .../chisel/src/main/scala/dpi/VTAMemDPI.scala | 98 +++++ .../src/main/scala/interface/axi/AXI.scala | 312 +++++++++++++++ .../chisel/src/main/scala/shell/Configs.scala | 51 +++ .../src/main/scala/shell/SimShell.scala | 78 ++++ .../chisel/src/main/scala/shell/VCR.scala | 242 ++++++++++++ .../chisel/src/main/scala/shell/VME.scala | 254 ++++++++++++ .../src/main/scala/shell/VTAShell.scala | 57 +++ .../src/main/scala/shell/XilinxShell.scala | 117 ++++++ .../chisel/src/main/scala/test/Test.scala | 33 ++ .../chisel/src/main/scala/util/Config.scala | 104 +++++ .../util/GenericParameterizedBundle.scala | 40 ++ .../chisel/src/main/scala/vta/Configs.scala | 51 +++ vta/hardware/dpi/tsim_device.cc | 10 + vta/include/vta/driver.h | 16 + vta/python/vta/environment.py | 2 +- vta/python/vta/testing/simulator.py | 19 + vta/python/vta/testing/util.py | 5 +- vta/src/runtime.cc | 63 ++- vta/src/tsim/tsim_driver.cc | 179 +++++++++ vta/tests/python/unittest/test_vta_insn.py | 28 +- 43 files changed, 4784 insertions(+), 23 deletions(-) create mode 100644 vta/hardware/chisel/src/main/scala/core/Compute.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Configs.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Core.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Decode.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Fetch.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/ISA.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Load.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/LoadUop.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Semaphore.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Store.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorAlu.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorGemm.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorLoad.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorStore.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorUtil.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/package.scala create mode 100644 vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/Configs.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/SimShell.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/VCR.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/VME.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/VTAShell.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala create mode 100644 vta/hardware/chisel/src/main/scala/test/Test.scala create mode 100644 vta/hardware/chisel/src/main/scala/util/Config.scala create mode 100644 vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala create mode 100644 vta/hardware/chisel/src/main/scala/vta/Configs.scala create mode 100644 vta/src/tsim/tsim_driver.cc diff --git a/cmake/config.cmake b/cmake/config.cmake index 7c5add5ce4a8..6239bc4e6dce 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -132,9 +132,6 @@ set(USE_SORT ON) # Build ANTLR parser for Relay text format set(USE_ANTLR OFF) -# Build TSIM for VTA -set(USE_VTA_TSIM OFF) - # Whether use Relay debug mode set(USE_RELAY_DEBUG OFF) diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index 1df6c6676fac..6d5ea000edc2 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -29,8 +29,7 @@ elseif(PYTHON) --use-cfg=${CMAKE_CURRENT_BINARY_DIR}/vta_config.json) endif() - execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE __vta_target) - string(STRIP ${__vta_target} VTA_TARGET) + execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE VTA_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE) message(STATUS "Build VTA runtime with target: " ${VTA_TARGET}) @@ -44,6 +43,13 @@ elseif(PYTHON) add_library(vta SHARED ${VTA_RUNTIME_SRCS}) + if(${VTA_TARGET} STREQUAL "tsim") + target_compile_definitions(vta PUBLIC USE_TSIM) + include_directories("vta/include") + file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS}) + endif() + target_include_directories(vta PUBLIC vta/include) foreach(__def ${VTA_DEFINITIONS}) @@ -61,12 +67,6 @@ elseif(PYTHON) target_link_libraries(vta ${__cma_lib}) endif() - if(NOT USE_VTA_TSIM STREQUAL "OFF") - include_directories("vta/include") - file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc) - list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS}) - endif() - else() message(STATUS "Cannot found python in env, VTA build is skipped..") endif() diff --git a/vta/apps/tsim_example/README.md b/vta/apps/tsim_example/README.md index b557b24ac690..dc06a92f2b0e 100644 --- a/vta/apps/tsim_example/README.md +++ b/vta/apps/tsim_example/README.md @@ -49,7 +49,7 @@ sudo apt install verilator sbt ## Setup in TVM 1. Install `verilator` and `sbt` as described above -2. Enable VTA TSIM by turning on the switch `USE_VTA_TSIM` in config.cmake +2. Set the VTA TARGET to `tsim` on `/vta/config/vta_config.json` 3. Build tvm ## How to run VTA TSIM examples diff --git a/vta/apps/tsim_example/cmake/modules/hw.cmake b/vta/apps/tsim_example/cmake/modules/hw.cmake index 87dd72b2e626..e016ea03b6fa 100644 --- a/vta/apps/tsim_example/cmake/modules/hw.cmake +++ b/vta/apps/tsim_example/cmake/modules/hw.cmake @@ -124,7 +124,7 @@ else() file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc) add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC}) - set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0) + set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0) if (NOT TSIM_USE_TRACE STREQUAL "OFF") list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd) else() diff --git a/vta/hardware/chisel/Makefile b/vta/hardware/chisel/Makefile index 65a9ed13c989..7371dd1b3686 100644 --- a/vta/hardware/chisel/Makefile +++ b/vta/hardware/chisel/Makefile @@ -15,5 +15,81 @@ # specific language governing permissions and limitations # under the License. +CONFIG = DefaultF1Config +TOP = VTA +TOP_TEST = Test +BUILD_NAME = build +USE_TRACE = 0 +VTA_LIBNAME = libvta_hw + +config_test = $(TOP_TEST)$(CONFIG) +vta_dir = $(abspath ../../) +tvm_dir = $(abspath ../../../) +verilator_inc_dir = /usr/local/share/verilator/include +verilator_build_dir = $(vta_dir)/$(BUILD_NAME)/verilator +chisel_build_dir = $(vta_dir)/$(BUILD_NAME)/chisel + +verilator_opt = --cc +verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN +verilator_opt += +define+RANDOMIZE_REG_INIT +verilator_opt += +define+RANDOMIZE_MEM_INIT +verilator_opt += --x-assign unique +verilator_opt += --output-split 20000 +verilator_opt += --output-split-cfuncs 20000 +verilator_opt += --top-module ${TOP_TEST} +verilator_opt += -Mdir ${verilator_build_dir} +verilator_opt += -I$(chisel_build_dir) + +cxx_flags = -O2 -Wall -fPIC -shared +cxx_flags += -fvisibility=hidden -std=c++11 +cxx_flags += -DVL_TSIM_NAME=V$(TOP_TEST) +cxx_flags += -DVL_PRINTF=printf +cxx_flags += -DVL_USER_FINISH +cxx_flags += -DVM_COVERAGE=0 +cxx_flags += -DVM_SC=0 +cxx_flags += -Wno-sign-compare +cxx_flags += -include V$(TOP_TEST).h +cxx_flags += -I$(verilator_build_dir) +cxx_flags += -I$(verilator_inc_dir) +cxx_flags += -I$(verilator_inc_dir)/vltstd +cxx_flags += -I$(vta_dir)/include +cxx_flags += -I$(tvm_dir)/include +cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include + +cxx_files = $(verilator_inc_dir)/verilated.cpp +cxx_files += $(verilator_inc_dir)/verilated_dpi.cpp +cxx_files += $(wildcard $(verilator_build_dir)/*.cpp) +cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc + +ifneq ($(USE_TRACE), 0) + verilator_opt += --trace + cxx_flags += -DVM_TRACE=1 + cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).vcd + cxx_files += $(verilator_inc_dir)/verilated_vcd_c.cpp +else + cxx_flags += -DVM_TRACE=0 +endif + +default: lib + +lib: $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so +$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so: $(verilator_build_dir)/V$(TOP_TEST).cpp + g++ $(cxx_flags) $(cxx_files) -o $@ + +verilator: $(verilator_build_dir)/V$(TOP_TEST).cpp +$(verilator_build_dir)/V$(TOP_TEST).cpp: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v + verilator $(verilator_opt) $< + +verilog: $(chisel_build_dir)/$(TOP).$(CONFIG).v +$(chisel_build_dir)/$(TOP).$(CONFIG).v: + sbt 'runMain vta.$(CONFIG) --target-dir $(chisel_build_dir) --top-name $(TOP).$(CONFIG)' + +verilog_test: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v +$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v: + sbt 'runMain vta.$(config_test) --target-dir $(chisel_build_dir) --top-name $(TOP_TEST).$(CONFIG)' + clean: -rm -rf target project/target project/project + +cleanall: + -rm -rf $(vta_dir)/$(BUILD_NAME) diff --git a/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v b/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v index 02fcf0d779e1..8ab85f6b752c 100644 --- a/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v +++ b/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v @@ -112,7 +112,7 @@ module VTAHostDPI # always_ff @(posedge clock) begin if (__exit == 'd1) begin - $display("[DONE] at cycle:%016d", cycles); + $display("[TSIM] Verilog $finish called at cycle:%016d", cycles); $finish; end end diff --git a/vta/hardware/chisel/src/main/scala/core/Compute.scala b/vta/hardware/chisel/src/main/scala/core/Compute.scala new file mode 100644 index 000000000000..ef56c3d4224e --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Compute.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** Compute. + * + * The compute unit is in charge of the following: + * - Loading micro-ops from memory (loadUop module) + * - Loading biases (acc) from memory (tensorAcc module) + * - Compute ALU instructions (tensorAlu module) + * - Compute GEMM instructions (tensorGemm module) + */ +class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val i_post = Vec(2, Input(Bool())) + val o_post = Vec(2, Output(Bool())) + val inst = Flipped(Decoupled(UInt(INST_BITS.W))) + val uop_baddr = Input(UInt(mp.addrBits.W)) + val acc_baddr = Input(UInt(mp.addrBits.W)) + val vme_rd = Vec(2, new VMEReadMaster) + val inp = new TensorMaster(tensorType = "inp") + val wgt = new TensorMaster(tensorType = "wgt") + val out = new TensorMaster(tensorType = "out") + val finish = Output(Bool()) + }) + val sIdle :: sSync :: sExe :: Nil = Enum(3) + val state = RegInit(sIdle) + + val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0))) + + val loadUop = Module(new LoadUop) + val tensorAcc = Module(new TensorLoad(tensorType = "acc")) + val tensorGemm = Module(new TensorGemm) + val tensorAlu = Module(new TensorAlu) + + val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries)) + + // decode + val dec = Module(new ComputeDecode) + dec.io.inst := inst_q.io.deq.bits + + val inst_type = Cat(dec.io.isFinish, + dec.io.isAlu, + dec.io.isGemm, + dec.io.isLoadAcc, + dec.io.isLoadUop).asUInt + + val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B) + val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B) + val start = snext & sprev + val done = + MuxLookup(inst_type, + false.B, // default + Array( + "h_01".U -> loadUop.io.done, + "h_02".U -> tensorAcc.io.done, + "h_04".U -> tensorGemm.io.done, + "h_08".U -> tensorAlu.io.done, + "h_10".U -> true.B // Finish + ) + ) + + // control + switch (state) { + is (sIdle) { + when (start) { + when (dec.io.isSync) { + state := sSync + } .elsewhen (inst_type.orR) { + state := sExe + } + } + } + is (sSync) { + state := sIdle + } + is (sExe) { + when (done) { + state := sIdle + } + } + } + + // instructions + inst_q.io.enq <> io.inst + inst_q.io.deq.ready := (state === sExe & done) | (state === sSync) + + // uop + loadUop.io.start := state === sIdle & start & dec.io.isLoadUop + loadUop.io.inst := inst_q.io.deq.bits + loadUop.io.baddr := io.uop_baddr + io.vme_rd(0) <> loadUop.io.vme_rd + loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx) + + // acc + tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc + tensorAcc.io.inst := inst_q.io.deq.bits + tensorAcc.io.baddr := io.acc_baddr + tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx) + tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr) + io.vme_rd(1) <> tensorAcc.io.vme_rd + + // gemm + tensorGemm.io.start := state === sIdle & start & dec.io.isGemm + tensorGemm.io.inst := inst_q.io.deq.bits + tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm + tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits + tensorGemm.io.inp <> io.inp + tensorGemm.io.wgt <> io.wgt + tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm + tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits + tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm + tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits + + // alu + tensorAlu.io.start := state === sIdle & start & dec.io.isAlu + tensorAlu.io.inst := inst_q.io.deq.bits + tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu + tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits + tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu + tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits + tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu + tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits + + // out + io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx) + io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr) + + // semaphore + s(0).io.spost := io.i_post(0) + s(1).io.spost := io.i_post(1) + s(0).io.swait := dec.io.pop_prev & (state === sIdle & start) + s(1).io.swait := dec.io.pop_next & (state === sIdle & start) + io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync)) + io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync)) + + // finish + io.finish := state === sExe & done & dec.io.isFinish + + // debug + if (debug) { + // start + when (state === sIdle && start) { + when (dec.io.isSync) { + printf("[Compute] start sync\n") + } .elsewhen (dec.io.isLoadUop) { + printf("[Compute] start load uop\n") + } .elsewhen (dec.io.isLoadAcc) { + printf("[Compute] start load acc\n") + } .elsewhen (dec.io.isGemm) { + printf("[Compute] start gemm\n") + } .elsewhen (dec.io.isAlu) { + printf("[Compute] start alu\n") + } .elsewhen (dec.io.isFinish) { + printf("[Compute] start finish\n") + } + } + // done + when (state === sSync) { + printf("[Compute] done sync\n") + } + when (state === sExe) { + when (done) { + when (dec.io.isLoadUop) { + printf("[Compute] done load uop\n") + } .elsewhen (dec.io.isLoadAcc) { + printf("[Compute] done load acc\n") + } .elsewhen (dec.io.isGemm) { + printf("[Compute] done gemm\n") + } .elsewhen (dec.io.isAlu) { + printf("[Compute] done alu\n") + } .elsewhen (dec.io.isFinish) { + printf("[Compute] done finish\n") + } + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/Configs.scala b/vta/hardware/chisel/src/main/scala/core/Configs.scala new file mode 100644 index 000000000000..b4e764b120cd --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Configs.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import vta.util.config._ + +/** CoreConfig. + * + * This is one supported configuration for VTA. This file will + * be eventually filled out with class configurations that can be + * mixed/matched with Shell configurations for different backends. + */ +class CoreConfig extends Config((site, here, up) => { + case CoreKey => CoreParams( + batch = 1, + blockOut = 16, + blockIn = 16, + inpBits = 8, + wgtBits = 8, + uopBits = 32, + accBits = 32, + outBits = 8, + uopMemDepth = 2048, + inpMemDepth = 2048, + wgtMemDepth = 1024, + accMemDepth = 2048, + outMemDepth = 2048, + instQueueEntries = 512) +}) diff --git a/vta/hardware/chisel/src/main/scala/core/Core.scala b/vta/hardware/chisel/src/main/scala/core/Core.scala new file mode 100644 index 000000000000..2a2d4e02784f --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Core.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import vta.util.config._ +import vta.shell._ + +/** Core parameters */ +case class CoreParams ( + batch: Int = 1, + blockOut: Int = 16, + blockIn: Int = 16, + inpBits: Int = 8, + wgtBits: Int = 8, + uopBits: Int = 32, + accBits: Int = 32, + outBits: Int = 8, + uopMemDepth: Int = 512, + inpMemDepth: Int = 512, + wgtMemDepth: Int = 512, + accMemDepth: Int = 512, + outMemDepth: Int = 512, + instQueueEntries: Int = 32 +) + +case object CoreKey extends Field[CoreParams] + +/** Core. + * + * The core defines the current VTA architecture by connecting memory and + * compute modules together such as load/store and compute. Most of the + * connections in the core are bulk (<>), and we should try to keep it this + * way, because it is easier to understand what is going on. + * + * Also, the core must be instantiated by a shell using the + * VTA Control Register (VCR) and the VTA Memory Engine (VME) interfaces. + * More info about these interfaces and modules can be found in the shell + * directory. + */ +class Core(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val vcr = new VCRClient + val vme = new VMEMaster + }) + val fetch = Module(new Fetch) + val load = Module(new Load) + val compute = Module(new Compute) + val store = Module(new Store) + + // Read(rd) and write(wr) from/to memory (i.e. DRAM) + io.vme.rd(0) <> fetch.io.vme_rd + io.vme.rd(1) <> compute.io.vme_rd(0) + io.vme.rd(2) <> load.io.vme_rd(0) + io.vme.rd(3) <> load.io.vme_rd(1) + io.vme.rd(4) <> compute.io.vme_rd(1) + io.vme.wr(0) <> store.io.vme_wr + + // Fetch instructions (tasks) from memory (DRAM) into queues (SRAMs) + fetch.io.launch := io.vcr.launch + fetch.io.ins_baddr := io.vcr.ptrs(0) + fetch.io.ins_count := io.vcr.vals(0) + + // Load inputs and weights from memory (DRAM) into scratchpads (SRAMs) + load.io.i_post := compute.io.o_post(0) + load.io.inst <> fetch.io.inst.ld + load.io.inp_baddr := io.vcr.ptrs(2) + load.io.wgt_baddr := io.vcr.ptrs(3) + + // The compute module performs the following: + // - Load micro-ops (uops) and accumulations (acc) + // - Compute dense and ALU instructions (tasks) + compute.io.i_post(0) := load.io.o_post + compute.io.i_post(1) := store.io.o_post + compute.io.inst <> fetch.io.inst.co + compute.io.uop_baddr := io.vcr.ptrs(1) + compute.io.acc_baddr := io.vcr.ptrs(4) + compute.io.inp <> load.io.inp + compute.io.wgt <> load.io.wgt + + // The store module performs the following: + // - Writes results from compute into scratchpads (SRAMs) + // - Store results from scratchpads (SRAMs) to memory (DRAM) + store.io.i_post := compute.io.o_post(1) + store.io.inst <> fetch.io.inst.st + store.io.out_baddr := io.vcr.ptrs(5) + store.io.out <> compute.io.out + + // Finish instruction is executed and asserts the VCR finish flag + val finish = RegNext(compute.io.finish) + io.vcr.finish := finish +} diff --git a/vta/hardware/chisel/src/main/scala/core/Decode.scala b/vta/hardware/chisel/src/main/scala/core/Decode.scala new file mode 100644 index 000000000000..f5bf3406347d --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Decode.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ + +import ISA._ + +/** MemDecode. + * + * Decode memory instructions with a Bundle. This is similar to an union, + * therefore order matters when declaring fields. These are the instructions + * decoded with this bundle: + * - LUOP + * - LWGT + * - LINP + * - LACC + * - SOUT + */ +class MemDecode extends Bundle { + val xpad_1 = UInt(M_PAD_BITS.W) + val xpad_0 = UInt(M_PAD_BITS.W) + val ypad_1 = UInt(M_PAD_BITS.W) + val ypad_0 = UInt(M_PAD_BITS.W) + val xstride = UInt(M_STRIDE_BITS.W) + val xsize = UInt(M_SIZE_BITS.W) + val ysize = UInt(M_SIZE_BITS.W) + val empty_0 = UInt(7.W) // derive this + val dram_offset = UInt(M_DRAM_OFFSET_BITS.W) + val sram_offset = UInt(M_SRAM_OFFSET_BITS.W) + val id = UInt(M_ID_BITS.W) + val push_next = Bool() + val push_prev = Bool() + val pop_next = Bool() + val pop_prev = Bool() + val op = UInt(OP_BITS.W) +} + +/** GemmDecode. + * + * Decode GEMM instruction with a Bundle. This is similar to an union, + * therefore order matters when declaring fields. + */ +class GemmDecode extends Bundle { + val wgt_1 = UInt(C_WIDX_BITS.W) + val wgt_0 = UInt(C_WIDX_BITS.W) + val inp_1 = UInt(C_IIDX_BITS.W) + val inp_0 = UInt(C_IIDX_BITS.W) + val acc_1 = UInt(C_AIDX_BITS.W) + val acc_0 = UInt(C_AIDX_BITS.W) + val empty_0 = Bool() + val lp_1 = UInt(C_ITER_BITS.W) + val lp_0 = UInt(C_ITER_BITS.W) + val uop_end = UInt(C_UOP_END_BITS.W) + val uop_begin = UInt(C_UOP_BGN_BITS.W) + val reset = Bool() + val push_next = Bool() + val push_prev = Bool() + val pop_next = Bool() + val pop_prev = Bool() + val op = UInt(OP_BITS.W) +} + +/** AluDecode. + * + * Decode ALU instructions with a Bundle. This is similar to an union, + * therefore order matters when declaring fields. These are the instructions + * decoded with this bundle: + * - VMIN + * - VMAX + * - VADD + * - VSHX + */ +class AluDecode extends Bundle { + val empty_1 = Bool() + val alu_imm = UInt(C_ALU_IMM_BITS.W) + val alu_use_imm = Bool() + val alu_op = UInt(C_ALU_DEC_BITS.W) + val src_1 = UInt(C_IIDX_BITS.W) + val src_0 = UInt(C_IIDX_BITS.W) + val dst_1 = UInt(C_AIDX_BITS.W) + val dst_0 = UInt(C_AIDX_BITS.W) + val empty_0 = Bool() + val lp_1 = UInt(C_ITER_BITS.W) + val lp_0 = UInt(C_ITER_BITS.W) + val uop_end = UInt(C_UOP_END_BITS.W) + val uop_begin = UInt(C_UOP_BGN_BITS.W) + val reset = Bool() + val push_next = Bool() + val push_prev = Bool() + val pop_next = Bool() + val pop_prev = Bool() + val op = UInt(OP_BITS.W) +} + +/** UopDecode. + * + * Decode micro-ops (uops). + */ +class UopDecode extends Bundle { + val u2 = UInt(10.W) + val u1 = UInt(11.W) + val u0 = UInt(11.W) +} + +/** FetchDecode. + * + * Partial decoding for dispatching instructions to Load, Compute, and Store. + */ +class FetchDecode extends Module { + val io = IO(new Bundle { + val inst = Input(UInt(INST_BITS.W)) + val isLoad = Output(Bool()) + val isCompute = Output(Bool()) + val isStore = Output(Bool()) + }) + val csignals = + ListLookup(io.inst, + List(N, OP_X), + Array( + LUOP -> List(Y, OP_G), + LWGT -> List(Y, OP_L), + LINP -> List(Y, OP_L), + LACC -> List(Y, OP_G), + SOUT -> List(Y, OP_S), + GEMM -> List(Y, OP_G), + FNSH -> List(Y, OP_G), + VMIN -> List(Y, OP_G), + VMAX -> List(Y, OP_G), + VADD -> List(Y, OP_G), + VSHX -> List(Y, OP_G) + ) + ) + + val (cs_val_inst: Bool) :: cs_op_type :: Nil = csignals + + io.isLoad := cs_val_inst & cs_op_type === OP_L + io.isCompute := cs_val_inst & cs_op_type === OP_G + io.isStore := cs_val_inst & cs_op_type === OP_S +} + +/** LoadDecode. + * + * Decode dependencies, type and sync for Load module. + */ +class LoadDecode extends Module { + val io = IO(new Bundle { + val inst = Input(UInt(INST_BITS.W)) + val push_next = Output(Bool()) + val pop_next = Output(Bool()) + val isInput = Output(Bool()) + val isWeight = Output(Bool()) + val isSync = Output(Bool()) + }) + val dec = io.inst.asTypeOf(new MemDecode) + io.push_next := dec.push_next + io.pop_next := dec.pop_next + io.isInput := io.inst === LINP & dec.xsize =/= 0.U + io.isWeight := io.inst === LWGT & dec.xsize =/= 0.U + io.isSync := (io.inst === LINP | io.inst === LWGT) & dec.xsize === 0.U +} + +/** ComputeDecode. + * + * Decode dependencies, type and sync for Compute module. + */ +class ComputeDecode extends Module { + val io = IO(new Bundle { + val inst = Input(UInt(INST_BITS.W)) + val push_next = Output(Bool()) + val push_prev = Output(Bool()) + val pop_next = Output(Bool()) + val pop_prev = Output(Bool()) + val isLoadAcc = Output(Bool()) + val isLoadUop = Output(Bool()) + val isSync = Output(Bool()) + val isAlu = Output(Bool()) + val isGemm = Output(Bool()) + val isFinish = Output(Bool()) + }) + val dec = io.inst.asTypeOf(new MemDecode) + io.push_next := dec.push_next + io.push_prev := dec.push_prev + io.pop_next := dec.pop_next + io.pop_prev := dec.pop_prev + io.isLoadAcc := io.inst === LACC & dec.xsize =/= 0.U + io.isLoadUop := io.inst === LUOP & dec.xsize =/= 0.U + io.isSync := (io.inst === LACC | io.inst === LUOP) & dec.xsize === 0.U + io.isAlu := io.inst === VMIN | io.inst === VMAX | io.inst === VADD | io.inst === VSHX + io.isGemm := io.inst === GEMM + io.isFinish := io.inst === FNSH +} + +/** StoreDecode. + * + * Decode dependencies, type and sync for Store module. + */ +class StoreDecode extends Module { + val io = IO(new Bundle { + val inst = Input(UInt(INST_BITS.W)) + val push_prev = Output(Bool()) + val pop_prev = Output(Bool()) + val isStore = Output(Bool()) + val isSync = Output(Bool()) + }) + val dec = io.inst.asTypeOf(new MemDecode) + io.push_prev := dec.push_prev + io.pop_prev := dec.pop_prev + io.isStore := io.inst === SOUT & dec.xsize =/= 0.U + io.isSync := io.inst === SOUT & dec.xsize === 0.U +} diff --git a/vta/hardware/chisel/src/main/scala/core/Fetch.scala b/vta/hardware/chisel/src/main/scala/core/Fetch.scala new file mode 100644 index 000000000000..bcc164a8f623 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Fetch.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** Fetch. + * + * The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the + * VTA Memory Engine (VME), and push them into an instruction queue called + * inst_q. Once the instruction queue is full, instructions are dispatched to + * the Load, Compute and Store module queues based on the instruction opcode. + * After draining the queue, the fetch unit checks if there are more instructions + * via the ins_count register which is written by the host. + * + * Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB) + * because we are using a DRAM payload of 8-bytes or half of a VTA instruction. + * This should be configurable for larger payloads, i.e. 64-bytes, which can load + * more than one instruction at the time. Finally, the instruction queue is + * sized (entries_q), depending on the maximum burst allowed in the memory. + */ +class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { + val vp = p(ShellKey).vcrParams + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val launch = Input(Bool()) + val ins_baddr = Input(UInt(mp.addrBits.W)) + val ins_count = Input(UInt(vp.regBits.W)) + val vme_rd = new VMEReadMaster + val inst = new Bundle { + val ld = Decoupled(UInt(INST_BITS.W)) + val co = Decoupled(UInt(INST_BITS.W)) + val st = Decoupled(UInt(INST_BITS.W)) + } + }) + val entries_q = 1 << (mp.lenBits - 1) // one-instr-every-two-vme-word + val inst_q = Module(new Queue(UInt(INST_BITS.W), entries_q)) + val dec = Module(new FetchDecode) + + val s1_launch = RegNext(io.launch) + val pulse = io.launch & ~s1_launch + + val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr)) + val rlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) + val ilen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) + + val xrem = Reg(chiselTypeOf(io.ins_count)) + val xsize = (io.ins_count << 1.U) - 1.U + val xmax = (1 << mp.lenBits).U + val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + + val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5) + val state = RegInit(sIdle) + + // control + switch (state) { + is (sIdle) { + when (pulse) { + state := sReadCmd + when (xsize < xmax) { + rlen := xsize + ilen := xsize >> 1.U + xrem := 0.U + } .otherwise { + rlen := xmax - 1.U + ilen := (xmax >> 1.U) - 1.U + xrem := xsize - xmax + } + } + } + is (sReadCmd) { + when (io.vme_rd.cmd.ready) { + state := sReadLSB + } + } + is (sReadLSB) { + when (io.vme_rd.data.valid) { + state := sReadMSB + } + } + is (sReadMSB) { + when (io.vme_rd.data.valid) { + when (inst_q.io.count === ilen) { + state := sDrain + } .otherwise { + state := sReadLSB + } + } + } + is (sDrain) { + when (inst_q.io.count === 0.U) { + when (xrem === 0.U) { + state := sIdle + } .elsewhen (xrem < xmax) { + state := sReadCmd + rlen := xrem + ilen := xrem >> 1.U + xrem := 0.U + } .otherwise { + state := sReadCmd + rlen := xmax - 1.U + ilen := (xmax >> 1.U) - 1.U + xrem := xrem - xmax + } + } + } + } + + // read instructions from dram + when (state === sIdle) { + raddr := io.ins_baddr + } .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) { + raddr := raddr + xmax_bytes + } + + io.vme_rd.cmd.valid := state === sReadCmd + io.vme_rd.cmd.bits.addr := raddr + io.vme_rd.cmd.bits.len := rlen + + io.vme_rd.data.ready := inst_q.io.enq.ready + + val lsb = Reg(chiselTypeOf(io.vme_rd.data.bits)) + val msb = io.vme_rd.data.bits + val inst = Cat(msb, lsb) + + when (state === sReadLSB) { lsb := io.vme_rd.data.bits } + + inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB + inst_q.io.enq.bits := inst + + // decode + dec.io.inst := inst_q.io.deq.bits + + // instruction queues + io.inst.ld.valid := dec.io.isLoad & inst_q.io.deq.valid & state === sDrain + io.inst.co.valid := dec.io.isCompute & inst_q.io.deq.valid & state === sDrain + io.inst.st.valid := dec.io.isStore & inst_q.io.deq.valid & state === sDrain + + io.inst.ld.bits := inst_q.io.deq.bits + io.inst.co.bits := inst_q.io.deq.bits + io.inst.st.bits := inst_q.io.deq.bits + + // check if selected queue is ready + val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt + val deq_ready = + MuxLookup(deq_sel, + false.B, // default + Array( + "h_01".U -> io.inst.ld.ready, + "h_02".U -> io.inst.st.ready, + "h_04".U -> io.inst.co.ready + ) + ) + + // dequeue instruction + inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain + + + // debug + if (debug) { + when (state === sIdle && pulse) { + printf("[Fetch] Launch\n") + } + // instruction + when (inst_q.io.deq.fire()) { + when (dec.io.isLoad) { + printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits) + } + when (dec.io.isCompute) { + printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits) + } + when (dec.io.isStore) { + printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits) + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/ISA.scala b/vta/hardware/chisel/src/main/scala/core/ISA.scala new file mode 100644 index 000000000000..c3bf6097adcd --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/ISA.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ + +/** ISAConstants. + * + * These constants are used for decoding (parsing) fields on instructions. + */ +trait ISAConstants +{ + val INST_BITS = 128 + + val OP_BITS = 3 + + val M_DEP_BITS = 4 + val M_ID_BITS = 2 + val M_SRAM_OFFSET_BITS = 16 + val M_DRAM_OFFSET_BITS = 32 + val M_SIZE_BITS = 16 + val M_STRIDE_BITS = 16 + val M_PAD_BITS = 4 + + val C_UOP_BGN_BITS = 13 + val C_UOP_END_BITS = 14 + val C_ITER_BITS = 14 + val C_AIDX_BITS = 11 + val C_IIDX_BITS = 11 + val C_WIDX_BITS = 10 + val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction + val C_ALU_OP_BITS = 3 + val C_ALU_IMM_BITS = 16 + + val Y = true.B + val N = false.B + + val OP_L = 0.asUInt(OP_BITS.W) + val OP_S = 1.asUInt(OP_BITS.W) + val OP_G = 2.asUInt(OP_BITS.W) + val OP_F = 3.asUInt(OP_BITS.W) + val OP_A = 4.asUInt(OP_BITS.W) + val OP_X = 5.asUInt(OP_BITS.W) + + val ALU_OP_NUM = 5 + val ALU_OP = Enum(ALU_OP_NUM) + + val M_ID_U = 0.asUInt(M_ID_BITS.W) + val M_ID_W = 1.asUInt(M_ID_BITS.W) + val M_ID_I = 2.asUInt(M_ID_BITS.W) + val M_ID_A = 3.asUInt(M_ID_BITS.W) +} + +/** ISA. + * + * This is the VTA ISA, here we specify the cares and dont-cares that makes + * decoding easier. Since instructions are quite long 128-bit, we could generate + * these based on ISAConstants. + * + * FIXME: VSHX should be replaced by VSHR and VSHL once we modify the compiler + * TODO: Add VXOR to clear accumulator + */ +object ISA { + def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000") + def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000") + def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000") + def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000") + def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001") + def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010") + def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011") +} diff --git a/vta/hardware/chisel/src/main/scala/core/Load.scala b/vta/hardware/chisel/src/main/scala/core/Load.scala new file mode 100644 index 000000000000..64795139aa4e --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Load.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** Load. + * + * Load inputs and weights from memory (DRAM) into scratchpads (SRAMs). + * This module instantiate the TensorLoad unit which is in charge of + * loading 1D and 2D tensors to scratchpads, so it can be used by + * other modules such as Compute. + */ +class Load(debug: Boolean = false)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val i_post = Input(Bool()) + val o_post = Output(Bool()) + val inst = Flipped(Decoupled(UInt(INST_BITS.W))) + val inp_baddr = Input(UInt(mp.addrBits.W)) + val wgt_baddr = Input(UInt(mp.addrBits.W)) + val vme_rd = Vec(2, new VMEReadMaster) + val inp = new TensorClient(tensorType = "inp") + val wgt = new TensorClient(tensorType = "wgt") + }) + val sIdle :: sSync :: sExe :: Nil = Enum(3) + val state = RegInit(sIdle) + + val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0)) + val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries)) + + val dec = Module(new LoadDecode) + dec.io.inst := inst_q.io.deq.bits + + val tensorType = Seq("inp", "wgt") + val tensorDec = Seq(dec.io.isInput, dec.io.isWeight) + val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i)))) + + val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B) + val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done) + + // control + switch (state) { + is (sIdle) { + when (start) { + when (dec.io.isSync) { + state := sSync + } .elsewhen (dec.io.isInput || dec.io.isWeight) { + state := sExe + } + } + } + is (sSync) { + state := sIdle + } + is (sExe) { + when (done) { + state := sIdle + } + } + } + + // instructions + inst_q.io.enq <> io.inst + inst_q.io.deq.ready := (state === sExe & done) | (state === sSync) + + // load tensor + // [0] input (inp) + // [1] weight (wgt) + val ptr = Seq(io.inp_baddr, io.wgt_baddr) + val tsor = Seq(io.inp, io.wgt) + for (i <- 0 until 2) { + tensorLoad(i).io.start := state === sIdle & start & tensorDec(i) + tensorLoad(i).io.inst := inst_q.io.deq.bits + tensorLoad(i).io.baddr := ptr(i) + tensorLoad(i).io.tensor <> tsor(i) + io.vme_rd(i) <> tensorLoad(i).io.vme_rd + } + + // semaphore + s.io.spost := io.i_post + s.io.swait := dec.io.pop_next & (state === sIdle & start) + io.o_post := dec.io.push_next & ((state === sExe & done) | (state === sSync)) + + // debug + if (debug) { + // start + when (state === sIdle && start) { + when (dec.io.isSync) { + printf("[Load] start sync\n") + } .elsewhen (dec.io.isInput) { + printf("[Load] start input\n") + } .elsewhen (dec.io.isWeight) { + printf("[Load] start weight\n") + } + } + // done + when (state === sSync) { + printf("[Load] done sync\n") + } + when (state === sExe) { + when (done) { + when (dec.io.isInput) { + printf("[Load] done input\n") + } .elsewhen (dec.io.isWeight) { + printf("[Load] done weight\n") + } + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/LoadUop.scala b/vta/hardware/chisel/src/main/scala/core/LoadUop.scala new file mode 100644 index 000000000000..07296523b254 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/LoadUop.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** UopMaster. + * + * Uop interface used by a master module, i.e. TensorAlu or TensorGemm, + * to request a micro-op (uop) from the uop-scratchpad. The index (idx) is + * used as an address to find the uop in the uop-scratchpad. + */ +class UopMaster(implicit p: Parameters) extends Bundle { + val addrBits = log2Ceil(p(CoreKey).uopMemDepth) + val idx = ValidIO(UInt(addrBits.W)) + val data = Flipped(ValidIO(new UopDecode)) + override def cloneType = new UopMaster().asInstanceOf[this.type] +} + +/** UopClient. + * + * Uop interface used by a client module, i.e. LoadUop, to receive + * a request from a master module, i.e. TensorAlu or TensorGemm. + * The index (idx) is used as an address to find the uop in the uop-scratchpad. + */ +class UopClient(implicit p: Parameters) extends Bundle { + val addrBits = log2Ceil(p(CoreKey).uopMemDepth) + val idx = Flipped(ValidIO(UInt(addrBits.W))) + val data = ValidIO(new UopDecode) + override def cloneType = new UopClient().asInstanceOf[this.type] +} + +/** LoadUop. + * + * Load micro-ops (uops) from memory, i.e. DRAM, and store them in the + * uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in + * group of 2 given the fact that the DRAM payload is 8-bytes. This module + * should be modified later on to support different DRAM sizes efficiently. + */ +class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val baddr = Input(UInt(mp.addrBits.W)) + val vme_rd = new VMEReadMaster + val uop = new UopClient + }) + val numUop = 2 // store two uops per sram word + val uopBits = p(CoreKey).uopBits + val uopDepth = p(CoreKey).uopMemDepth / numUop + + val dec = io.inst.asTypeOf(new MemDecode) + val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr)) + val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) + val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) + val xrem = Reg(chiselTypeOf(dec.xsize)) + val xsize = dec.xsize(0) + (dec.xsize >> log2Ceil(numUop)) - 1.U + val xmax = (1 << mp.lenBits).U + val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + + val offsetIsEven = (dec.sram_offset % 2.U) === 0.U + val sizeIsEven = (dec.xsize % 2.U) === 0.U + + val sIdle :: sReadCmd :: sReadData :: Nil = Enum(3) + val state = RegInit(sIdle) + + // control + switch (state) { + is (sIdle) { + when (io.start) { + state := sReadCmd + when (xsize < xmax) { + xlen := xsize + xrem := 0.U + } .otherwise { + xlen := xmax - 1.U + xrem := xsize - xmax + } + } + } + is (sReadCmd) { + when (io.vme_rd.cmd.ready) { + state := sReadData + } + } + is (sReadData) { + when (io.vme_rd.data.valid) { + when(xcnt === xlen) { + when (xrem === 0.U) { + state := sIdle + } .elsewhen (xrem < xmax) { + state := sReadCmd + xlen := xrem + xrem := 0.U + } .otherwise { + state := sReadCmd + xlen := xmax - 1.U + xrem := xrem - xmax + } + } + } + } + } + + // read-from-dram + when (state === sIdle) { + when (offsetIsEven) { + raddr := io.baddr + dec.dram_offset + } .otherwise { + raddr := io.baddr + dec.dram_offset - 4.U + } + } .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) { + raddr := raddr + xmax_bytes + } + + io.vme_rd.cmd.valid := state === sReadCmd + io.vme_rd.cmd.bits.addr := raddr + io.vme_rd.cmd.bits.len := xlen + + io.vme_rd.data.ready := state === sReadData + + when (state =/= sReadData) { + xcnt := 0.U + } .elsewhen (io.vme_rd.data.fire()) { + xcnt := xcnt + 1.U + } + + val waddr = Reg(UInt(log2Ceil(uopDepth).W)) + when (state === sIdle) { + waddr := dec.sram_offset >> log2Ceil(numUop) + } .elsewhen (io.vme_rd.data.fire()) { + waddr := waddr + 1.U + } + + val wdata = Wire(Vec(numUop, UInt(uopBits.W))) + val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata)) + val wmask = Reg(Vec(numUop, Bool())) + + when (offsetIsEven) { + when (sizeIsEven) { + wmask := "b_11".U.asTypeOf(wmask) + } .elsewhen (io.vme_rd.cmd.fire()) { + when (dec.xsize === 1.U) { + wmask := "b_01".U.asTypeOf(wmask) + } .otherwise { + wmask := "b_11".U.asTypeOf(wmask) + } + } .elsewhen (io.vme_rd.data.fire()) { + when (xcnt === xlen - 1.U) { + wmask := "b_01".U.asTypeOf(wmask) + } .otherwise { + wmask := "b_11".U.asTypeOf(wmask) + } + } + } .otherwise { + when (io.vme_rd.cmd.fire()) { + wmask := "b_10".U.asTypeOf(wmask) + } .elsewhen (io.vme_rd.data.fire()) { + when (sizeIsEven && xcnt === xlen - 1.U) { + wmask := "b_01".U.asTypeOf(wmask) + } .otherwise { + wmask := "b_11".U.asTypeOf(wmask) + } + } + } + + wdata := io.vme_rd.data.bits.asTypeOf(wdata) + when (io.vme_rd.data.fire()) { + mem.write(waddr, wdata, wmask) + } + + // read-from-sram + io.uop.data.valid := RegNext(io.uop.idx.valid) + + val sIdx = io.uop.idx.bits % numUop.U + val rIdx = io.uop.idx.bits >> log2Ceil(numUop) + val memRead = mem.read(rIdx, io.uop.idx.valid) + val sWord = memRead.asUInt.asTypeOf(wdata) + val sUop = sWord(sIdx).asTypeOf(io.uop.data.bits) + + io.uop.data.bits <> sUop + + // done + io.done := state === sReadData & io.vme_rd.data.valid & xcnt === xlen & xrem === 0.U + + // debug + if (debug) { + when (io.vme_rd.cmd.fire()) { + printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem) + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/Semaphore.scala b/vta/hardware/chisel/src/main/scala/core/Semaphore.scala new file mode 100644 index 000000000000..06df51e20e27 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Semaphore.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ + +/** Semaphore. + * + * This semaphore is used instead of push/pop fifo, used in the initial + * version of VTA. This semaphore is incremented (spost) or decremented (swait) + * depending on the push and pop fields on instructions to prevent RAW and WAR + * hazards. + */ +class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module { + val io = IO(new Bundle { + val spost = Input(Bool()) + val swait = Input(Bool()) + val sready = Output(Bool()) + }) + val cnt = RegInit(counterInitValue.U(counterBits.W)) + when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U } + when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U } + io.sready := cnt =/= 0.U +} diff --git a/vta/hardware/chisel/src/main/scala/core/Store.scala b/vta/hardware/chisel/src/main/scala/core/Store.scala new file mode 100644 index 000000000000..5d89871e65be --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Store.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** Store. + * + * Store results back to memory (DRAM) from scratchpads (SRAMs). + * This module instantiate the TensorStore unit which is in charge + * of storing 1D and 2D tensors to main memory. + */ +class Store(debug: Boolean = false)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val i_post = Input(Bool()) + val o_post = Output(Bool()) + val inst = Flipped(Decoupled(UInt(INST_BITS.W))) + val out_baddr = Input(UInt(mp.addrBits.W)) + val vme_wr = new VMEWriteMaster + val out = new TensorClient(tensorType = "out") + }) + val sIdle :: sSync :: sExe :: Nil = Enum(3) + val state = RegInit(sIdle) + + val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0)) + val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries)) + + val dec = Module(new StoreDecode) + dec.io.inst := inst_q.io.deq.bits + + val tensorStore = Module(new TensorStore(tensorType = "out")) + + val start = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s.io.sready, true.B) + val done = tensorStore.io.done + + // control + switch (state) { + is (sIdle) { + when (start) { + when (dec.io.isSync) { + state := sSync + } .elsewhen (dec.io.isStore) { + state := sExe + } + } + } + is (sSync) { + state := sIdle + } + is (sExe) { + when (done) { + state := sIdle + } + } + } + + // instructions + inst_q.io.enq <> io.inst + inst_q.io.deq.ready := (state === sExe & done) | (state === sSync) + + // store + tensorStore.io.start := state === sIdle & start & dec.io.isStore + tensorStore.io.inst := inst_q.io.deq.bits + tensorStore.io.baddr := io.out_baddr + io.vme_wr <> tensorStore.io.vme_wr + tensorStore.io.tensor <> io.out + + // semaphore + s.io.spost := io.i_post + s.io.swait := dec.io.pop_prev & (state === sIdle & start) + io.o_post := dec.io.push_prev & ((state === sExe & done) | (state === sSync)) + + // debug + if (debug) { + // start + when (state === sIdle && start) { + when (dec.io.isSync) { + printf("[Store] start sync\n") + } .elsewhen (dec.io.isStore) { + printf("[Store] start\n") + } + } + // done + when (state === sSync) { + printf("[Store] done sync\n") + } + when (state === sExe) { + when (done) { + printf("[Store] done\n") + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala b/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala new file mode 100644 index 000000000000..7f429be7249f --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ + +/** ALU datapath */ +class Alu(implicit p: Parameters) extends Module { + val aluBits = p(CoreKey).accBits + val io = IO(new Bundle { + val opcode = Input(UInt(C_ALU_OP_BITS.W)) + val a = Input(SInt(aluBits.W)) + val b = Input(SInt(aluBits.W)) + val y = Output(SInt(aluBits.W)) + }) + + // FIXME: the following three will change once we support properly SHR and SHL + val ub = io.b.asUInt + val width = log2Ceil(aluBits) + val m = ~ub(width - 1, 0) + 1.U + + val n = ub(width - 1, 0) + val fop = Seq(Mux(io.a < io.b, io.a, io.b), + Mux(io.a < io.b, io.b, io.a), + io.a + io.b, + io.a >> n, + io.a << m) + + val opmux = Seq.tabulate(ALU_OP_NUM)(i => ALU_OP(i) -> fop(i)) + io.y := MuxLookup(io.opcode, io.a, opmux) +} + +/** Pipelined ALU */ +class AluReg(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val opcode = Input(UInt(C_ALU_OP_BITS.W)) + val a = Flipped(ValidIO(UInt(p(CoreKey).accBits.W))) + val b = Flipped(ValidIO(UInt(p(CoreKey).accBits.W))) + val y = ValidIO(UInt(p(CoreKey).accBits.W)) + }) + val alu = Module(new Alu) + val rA = RegEnable(io.a.bits, io.a.valid) + val rB = RegEnable(io.b.bits, io.b.valid) + val valid = RegNext(io.b.valid) + + alu.io.opcode := io.opcode + + // register input + alu.io.a := rA.asSInt + alu.io.b := rB.asSInt + + // output + io.y.valid := valid + io.y.bits := alu.io.y.asUInt +} + +/** Vector of pipeline ALUs */ +class AluVector(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val opcode = Input(UInt(C_ALU_OP_BITS.W)) + val acc_a = new TensorMasterData(tensorType = "acc") + val acc_b = new TensorMasterData(tensorType = "acc") + val acc_y = new TensorClientData(tensorType = "acc") + val out = new TensorClientData(tensorType = "out") + }) + val blockOut = p(CoreKey).blockOut + val f = Seq.fill(blockOut)(Module(new AluReg)) + val valid = Wire(Vec(blockOut, Bool())) + for (i <- 0 until blockOut) { + f(i).io.opcode := io.opcode + f(i).io.a.valid := io.acc_a.data.valid + f(i).io.a.bits := io.acc_a.data.bits(0)(i) + f(i).io.b.valid := io.acc_b.data.valid + f(i).io.b.bits := io.acc_b.data.bits(0)(i) + valid(i) := f(i).io.y.valid + io.acc_y.data.bits(0)(i) := f(i).io.y.bits + io.out.data.bits(0)(i) := f(i).io.y.bits + } + io.acc_y.data.valid := valid.asUInt.andR + io.out.data.valid := valid.asUInt.andR +} + +/** TensorAlu. + * + * This unit instantiate the ALU vector unit (AluVector) and go over the + * micro-ops (uops) which are used to read the source operands (vectors) + * from the acc-scratchpad and then they are written back the same + * acc-scratchpad. + */ +class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val uop = new UopMaster + val acc = new TensorMaster(tensorType = "acc") + val out = new TensorMaster(tensorType = "out") + }) + val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6) + val state = RegInit(sIdle) + val alu = Module(new AluVector) + val dec = io.inst.asTypeOf(new AluDecode) + val uop_idx = Reg(chiselTypeOf(dec.uop_end)) + val uop_end = dec.uop_end + val uop_dst = Reg(chiselTypeOf(dec.uop_end)) + val uop_src = Reg(chiselTypeOf(dec.uop_end)) + val cnt_o = Reg(chiselTypeOf(dec.lp_0)) + val dst_o = Reg(chiselTypeOf(dec.uop_end)) + val src_o = Reg(chiselTypeOf(dec.uop_end)) + val cnt_i = Reg(chiselTypeOf(dec.lp_1)) + val dst_i = Reg(chiselTypeOf(dec.uop_end)) + val src_i = Reg(chiselTypeOf(dec.uop_end)) + val done = + state === sExe & + alu.io.out.data.valid & + (cnt_o === dec.lp_0 - 1.U) & + (cnt_i === dec.lp_1 - 1.U) & + (uop_idx === uop_end - 1.U) + + switch (state) { + is (sIdle) { + when (io.start) { + state := sReadUop + } + } + is (sReadUop) { + state := sComputeIdx + } + is (sComputeIdx) { + state := sReadTensorA + } + is (sReadTensorA) { + state := sReadTensorB + } + is (sReadTensorB) { + state := sExe + } + is (sExe) { + when (alu.io.out.data.valid) { + when ((cnt_o === dec.lp_0 - 1.U) && + (cnt_i === dec.lp_1 - 1.U) && + (uop_idx === uop_end - 1.U)) { + state := sIdle + } .otherwise { + state := sReadUop + } + } + } + } + + when (state === sIdle || + (state === sExe && + alu.io.out.data.valid && + uop_idx === uop_end - 1.U)) { + uop_idx := dec.uop_begin + } .elsewhen (state === sExe && alu.io.out.data.valid) { + uop_idx := uop_idx + 1.U + } + + when (state === sIdle) { + cnt_o := 0.U + dst_o := 0.U + src_o := 0.U + } .elsewhen (state === sExe && + alu.io.out.data.valid && + uop_idx === uop_end - 1.U && + cnt_i === dec.lp_1 - 1.U) { + cnt_o := cnt_o + 1.U + dst_o := dst_o + dec.dst_0 + src_o := src_o + dec.src_0 + } + + when (state === sIdle) { + cnt_i := 0.U + dst_i := 0.U + src_i := 0.U + } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) { + cnt_i := 0.U + dst_i := dst_o + src_i := src_o + } .elsewhen (state === sExe && + alu.io.out.data.valid && + uop_idx === uop_end - 1.U) { + cnt_i := cnt_i + 1.U + dst_i := dst_i + dec.dst_1 + src_i := src_i + dec.src_1 + } + + when (state === sComputeIdx && io.uop.data.valid) { + uop_dst := io.uop.data.bits.u0 + dst_i + uop_src := io.uop.data.bits.u1 + src_i + } + + // uop + io.uop.idx.valid := state === sReadUop + io.uop.idx.bits := uop_idx + + // acc_i + io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm) + io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src) + + // imm + val tensorImm = Wire(new TensorClientData(tensorType = "acc")) + tensorImm.data.valid := state === sReadTensorB + tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } } + + // alu + val isSHR = dec.alu_op === ALU_OP(3) + val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1) + val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op)) + alu.io.opcode := fixme_alu_op + alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB + alu.io.acc_a.data.bits <> io.acc.rd.data.bits + alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe) + alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits) + + // acc_o + io.acc.wr.valid := alu.io.acc_y.data.valid + io.acc.wr.bits.idx := uop_dst + io.acc.wr.bits.data <> alu.io.acc_y.data.bits + + // out + io.out.wr.valid := alu.io.out.data.valid + io.out.wr.bits.idx := uop_dst + io.out.wr.bits.data <> alu.io.out.data.bits + io.out.tieoffRead() // write-only + + io.done := done + + if (debug) { + + when (state === sReadUop) { + printf("[TensorAlu] [uop] idx:%x\n", uop_idx) + } + + when (state === sReadTensorA) { + printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src) + } + + when (state === sIdle && io.start) { + printf(p"[TensorAlu] decode:$dec\n") + } + + alu.io.acc_a.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (alu.io.acc_a.data.valid) { + printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem) + } + } + } + + alu.io.acc_b.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (alu.io.acc_b.data.valid) { + printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem) + } + } + } + + alu.io.acc_y.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (alu.io.acc_y.data.valid) { + printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem) + } + } + } + + alu.io.out.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (alu.io.out.data.valid) { + printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem) + } + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala new file mode 100644 index 000000000000..2dd8c33aea33 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import chisel3.experimental._ +import vta.util.config._ +import scala.math.pow + +/** Pipelined multiply and accumulate */ +class MAC(dataBits: Int = 8, cBits: Int = 16, outBits: Int = 17) extends Module { + require (cBits >= dataBits * 2) + require (outBits >= dataBits * 2) + val io = IO(new Bundle { + val a = Input(SInt(dataBits.W)) + val b = Input(SInt(dataBits.W)) + val c = Input(SInt(cBits.W)) + val y = Output(SInt(outBits.W)) + }) + val mult = Wire(SInt(cBits.W)) + val add = Wire(SInt(outBits.W)) + val rA = RegNext(io.a) + val rB = RegNext(io.b) + val rC = RegNext(io.c) + mult := rA * rB + add := rC + mult + io.y := add +} + +/** Pipelined adder */ +class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module { + require (outBits >= dataBits) + val io = IO(new Bundle { + val a = Input(SInt(dataBits.W)) + val b = Input(SInt(dataBits.W)) + val y = Output(SInt(outBits.W)) + }) + val add = Wire(SInt(outBits.W)) + val rA = RegNext(io.a) + val rB = RegNext(io.b) + add := rA + rB + io.y := add +} + +/** Pipelined DotProduct based on MAC and Adder */ +class DotProduct(dataBits: Int = 8, size: Int = 16) extends Module { + val errMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n" + require(size >= 4 && isPow2(size), errMsg) + val b = dataBits * 2 + val outBits = b + log2Ceil(size) + 1 + val io = IO(new Bundle { + val a = Input(Vec(size, SInt(dataBits.W))) + val b = Input(Vec(size, SInt(dataBits.W))) + val y = Output(SInt(outBits.W)) + }) + val p = log2Ceil(size/2) + val s = Seq.tabulate(log2Ceil(size))(i => pow(2, p - i).toInt) + val da = Seq.tabulate(s(0))(i => RegNext(io.a(s(0) + i))) + val db = Seq.tabulate(s(0))(i => RegNext(io.b(s(0) + i))) + val m = Seq.tabulate(2)(i => + Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b + i, outBits = b + i + 1))) + ) + val a = Seq.tabulate(p)(i => + Seq.fill(s(i + 1))(Module(new Adder(dataBits = b + i + 2, outBits = b + i + 3))) + ) + + for (i <- 0 until log2Ceil(size)) { + for (j <- 0 until s(i)) { + if (i == 0) { + m(i)(j).io.a := io.a(j) + m(i)(j).io.b := io.b(j) + m(i)(j).io.c := 0.S + m(i + 1)(j).io.a := da(j) + m(i + 1)(j).io.b := db(j) + m(i + 1)(j).io.c := m(i)(j).io.y + } else if (i == 1) { + a(i - 1)(j).io.a := m(i)(2*j).io.y + a(i - 1)(j).io.b := m(i)(2*j + 1).io.y + } else { + a(i - 1)(j).io.a := a(i - 2)(2*j).io.y + a(i - 1)(j).io.b := a(i - 2)(2*j + 1).io.y + } + } + } + io.y := a(p-1)(0).io.y +} + +/** Perform matric-vector-multiplication based on DotProduct */ +class MatrixVectorCore(implicit p: Parameters) extends Module { + val accBits = p(CoreKey).accBits + val size = p(CoreKey).blockOut + val dataBits = p(CoreKey).inpBits + val io = IO(new Bundle{ + val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr + val inp = new TensorMasterData(tensorType = "inp") + val wgt = new TensorMasterData(tensorType = "wgt") + val acc_i = new TensorMasterData(tensorType = "acc") + val acc_o = new TensorClientData(tensorType = "acc") + val out = new TensorClientData(tensorType = "out") + }) + val dot = Seq.fill(size)(Module(new DotProduct(dataBits, size))) + val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1))) + val add = Seq.fill(size)(Wire(SInt(accBits.W))) + val vld = Wire(Vec(size, Bool())) + + for (i <- 0 until size) { + acc(i).io.enq.valid := io.inp.data.valid & io.wgt.data.valid & io.acc_i.data.valid & ~io.reset + acc(i).io.enq.bits := io.acc_i.data.bits(0)(i) + for (j <- 0 until size) { + dot(i).io.a(j) := io.inp.data.bits(0)(j).asSInt + dot(i).io.b(j) := io.wgt.data.bits(i)(j).asSInt + } + add(i) := acc(i).io.deq.bits.asSInt + dot(i).io.y + io.acc_o.data.bits(0)(i) := Mux(io.reset, 0.U, add(i).asUInt) + io.out.data.bits(0)(i) := add(i).asUInt + vld(i) := acc(i).io.deq.valid + } + io.acc_o.data.valid := vld.asUInt.andR | io.reset + io.out.data.valid := vld.asUInt.andR +} + +/** TensorGemm. + * + * This unit instantiate the MatrixVectorCore and go over the + * micro-ops (uops) which are used to read inputs, weights and biases, + * and writes results back to the acc and out scratchpads. + * + * Also, the TensorGemm uses the reset field in the Gemm instruction to + * clear or zero-out the acc-scratchpad locations based on the micro-ops. + */ +class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val uop = new UopMaster + val inp = new TensorMaster(tensorType = "inp") + val wgt = new TensorMaster(tensorType = "wgt") + val acc = new TensorMaster(tensorType = "acc") + val out = new TensorMaster(tensorType = "out") + }) + val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6) + val state = RegInit(sIdle) + val mvc = Module(new MatrixVectorCore) + val dec = io.inst.asTypeOf(new GemmDecode) + val uop_idx = Reg(chiselTypeOf(dec.uop_end)) + val uop_end = dec.uop_end + val uop_acc = Reg(chiselTypeOf(dec.uop_end)) + val uop_inp = Reg(chiselTypeOf(dec.uop_end)) + val uop_wgt = Reg(chiselTypeOf(dec.uop_end)) + val cnt_o = Reg(chiselTypeOf(dec.lp_0)) + val acc_o = Reg(chiselTypeOf(dec.uop_end)) + val inp_o = Reg(chiselTypeOf(dec.uop_end)) + val wgt_o = Reg(chiselTypeOf(dec.uop_end)) + val cnt_i = Reg(chiselTypeOf(dec.lp_1)) + val acc_i = Reg(chiselTypeOf(dec.uop_end)) + val inp_i = Reg(chiselTypeOf(dec.uop_end)) + val wgt_i = Reg(chiselTypeOf(dec.uop_end)) + val pBits = log2Ceil(p(CoreKey).blockOut) + 1 + val inflight = Reg(UInt(pBits.W)) + val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits)) + val done = inflight === 0.U & + ((state === sExe & + cnt_o === dec.lp_0 - 1.U & + cnt_i === dec.lp_1 - 1.U & + uop_idx === uop_end - 1.U & + inflight === 0.U) | + state === sWait) + + switch (state) { + is (sIdle) { + when (io.start) { + state := sReadUop + } + } + is (sReadUop) { + state := sComputeIdx + } + is (sComputeIdx) { + state := sReadTensor + } + is (sReadTensor) { + state := sExe + } + is (sExe) { + when ((cnt_o === dec.lp_0 - 1.U) && + (cnt_i === dec.lp_1 - 1.U) && + (uop_idx === uop_end - 1.U)) { + when (inflight =/= 0.U) { + state := sWait + } .otherwise { + state := sIdle + } + } .otherwise { + state := sReadUop + } + } + is (sWait) { + when (inflight === 0.U) { + state := sIdle + } + } + } + + when (state === sIdle) { + inflight := 0.U + } .elsewhen (!dec.reset) { + when (state === sExe && inflight =/= ((1 << pBits) - 1).asUInt) { // overflow check + inflight := inflight + 1.U + } .elsewhen (mvc.io.acc_o.data.valid && inflight =/= 0.U) { // underflow check + inflight := inflight - 1.U + } + } + + when (state === sIdle || + (state === sExe && + uop_idx === uop_end - 1.U)) { + uop_idx := dec.uop_begin + } .elsewhen (state === sExe) { + uop_idx := uop_idx + 1.U + } + + when (state === sIdle) { + cnt_o := 0.U + acc_o := 0.U + inp_o := 0.U + wgt_o := 0.U + } .elsewhen (state === sExe && + uop_idx === uop_end - 1.U && + cnt_i === dec.lp_1 - 1.U) { + cnt_o := cnt_o + 1.U + acc_o := acc_o + dec.acc_0 + inp_o := inp_o + dec.inp_0 + wgt_o := wgt_o + dec.wgt_0 + } + + when (state === sIdle) { + cnt_i := 0.U + acc_i := 0.U + inp_i := 0.U + wgt_i := 0.U + } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) { + cnt_i := 0.U + acc_i := acc_o + inp_i := inp_o + wgt_i := wgt_o + } .elsewhen (state === sExe && + uop_idx === uop_end - 1.U) { + cnt_i := cnt_i + 1.U + acc_i := acc_i + dec.acc_1 + inp_i := inp_i + dec.inp_1 + wgt_i := wgt_i + dec.wgt_1 + } + + when (state === sComputeIdx && io.uop.data.valid) { + uop_acc := io.uop.data.bits.u0 + acc_i + uop_inp := io.uop.data.bits.u1 + inp_i + uop_wgt := io.uop.data.bits.u2 + wgt_i + } + + wrpipe.io.enq.valid := state === sExe & ~dec.reset + wrpipe.io.enq.bits := uop_acc + + // uop + io.uop.idx.valid := state === sReadUop + io.uop.idx.bits := uop_idx + + // inp + io.inp.rd.idx.valid := state === sReadTensor + io.inp.rd.idx.bits := uop_inp + io.inp.tieoffWrite() // read-only + + // wgt + io.wgt.rd.idx.valid := state === sReadTensor + io.wgt.rd.idx.bits := uop_wgt + io.wgt.tieoffWrite() // read-only + + // acc_i + io.acc.rd.idx.valid := state === sReadTensor + io.acc.rd.idx.bits := uop_acc + + // mvc + mvc.io.reset := dec.reset & state === sExe + mvc.io.inp.data <> io.inp.rd.data + mvc.io.wgt.data <> io.wgt.rd.data + mvc.io.acc_i.data <> io.acc.rd.data + + // acc_o + io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid) + io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits) + io.acc.wr.bits.data <> mvc.io.acc_o.data.bits + + // out + io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid + io.out.wr.bits.idx := wrpipe.io.deq.bits + io.out.wr.bits.data <> mvc.io.out.data.bits + io.out.tieoffRead() // write-only + + io.done := done + + if (debug) { + when (state === sReadUop && ~dec.reset) { + printf("[TensorGemm] [uop] idx:%x\n", uop_idx) + } + + when (state === sReadTensor && ~dec.reset) { + printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt) + } + + io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) => + when (io.inp.rd.data.valid && ~dec.reset) { + printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt) + } + } + + io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) => + when (io.wgt.rd.data.valid && ~dec.reset) { + printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt) + } + } + + io.acc.rd.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (io.acc.rd.data.valid && ~dec.reset) { + printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem) + } + } + } + + mvc.io.acc_o.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (mvc.io.acc_o.data.valid && ~dec.reset) { + printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem) + } + } + } + + mvc.io.out.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (mvc.io.out.data.valid && ~dec.reset) { + printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem) + } + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala new file mode 100644 index 000000000000..d96a681e7d69 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** TensorStore. + * + * Load 1D and 2D tensors from main memory (DRAM) to input/weight + * scratchpads (SRAM). Also, there is support for zero padding, while + * doing the load. Zero-padding works on the y and x axis, and it is + * managed by TensorPadCtrl. The TensorDataCtrl is in charge of + * handling the way tensors are stored on the scratchpads. + */ +class TensorLoad(tensorType: String = "none", debug: Boolean = false) + (implicit p: Parameters) extends Module { + val tp = new TensorParams(tensorType) + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val baddr = Input(UInt(mp.addrBits.W)) + val vme_rd = new VMEReadMaster + val tensor = new TensorClient(tensorType) + }) + val sizeFactor = tp.tensorLength * tp.numMemBlock + val strideFactor = tp.tensorLength * tp.tensorWidth + + val dec = io.inst.asTypeOf(new MemDecode) + val dataCtrl = Module(new TensorDataCtrl(sizeFactor, strideFactor)) + val dataCtrlDone = RegInit(false.B) + val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor)) + val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor)) + val xPadCtrl0 = Module(new TensorPadCtrl(padType = "XPad0", sizeFactor)) + val xPadCtrl1 = Module(new TensorPadCtrl(padType = "XPad1", sizeFactor)) + + val tag = Reg(UInt(8.W)) + val set = Reg(UInt(8.W)) + + val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7) + val state = RegInit(sIdle) + + // control + switch (state) { + is (sIdle) { + when (io.start) { + when (dec.ypad_0 =/= 0.U) { + state := sYPad0 + } .elsewhen (dec.xpad_0 =/= 0.U) { + state := sXPad0 + } .otherwise { + state := sReadCmd + } + } + } + is (sYPad0) { + when (yPadCtrl0.io.done) { + when (dec.xpad_0 =/= 0.U) { + state := sXPad0 + } .otherwise { + state := sReadCmd + } + } + } + is (sXPad0) { + when (xPadCtrl0.io.done) { + state := sReadCmd + } + } + is (sReadCmd) { + when (io.vme_rd.cmd.ready) { + state := sReadData + } + } + is (sReadData) { + when (io.vme_rd.data.valid) { + when (dataCtrl.io.done) { + when (dec.xpad_1 =/= 0.U) { + state := sXPad1 + } .elsewhen (dec.ypad_1 =/= 0.U) { + state := sYPad1 + } .otherwise { + state := sIdle + } + } .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) { + when (dec.xpad_1 =/= 0.U) { + state := sXPad1 + } .elsewhen (dec.xpad_0 =/= 0.U) { + state := sXPad0 + } .otherwise { + state := sReadCmd + } + } + } + } + is (sXPad1) { + when (xPadCtrl1.io.done) { + when (dataCtrlDone) { + when (dec.ypad_1 =/= 0.U) { + state := sYPad1 + } .otherwise { + state := sIdle + } + } .otherwise { + when (dec.xpad_0 =/= 0.U) { + state := sXPad0 + } .otherwise { + state := sReadCmd + } + } + } + } + is (sYPad1) { + when (yPadCtrl1.io.done && dataCtrlDone) { + state := sIdle + } + } + } + + // data controller + dataCtrl.io.start := state === sIdle & io.start + dataCtrl.io.inst := io.inst + dataCtrl.io.baddr := io.baddr + dataCtrl.io.xinit := io.vme_rd.cmd.fire() + dataCtrl.io.xupdate := io.vme_rd.data.fire() + dataCtrl.io.yupdate := io.vme_rd.data.fire() + + when (state === sIdle) { + dataCtrlDone := false.B + } .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) { + dataCtrlDone := true.B + } + + // pad + yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start + + yPadCtrl1.io.start := dec.ypad_1 =/= 0.U & + ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) | + (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone)) + + xPadCtrl0.io.start := dec.xpad_0 =/= 0.U & + ((state === sIdle & io.start) | + (state === sYPad0 & yPadCtrl0.io.done) | + (io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) | + (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone)) + + xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() & + ((dataCtrl.io.done) | + (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U)) + + yPadCtrl0.io.inst := io.inst + yPadCtrl1.io.inst := io.inst + xPadCtrl0.io.inst := io.inst + xPadCtrl1.io.inst := io.inst + + // read-from-dram + io.vme_rd.cmd.valid := state === sReadCmd + io.vme_rd.cmd.bits.addr := dataCtrl.io.addr + io.vme_rd.cmd.bits.len := dataCtrl.io.len + + io.vme_rd.data.ready := state === sReadData + + // write-to-sram + val isZeroPad = state === sYPad0 | + state === sXPad0 | + state === sXPad1 | + state === sYPad1 + + when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) { + tag := 0.U + } .elsewhen (io.vme_rd.data.fire() || isZeroPad) { + tag := tag + 1.U + } + + when (state === sIdle || state === sReadCmd || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) { + set := 0.U + } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) { + set := set + 1.U + } + + val waddr_cur = Reg(UInt(tp.memAddrBits.W)) + val waddr_nxt = Reg(UInt(tp.memAddrBits.W)) + when (state === sIdle) { + waddr_cur := dec.sram_offset + waddr_nxt := dec.sram_offset + } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) { + waddr_cur := waddr_cur + 1.U + } .elsewhen (dataCtrl.io.stride) { + waddr_cur := waddr_nxt + dec.xsize + waddr_nxt := waddr_nxt + dec.xsize + } + + val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) } + val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) } + val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) } + val no_mask = Wire(Vec(tp.numMemBlock, Bool())) + no_mask.foreach { m => m := true.B } + + for (i <- 0 until tp.tensorLength) { + for (j <- 0 until tp.numMemBlock) { + wmask(i)(j) := tag === j.U + wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits) + } + val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i)) + val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U) + val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur) + val muxWdata = Mux(state === sIdle, tdata, wdata(i)) + val muxWmask = Mux(state === sIdle, no_mask, wmask(i)) + when (muxWen) { + tensorFile(i).write(muxWaddr, muxWdata, muxWmask) + } + } + + // read-from-sram + val rvalid = RegNext(io.tensor.rd.idx.valid) + io.tensor.rd.data.valid := rvalid + + val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid)) + rdata.zipWithIndex.foreach { case(r, i) => + io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i)) + } + + // done + val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U + val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U + val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done + io.done := done_no_pad | done_x_pad | done_y_pad + + // debug + if (debug) { + if (tensorType == "inp") { + when (io.vme_rd.cmd.fire()) { + printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len) + } + when (state === sYPad0) { + printf("[TensorLoad] [inp] sYPad0\n") + } + when (state === sYPad1) { + printf("[TensorLoad] [inp] sYPad1\n") + } + when (state === sXPad0) { + printf("[TensorLoad] [inp] sXPad0\n") + } + when (state === sXPad1) { + printf("[TensorLoad] [inp] sXPad1\n") + } + } else if (tensorType == "wgt") { + when (io.vme_rd.cmd.fire()) { + printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len) + } + } else if (tensorType == "acc") { + when (io.vme_rd.cmd.fire()) { + printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len) + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorStore.scala b/vta/hardware/chisel/src/main/scala/core/TensorStore.scala new file mode 100644 index 000000000000..0012e4771c0e --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorStore.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** TensorStore. + * + * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM). + */ +class TensorStore(tensorType: String = "true", debug: Boolean = false) + (implicit p: Parameters) extends Module { + val tp = new TensorParams(tensorType) + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val baddr = Input(UInt(mp.addrBits.W)) + val vme_wr = new VMEWriteMaster + val tensor = new TensorClient(tensorType) + }) + val tensorLength = tp.tensorLength + val tensorWidth = tp.tensorWidth + val tensorElemBits = tp.tensorElemBits + val memBlockBits = tp.memBlockBits + val memDepth = tp.memDepth + val numMemBlock = tp.numMemBlock + + val dec = io.inst.asTypeOf(new MemDecode) + val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr)) + val waddr_nxt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr)) + val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len)) + val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len)) + val xrem = Reg(chiselTypeOf(dec.xsize)) + val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U + val xmax = (1 << mp.lenBits).U + val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + val ycnt = Reg(chiselTypeOf(dec.ysize)) + val ysize = dec.ysize + val tag = Reg(UInt(8.W)) + val set = Reg(UInt(8.W)) + + val sIdle :: sWriteCmd :: sWriteData :: sReadMem :: sWriteAck :: Nil = Enum(5) + val state = RegInit(sIdle) + + // control + switch (state) { + is (sIdle) { + when (io.start) { + state := sWriteCmd + when (xsize < xmax) { + xlen := xsize + xrem := 0.U + } .otherwise { + xlen := xmax - 1.U + xrem := xsize - xmax + } + } + } + is (sWriteCmd) { + when (io.vme_wr.cmd.ready) { + state := sWriteData + } + } + is (sWriteData) { + when (io.vme_wr.data.ready) { + when (xcnt === xlen) { + state := sWriteAck + } .elsewhen (tag === (numMemBlock - 1).U) { + state := sReadMem + } + } + } + is (sReadMem) { + state := sWriteData + } + is (sWriteAck) { + when (io.vme_wr.ack) { + when (xrem === 0.U) { + when (ycnt === ysize - 1.U) { + state := sIdle + } .otherwise { + state := sWriteCmd + when (xsize < xmax) { + xlen := xsize + xrem := 0.U + } .otherwise { + xlen := xmax - 1.U + xrem := xsize - xmax + } + } + } .elsewhen (xrem < xmax) { + state := sWriteCmd + xlen := xrem + xrem := 0.U + } .otherwise { + state := sWriteCmd + xlen := xmax - 1.U + xrem := xrem - xmax + } + } + } + } + + // write-to-sram + val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) } + val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W))) + val no_mask = Wire(Vec(numMemBlock, Bool())) + + wdata_t := DontCare + no_mask.foreach { m => m := true.B } + + for (i <- 0 until tensorLength) { + val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t) + when (io.tensor.wr.valid) { + tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask) + } + } + + // read-from-sram + val stride = state === sWriteAck & + io.vme_wr.ack & + xcnt === xlen + 1.U & + xrem === 0.U & + ycnt =/= ysize - 1.U + + when (state === sIdle) { + ycnt := 0.U + } .elsewhen (stride) { + ycnt := ycnt + 1.U + } + + when (state === sWriteCmd || tag === (numMemBlock - 1).U) { + tag := 0.U + } .elsewhen (io.vme_wr.data.fire()) { + tag := tag + 1.U + } + + when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) { + set := 0.U + } .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) { + set := set + 1.U + } + + val raddr_cur = Reg(UInt(tp.memAddrBits.W)) + val raddr_nxt = Reg(UInt(tp.memAddrBits.W)) + when (state === sIdle) { + raddr_cur := dec.sram_offset + raddr_nxt := dec.sram_offset + } .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) { + raddr_cur := raddr_cur + 1.U + } .elsewhen (stride) { + raddr_cur := raddr_nxt + dec.xsize + raddr_nxt := raddr_nxt + dec.xsize + } + + val tread = Seq.tabulate(tensorLength) { i => i.U -> + tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) } + val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread) + + // write-to-dram + when (state === sIdle) { + waddr_cur := io.baddr + dec.dram_offset + waddr_nxt := io.baddr + dec.dram_offset + } .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) { + waddr_cur := waddr_cur + xmax_bytes + } .elsewhen (stride) { + waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth)) + waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth)) + } + + io.vme_wr.cmd.valid := state === sWriteCmd + io.vme_wr.cmd.bits.addr := waddr_cur + io.vme_wr.cmd.bits.len := xlen + + io.vme_wr.data.valid := state === sWriteData + io.vme_wr.data.bits := mdata(tag) + + when (state === sWriteCmd) { + xcnt := 0.U + } .elsewhen (io.vme_wr.data.fire()) { + xcnt := xcnt + 1.U + } + + // disable external read-from-sram requests + io.tensor.tieoffRead() + + // done + io.done := state === sWriteAck & io.vme_wr.ack & xrem === 0.U & ycnt === ysize - 1.U + + // debug + if (debug) { + when (io.vme_wr.cmd.fire()) { + printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem) + } + when (io.vme_wr.data.fire()) { + printf("[TensorStore] data:%x\n", io.vme_wr.data.bits) + } + when (io.vme_wr.ack) { + printf("[TensorStore] ack\n") + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala b/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala new file mode 100644 index 000000000000..e41a2c5b18e9 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** TensorParams. + * + * This Bundle derives parameters for each tensorType, including inputs (inp), + * weights (wgt), biases (acc), and outputs (out). This is used to avoid + * doing the same boring calculations over and over again. + */ +class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle { + val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n" + + require (tensorType == "inp" || tensorType == "wgt" + || tensorType == "acc" || tensorType == "out", errorMsg) + + val (tensorLength, tensorWidth, tensorElemBits) = + if (tensorType == "inp") + (p(CoreKey).batch, p(CoreKey).blockIn, p(CoreKey).inpBits) + else if (tensorType == "wgt") + (p(CoreKey).blockOut, p(CoreKey).blockIn, p(CoreKey).wgtBits) + else if (tensorType == "acc") + (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).accBits) + else + (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).outBits) + + val memBlockBits = p(ShellKey).memParams.dataBits + val numMemBlock = (tensorWidth * tensorElemBits) / memBlockBits + + val memDepth = + if (tensorType == "inp") + p(CoreKey).inpMemDepth + else if (tensorType == "wgt") + p(CoreKey).wgtMemDepth + else if (tensorType == "acc") + p(CoreKey).accMemDepth + else + p(CoreKey).outMemDepth + + val memAddrBits = log2Ceil(memDepth) +} + +/** TensorMaster. + * + * This interface issue read and write tensor-requests to scratchpads. For example, + * The TensorGemm unit uses this interface for managing the inputs (inp), weights (wgt), + * biases (acc), and outputs (out). + * + */ +class TensorMaster(tensorType: String = "none") + (implicit p: Parameters) extends TensorParams(tensorType) { + val rd = new Bundle { + val idx = ValidIO(UInt(memAddrBits.W)) + val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) + } + val wr = ValidIO(new Bundle { + val idx = UInt(memAddrBits.W) + val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))) + }) + def tieoffRead() { + rd.idx.valid := false.B + rd.idx.bits := 0.U + } + def tieoffWrite() { + wr.valid := false.B + wr.bits.idx := 0.U + wr.bits.data.foreach { b => b.foreach { c => c := 0.U } } + } + override def cloneType = + new TensorMaster(tensorType).asInstanceOf[this.type] +} + +/** TensorClient. + * + * This interface receives read and write tensor-requests to scratchpads. For example, + * The TensorLoad unit uses this interface for receiving read and write requests from + * the TensorGemm unit. + */ +class TensorClient(tensorType: String = "none") + (implicit p: Parameters) extends TensorParams(tensorType) { + val rd = new Bundle { + val idx = Flipped(ValidIO(UInt(memAddrBits.W))) + val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) + } + val wr = Flipped(ValidIO(new Bundle { + val idx = UInt(memAddrBits.W) + val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))) + })) + def tieoffRead() { + rd.data.valid := false.B + rd.data.bits.foreach { b => b.foreach { c => c := 0.U } } + } + override def cloneType = + new TensorClient(tensorType).asInstanceOf[this.type] +} + +/** TensorMasterData. + * + * This interface is only used for datapath only purposes and the direction convention + * is based on the TensorMaster interface, which means this is an input. This interface + * is used on datapath only module such MatrixVectorCore or AluVector. + */ +class TensorMasterData(tensorType: String = "none") + (implicit p: Parameters) extends TensorParams(tensorType) { + val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) + override def cloneType = + new TensorMasterData(tensorType).asInstanceOf[this.type] +} + +/** TensorClientData. + * + * This interface is only used for datapath only purposes and the direction convention + * is based on the TensorClient interface, which means this is an output. This interface + * is used on datapath only module such MatrixVectorCore or AluVector. + */ +class TensorClientData(tensorType: String = "none") + (implicit p: Parameters) extends TensorParams(tensorType) { + val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) + override def cloneType = + new TensorClientData(tensorType).asInstanceOf[this.type] +} + +/** TensorPadCtrl. Zero-padding controller for TensorLoad. */ +class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module { + val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n" + require (padType == "YPad0" || padType == "YPad1" + || padType == "XPad0" || padType == "XPad1", errorMsg) + + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + }) + + val dec = io.inst.asTypeOf(new MemDecode) + + val xmax = Reg(chiselTypeOf(dec.xsize)) + val ymax = Reg(chiselTypeOf(dec.ypad_0)) + val xcnt = Reg(chiselTypeOf(dec.xsize)) + val ycnt = Reg(chiselTypeOf(dec.ypad_0)) + + val xval = + if (padType == "YPad0" || padType == "YPad1") + ((dec.xpad_0 + dec.xsize + dec.xpad_1) << log2Ceil(sizeFactor)) - 1.U + else if (padType == "XPad0") + (dec.xpad_0 << log2Ceil(sizeFactor)) - 1.U + else + (dec.xpad_1 << log2Ceil(sizeFactor)) - 1.U + + val yval = + if (padType == "YPad0") + Mux(dec.ypad_0 =/= 0.U, dec.ypad_0 - 1.U, 0.U) + else if (padType == "YPad1") + Mux(dec.ypad_1 =/= 0.U, dec.ypad_1 - 1.U, 0.U) + else + 0.U + + val sIdle :: sActive :: Nil = Enum(2) + val state = RegInit(sIdle) + + switch (state) { + is (sIdle) { + when (io.start) { + state := sActive + } + } + is (sActive) { + when (ycnt === ymax && xcnt === xmax) { + state := sIdle + } + } + } + + when (state === sIdle) { + xmax := xval + ymax := yval + } + + when (state === sIdle || xcnt === xmax) { + xcnt := 0.U + } .elsewhen (state === sActive) { + xcnt := xcnt + 1.U + } + + when (state === sIdle || ymax === 0.U) { + ycnt := 0.U + } .elsewhen (state === sActive && xcnt === xmax) { + ycnt := ycnt + 1.U + } + + io.done := state === sActive & ycnt === ymax & xcnt === xmax +} + +/** TensorDataCtrl. Data controller for TensorLoad. */ +class TensorDataCtrl(sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val baddr = Input(UInt(mp.addrBits.W)) + val xinit = Input(Bool()) + val xupdate = Input(Bool()) + val yupdate = Input(Bool()) + val stride = Output(Bool()) + val split = Output(Bool()) + val commit = Output(Bool()) + val addr = Output(UInt(mp.addrBits.W)) + val len = Output(UInt(mp.lenBits.W)) + }) + + val dec = io.inst.asTypeOf(new MemDecode) + + val caddr = Reg(UInt(mp.addrBits.W)) + val baddr = Reg(UInt(mp.addrBits.W)) + + val len = Reg(UInt(mp.lenBits.W)) + + val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + val xcnt = Reg(UInt(mp.lenBits.W)) + val xrem = Reg(chiselTypeOf(dec.xsize)) + val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U + val xmax = (1 << mp.lenBits).U + val ycnt = Reg(chiselTypeOf(dec.ysize)) + + val stride = xcnt === len & + xrem === 0.U & + ycnt =/= dec.ysize - 1.U + + val split = xcnt === len & xrem =/= 0.U + + when (io.start || (io.xupdate && stride)) { + when (xsize < xmax) { + len := xsize + xrem := 0.U + } .otherwise { + len := xmax - 1.U + xrem := xsize - xmax + } + } .elsewhen (io.xupdate && split) { + when (xrem < xmax) { + len := xrem + xrem := 0.U + } .otherwise { + len := xmax - 1.U + xrem := xrem - xmax + } + } + + when (io.xinit) { + xcnt := 0.U + } .elsewhen (io.xupdate) { + xcnt := xcnt + 1.U + } + + when (io.start) { + ycnt := 0.U + } .elsewhen (io.yupdate && stride) { + ycnt := ycnt + 1.U + } + + when (io.start) { + caddr := io.baddr + dec.dram_offset + baddr := io.baddr + dec.dram_offset + } .elsewhen (io.yupdate) { + when (split) { + caddr := caddr + xmax_bytes + } .elsewhen (stride) { + caddr := baddr + (dec.xstride << log2Ceil(strideFactor)) + baddr := baddr + (dec.xstride << log2Ceil(strideFactor)) + } + } + + io.stride := stride + io.split := split + io.commit := xcnt === len + io.addr := caddr + io.len := len + io.done := xcnt === len & + xrem === 0.U & + ycnt === dec.ysize - 1.U +} diff --git a/vta/hardware/chisel/src/main/scala/core/package.scala b/vta/hardware/chisel/src/main/scala/core/package.scala new file mode 100644 index 000000000000..673d390901de --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/package.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta + +/** This trick makes ISAConstants globally available */ +package object core extends vta.core.ISAConstants diff --git a/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala b/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala index aab2d630c307..115bcbcb5a93 100644 --- a/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala +++ b/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala @@ -21,6 +21,9 @@ package vta.dpi import chisel3._ import chisel3.util._ +import vta.util.config._ +import vta.interface.axi._ +import vta.shell._ /** Host DPI parameters */ trait VTAHostDPIParams { @@ -70,3 +73,83 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource { }) setResource("/verilog/VTAHostDPI.v") } + +/** Host DPI to AXI Converter. + * + * Convert Host DPI to AXI for VTAShell + */ + +class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val dpi = new VTAHostDPIClient + val axi = new AXILiteMaster(p(ShellKey).hostParams) + }) + val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr))) + val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value))) + val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6) + val state = RegInit(sIdle) + + switch (state) { + is (sIdle) { + when (io.dpi.req.valid) { + when (io.dpi.req.opcode) { + state := sWriteAddress + } .otherwise { + state := sReadAddress + } + } + } + is (sReadAddress) { + when (io.axi.ar.ready) { + state := sReadData + } + } + is (sReadData) { + when (io.axi.r.valid) { + state := sIdle + } + } + is (sWriteAddress) { + when (io.axi.aw.ready) { + state := sWriteData + } + } + is (sWriteData) { + when (io.axi.w.ready) { + state := sWriteResponse + } + } + is (sWriteResponse) { + when (io.axi.b.valid) { + state := sIdle + } + } + } + + when (state === sIdle && io.dpi.req.valid) { + addr := io.dpi.req.addr + data := io.dpi.req.value + } + + io.axi.aw.valid := state === sWriteAddress + io.axi.aw.bits.addr := addr + io.axi.w.valid := state === sWriteData + io.axi.w.bits.data := data + io.axi.w.bits.strb := "h_f".U + io.axi.b.ready := state === sWriteResponse + + io.axi.ar.valid := state === sReadAddress + io.axi.ar.bits.addr := addr + io.axi.r.ready := state === sReadData + + io.dpi.req.deq := (state === sReadAddress & io.axi.ar.ready) | (state === sWriteAddress & io.axi.aw.ready) + io.dpi.resp.valid := io.axi.r.valid + io.dpi.resp.bits := io.axi.r.bits.data + + if (debug) { + when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) } + when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) } + when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) } + when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) } + } +} diff --git a/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala b/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala index 090f0459570a..5e2fa741d72a 100644 --- a/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala +++ b/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala @@ -21,6 +21,9 @@ package vta.dpi import chisel3._ import chisel3.util._ +import vta.util.config._ +import vta.interface.axi._ +import vta.shell._ /** Memory DPI parameters */ trait VTAMemDPIParams { @@ -71,3 +74,98 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource { }) setResource("/verilog/VTAMemDPI.v") } + +class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val dpi = new VTAMemDPIMaster + val axi = new AXIClient(p(ShellKey).memParams) + }) + val opcode = RegInit(false.B) + val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len))) + val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr))) + val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6) + val state = RegInit(sIdle) + + switch (state) { + is (sIdle) { + when (io.axi.ar.valid) { + state := sReadAddress + } .elsewhen (io.axi.aw.valid) { + state := sWriteAddress + } + } + is (sReadAddress) { + when (io.axi.ar.valid) { + state := sReadData + } + } + is (sReadData) { + when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) { + state := sIdle + } + } + is (sWriteAddress) { + when (io.axi.aw.valid) { + state := sWriteData + } + } + is (sWriteData) { + when (io.axi.w.valid && io.axi.w.bits.last) { + state := sWriteResponse + } + } + is (sWriteResponse) { + when (io.axi.b.ready) { + state := sIdle + } + } + } + + when (state === sIdle) { + when (io.axi.ar.valid) { + opcode := false.B + len := io.axi.ar.bits.len + addr := io.axi.ar.bits.addr + } .elsewhen (io.axi.aw.valid) { + opcode := true.B + len := io.axi.aw.bits.len + addr := io.axi.aw.bits.addr + } + } .elsewhen (state === sReadData) { + when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) { + len := len - 1.U + } + } + + io.dpi.req.valid := (state === sReadAddress & io.axi.ar.valid) | (state === sWriteAddress & io.axi.aw.valid) + io.dpi.req.opcode := opcode + io.dpi.req.len := len + io.dpi.req.addr := addr + + io.axi.ar.ready := state === sReadAddress + io.axi.aw.ready := state === sWriteAddress + + io.axi.r.valid := state === sReadData & io.dpi.rd.valid + io.axi.r.bits.data := io.dpi.rd.bits + io.axi.r.bits.last := len === 0.U + io.axi.r.bits.resp := 0.U + io.axi.r.bits.user := 0.U + io.axi.r.bits.id := 0.U + io.dpi.rd.ready := state === sReadData & io.axi.r.ready + + io.dpi.wr.valid := state === sWriteData & io.axi.w.valid + io.dpi.wr.bits := io.axi.w.bits.data + io.axi.w.ready := state === sWriteData + + io.axi.b.valid := state === sWriteResponse + io.axi.b.bits.resp := 0.U + io.axi.b.bits.user := 0.U + io.axi.b.bits.id := 0.U + + if (debug) { + when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) } + when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) } + when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) } + when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) } + } +} diff --git a/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala b/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala new file mode 100644 index 000000000000..a853e85e2bd8 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.interface.axi + +import chisel3._ +import chisel3.util._ +import vta.util.genericbundle._ + +case class AXIParams( + addrBits: Int = 32, + dataBits: Int = 64 +) +{ + require (addrBits > 0) + require (dataBits >= 8 && dataBits % 2 == 0) + + val idBits = 1 + val userBits = 1 + val strbBits = dataBits/8 + val lenBits = 8 + val sizeBits = 3 + val burstBits = 2 + val lockBits = 2 + val cacheBits = 4 + val protBits = 3 + val qosBits = 4 + val regionBits = 4 + val respBits = 2 + val sizeConst = log2Ceil(dataBits/8) + val idConst = 0 + val userConst = 0 + val burstConst = 1 + val lockConst = 0 + val cacheConst = 3 + val protConst = 0 + val qosConst = 0 + val regionConst = 0 +} + +abstract class AXIBase(params: AXIParams) + extends GenericParameterizedBundle(params) + +// AXILite + +class AXILiteAddress(params: AXIParams) extends AXIBase(params) { + val addr = UInt(params.addrBits.W) +} + +class AXILiteWriteData(params: AXIParams) extends AXIBase(params) { + val data = UInt(params.dataBits.W) + val strb = UInt(params.strbBits.W) +} + +class AXILiteWriteResponse(params: AXIParams) extends AXIBase(params) { + val resp = UInt(params.respBits.W) +} + +class AXILiteReadData(params: AXIParams) extends AXIBase(params) { + val data = UInt(params.dataBits.W) + val resp = UInt(params.respBits.W) +} + +class AXILiteMaster(params: AXIParams) extends AXIBase(params) { + val aw = Decoupled(new AXILiteAddress(params)) + val w = Decoupled(new AXILiteWriteData(params)) + val b = Flipped(Decoupled(new AXILiteWriteResponse(params))) + val ar = Decoupled(new AXILiteAddress(params)) + val r = Flipped(Decoupled(new AXILiteReadData(params))) + + def tieoff() { + aw.valid := false.B + aw.bits.addr := 0.U + w.valid := false.B + w.bits.data := 0.U + w.bits.strb := 0.U + b.ready := false.B + ar.valid := false.B + ar.bits.addr := 0.U + r.ready := false.B + } +} + +class AXILiteClient(params: AXIParams) extends AXIBase(params) { + val aw = Flipped(Decoupled(new AXILiteAddress(params))) + val w = Flipped(Decoupled(new AXILiteWriteData(params))) + val b = Decoupled(new AXILiteWriteResponse(params)) + val ar = Flipped(Decoupled(new AXILiteAddress(params))) + val r = Decoupled(new AXILiteReadData(params)) + + def tieoff() { + aw.ready := false.B + w.ready := false.B + b.valid := false.B + b.bits.resp := 0.U + ar.ready := false.B + r.valid := false.B + r.bits.resp := 0.U + r.bits.data := 0.U + } +} + +// AXI extends AXILite + +class AXIAddress(params: AXIParams) extends AXILiteAddress(params) { + val id = UInt(params.idBits.W) + val user = UInt(params.userBits.W) + val len = UInt(params.lenBits.W) + val size = UInt(params.sizeBits.W) + val burst = UInt(params.burstBits.W) + val lock = UInt(params.lockBits.W) + val cache = UInt(params.cacheBits.W) + val prot = UInt(params.protBits.W) + val qos = UInt(params.qosBits.W) + val region = UInt(params.regionBits.W) +} + +class AXIWriteData(params: AXIParams) extends AXILiteWriteData(params) { + val last = Bool() + val id = UInt(params.idBits.W) + val user = UInt(params.userBits.W) +} + +class AXIWriteResponse(params: AXIParams) extends AXILiteWriteResponse(params) { + val id = UInt(params.idBits.W) + val user = UInt(params.userBits.W) +} + +class AXIReadData(params: AXIParams) extends AXILiteReadData(params) { + val last = Bool() + val id = UInt(params.idBits.W) + val user = UInt(params.userBits.W) +} + +class AXIMaster(params: AXIParams) extends AXIBase(params) { + val aw = Decoupled(new AXIAddress(params)) + val w = Decoupled(new AXIWriteData(params)) + val b = Flipped(Decoupled(new AXIWriteResponse(params))) + val ar = Decoupled(new AXIAddress(params)) + val r = Flipped(Decoupled(new AXIReadData(params))) + + def tieoff() { + aw.valid := false.B + aw.bits.addr := 0.U + aw.bits.id := 0.U + aw.bits.user := 0.U + aw.bits.len := 0.U + aw.bits.size := 0.U + aw.bits.burst := 0.U + aw.bits.lock := 0.U + aw.bits.cache := 0.U + aw.bits.prot := 0.U + aw.bits.qos := 0.U + aw.bits.region := 0.U + w.valid := false.B + w.bits.data := 0.U + w.bits.strb := 0.U + w.bits.last := false.B + w.bits.id := 0.U + w.bits.user := 0.U + b.ready := false.B + ar.valid := false.B + ar.bits.addr := 0.U + ar.bits.id := 0.U + ar.bits.user := 0.U + ar.bits.len := 0.U + ar.bits.size := 0.U + ar.bits.burst := 0.U + ar.bits.lock := 0.U + ar.bits.cache := 0.U + ar.bits.prot := 0.U + ar.bits.qos := 0.U + ar.bits.region := 0.U + r.ready := false.B + } + + def setConst() { + aw.bits.user := params.userConst.U + aw.bits.burst := params.burstConst.U + aw.bits.lock := params.lockConst.U + aw.bits.cache := params.cacheConst.U + aw.bits.prot := params.protConst.U + aw.bits.qos := params.qosConst.U + aw.bits.region := params.regionConst.U + aw.bits.size := params.sizeConst.U + aw.bits.id := params.idConst.U + w.bits.id := params.idConst.U + w.bits.user := params.userConst.U + w.bits.strb := Fill(params.strbBits, true.B) + ar.bits.user := params.userConst.U + ar.bits.burst := params.burstConst.U + ar.bits.lock := params.lockConst.U + ar.bits.cache := params.cacheConst.U + ar.bits.prot := params.protConst.U + ar.bits.qos := params.qosConst.U + ar.bits.region := params.regionConst.U + ar.bits.size := params.sizeConst.U + ar.bits.id := params.idConst.U + } +} + +class AXIClient(params: AXIParams) extends AXIBase(params) { + val aw = Flipped(Decoupled(new AXIAddress(params))) + val w = Flipped(Decoupled(new AXIWriteData(params))) + val b = Decoupled(new AXIWriteResponse(params)) + val ar = Flipped(Decoupled(new AXIAddress(params))) + val r = Decoupled(new AXIReadData(params)) + + def tieoff() { + aw.ready := false.B + w.ready := false.B + b.valid := false.B + b.bits.resp := 0.U + b.bits.user := 0.U + b.bits.id := 0.U + ar.ready := false.B + r.valid := false.B + r.bits.resp := 0.U + r.bits.data := 0.U + r.bits.user := 0.U + r.bits.last := false.B + r.bits.id := 0.U + } +} + +// XilinxAXILiteClient and XilinxAXIMaster bundles are needed +// for wrapper purposes, because the package RTL tool in Xilinx Vivado +// only allows certain name formats + +class XilinxAXILiteClient(params: AXIParams) extends AXIBase(params) { + val AWVALID = Input(Bool()) + val AWREADY = Output(Bool()) + val AWADDR = Input(UInt(params.addrBits.W)) + val WVALID = Input(Bool()) + val WREADY = Output(Bool()) + val WDATA = Input(UInt(params.dataBits.W)) + val WSTRB = Input(UInt(params.strbBits.W)) + val BVALID = Output(Bool()) + val BREADY = Input(Bool()) + val BRESP = Output(UInt(params.respBits.W)) + val ARVALID = Input(Bool()) + val ARREADY = Output(Bool()) + val ARADDR = Input(UInt(params.addrBits.W)) + val RVALID = Output(Bool()) + val RREADY = Input(Bool()) + val RDATA = Output(UInt(params.dataBits.W)) + val RRESP = Output(UInt(params.respBits.W)) +} + +class XilinxAXIMaster(params: AXIParams) extends AXIBase(params) { + val AWVALID = Output(Bool()) + val AWREADY = Input(Bool()) + val AWADDR = Output(UInt(params.addrBits.W)) + val AWID = Output(UInt(params.idBits.W)) + val AWUSER = Output(UInt(params.userBits.W)) + val AWLEN = Output(UInt(params.lenBits.W)) + val AWSIZE = Output(UInt(params.sizeBits.W)) + val AWBURST = Output(UInt(params.burstBits.W)) + val AWLOCK = Output(UInt(params.lockBits.W)) + val AWCACHE = Output(UInt(params.cacheBits.W)) + val AWPROT = Output(UInt(params.protBits.W)) + val AWQOS = Output(UInt(params.qosBits.W)) + val AWREGION = Output(UInt(params.regionBits.W)) + val WVALID = Output(Bool()) + val WREADY = Input(Bool()) + val WDATA = Output(UInt(params.dataBits.W)) + val WSTRB = Output(UInt(params.strbBits.W)) + val WLAST = Output(Bool()) + val WID = Output(UInt(params.idBits.W)) + val WUSER = Output(UInt(params.userBits.W)) + val BVALID = Input(Bool()) + val BREADY = Output(Bool()) + val BRESP = Input(UInt(params.respBits.W)) + val BID = Input(UInt(params.idBits.W)) + val BUSER = Input(UInt(params.userBits.W)) + val ARVALID = Output(Bool()) + val ARREADY = Input(Bool()) + val ARADDR = Output(UInt(params.addrBits.W)) + val ARID = Output(UInt(params.idBits.W)) + val ARUSER = Output(UInt(params.userBits.W)) + val ARLEN = Output(UInt(params.lenBits.W)) + val ARSIZE = Output(UInt(params.sizeBits.W)) + val ARBURST = Output(UInt(params.burstBits.W)) + val ARLOCK = Output(UInt(params.lockBits.W)) + val ARCACHE = Output(UInt(params.cacheBits.W)) + val ARPROT = Output(UInt(params.protBits.W)) + val ARQOS = Output(UInt(params.qosBits.W)) + val ARREGION = Output(UInt(params.regionBits.W)) + val RVALID = Input(Bool()) + val RREADY = Output(Bool()) + val RDATA = Input(UInt(params.dataBits.W)) + val RRESP = Input(UInt(params.respBits.W)) + val RLAST = Input(Bool()) + val RID = Input(UInt(params.idBits.W)) + val RUSER = Input(UInt(params.userBits.W)) +} diff --git a/vta/hardware/chisel/src/main/scala/shell/Configs.scala b/vta/hardware/chisel/src/main/scala/shell/Configs.scala new file mode 100644 index 000000000000..1d1d5223b73c --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/Configs.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.interface.axi._ + +/** PynqConfig. Shell configuration for Pynq */ +class PynqConfig extends Config((site, here, up) => { + case ShellKey => ShellParams( + hostParams = AXIParams( + addrBits = 16, + dataBits = 32), + memParams = AXIParams( + addrBits = 32, + dataBits = 64), + vcrParams = VCRParams(), + vmeParams = VMEParams()) +}) + +/** F1Config. Shell configuration for F1 */ +class F1Config extends Config((site, here, up) => { + case ShellKey => ShellParams( + hostParams = AXIParams( + addrBits = 16, + dataBits = 32), + memParams = AXIParams( + addrBits = 64, + dataBits = 64), + vcrParams = VCRParams(), + vmeParams = VMEParams()) +}) diff --git a/vta/hardware/chisel/src/main/scala/shell/SimShell.scala b/vta/hardware/chisel/src/main/scala/shell/SimShell.scala new file mode 100644 index 000000000000..3ad4b6548ce3 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/SimShell.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import vta.util.config._ +import vta.interface.axi._ +import vta.shell._ +import vta.dpi._ + +/** VTAHost. + * + * This module translate the DPI protocol into AXI. This is a simulation only + * module and used to test host-to-VTA communication. This module should be updated + * for testing hosts using a different bus protocol, other than AXI. + */ +class VTAHost(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val axi = new AXILiteMaster(p(ShellKey).hostParams) + }) + val host_dpi = Module(new VTAHostDPI) + val host_axi = Module(new VTAHostDPIToAXI) + host_dpi.io.reset := reset + host_dpi.io.clock := clock + host_axi.io.dpi <> host_dpi.io.dpi + io.axi <> host_axi.io.axi +} + +/** VTAMem. + * + * This module translate the DPI protocol into AXI. This is a simulation only + * module and used to test VTA-to-memory communication. This module should be updated + * for testing memories using a different bus protocol, other than AXI. + */ +class VTAMem(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val axi = new AXIClient(p(ShellKey).memParams) + }) + val mem_dpi = Module(new VTAMemDPI) + val mem_axi = Module(new VTAMemDPIToAXI) + mem_dpi.io.reset := reset + mem_dpi.io.clock := clock + mem_dpi.io.dpi <> mem_axi.io.dpi + mem_axi.io.axi <> io.axi +} + +/** SimShell. + * + * The simulation shell instantiate a host and memory simulation modules and it is + * intended to be connected to the VTAShell. + */ +class SimShell(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val mem = new AXIClient(p(ShellKey).memParams) + val host = new AXILiteMaster(p(ShellKey).hostParams) + }) + val host = Module(new VTAHost) + val mem = Module(new VTAMem) + io.mem <> mem.io.axi + io.host <> host.io.axi +} diff --git a/vta/hardware/chisel/src/main/scala/shell/VCR.scala b/vta/hardware/chisel/src/main/scala/shell/VCR.scala new file mode 100644 index 000000000000..463f55bc8bbd --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/VCR.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.util.genericbundle._ +import scala.collection.mutable.ListBuffer +import scala.collection.mutable.LinkedHashMap +import vta.interface.axi._ + +/** VCR parameters. + * + * These parameters are used on VCR interfaces and modules. + */ +case class VCRParams() +{ + val nValsReg: Int = 1 + val nPtrsReg: Int = 6 + val regBits: Int = 32 + val nCtrlReg: Int = 4 + val ctrlBaseAddr: Int = 0 + + require (nValsReg > 0) + require (nPtrsReg > 0) +} + +/** VCRBase. Parametrize base class. */ +abstract class VCRBase(implicit p: Parameters) + extends GenericParameterizedBundle(p) + +/** VCRMaster. + * + * This is the master interface used by VCR in the VTAShell to control + * the Core unit. + */ +class VCRMaster(implicit p: Parameters) extends VCRBase { + val vp = p(ShellKey).vcrParams + val mp = p(ShellKey).memParams + val launch = Output(Bool()) + val finish = Input(Bool()) + val irq = Output(Bool()) + val ptrs = Output(Vec(vp.nPtrsReg, UInt(mp.addrBits.W))) + val vals = Output(Vec(vp.nValsReg, UInt(vp.regBits.W))) +} + +/** VCRClient. + * + * This is the client interface used by the Core module to communicate + * to the VCR in the VTAShell. + */ +class VCRClient(implicit p: Parameters) extends VCRBase { + val vp = p(ShellKey).vcrParams + val mp = p(ShellKey).memParams + val launch = Input(Bool()) + val finish = Output(Bool()) + val irq = Input(Bool()) + val ptrs = Input(Vec(vp.nPtrsReg, UInt(mp.addrBits.W))) + val vals = Input(Vec(vp.nValsReg, UInt(vp.regBits.W))) +} + +/** VTA Control Registers (VCR). + * + * This unit provides control registers (32 and 64 bits) to be used by a control' + * unit, typically a host processor. These registers are read-only by the core + * at the moment but this will likely change once we add support to general purpose + * registers that could be used as event counters by the Core unit. + */ +class VCR(implicit p: Parameters) extends Module { + val io = IO(new Bundle{ + val host = new AXILiteClient(p(ShellKey).hostParams) + val vcr = new VCRMaster + }) + + val vp = p(ShellKey).vcrParams + val mp = p(ShellKey).memParams + val hp = p(ShellKey).hostParams + + // Write control (AW, W, B) + val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address + val wdata = io.host.w.bits.data + val wstrb = io.host.w.bits.strb + val wmask = Cat(Fill(8, wstrb(3)), Fill(8, wstrb(2)), Fill(8, wstrb(1)), Fill(8, wstrb(0))) + val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3) + val wstate = RegInit(sWriteAddress) + switch (wstate) { + is (sWriteAddress) { + when (io.host.aw.valid) { + wstate := sWriteData + } + } + is (sWriteData) { + when (io.host.w.valid) { + wstate := sWriteResponse + } + } + is (sWriteResponse) { + when (io.host.b.ready) { + wstate := sWriteAddress + } + } + } + + when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr } + + io.host.aw.ready := wstate === sWriteAddress + io.host.w.ready := wstate === sWriteData + io.host.b.valid := wstate === sWriteResponse + io.host.b.bits.resp := "h_0".U + + // read control (AR, R) + val sReadAddress :: sReadData :: Nil = Enum(2) + val rstate = RegInit(sReadAddress) + + switch (rstate) { + is (sReadAddress) { + when (io.host.ar.valid) { + rstate := sReadData + } + } + is (sReadData) { + when (io.host.r.ready) { + rstate := sReadAddress + } + } + } + + io.host.ar.ready := rstate === sReadAddress + io.host.r.valid := rstate === sReadData + + val nPtrsReg = vp.nPtrsReg + val nValsReg = vp.nValsReg + val regBits = vp.regBits + val ptrsBits = mp.addrBits + val nCtrlReg = vp.nCtrlReg + val rStride = regBits/8 + val pStride = ptrsBits/8 + val ctrlBaseAddr = vp.ctrlBaseAddr + val valsBaseAddr = ctrlBaseAddr + nCtrlReg*rStride + val ptrsBaseAddr = valsBaseAddr + nValsReg*rStride + + val ctrlAddr = Seq.tabulate(nCtrlReg)(i => i*rStride + ctrlBaseAddr) + val valsAddr = Seq.tabulate(nValsReg)(i => i*rStride + valsBaseAddr) + + val ptrsAddr = new ListBuffer[Int]() + for (i <- 0 until nPtrsReg) { + ptrsAddr += i*pStride + ptrsBaseAddr + if (ptrsBits == 64) { + ptrsAddr += i*pStride + rStride + ptrsBaseAddr + } + } + + // AP register + val c0 = RegInit(VecInit(Seq.fill(regBits)(false.B))) + + // ap start + when (io.host.w.fire() && waddr === ctrlAddr(0).asUInt && wstrb(0) && wdata(0)) { + c0(0) := true.B + } .elsewhen (io.vcr.finish) { + c0(0) := false.B + } + + // ap done = finish + when (io.vcr.finish) { + c0(1) := true.B + } .elsewhen (io.host.ar.fire() && io.host.ar.bits.addr === ctrlAddr(0).asUInt) { + c0(1) := false.B + } + + val c1 = 0.U + val c2 = 0.U + val c3 = 0.U + + val ctrlRegList = List(c0, c1, c2, c3) + + io.vcr.launch := c0(0) + + // interrupts not supported atm + io.vcr.irq := false.B + + // Write pointer and value registers + val pvAddr = valsAddr ++ ptrsAddr + val pvNumReg = if (ptrsBits == 64) nValsReg + nPtrsReg*2 else nValsReg + nPtrsReg + val pvReg = RegInit(VecInit(Seq.fill(pvNumReg)(0.U(regBits.W)))) + val pvRegList = new ListBuffer[UInt]() + + for (i <- 0 until pvNumReg) { + when (io.host.w.fire() && (waddr === pvAddr(i).U)) { + pvReg(i) := (wdata & wmask) | (pvReg(i) & ~wmask) + } + pvRegList += pvReg(i) + } + + for (i <- 0 until nValsReg) { + io.vcr.vals(i) := pvReg(i) + } + + for (i <- 0 until nPtrsReg) { + if (ptrsBits == 64) { + io.vcr.ptrs(i) := Cat(pvReg(nValsReg + i*2 + 1), pvReg(nValsReg + i*2)) + } else { + io.vcr.ptrs(i) := pvReg(nValsReg + i) + } + } + + // Read pointer and value registers + val mapAddr = ctrlAddr ++ valsAddr ++ ptrsAddr + val mapRegList = ctrlRegList ++ pvRegList + + val rdata = RegInit(0.U(regBits.W)) + val rmap = LinkedHashMap[Int,UInt]() + + val totalReg = mapRegList.length + for (i <- 0 until totalReg) { rmap += mapAddr(i) -> mapRegList(i).asUInt } + + val decodeAddr = rmap map { case (k, _) => k -> (io.host.ar.bits.addr === k.asUInt) } + + when (io.host.ar.fire()) { + rdata := Mux1H(for ((k, v) <- rmap) yield decodeAddr(k) -> v) + } + + io.host.r.bits.resp := 0.U + io.host.r.bits.data := rdata +} diff --git a/vta/hardware/chisel/src/main/scala/shell/VME.scala b/vta/hardware/chisel/src/main/scala/shell/VME.scala new file mode 100644 index 000000000000..862e9810c510 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/VME.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.util.genericbundle._ +import vta.interface.axi._ + +/** VME parameters. + * + * These parameters are used on VME interfaces and modules. + */ +case class VMEParams() { + val nReadClients: Int = 5 + val nWriteClients: Int = 1 + require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n") + require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n") +} + +/** VMEBase. Parametrize base class. */ +abstract class VMEBase(implicit p: Parameters) + extends GenericParameterizedBundle(p) + +/** VMECmd. + * + * This interface is used for creating write and read requests to memory. + */ +class VMECmd(implicit p: Parameters) extends VMEBase { + val addrBits = p(ShellKey).memParams.addrBits + val lenBits = p(ShellKey).memParams.lenBits + val addr = UInt(addrBits.W) + val len = UInt(lenBits.W) +} + +/** VMEReadMaster. + * + * This interface is used by modules inside the core to generate read requests + * and receive responses from VME. + */ +class VMEReadMaster(implicit p: Parameters) extends Bundle { + val dataBits = p(ShellKey).memParams.dataBits + val cmd = Decoupled(new VMECmd) + val data = Flipped(Decoupled(UInt(dataBits.W))) + override def cloneType = + new VMEReadMaster().asInstanceOf[this.type] +} + +/** VMEReadClient. + * + * This interface is used by the VME to receive read requests and generate + * responses to modules inside the core. + */ +class VMEReadClient(implicit p: Parameters) extends Bundle { + val dataBits = p(ShellKey).memParams.dataBits + val cmd = Flipped(Decoupled(new VMECmd)) + val data = Decoupled(UInt(dataBits.W)) + override def cloneType = + new VMEReadClient().asInstanceOf[this.type] +} + +/** VMEWriteMaster. + * + * This interface is used by modules inside the core to generate write requests + * to the VME. + */ +class VMEWriteMaster(implicit p: Parameters) extends Bundle { + val dataBits = p(ShellKey).memParams.dataBits + val cmd = Decoupled(new VMECmd) + val data = Decoupled(UInt(dataBits.W)) + val ack = Input(Bool()) + override def cloneType = + new VMEWriteMaster().asInstanceOf[this.type] +} + +/** VMEWriteClient. + * + * This interface is used by the VME to handle write requests from modules inside + * the core. + */ +class VMEWriteClient(implicit p: Parameters) extends Bundle { + val dataBits = p(ShellKey).memParams.dataBits + val cmd = Flipped(Decoupled(new VMECmd)) + val data = Flipped(Decoupled(UInt(dataBits.W))) + val ack = Output(Bool()) + override def cloneType = + new VMEWriteClient().asInstanceOf[this.type] +} + +/** VMEMaster. + * + * Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster + * interfaces. + */ +class VMEMaster(implicit p: Parameters) extends Bundle { + val nRd = p(ShellKey).vmeParams.nReadClients + val nWr = p(ShellKey).vmeParams.nWriteClients + val rd = Vec(nRd, new VMEReadMaster) + val wr = Vec(nWr, new VMEWriteMaster) +} + +/** VMEClient. + * + * Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient + * interfaces. + */ +class VMEClient(implicit p: Parameters) extends Bundle { + val nRd = p(ShellKey).vmeParams.nReadClients + val nWr = p(ShellKey).vmeParams.nWriteClients + val rd = Vec(nRd, new VMEReadClient) + val wr = Vec(nWr, new VMEWriteClient) +} + +/** VTA Memory Engine (VME). + * + * This unit multiplexes the memory controller interface for the Core. Currently, + * it supports single-writer and multiple-reader mode and it is also based on AXI. + */ +class VME(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val mem = new AXIMaster(p(ShellKey).memParams) + val vme = new VMEClient + }) + + val nReadClients = p(ShellKey).vmeParams.nReadClients + val rd_arb = Module(new Arbiter(new VMECmd, nReadClients)) + val rd_arb_chosen = RegEnable(rd_arb.io.chosen, rd_arb.io.out.fire()) + + for (i <- 0 until nReadClients) { rd_arb.io.in(i) <> io.vme.rd(i).cmd } + + val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3) + val rstate = RegInit(sReadIdle) + + switch (rstate) { + is (sReadIdle) { + when (rd_arb.io.out.valid) { + rstate := sReadAddr + } + } + is (sReadAddr) { + when (io.mem.ar.ready) { + rstate := sReadData + } + } + is (sReadData) { + when (io.mem.r.fire() && io.mem.r.bits.last) { + rstate := sReadIdle + } + } + } + + val sWriteIdle :: sWriteAddr :: sWriteData :: sWriteResp :: Nil = Enum(4) + val wstate = RegInit(sWriteIdle) + val addrBits = p(ShellKey).memParams.addrBits + val lenBits = p(ShellKey).memParams.lenBits + val wr_cnt = RegInit(0.U(lenBits.W)) + + when (wstate === sWriteIdle) { + wr_cnt := 0.U + } .elsewhen (io.mem.w.fire()) { + wr_cnt := wr_cnt + 1.U + } + + switch (wstate) { + is (sWriteIdle) { + when (io.vme.wr(0).cmd.valid) { + wstate := sWriteAddr + } + } + is (sWriteAddr) { + when (io.mem.aw.ready) { + wstate := sWriteData + } + } + is (sWriteData) { + when (io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) { + wstate := sWriteResp + } + } + is (sWriteResp) { + when (io.mem.b.valid) { + wstate := sWriteIdle + } + } + } + + // registers storing read/write cmds + + val rd_len = RegInit(0.U(lenBits.W)) + val wr_len = RegInit(0.U(lenBits.W)) + val rd_addr = RegInit(0.U(addrBits.W)) + val wr_addr = RegInit(0.U(addrBits.W)) + + when (rd_arb.io.out.fire()) { + rd_len := rd_arb.io.out.bits.len + rd_addr := rd_arb.io.out.bits.addr + } + + when (io.vme.wr(0).cmd.fire()) { + wr_len := io.vme.wr(0).cmd.bits.len + wr_addr := io.vme.wr(0).cmd.bits.addr + } + + // rd arb + rd_arb.io.out.ready := rstate === sReadIdle + + // vme + for (i <- 0 until nReadClients) { + io.vme.rd(i).data.valid := rd_arb_chosen === i.asUInt & io.mem.r.valid + io.vme.rd(i).data.bits := io.mem.r.bits.data + } + + io.vme.wr(0).cmd.ready := wstate === sWriteIdle + io.vme.wr(0).ack := io.mem.b.fire() + io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready + + // mem + io.mem.aw.valid := wstate === sWriteAddr + io.mem.aw.bits.addr := wr_addr + io.mem.aw.bits.len := wr_len + + io.mem.w.valid := wstate === sWriteData & io.vme.wr(0).data.valid + io.mem.w.bits.data := io.vme.wr(0).data.bits + io.mem.w.bits.last := wr_cnt === io.vme.wr(0).cmd.bits.len + + io.mem.b.ready := wstate === sWriteResp + + io.mem.ar.valid := rstate === sReadAddr + io.mem.ar.bits.addr := rd_addr + io.mem.ar.bits.len := rd_len + + io.mem.r.ready := rstate === sReadData & io.vme.rd(rd_arb_chosen).data.ready + + // AXI constants - statically defined + io.mem.setConst() +} diff --git a/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala b/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala new file mode 100644 index 000000000000..c8093118308f --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import vta.util.config._ +import vta.interface.axi._ +import vta.core._ + +/** Shell parameters. */ +case class ShellParams( + hostParams: AXIParams, + memParams: AXIParams, + vcrParams: VCRParams, + vmeParams: VMEParams +) + +case object ShellKey extends Field[ShellParams] + +/** VTAShell. + * + * The VTAShell is based on a VME, VCR and core. This creates a complete VTA + * system that can be used for simulation or real hardware. + */ +class VTAShell(implicit p: Parameters) extends Module { + val io = IO(new Bundle{ + val host = new AXILiteClient(p(ShellKey).hostParams) + val mem = new AXIMaster(p(ShellKey).memParams) + }) + + val vcr = Module(new VCR) + val vme = Module(new VME) + val core = Module(new Core) + + core.io.vcr <> vcr.io.vcr + vme.io.vme <> core.io.vme + + vcr.io.host <> io.host + io.mem <> vme.io.mem +} diff --git a/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala b/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala new file mode 100644 index 000000000000..db721373b7e3 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import chisel3.experimental.{RawModule, withClockAndReset} +import vta.util.config._ +import vta.interface.axi._ + +/** XilinxShell. + * + * This is a wrapper shell mostly used to match Xilinx convention naming, + * therefore we can pack VTA as an IP for IPI based flows. + */ +class XilinxShell(implicit p: Parameters) extends RawModule { + + val hp = p(ShellKey).hostParams + val mp = p(ShellKey).memParams + + val ap_clk = IO(Input(Clock())) + val ap_rst_n = IO(Input(Bool())) + val m_axi_gmem = IO(new XilinxAXIMaster(mp)) + val s_axi_control = IO(new XilinxAXILiteClient(hp)) + + val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) } + + // memory + m_axi_gmem.AWVALID := shell.io.mem.aw.valid + shell.io.mem.aw.ready := m_axi_gmem.AWREADY + m_axi_gmem.AWADDR := shell.io.mem.aw.bits.addr + m_axi_gmem.AWID := shell.io.mem.aw.bits.id + m_axi_gmem.AWUSER := shell.io.mem.aw.bits.user + m_axi_gmem.AWLEN := shell.io.mem.aw.bits.len + m_axi_gmem.AWSIZE := shell.io.mem.aw.bits.size + m_axi_gmem.AWBURST := shell.io.mem.aw.bits.burst + m_axi_gmem.AWLOCK := shell.io.mem.aw.bits.lock + m_axi_gmem.AWCACHE := shell.io.mem.aw.bits.cache + m_axi_gmem.AWPROT := shell.io.mem.aw.bits.prot + m_axi_gmem.AWQOS := shell.io.mem.aw.bits.qos + m_axi_gmem.AWREGION := shell.io.mem.aw.bits.region + + m_axi_gmem.WVALID := shell.io.mem.w.valid + shell.io.mem.w.ready := m_axi_gmem.WREADY + m_axi_gmem.WDATA := shell.io.mem.w.bits.data + m_axi_gmem.WSTRB := shell.io.mem.w.bits.strb + m_axi_gmem.WLAST := shell.io.mem.w.bits.last + m_axi_gmem.WID := shell.io.mem.w.bits.id + m_axi_gmem.WUSER := shell.io.mem.w.bits.user + + shell.io.mem.b.valid := m_axi_gmem.BVALID + m_axi_gmem.BREADY := shell.io.mem.b.valid + shell.io.mem.b.bits.resp := m_axi_gmem.BRESP + shell.io.mem.b.bits.id := m_axi_gmem.BID + shell.io.mem.b.bits.user := m_axi_gmem.BUSER + + m_axi_gmem.ARVALID := shell.io.mem.ar.valid + shell.io.mem.ar.ready := m_axi_gmem.ARREADY + m_axi_gmem.ARADDR := shell.io.mem.ar.bits.addr + m_axi_gmem.ARID := shell.io.mem.ar.bits.id + m_axi_gmem.ARUSER := shell.io.mem.ar.bits.user + m_axi_gmem.ARLEN := shell.io.mem.ar.bits.len + m_axi_gmem.ARSIZE := shell.io.mem.ar.bits.size + m_axi_gmem.ARBURST := shell.io.mem.ar.bits.burst + m_axi_gmem.ARLOCK := shell.io.mem.ar.bits.lock + m_axi_gmem.ARCACHE := shell.io.mem.ar.bits.cache + m_axi_gmem.ARPROT := shell.io.mem.ar.bits.prot + m_axi_gmem.ARQOS := shell.io.mem.ar.bits.qos + m_axi_gmem.ARREGION := shell.io.mem.ar.bits.region + + shell.io.mem.r.valid := m_axi_gmem.RVALID + m_axi_gmem.RREADY := shell.io.mem.r.ready + shell.io.mem.r.bits.data := m_axi_gmem.RDATA + shell.io.mem.r.bits.resp := m_axi_gmem.RRESP + shell.io.mem.r.bits.last := m_axi_gmem.RLAST + shell.io.mem.r.bits.id := m_axi_gmem.RID + shell.io.mem.r.bits.user := m_axi_gmem.RUSER + + // host + shell.io.host.aw.valid := s_axi_control.AWVALID + s_axi_control.AWREADY := shell.io.host.aw.ready + shell.io.host.aw.bits.addr := s_axi_control.AWADDR + + shell.io.host.w.valid := s_axi_control.WVALID + s_axi_control.WREADY := shell.io.host.w.ready + shell.io.host.w.bits.data := s_axi_control.WDATA + shell.io.host.w.bits.strb := s_axi_control.WSTRB + + s_axi_control.BVALID := shell.io.host.b.valid + shell.io.host.b.ready := s_axi_control.BREADY + s_axi_control.BRESP := shell.io.host.b.bits.resp + + shell.io.host.ar.valid := s_axi_control.ARVALID + s_axi_control.ARREADY := shell.io.host.ar.ready + shell.io.host.ar.bits.addr := s_axi_control.ARADDR + + s_axi_control.RVALID := shell.io.host.r.valid + shell.io.host.r.ready := s_axi_control.RREADY + s_axi_control.RDATA := shell.io.host.r.bits.data + s_axi_control.RRESP := shell.io.host.r.bits.resp +} diff --git a/vta/hardware/chisel/src/main/scala/test/Test.scala b/vta/hardware/chisel/src/main/scala/test/Test.scala new file mode 100644 index 000000000000..db060739147d --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/test/Test.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.test + +import chisel3._ +import vta.util.config._ +import vta.shell._ + +/** Test. This generates a testbench file for simulation */ +class Test(implicit p: Parameters) extends Module { + val io = IO(new Bundle {}) + val sim_shell = Module(new SimShell) + val vta_shell = Module(new VTAShell) + vta_shell.io.host <> sim_shell.io.host + sim_shell.io.mem <> vta_shell.io.mem +} diff --git a/vta/hardware/chisel/src/main/scala/util/Config.scala b/vta/hardware/chisel/src/main/scala/util/Config.scala new file mode 100644 index 000000000000..6699507c9f13 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/util/Config.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.util.config + +// taken from https://github.com/vta.roject/rocket-chip + +abstract class Field[T] private (val default: Option[T]) +{ + def this() = this(None) + def this(default: T) = this(Some(default)) +} + +abstract class View { + final def apply[T](pname: Field[T]): T = apply(pname, this) + final def apply[T](pname: Field[T], site: View): T = { + val out = find(pname, site) + require (out.isDefined, s"Key ${pname} is not defined in Parameters") + out.get + } + + final def lift[T](pname: Field[T]): Option[T] = lift(pname, this) + final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T]) + + protected[config] def find[T](pname: Field[T], site: View): Option[T] +} + +abstract class Parameters extends View { + final def ++ (x: Parameters): Parameters = + new ChainParameters(this, x) + + final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = + Parameters(f) ++ this + + final def alterPartial(f: PartialFunction[Any,Any]): Parameters = + Parameters((_,_,_) => f) ++ this + + final def alterMap(m: Map[Any,Any]): Parameters = + new MapParameters(m) ++ this + + protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T] + protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname) +} + +object Parameters { + def empty: Parameters = new EmptyParameters + def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f) +} + +class Config(p: Parameters) extends Parameters { + def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f)) + + protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname) + override def toString = this.getClass.getSimpleName + def toInstance = this +} + +// Internal implementation: + +private class TerminalView extends View { + def find[T](pname: Field[T], site: View): Option[T] = pname.default +} + +private class ChainView(head: Parameters, tail: View) extends View { + def find[T](pname: Field[T], site: View) = head.chain(site, tail, pname) +} + +private class ChainParameters(x: Parameters, y: Parameters) extends Parameters { + def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname) +} + +private class EmptyParameters extends Parameters { + def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site) +} + +private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters { + protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = { + val g = f(site, this, tail) + if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site) + } +} + +private class MapParameters(map: Map[Any, Any]) extends Parameters { + protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = { + val g = map.get(pname) + if (g.isDefined) Some(g.get.asInstanceOf[T]) else tail.find(pname, site) + } +} diff --git a/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala b/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala new file mode 100644 index 000000000000..db19635c9345 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.util.genericbundle + +// taken from https://github.com/vta.roject/rocket-chip + +import chisel3._ + +abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle +{ + override def cloneType = { + try { + this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type] + } catch { + case e: java.lang.IllegalArgumentException => + throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " + + this.getClass + ", probably because " + this.getClass + + "() takes more than one argument. Consider overriding " + + "cloneType() on " + this.getClass, e) + } + } +} + diff --git a/vta/hardware/chisel/src/main/scala/vta/Configs.scala b/vta/hardware/chisel/src/main/scala/vta/Configs.scala new file mode 100644 index 000000000000..d5aa12798fe7 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/vta/Configs.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta + +import chisel3._ +import vta.util.config._ +import vta.shell._ +import vta.core._ +import vta.test._ + +/** VTA. + * + * This file contains all the configurations supported by VTA. + * These configurations are built in a mix/match form based on core + * and shell configurations. + */ + +class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig) +class DefaultF1Config extends Config(new CoreConfig ++ new F1Config) + +object DefaultPynqConfig extends App { + implicit val p: Parameters = new DefaultPynqConfig + chisel3.Driver.execute(args, () => new XilinxShell) +} + +object DefaultF1Config extends App { + implicit val p: Parameters = new DefaultF1Config + chisel3.Driver.execute(args, () => new XilinxShell) +} + +object TestDefaultF1Config extends App { + implicit val p: Parameters = new DefaultF1Config + chisel3.Driver.execute(args, () => new Test) +} diff --git a/vta/hardware/dpi/tsim_device.cc b/vta/hardware/dpi/tsim_device.cc index 08954179f1d2..0b315e4cb541 100644 --- a/vta/hardware/dpi/tsim_device.cc +++ b/vta/hardware/dpi/tsim_device.cc @@ -70,8 +70,18 @@ void VTADPIInit(VTAContextHandle handle, _mem_dpi = mem_dpi; } + +// Override Verilator finish definition +// VL_USER_FINISH needs to be defined when compiling Verilator code +void vl_finish(const char* filename, int linenum, const char* hier) { + Verilated::gotFinish(true); + VL_PRINTF("[TSIM] exiting simulation\n"); +} + int VTADPISim(uint64_t max_cycles) { uint64_t trace_count = 0; + Verilated::flushCall(); + Verilated::gotFinish(false); #if VM_TRACE uint64_t start = 0; diff --git a/vta/include/vta/driver.h b/vta/include/vta/driver.h index d583051dc194..eca9e4da9799 100644 --- a/vta/include/vta/driver.h +++ b/vta/include/vta/driver.h @@ -53,7 +53,11 @@ extern "C" { typedef void * VTADeviceHandle; /*! \brief physical address */ +#ifdef USE_TSIM +typedef uint64_t vta_phy_addr_t; +#else typedef uint32_t vta_phy_addr_t; +#endif /*! * \brief Allocate a device resource handle @@ -76,10 +80,22 @@ void VTADeviceFree(VTADeviceHandle handle); * * \return 0 if running is successful, 1 if timeout. */ +#ifdef USE_TSIM +int VTADeviceRun(VTADeviceHandle device, + vta_phy_addr_t insn_phy_addr, + vta_phy_addr_t uop_phy_addr, + vta_phy_addr_t inp_phy_addr, + vta_phy_addr_t wgt_phy_addr, + vta_phy_addr_t acc_phy_addr, + vta_phy_addr_t out_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles); +#else int VTADeviceRun(VTADeviceHandle device, vta_phy_addr_t insn_phy_addr, uint32_t insn_count, uint32_t wait_cycles); +#endif /*! * \brief Allocates physically contiguous region in memory (limited by MAX_XFER). diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index d5400d868ae4..4c2200d04727 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -239,7 +239,7 @@ def target_host(self): """The target host""" if self.TARGET == "pynq": return "llvm -target=armv7-none-linux-gnueabihf" - if self.TARGET == "sim": + if self.TARGET == "sim" or self.TARGET == "tsim": return "llvm" raise ValueError("Unknown target %s" % self.TARGET) diff --git a/vta/python/vta/testing/simulator.py b/vta/python/vta/testing/simulator.py index a1e15ba69880..858e1157d8b2 100644 --- a/vta/python/vta/testing/simulator.py +++ b/vta/python/vta/testing/simulator.py @@ -17,6 +17,8 @@ """Utilities to start simulator.""" import ctypes import json +import sys +import os import tvm from ..libinfo import find_libvta @@ -55,5 +57,22 @@ def stats(): x = tvm.get_global_func("vta.simulator.profiler_status")() return json.loads(x) +def tsim_init(hw_lib): + """Init hardware shared library for TSIM + + Parameters + ------------ + hw_lib : str + Name of hardware shared library + """ + cur_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + vta_build_path = os.path.join(cur_path, "..", "..", "..", "build") + if not hw_lib.endswith(("dylib", "so")): + hw_lib += ".dylib" if sys.platform == "darwin" else ".so" + lib = os.path.join(vta_build_path, hw_lib) + f = tvm.get_global_func("tvm.vta.tsim.init") + m = tvm.module.load(lib, "vta-tsim") + f(m) + LIBS = _load_lib() diff --git a/vta/python/vta/testing/util.py b/vta/python/vta/testing/util.py index 48dd08588962..06c700cd7119 100644 --- a/vta/python/vta/testing/util.py +++ b/vta/python/vta/testing/util.py @@ -31,7 +31,7 @@ def run(run_func): """ env = get_env() - if env.TARGET == "sim": + if env.TARGET in ["sim", "tsim"]: # Talk to local RPC if necessary to debug RPC server. # Compile vta on your host with make at the root. @@ -48,7 +48,8 @@ def run(run_func): # Make sure simulation library exists # If this fails, build vta on host (make) # with TARGET="sim" in the json.config file. - assert simulator.enabled() + if env.TARGET == "sim": + assert simulator.enabled() run_func(env, rpc.LocalSession()) elif env.TARGET == "pynq": diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index 79a407fe521e..06b34743955f 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -56,7 +56,7 @@ struct DataBuffer { return data_; } /*! \return Physical address of the data. */ - uint32_t phy_addr() const { + vta_phy_addr_t phy_addr() const { return phy_addr_; } /*! @@ -113,7 +113,7 @@ struct DataBuffer { /*! \brief The internal data. */ void* data_; /*! \brief The physical address of the buffer, excluding header. */ - uint32_t phy_addr_; + vta_phy_addr_t phy_addr_; }; /*! @@ -302,7 +302,7 @@ class BaseQueue { return dram_buffer_; } /*! \return Physical address of DRAM. */ - uint32_t dram_phy_addr() const { + vta_phy_addr_t dram_phy_addr() const { return dram_phy_addr_; } /*! \return Whether there is pending information. */ @@ -367,7 +367,7 @@ class BaseQueue { // The buffer in DRAM char* dram_buffer_{nullptr}; // Physics address of the buffer - uint32_t dram_phy_addr_; + vta_phy_addr_t dram_phy_addr_; }; /*! @@ -424,7 +424,11 @@ class UopQueue : public BaseQueue { CHECK((dram_end_ - dram_begin_) == (sram_end_ - sram_begin_)); insn->memory_type = VTA_MEM_ID_UOP; insn->sram_base = sram_begin_; +#ifdef USE_TSIM + insn->dram_base = (uint32_t) dram_phy_addr_ + dram_begin_*kElemBytes; +#else insn->dram_base = dram_phy_addr_ / kElemBytes + dram_begin_; +#endif insn->y_size = 1; insn->x_size = (dram_end_ - dram_begin_); insn->x_stride = (dram_end_ - dram_begin_); @@ -958,7 +962,11 @@ class CommandQueue { insn->memory_type = dst_memory_type; insn->sram_base = dst_sram_index; DataBuffer* src = DataBuffer::FromHandle(src_dram_addr); +#ifdef USE_TSIM + insn->dram_base = (uint32_t) src->phy_addr() + src_elem_offset*GetElemBytes(dst_memory_type); +#else insn->dram_base = src->phy_addr() / GetElemBytes(dst_memory_type) + src_elem_offset; +#endif insn->y_size = y_size; insn->x_size = x_size; insn->x_stride = x_stride; @@ -981,7 +989,11 @@ class CommandQueue { insn->memory_type = src_memory_type; insn->sram_base = src_sram_index; DataBuffer* dst = DataBuffer::FromHandle(dst_dram_addr); +#ifdef USE_TSIM + insn->dram_base = (uint32_t) dst->phy_addr() + dst_elem_offset*GetElemBytes(src_memory_type); +#else insn->dram_base = dst->phy_addr() / GetElemBytes(src_memory_type) + dst_elem_offset; +#endif insn->y_size = y_size; insn->x_size = x_size; insn->x_stride = x_stride; @@ -1046,11 +1058,24 @@ class CommandQueue { // Make sure that we don't exceed contiguous physical memory limits CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER); +#ifdef USE_TSIM int timeout = VTADeviceRun( device_, insn_queue_.dram_phy_addr(), + uop_queue_.dram_phy_addr(), + inp_phy_addr_, + wgt_phy_addr_, + acc_phy_addr_, + out_phy_addr_, insn_queue_.count(), wait_cycles); +#else + int timeout = VTADeviceRun( + device_, + insn_queue_.dram_phy_addr(), + insn_queue_.count(), + wait_cycles); +#endif CHECK_EQ(timeout, 0); // Reset buffers uop_queue_.Reset(); @@ -1125,6 +1150,18 @@ class CommandQueue { ThreadLocal().reset(); } +#ifdef USE_TSIM + void SetBufPhyAddr(uint32_t type, vta_phy_addr_t addr) { + switch (type) { + case VTA_MEM_ID_INP: inp_phy_addr_ = addr; + case VTA_MEM_ID_WGT: wgt_phy_addr_ = addr; + case VTA_MEM_ID_ACC: acc_phy_addr_ = addr; + case VTA_MEM_ID_OUT: out_phy_addr_ = addr; + default: break; + } + } +#endif + private: // Push GEMM uop to the command buffer void PushGEMMOp(UopKernel* kernel) { @@ -1229,6 +1266,16 @@ class CommandQueue { InsnQueue insn_queue_; // Device handle VTADeviceHandle device_{nullptr}; +#ifdef USE_TSIM + // Input phy addr + vta_phy_addr_t inp_phy_addr_{0}; + // Weight phy addr + vta_phy_addr_t wgt_phy_addr_{0}; + // Accumulator phy addr + vta_phy_addr_t acc_phy_addr_{0}; + // Output phy addr + vta_phy_addr_t out_phy_addr_{0}; +#endif }; } // namespace vta @@ -1317,6 +1364,10 @@ void VTALoadBuffer2D(VTACommandHandle cmd, uint32_t y_pad_after, uint32_t dst_sram_index, uint32_t dst_memory_type) { +#ifdef USE_TSIM + vta::DataBuffer* src = vta::DataBuffer::FromHandle(src_dram_addr); + static_cast(cmd)->SetBufPhyAddr(dst_memory_type, src->phy_addr()); +#endif static_cast(cmd)-> LoadBuffer2D(src_dram_addr, src_elem_offset, x_size, y_size, x_stride, @@ -1333,6 +1384,10 @@ void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t x_size, uint32_t y_size, uint32_t x_stride) { +#ifdef USE_TSIM + vta::DataBuffer* dst = vta::DataBuffer::FromHandle(dst_dram_addr); + static_cast(cmd)->SetBufPhyAddr(src_memory_type, dst->phy_addr()); +#endif static_cast(cmd)-> StoreBuffer2D(src_sram_index, src_memory_type, dst_dram_addr, dst_elem_offset, diff --git a/vta/src/tsim/tsim_driver.cc b/vta/src/tsim/tsim_driver.cc new file mode 100644 index 000000000000..e0ceb9028503 --- /dev/null +++ b/vta/src/tsim/tsim_driver.cc @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +namespace vta { +namespace tsim { + +using vta::dpi::DPIModuleNode; +using tvm::runtime::Module; + +class DPILoader { + public: + void Init(Module module) { + mod_ = module; + } + + DPIModuleNode* Get() { + return static_cast(mod_.operator->()); + } + + static DPILoader* Global() { + static DPILoader inst; + return &inst; + } + + Module mod_; +}; + +class Device { + public: + Device() { + dpi_ = DPILoader::Global(); + } + + int Run(vta_phy_addr_t insn_phy_addr, + vta_phy_addr_t uop_phy_addr, + vta_phy_addr_t inp_phy_addr, + vta_phy_addr_t wgt_phy_addr, + vta_phy_addr_t acc_phy_addr, + vta_phy_addr_t out_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles) { + this->Init(); + this->Launch(insn_phy_addr, + uop_phy_addr, + inp_phy_addr, + wgt_phy_addr, + acc_phy_addr, + out_phy_addr, + insn_count, + wait_cycles); + this->WaitForCompletion(wait_cycles); + dev_->Finish(); + return 0; + } + + private: + void Init() { + dev_ = dpi_->Get(); + } + + void Launch(vta_phy_addr_t insn_phy_addr, + vta_phy_addr_t uop_phy_addr, + vta_phy_addr_t inp_phy_addr, + vta_phy_addr_t wgt_phy_addr, + vta_phy_addr_t acc_phy_addr, + vta_phy_addr_t out_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles) { + // launch simulation thread + dev_->Launch(wait_cycles); + dev_->WriteReg(0x10, insn_count); + dev_->WriteReg(0x14, insn_phy_addr); + dev_->WriteReg(0x18, insn_phy_addr >> 32); + dev_->WriteReg(0x1c, 0); + dev_->WriteReg(0x20, uop_phy_addr >> 32); + dev_->WriteReg(0x24, 0); + dev_->WriteReg(0x28, inp_phy_addr >> 32); + dev_->WriteReg(0x2c, 0); + dev_->WriteReg(0x30, wgt_phy_addr >> 32); + dev_->WriteReg(0x34, 0); + dev_->WriteReg(0x38, acc_phy_addr >> 32); + dev_->WriteReg(0x3c, 0); + dev_->WriteReg(0x40, out_phy_addr >> 32); + // start + dev_->WriteReg(0x00, 0x1); + } + + void WaitForCompletion(uint32_t wait_cycles) { + uint32_t i, val; + for (i = 0; i < wait_cycles; i++) { + val = dev_->ReadReg(0x00); + val &= 0x2; + if (val == 0x2) break; // finish + } + } + + DPILoader* dpi_; + DPIModuleNode* dev_; +}; + +using tvm::runtime::TVMRetValue; +using tvm::runtime::TVMArgs; + +TVM_REGISTER_GLOBAL("tvm.vta.tsim.init") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Module m = args[0]; + DPILoader::Global()->Init(m); + }); + +} // namespace tsim +} // namespace vta + +void* VTAMemAlloc(size_t size, int cached) { + void *p = malloc(size); + return p; +} + +void VTAMemFree(void* buf) { + free(buf); +} + +vta_phy_addr_t VTAMemGetPhyAddr(void* buf) { + return reinterpret_cast(reinterpret_cast(buf)); +} + +void VTAFlushCache(vta_phy_addr_t buf, int size) { +} + +void VTAInvalidateCache(vta_phy_addr_t buf, int size) { +} + +VTADeviceHandle VTADeviceAlloc() { + return new vta::tsim::Device(); +} + +void VTADeviceFree(VTADeviceHandle handle) { + delete static_cast(handle); +} + +int VTADeviceRun(VTADeviceHandle handle, + vta_phy_addr_t insn_phy_addr, + vta_phy_addr_t uop_phy_addr, + vta_phy_addr_t inp_phy_addr, + vta_phy_addr_t wgt_phy_addr, + vta_phy_addr_t acc_phy_addr, + vta_phy_addr_t out_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles) { + return static_cast(handle)->Run( + insn_phy_addr, + uop_phy_addr, + inp_phy_addr, + wgt_phy_addr, + acc_phy_addr, + out_phy_addr, + insn_count, + wait_cycles); +} diff --git a/vta/tests/python/unittest/test_vta_insn.py b/vta/tests/python/unittest/test_vta_insn.py index 58835bbe3eab..2cedceae4e7d 100644 --- a/vta/tests/python/unittest/test_vta_insn.py +++ b/vta/tests/python/unittest/test_vta_insn.py @@ -68,6 +68,10 @@ def _run(env, remote): y_np = x_np.astype(y.dtype) x_nd = tvm.nd.array(x_np, ctx) y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + f(x_nd, y_nd) np.testing.assert_equal(y_np, y_nd.asnumpy()) @@ -126,6 +130,10 @@ def _run(env, remote): :] = x_np x_nd = tvm.nd.array(x_np, ctx) y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + f(x_nd, y_nd) np.testing.assert_equal(y_np, y_nd.asnumpy()) @@ -197,6 +205,9 @@ def verify(s): y_np = np.right_shift(y_np, 8) y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype) + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + if env.TARGET == "sim": simulator.clear_stats() f(x_nd, w_nd, y_nd) @@ -351,6 +362,10 @@ def check_alu(tvm_op, np_op=None, use_imm=False): a_nd = tvm.nd.array(a_np, ctx) res_nd = tvm.nd.array( np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + if use_imm: f(a_nd, res_nd) else: @@ -420,6 +435,10 @@ def _run(env, remote): a_nd = tvm.nd.array(a_np, ctx) res_nd = tvm.nd.array( np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + f(a_nd, res_nd) np.testing.assert_equal(res_np, res_nd.asnumpy()) @@ -479,6 +498,10 @@ def _run(env, remote): a_nd = tvm.nd.array(a_np, ctx) res_nd = tvm.nd.array( np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + f(a_nd, res_nd) np.testing.assert_equal(res_np, res_nd.asnumpy()) @@ -503,11 +526,12 @@ def _run(env, remote): print("Load/store test") test_save_load_out() print("Padded load test") - #test_padded_load() + test_padded_load() print("GEMM test") test_gemm() - test_alu() print("ALU test") + test_alu() + print("Relu test") test_relu() print("Shift and scale") test_shift_and_scale()