From 8cafc1b4dbb6c89e5f55358717e1eb34f0e2fa04 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Sat, 27 Jul 2019 13:39:37 -0700 Subject: [PATCH] [VTA] [Chisel] fix tensor issue/commit in gemm (#3637) * fix tensor issue/commit in gemm * remove trailing spaces --- vta/hardware/chisel/src/main/scala/core/TensorGemm.scala | 8 ++++---- vta/hardware/chisel/src/main/scala/core/TensorStore.scala | 2 +- vta/tests/hardware/metal_test/metal_test.cc | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala index bfa79dd600ef..051e0114190c 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala @@ -39,10 +39,10 @@ class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module { val rA = RegNext(io.a) val rB = RegNext(io.b) val rC = RegNext(io.c) - + mult := rA * rB add := rC +& mult - + io.y := add } @@ -226,9 +226,9 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module when (state === sIdle) { inflight := 0.U } .elsewhen (!dec.reset) { - when (state === sExe && inflight =/= ((1 << pBits) - 1).asUInt) { // overflow check + when (state === sReadTensor) { // issue a tensor inflight := inflight + 1.U - } .elsewhen (mvc.io.acc_o.data.valid && inflight =/= 0.U) { // underflow check + } .elsewhen (mvc.io.acc_o.data.valid) { // commit a tensor inflight := inflight - 1.U } } diff --git a/vta/hardware/chisel/src/main/scala/core/TensorStore.scala b/vta/hardware/chisel/src/main/scala/core/TensorStore.scala index 1c6be03ba037..af53ba7b5409 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorStore.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorStore.scala @@ -28,7 +28,7 @@ import vta.shell._ * * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM). */ -class TensorStore(tensorType: String = "true", debug: Boolean = false) +class TensorStore(tensorType: String = "none", debug: Boolean = false) (implicit p: Parameters) extends Module { val tp = new TensorParams(tensorType) val mp = p(ShellKey).memParams diff --git a/vta/tests/hardware/metal_test/metal_test.cc b/vta/tests/hardware/metal_test/metal_test.cc index 9c4600716996..1dec5d2de660 100644 --- a/vta/tests/hardware/metal_test/metal_test.cc +++ b/vta/tests/hardware/metal_test/metal_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY