diff --git a/src/main/scala/yunsuan/encoding/Opcode/VialuOpcode.scala b/src/main/scala/yunsuan/encoding/Opcode/VialuOpcode.scala index a24a842..987cba8 100644 --- a/src/main/scala/yunsuan/encoding/Opcode/VialuOpcode.scala +++ b/src/main/scala/yunsuan/encoding/Opcode/VialuOpcode.scala @@ -39,4 +39,14 @@ object VialuOpcode { def vssrl = 31.U(width.W) def vssra = 32.U(width.W) def vmvsx = 46.U(width.W) + // Zvbb + def vcpop = 39.U(width.W) + def vbrev = 48.U(width.W) + def vbrev8 = 49.U(width.W) + def vrev8 = 50.U(width.W) + def vclz = 51.U(width.W) + def vctz = 52.U(width.W) + def vrol = 53.U(width.W) + def vror = 54.U(width.W) + def vwsll = 55.U(width.W) } diff --git a/src/main/scala/yunsuan/package.scala b/src/main/scala/yunsuan/package.scala index 9d74f6c..1a20c38 100644 --- a/src/main/scala/yunsuan/package.scala +++ b/src/main/scala/yunsuan/package.scala @@ -176,6 +176,17 @@ package object yunsuan { def vmorn_mm = LiteralCat(FMT.MMM , UINT, VialuOpcode.vorn) // "b10_0_001101".U(OpTypeWidth.W) // vorn def vmxnor_mm = LiteralCat(FMT.MMM , UINT, VialuOpcode.vxnor) // "b10_0_001110".U(OpTypeWidth.W) // vxnor def vmv_s_x = LiteralCat(FMT.ZXV , SINT, VialuOpcode.vmvsx) // "b00_1_101110".U(OpTypeWidth.W) // vmvsx + // Zvbb + def vandn_vv = LiteralCat(FMT.VVV , UINT, VialuOpcode.vandn) // "b00_0_001001".U(OPTypeWidth.W) // vandn + def vbrev_v = LiteralCat(FMT.VVV , UINT, VialuOpcode.vbrev) // "b00_0_101111".U(OpTypeWidth.W) // vbrev + def vbrev8_v = LiteralCat(FMT.VVV , UINT, VialuOpcode.vbrev8) // "b00_0_110000".U(OpTypeWidth.W) // vbrev8 + def vrev8_v = LiteralCat(FMT.VVV , UINT, VialuOpcode.vrev8) // "b00_0_110001".U(OpTypeWidth.W) // vrev8 + def vclz_v = LiteralCat(FMT.VVV , UINT, VialuOpcode.vclz) // "b00_0_110010".U(OpTypeWidth.W) // vclz + def vctz_v = LiteralCat(FMT.VVV , UINT, VialuOpcode.vctz) // "b00_0_110011".U(OpTypeWidth.W) // vctz + def vcpop_v = LiteralCat(FMT.VVV , UINT, VialuOpcode.vcpop) // "b00_0_110100".U(OpTypeWidth.W) // vcpop + def vrol_vv = LiteralCat(FMT.VVV , UINT, VialuOpcode.vrol) // "b00_0_110101".U(OpTypeWidth.W) // vrol + def vror_vv = LiteralCat(FMT.VVV , UINT, VialuOpcode.vror) // "b00_0_110110".U(OpTypeWidth.W) // vror + def vwsll_vv = LiteralCat(FMT.WVW , UINT, VialuOpcode.vwsll) // "b10_0_110111".U(OpTypeWidth.W) // vwsll def getOpcode(fuOpType: UInt) : UInt = fuOpType(5, 0) diff --git a/src/main/scala/yunsuan/vector/VectorALU/VAluBundles.scala b/src/main/scala/yunsuan/vector/VectorALU/VAluBundles.scala index e908cd8..4c129f1 100644 --- a/src/main/scala/yunsuan/vector/VectorALU/VAluBundles.scala +++ b/src/main/scala/yunsuan/vector/VectorALU/VAluBundles.scala @@ -52,6 +52,7 @@ class VAluOpcode extends Bundle{ def isIntFixp = op < vredsum || op === vmvsx def isVmvsx = op === vmvsx def isVmvxs = op === vmvxs + def isVmergeMove = op === vmerge || op === vmv || op === vmvsx // IMac opcode: def op3b = op(2, 0) def highHalf = op3b === 1.U @@ -59,6 +60,18 @@ class VAluOpcode extends Bundle{ def isSub = op3b === 3.U || op3b === 5.U def isFixP = op3b === 6.U def overWriteMultiplicand = op3b === 4.U || op3b === 5.U + // Zvbb opcode: + def isVbrev = op === vbrev + def isVbrev8 = op === vbrev8 + def isVrev8 = op === vrev8 + def isClz = op === vclz + def isCtz = op === vctz + def isVrol = op === vrol + def isVror = op === vror + def isVwsll = op === vwsll + def isVrev = op === vbrev || op === vbrev8 || op === vrev8 + def isVCount = op === vclz || op === vctz || op === vcpop + def isVro = op === vrol || op === vror } class VIFuInfo extends Bundle { diff --git a/src/main/scala/yunsuan/vector/VectorALU/VAluDecode.scala b/src/main/scala/yunsuan/vector/VectorALU/VAluDecode.scala index 682172e..965ee99 100644 --- a/src/main/scala/yunsuan/vector/VectorALU/VAluDecode.scala +++ b/src/main/scala/yunsuan/vector/VectorALU/VAluDecode.scala @@ -54,6 +54,15 @@ object VAluOpcode { val vid = 45.U(6.W) val vmvsx = 46.U(6.W) val vmvxs = 47.U(6.W) + // Zvbb + val vbrev = 48.U(6.W) + val vbrev8 = 49.U(6.W) + val vrev8 = 50.U(6.W) + val vclz = 51.U(6.W) + val vctz = 52.U(6.W) + val vrol = 53.U(6.W) + val vror = 54.U(6.W) + val vwsll = 55.U(6.W) } import VAluOpcode._ diff --git a/src/main/scala/yunsuan/vector/VectorALU/VIntMisc64b.scala b/src/main/scala/yunsuan/vector/VectorALU/VIntMisc64b.scala index 492c48e..0bc4e7f 100644 --- a/src/main/scala/yunsuan/vector/VectorALU/VIntMisc64b.scala +++ b/src/main/scala/yunsuan/vector/VectorALU/VIntMisc64b.scala @@ -18,6 +18,7 @@ import chisel3._ import chisel3.util._ import yunsuan.vector._ import yunsuan.vector.alu.VAluOpcode._ +import yunsuan.vector.VectorConvert.util.CLZ class VIntMisc64b extends Module { val io = IO(new Bundle { @@ -195,8 +196,355 @@ class VIntMisc64b extends Module { } val mergeMove = Mux(vm || opcode.isVmvsx, vs1, mergeResult.asUInt) + /** + * Zvbb vbrev.v vbrev8.v vrev8.v + */ + val revResult = Wire(UInt(64.W)) + val brevResult_8 = Wire(Vec(8, UInt(8.W))) + val brevResult_8_tmp = Wire(Vec(8, UInt(8.W))) + val brevResult_16 = Wire(Vec(4, UInt(16.W))) + val brevResult_16_tmp = Wire(Vec(4, UInt(16.W))) + val brevResult_32 = Wire(Vec(2, UInt(32.W))) + val brevResult_32_tmp = Wire(Vec(2, UInt(32.W))) + val brevResult_64 = Wire(Vec(1, UInt(64.W))) + val brevResult_64_tmp = Wire(Vec(1, UInt(64.W))) + brevResult_8_tmp := vs2.asTypeOf(brevResult_8_tmp) + brevResult_16_tmp := vs2.asTypeOf(brevResult_16_tmp) + brevResult_32_tmp := vs2.asTypeOf(brevResult_32_tmp) + brevResult_64_tmp := vs2.asTypeOf(brevResult_64_tmp) + + for (i <- 0 until 8) { + brevResult_8(i) := VecInit(brevResult_8_tmp(i).asBools.reverse).asUInt + } + for (i <- 0 until 4) { + brevResult_16(i) := VecInit(brevResult_16_tmp(i).asBools.reverse).asUInt + } + for (i <- 0 until 2) { + brevResult_32(i) := VecInit(brevResult_32_tmp(i).asBools.reverse).asUInt + } + for (i <- 0 until 1) { + brevResult_64(i) := VecInit(brevResult_64_tmp(i).asBools.reverse).asUInt + } + + val brev8Result = Wire(Vec(8, UInt(8.W))) + val brev8Result_tmp = Wire(Vec(8, UInt(8.W))) + brev8Result_tmp := vs2.asTypeOf(brev8Result_tmp) + for (i <- 0 until 8) { + brev8Result(i) := VecInit(brev8Result_tmp(i).asBools.reverse).asUInt + } + + val rev8Result_16 = Wire(Vec(4, Vec(2, UInt(8.W)))) + val rev8Result_16_tmp = Wire(Vec(4, Vec(2, UInt(8.W)))) + val rev8Result_32 = Wire(Vec(2, Vec(4, UInt(8.W)))) + val rev8Result_32_tmp = Wire(Vec(2, Vec(4, UInt(8.W)))) + val rev8Result_64 = Wire(Vec(8, UInt(8.W))) + val rev8Result_64_tmp = Wire(Vec(8, UInt(8.W))) + rev8Result_16_tmp := vs2.asTypeOf(rev8Result_16_tmp) + rev8Result_32_tmp := vs2.asTypeOf(rev8Result_32_tmp) + rev8Result_64_tmp := vs2.asTypeOf(rev8Result_64_tmp) + + for (i <- 0 until 4) { + for (j <- 0 until 2) { + rev8Result_16(i)(1-j) := rev8Result_16_tmp(i)(j) + } + } + for (i <- 0 until 2) { + for (j <- 0 until 4) { + rev8Result_32(i)(3-j) := rev8Result_32_tmp(i)(j) + } + } + for (i <- 0 until 8) { + rev8Result_64(7-i) := rev8Result_64_tmp(i) + } + + revResult := Mux1H( + Seq( + (opcode.op === vbrev) && eewVd.is8, + (opcode.op === vbrev) && eewVd.is16, + (opcode.op === vbrev) && eewVd.is32, + (opcode.op === vbrev) && eewVd.is64, + opcode.op === vbrev8, + (opcode.op === vrev8) && eewVd.is8, + (opcode.op === vrev8) && eewVd.is16, + (opcode.op === vrev8) && eewVd.is32, + (opcode.op === vrev8) && eewVd.is64, + ), + Seq( + brevResult_8.asUInt, + brevResult_16.asUInt, + brevResult_32.asUInt, + brevResult_64.asUInt, + brev8Result.asUInt, + vs2, + rev8Result_16.asUInt, + rev8Result_32.asUInt, + rev8Result_64.asUInt, + ) + ) + + /** + * vclz.v + * vctz.v + * vcpop.v + */ + val countResult = Wire(UInt(64.W)) + val countResult_8 = Wire(Vec(4, UInt(8.W))) + val countResult_16 = Wire(Vec(2, UInt(16.W))) + val countResult_32 = Wire(UInt(32.W)) + val countResult_64 = Wire(UInt(64.W)) + val pop_8 = Wire(Vec(8, UInt(8.W))) + val pop_16 = Wire(Vec(4, UInt(16.W))) + val pop_32 = Wire(Vec(2, UInt(32.W))) + val pop_64 = Wire(Vec(1, UInt(64.W))) + val cnt8 = Wire(Vec(8, UInt(8.W))) + val cnt16 = Wire(Vec(4, UInt(8.W))) + val cnt32 = Wire(Vec(2, UInt(8.W))) + val cnt64 = Wire(Vec(1, UInt(8.W))) + + pop_8 := vs2.asTypeOf(pop_8) + pop_16 := vs2.asTypeOf(pop_16) + pop_32 := vs2.asTypeOf(pop_32) + pop_64 := vs2.asTypeOf(pop_64) + + for (i <- 0 until 4) { + countResult_8(i) := Mux(opcode.isClz, vs2(8*i+7, 8*i), VecInit(vs2(8*i+7, 8*i).asBools.reverse).asUInt) + } + for (i <- 0 until 2) { + countResult_16(i) := Mux1H( + Seq( + eewVd.is8, + eewVd.is16, + ), + Seq( + Mux(opcode.isClz, vs2(8*i+7+32,8*i+32) << 8, VecInit((vs2(8*i+7+32,8*i+32) << 8).asBools.reverse).asUInt) | (1.U << 7), + Mux(opcode.isClz, vs2(16*i+15,16*i), VecInit(vs2(16*i+15,16*i).asBools.reverse).asUInt), + ) + ) + } + + countResult_32 := Mux1H( + Seq( + eewVd.is8, + eewVd.is16, + eewVd.is32, + ), + Seq( + Mux(opcode.isClz, vs2(55, 48) << 24, VecInit((vs2(55, 48) << 24).asBools.reverse).asUInt) | (1.U << 23), + Mux(opcode.isClz, vs2(47, 32) << 16, VecInit((vs2(47, 32) << 16).asBools.reverse).asUInt) | (1.U << 15), + Mux(opcode.isClz, vs2(31, 0), VecInit(vs2(31, 0).asBools.reverse).asUInt), + ) + ) + countResult_64 := Mux1H( + Seq( + eewVd.is8, + eewVd.is16, + eewVd.is32, + eewVd.is64, + ), + Seq( + Mux(opcode.isClz, vs2(63, 56) << 56, VecInit((vs2(63, 56) << 56).asBools.reverse).asUInt) | (1.U << 55), + Mux(opcode.isClz, vs2(63, 48) << 48, VecInit((vs2(63, 48) << 48).asBools.reverse).asUInt) | (1.U << 47), + Mux(opcode.isClz, vs2(63, 32) << 32, VecInit((vs2(63, 32) << 32).asBools.reverse).asUInt) | (1.U << 31), + Mux(opcode.isClz, vs2, VecInit(vs2.asBools.reverse).asUInt), + ) + ) + + val cnt16_0_tmp = CLZ(countResult_16(0)) + val cnt16_1_tmp = CLZ(countResult_16(1)) + val cnt32_tmp = CLZ(countResult_32) + val cnt64_tmp = CLZ(countResult_64) + + for (i <- 0 until 4) { + cnt8(i) := Mux(opcode.isVcpop, PopCount(pop_8(i)), CLZ(countResult_8(i))) + } + cnt8(4) := Mux(opcode.isVcpop, PopCount(pop_8(4)), cnt16_0_tmp) + cnt8(5) := Mux(opcode.isVcpop, PopCount(pop_8(5)), cnt16_1_tmp) + cnt8(6) := Mux(opcode.isVcpop, PopCount(pop_8(6)), cnt32_tmp) + cnt8(7) := Mux(opcode.isVcpop, PopCount(pop_8(7)), cnt64_tmp) + cnt16(0) := Mux(opcode.isVcpop, PopCount(pop_16(0)), cnt16_0_tmp) + cnt16(1) := Mux(opcode.isVcpop, PopCount(pop_16(1)), cnt16_1_tmp) + cnt16(2) := Mux(opcode.isVcpop, PopCount(pop_16(2)), cnt32_tmp) + cnt16(3) := Mux(opcode.isVcpop, PopCount(pop_16(3)), cnt64_tmp) + cnt32(0) := Mux(opcode.isVcpop, PopCount(pop_32(0)), cnt32_tmp) + cnt32(1) := Mux(opcode.isVcpop, PopCount(pop_32(1)), cnt64_tmp) + cnt64(0) := Mux(opcode.isVcpop, PopCount(pop_64(0)), cnt64_tmp) + + countResult := Mux1H( + Seq( + opcode.isVCount && eewVd.is8, + opcode.isVCount && eewVd.is16, + opcode.isVCount && eewVd.is32, + opcode.isVCount && eewVd.is64, + ), + Seq( + cnt8.asUInt, + cnt16.asUInt, + cnt32.asUInt, + cnt64.asUInt, + ) + ) + + val vroResult = Wire(UInt(64.W)) + val vroResult_8 = Wire(Vec(8, UInt(8.W))) + val vroResult_16 = Wire(Vec(4, UInt(16.W))) + val vroResult_32 = Wire(Vec(2, UInt(32.W))) + val vroResult_64 = Wire(Vec(1, UInt(64.W))) + vroResult_8 := vs2.asTypeOf(vroResult_8) + vroResult_16 := vs2.asTypeOf(vroResult_16) + vroResult_32 := vs2.asTypeOf(vroResult_32) + vroResult_64 := vs2.asTypeOf(vroResult_64) + val vroShift8 = Wire(Vec(8, UInt(3.W))) + val vroShift16 = Wire(Vec(4, UInt(4.W))) + val vroShift32 = Wire(Vec(2, UInt(5.W))) + val vroShift64 = Wire(Vec(1, UInt(6.W))) + val vroShift8_neg = Wire(Vec(8, UInt(3.W))) + val vroShift16_neg = Wire(Vec(4, UInt(4.W))) + val vroShift32_neg = Wire(Vec(2, UInt(5.W))) + val vroShift64_neg = Wire(Vec(1, UInt(6.W))) + + for (i <- 0 until 8) { + vroShift8(i) := vs1(8*i+2, 8*i) + vroShift8_neg(i) := (~vs1(8*i+2, 8*i)).asUInt + 1.U + } + for (i <- 0 until 4) { + vroShift16(i) := vs1(16*i+3, 16*i) + vroShift16_neg(i) := (~vs1(16*i+3, 16*i)).asUInt + 1.U + } + for (i <- 0 until 2) { + vroShift32(i) := vs1(32*i+4, 32*i) + vroShift32_neg(i) := (~vs1(32*i+4, 32*i)).asUInt + 1.U + } + for (i <- 0 until 1) { + vroShift64(i) := vs1(64*i+5, 64*i) + vroShift64_neg(i) := (~vs1(64*i+5, 64*i)).asUInt + 1.U + } + + val vroResult_8_tmp = Wire(Vec(8, UInt(8.W))) + val vroResult_16_tmp = Wire(Vec(4, UInt(16.W))) + val vroResult_32_tmp = Wire(Vec(2, UInt(32.W))) + val vroResult_64_tmp = Wire(Vec(1, UInt(64.W))) + + // vs2 << vs1 is equal to (vs2.reverse >> vs1).reverse + for (i <- 0 until 8) { + vroResult_8_tmp(i) := Mux1H( + Seq( + opcode.isVrol, + opcode.isVror, + ), + Seq( + VecInit(shiftOneElement(vroShift8(i), VecInit(vroResult_8(i).asBools.reverse).asUInt, 8)._1.asBools.reverse).asUInt | shiftOneElement(vroShift8_neg(i), vroResult_8(i), 8)._1, + VecInit(shiftOneElement(vroShift8_neg(i), VecInit(vroResult_8(i).asBools.reverse).asUInt, 8)._1.asBools.reverse).asUInt | shiftOneElement(vroShift8(i), vroResult_8(i), 8)._1, + ) + ) + } + for (i <- 0 until 4) { + vroResult_16_tmp(i) := Mux1H( + Seq( + opcode.isVrol, + opcode.isVror, + ), + Seq( + VecInit(shiftOneElement(vroShift16(i), VecInit(vroResult_16(i).asBools.reverse).asUInt, 16)._1.asBools.reverse).asUInt | shiftOneElement(vroShift16_neg(i), vroResult_16(i), 16)._1, + VecInit(shiftOneElement(vroShift16_neg(i), VecInit(vroResult_16(i).asBools.reverse).asUInt, 16)._1.asBools.reverse).asUInt | shiftOneElement(vroShift16(i), vroResult_16(i), 16)._1, + ) + ) + } + for (i <- 0 until 2) { + vroResult_32_tmp(i) := Mux1H( + Seq( + opcode.isVrol, + opcode.isVror, + ), + Seq( + VecInit(shiftOneElement(vroShift32(i), VecInit(vroResult_32(i).asBools.reverse).asUInt, 32)._1.asBools.reverse).asUInt | shiftOneElement(vroShift32_neg(i), vroResult_32(i), 32)._1, + VecInit(shiftOneElement(vroShift32_neg(i), VecInit(vroResult_32(i).asBools.reverse).asUInt, 32)._1.asBools.reverse).asUInt | shiftOneElement(vroShift32(i), vroResult_32(i), 32)._1, + ) + ) + } + for (i <- 0 until 1) { + vroResult_64_tmp(i) := Mux1H( + Seq( + opcode.isVrol, + opcode.isVror, + ), + Seq( + VecInit(shiftOneElement(vroShift64(i), VecInit(vroResult_64(i).asBools.reverse).asUInt, 64)._1.asBools.reverse).asUInt | shiftOneElement(vroShift64_neg(i), vroResult_64(i), 64)._1, + VecInit(shiftOneElement(vroShift64_neg(i), VecInit(vroResult_64(i).asBools.reverse).asUInt, 64)._1.asBools.reverse).asUInt | shiftOneElement(vroShift64(i), vroResult_64(i), 64)._1, + ) + ) + } + vroResult := Mux1H( + Seq( + eewVd.is8, + eewVd.is16, + eewVd.is32, + eewVd.is64, + ), + Seq( + vroResult_8_tmp.asUInt, + vroResult_16_tmp.asUInt, + vroResult_32_tmp.asUInt, + vroResult_64_tmp.asUInt, + ) + ) + + /** + * vwsll.vv vwsll.vx vwsll.vi + */ + val wsllResult = Wire(UInt(64.W)) + val wsllResult_8 = Wire(Vec(4, UInt(8.W))) + val wsllResult_16 = Wire(Vec(2, UInt(16.W))) + val wsllResult_32 = Wire(Vec(1, UInt(32.W))) + wsllResult_8 := vs2.asTypeOf(wsllResult_8) + wsllResult_16 := vs2.asTypeOf(wsllResult_16) + wsllResult_32 := vs2.asTypeOf(wsllResult_32) + + val wsllResult_8_tmp = Wire(Vec(4, UInt(16.W))) + val wsllResult_16_tmp = Wire(Vec(2, UInt(32.W))) + val wsllResult_32_tmp = Wire(Vec(1, UInt(64.W))) + for (i <- 0 until 4) { + wsllResult_8_tmp(i) := VecInit(shiftOneElement(vs1(8*i+3, 8*i), VecInit(Cat(Fill(8, 0.U), wsllResult_8(i)).asBools.reverse).asUInt, 16)._1.asBools.reverse).asUInt + } + for (i <- 0 until 2) { + wsllResult_16_tmp(i) := VecInit(shiftOneElement(vs1(16*i+4, 16*i), VecInit(Cat(Fill(16, 0.U), wsllResult_16(i)).asBools.reverse).asUInt, 32)._1.asBools.reverse).asUInt + } + for (i <- 0 until 1) { + wsllResult_32_tmp(i) := VecInit(shiftOneElement(vs1(32*i+5, 32*i), VecInit(Cat(Fill(32, 0.U), wsllResult_32(i)).asBools.reverse).asUInt, 64)._1.asBools.reverse).asUInt + } + wsllResult := Mux1H( + Seq( + opcode.isVwsll && eewVd.is16, + opcode.isVwsll && eewVd.is32, + opcode.isVwsll && eewVd.is64, + ), + Seq( + wsllResult_8_tmp.asUInt, + wsllResult_16_tmp.asUInt, + wsllResult_32_tmp.asUInt, + ) + ) + // Output arbiter - io.vd := Mux(opcode.isShift, shiftResult, - Mux(opcode.isVext, extResult, - Mux(opcode.isBitLogical, bitLogical, mergeMove))) + io.vd := Mux1H( + Seq( + opcode.isShift, + opcode.isVext, + opcode.isBitLogical, + opcode.isVmergeMove, + opcode.isVrev, + opcode.isVCount, + opcode.isVro, + opcode.isVwsll, + ), + Seq( + shiftResult, + extResult, + bitLogical, + mergeMove, + revResult, + countResult, + vroResult, + wsllResult, + ) + ) }