Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vperm: support lmul > 1 for vslideup/dn and vrgather #100

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 172 additions & 57 deletions src/main/scala/yunsuan/vector/VectorPerm/Permutation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,85 @@ import chisel3.util._
import chisel3.util.experimental.decode.TruthTable
import scala.language.{existentials, postfixOps}
import yunsuan.vector._
import chisel3.util.experimental.decode.{QMCMinimizer, TruthTable, decoder}

class slideupVs2VdTable() extends Module {
// convert uop index of slide instruction to offset of vs2 and vd
val src = IO(Input(UInt(8.W)))
val outOffsetVs2 = IO(Output(UInt(3.W)))
val outOffsetVd = IO(Output(UInt(3.W)))
def compute_vs2_vd(lmul:Int, uopIdx:Int): (Int, Int) = {
for (i <- 0 until lmul) {
var prev = i * (i + 1) / 2
for (j <- 0 until i + 1) {
if (uopIdx == prev + j) {
return (i - j, i)
}
}
}
return (0, 0)
}
var combLmulUopIdx : Seq[(Int, Int, Int, Int)] = Seq()
for (lmul <- 0 until 4) {
for (uopIdx <- 0 until 36) {
var offset = compute_vs2_vd(1 << lmul, uopIdx)
var offsetVs2 = offset._1
var offsetVd = offset._2
combLmulUopIdx :+= (lmul, uopIdx, offsetVs2, offsetVd)
}
}
val out = decoder(QMCMinimizer, src, TruthTable(combLmulUopIdx.map {
case (lmul, uopIdx, offsetVs2, offsetVd) =>
(BitPat((lmul << 6 | uopIdx).U(8.W)), BitPat((offsetVs2 << 3 | offsetVd).U(6.W)))
}, BitPat.N(6)))
outOffsetVs2 := out(5, 3)
outOffsetVd := out(2, 0)
}

class slidednVs2VdTable() extends Module {
// convert uop index of slide instruction to offset of vs2 and vd
val src = IO(Input(UInt(8.W)))
val outOffsetVs2 = IO(Output(UInt(3.W)))
val outOffsetVd = IO(Output(UInt(3.W)))
val outIsFirst = IO(Output(Bool()))
def compute_vs2_vd(lmul:Int, uopIdx:Int): (Int, Int, Int) = {
var uopNum = lmul * (lmul + 1) / 2
for (i <- 0 until lmul) {
var prev = lmul * i - i * (i - 1) / 2
for (j <- 0 until lmul - i) {
if (uopIdx == prev + lmul - i - j - 1) {
return (j, i, if (j == lmul - i - 1) 1 else 0)
}
}
}
return (0, 0, 0)
}
var combLmulUopIdx : Seq[(Int, Int, Int, Int, Int)] = Seq()
for (lmul <- 0 until 4) {
for (uopIdx <- 0 until 36) {
var offset = compute_vs2_vd(1 << lmul, uopIdx)
var offsetVs2 = offset._1
var offsetVd = offset._2
var isFirst = offset._3
combLmulUopIdx :+= (lmul, uopIdx, offsetVs2, offsetVd, isFirst)
}
}
val out = decoder(QMCMinimizer, src, TruthTable(combLmulUopIdx.map {
case (lmul, uopIdx, offsetVs2, offsetVd, isFirst) =>
(BitPat((lmul << 6 | uopIdx).U(8.W)), BitPat((isFirst << 6 | offsetVs2 << 3 | offsetVd).U(7.W)))
}, BitPat.N(7)))
outIsFirst := out(6).asBool
outOffsetVs2 := out(5, 3)
outOffsetVd := out(2, 0)
}

class Permutation extends Module {
val VLEN = 128
val xLen = 64
val LaneWidth = 64
val NLanes = VLEN / 64
val vlenb = VLEN / 8
val vlenbWidth = log2Ceil(vlenb)
val io = IO(new Bundle {
val in = Flipped(ValidIO(new VPermInput))
val out = Output(new VIFuOutput)
Expand All @@ -29,6 +101,7 @@ class Permutation extends Module {
val ma = io.in.bits.info.ma
val ta = io.in.bits.info.ta
val vlmul = io.in.bits.info.vlmul
val lmul = Mux(vlmul > 4.U, 0.U, vlmul)
val vstart = io.in.bits.info.vstart
val vl = io.in.bits.info.vl
val uopIdx = io.in.bits.info.uopIdx
Expand All @@ -37,7 +110,6 @@ class Permutation extends Module {
val vsew = srcTypeVs2(1, 0)
val vsew_plus1 = Wire(UInt(3.W))
vsew_plus1 := Cat(0.U(1.W), ~vsew) + 1.U
val signed = srcTypeVs2(3, 2) === 1.U
val widen = vdType(1, 0) === (srcTypeVs2(1, 0) + 1.U)
val vsew_bytes = 1.U << vsew
val vsew_bits = 8.U << vsew
Expand Down Expand Up @@ -244,62 +316,96 @@ class Permutation extends Module {
dontTouch(compressed_res)

val base = Wire(UInt(7.W))
val vmask0 = Mux(vcompress, vs1, vmask)
val vmask1 = Mux(vcompress, vs1 >> ele_cnt, vmask >> ele_cnt)
val vmask0 = vmask
val vmask_uop = Wire(UInt(VLEN.W))
val vmask_byte_strb = Wire(Vec(vlenb, UInt(1.W)))
val vs1_bytes = VecInit(Seq.tabulate(vlenb)(i => vs1((i + 1) * 8 - 1, i * 8)))
val vs2_bytes = VecInit(Seq.tabulate(vlenb)(i => vs2((i + 1) * 8 - 1, i * 8)))
val emul = vlmul(1, 0)
val evl = Mux1H(Seq.tabulate(4)(i => (emul === i.U) -> (ele_cnt << i.U)))

val vslideupOffset = Module(new slideupVs2VdTable)
vslideupOffset.src := Cat(lmul, uopIdx)
val vslideupVs2Id = vslideupOffset.outOffsetVs2
val vslideupVd2Id = vslideupOffset.outOffsetVd

val vslidednOffset = Module(new slidednVs2VdTable)
vslidednOffset.src := Cat(lmul, uopIdx)
val vslidednVs2Id = vslidednOffset.outOffsetVs2
val vslidednVd2Id = vslidednOffset.outOffsetVd

val vrgatherVdId = Mux1H(Seq(
(vlmul === 0.U) -> 0.U,
(vlmul === 1.U) -> uopIdx(1),
(vlmul === 2.U) -> uopIdx(3, 2),
(vlmul === 3.U) -> uopIdx(5, 3),
))

val vrgatherVs2Id = Mux1H(Seq(
(vlmul === 0.U) -> 0.U,
(vlmul === 1.U) -> uopIdx(0),
(vlmul === 2.U) -> uopIdx(1, 0),
(vlmul === 3.U) -> uopIdx(2, 0),
))

val vrgather16_sew8VdId = Mux1H(Seq(
(vlmul === 0.U) -> 0.U,
(vlmul === 1.U) -> uopIdx(2),
(vlmul === 2.U) -> uopIdx(4, 3),
))

val vrgather16_sew8Vs2Id = Mux1H(Seq(
(vlmul === 0.U) -> 0.U,
(vlmul === 1.U) -> uopIdx(1),
(vlmul === 2.U) -> uopIdx(2, 1),
))

val vdId = Mux1H(Seq(
((vrgather && !vrgather16_sew8) || vrgather_vx) -> vrgatherVdId,
vrgather16_sew8 -> vrgather16_sew8VdId,
(vslideup) -> vslideupVd2Id,
(vslidedn) -> vslidednVd2Id,
))
val vs2Id = Mux1H(Seq(
((vrgather && !vrgather16_sew8) || vrgather_vx) -> vrgatherVs2Id,
vrgather16_sew8 -> vrgather16_sew8Vs2Id,
(vslideup) -> vslideupVs2Id,
(vslidedn) -> vslidednVs2Id,
))

dontTouch(vdId)
dontTouch(vs2Id)

val vslideup_vl = Wire(UInt(8.W))
vlRemain := vslideup_vl
when((vcompress && uopIdx(1)) ||
(vslideup && ((uopIdx === 1.U) || (uopIdx === 2.U))) ||
(vslidedn && (uopIdx === 2.U)) ||
(((vrgather && !vrgather16_sew8) || vrgather_vx) && (uopIdx >= 2.U)) ||
(vrgather16_sew8 && (uopIdx >= 4.U))
) {
vlRemain := Mux(vslideup_vl >= ele_cnt, vslideup_vl - ele_cnt, 0.U)
}.elsewhen(vslide1up) {
when(vslide1up) {
vlRemain := Mux(vl >= (uopIdx << vsew_plus1), vl - (uopIdx << vsew_plus1), 0.U)
}.elsewhen(vslide1dn) {
vlRemain := Mux(vl >= (uopIdx(5, 1) << vsew_plus1), vl - (uopIdx(5, 1) << vsew_plus1), 0.U)
}.otherwise {
vlRemain := Mux1H(Seq.tabulate(8)(i => (vdId === i.U) -> Mux(vslideup_vl >= (ele_cnt * i.U), vslideup_vl - (ele_cnt * i.U), 0.U)))
}

vmask_uop := vmask0
when((vcompress && uopIdx(1)) ||
(vslideup && ((uopIdx === 1.U) || (uopIdx === 2.U))) ||
(vslidedn && (uopIdx === 2.U)) ||
(((vrgather && !vrgather16_sew8) || vrgather_vx) && (uopIdx >= 2.U)) ||
(vrgather16_sew8 && (uopIdx >= 4.U))
) {
vmask_uop := vmask1
}.elsewhen(vslide1up) {
when(vslide1up) {
vmask_uop := vmask >> (uopIdx << vsew_plus1)
}.elsewhen(vslide1dn) {
vmask_uop := vmask >> (uopIdx(5, 1) << vsew_plus1)
}

when((vcompress && (uopIdx === 3.U)) ||
(vslideup && (uopIdx === 1.U)) ||
(vslidedn && (uopIdx === 0.U) && (vlmul === 1.U))
) {
base := vlenb.U
}.otherwise {
base := 0.U
vmask_uop := Mux1H(Seq.tabulate(8)(i => (vdId === i.U) -> (vmask >> (ele_cnt * i.U))))
}

base := Mux1H(Seq.tabulate(8)(i => (vs2Id === i.U) -> (vlenb * i).U))

for (i <- 0 until vlenb) {
when(i.U < vlRemainBytes) {
vmask_byte_strb(i) := vmask_uop(i) | (vm & !vcompress)
vmask_byte_strb(i) := vmask_uop(i) | vm
when(vsew === 1.U(3.W)) {
vmask_byte_strb(i) := vmask_uop(i / 2) | (vm & !vcompress)
vmask_byte_strb(i) := vmask_uop(i / 2) | vm
}.elsewhen(vsew === 2.U(3.W)) {
vmask_byte_strb(i) := vmask_uop(i / 4) | (vm & !vcompress)
vmask_byte_strb(i) := vmask_uop(i / 4) | vm
}.elsewhen(vsew === 3.U(3.W)) {
vmask_byte_strb(i) := vmask_uop(i / 8) | (vm & !vcompress)
vmask_byte_strb(i) := vmask_uop(i / 8) | vm
}
}.otherwise {
vmask_byte_strb(i) := 0.U
Expand All @@ -309,9 +415,14 @@ class Permutation extends Module {
// vrgather/vrgather16
val vlmax_bytes = Wire(UInt(5.W))
val vrgather_byte_sel = Wire(Vec(vlenb, UInt(64.W)))
val first_gather = (vlmul >= 4.U) || (vlmul === 0.U) || ((vlmul === 1.U) && (Mux(vrgather16_sew8, uopIdx(1), uopIdx(0)) === 0.U))
val vs2_bytes_min = Mux((vrgather16_sew8 && uopIdx(1)) || (((vrgather && !vrgather16_sew8) || vrgather_vx) && uopIdx(0)), vlenb.U, 0.U)
val vs2_bytes_max = Mux((vrgather16_sew8 && uopIdx(1)) || (((vrgather && !vrgather16_sew8) || vrgather_vx) && uopIdx(0)), Cat(vlenb.U, 0.U), vlmax_bytes)
val first_gather = (vlmul >= 4.U) || vs2Id === 0.U
val vs2_bytes_min = Mux1H(Seq.tabulate(8)(i => (vs2Id === i.U) -> (vlenb * i).U))
val vs2_bytes_max = Mux1H(Seq(
(vs2Id === 0.U) -> vlmax_bytes,
) ++ (1 until 8).map(i => (vs2Id === i.U) -> (vlenb * (i + 1)).U))

dontTouch(vs2_bytes_min)
dontTouch(vs2_bytes_max)
val vrgather_vd = Wire(Vec(vlenb, UInt(8.W)))

vlmax_bytes := vlenb.U
Expand Down Expand Up @@ -354,10 +465,14 @@ class Permutation extends Module {
vrgather_byte_sel(i) := Cat(vs1((i / 4 + 1) * 16 - 1, i / 4 * 16), 0.U(2.W)) + i.U % 4.U
}
}.elsewhen(srcTypeVs2(1, 0) === 3.U) {
when(uopIdx(1) === 0.U) {
when(uopIdx(1, 0) === 0.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1) * 16 - 1, (i / 8) * 16), 0.U(3.W)) + i.U % 8.U
}.otherwise {
}.elsewhen(uopIdx(1, 0) === 1.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 2) * 16 - 1, (i / 8 + 2) * 16), 0.U(3.W)) + i.U % 8.U
}.elsewhen(uopIdx(1, 0) === 2.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 4) * 16 - 1, (i / 8 + 4) * 16), 0.U(3.W)) + i.U % 8.U
}.otherwise {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 6) * 16 - 1, (i / 8 + 6) * 16), 0.U(3.W)) + i.U % 8.U
}
}
}.elsewhen(srcTypeVs1(1, 0) === 2.U) {
Expand Down Expand Up @@ -394,10 +509,14 @@ class Permutation extends Module {
vrgather_byte_sel(i) := Cat(vs1((i / 4 + 1) * 16 - 1, i / 4 * 16), 0.U(2.W)) + i.U % 4.U
}
}.elsewhen(srcTypeVs2(1, 0) === 3.U) {
when(uopIdx(1) === 0.U) {
when(uopIdx(1, 0) === 0.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1) * 16 - 1, (i / 8) * 16), 0.U(3.W)) + i.U % 8.U
}.otherwise {
}.elsewhen(uopIdx(1, 0) === 1.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 2) * 16 - 1, (i / 8 + 2) * 16), 0.U(3.W)) + i.U % 8.U
}.elsewhen(uopIdx(1, 0) === 2.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 4) * 16 - 1, (i / 8 + 4) * 16), 0.U(3.W)) + i.U % 8.U
}.otherwise {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 6) * 16 - 1, (i / 8 + 6) * 16), 0.U(3.W)) + i.U % 8.U
}
}
}.elsewhen(srcTypeVs1(1, 0) === 2.U) {
Expand All @@ -413,7 +532,7 @@ class Permutation extends Module {
vrgather_vd(i) := Mux(ma, "hff".U, old_vd((i + 1) * 8 - 1, i * 8))
when(vmask_byte_strb(i).asBool) {
when((vrgather_byte_sel(i) >= vs2_bytes_min) && (vrgather_byte_sel(i) < vs2_bytes_max)) {
vrgather_vd(i) := vs2_bytes(vrgather_byte_sel(i.U) - vs2_bytes_min)
vrgather_vd(i) := vs2_bytes((vrgather_byte_sel(i) - vs2_bytes_min)(vlenbWidth - 1, 0))
}.elsewhen(first_gather) {
vrgather_vd(i) := 0.U
}.otherwise {
Expand All @@ -431,18 +550,19 @@ class Permutation extends Module {
val vslidedn_vd = Wire(Vec(vlenb, UInt(8.W)))
val vslide1dn_vd_wo_rs1 = Wire(Vec(vlenb, UInt(8.W)))
val vslide1dn_vd_rs1 = Wire(UInt(VLEN.W))
val first_slidedn = vslidedn && (uopIdx === 0.U || uopIdx === 2.U)
val first_slidedn = vslidedn && vslidednOffset.outIsFirst
val load_rs1 = (((vlmul >= 4.U) || (vlmul === 0.U)) && (uopIdx === 0.U)) ||
((vlmul === 1.U) && (uopIdx === 2.U)) ||
((vlmul === 2.U) && (uopIdx === 6.U)) ||
(uopIdx === 14.U)
val vslide1dn_vd = Mux((load_rs1 || uopIdx(0)), VecInit(Seq.tabulate(vlenb)(i => vslide1dn_vd_rs1((i + 1) * 8 - 1, i * 8))), vslide1dn_vd_wo_rs1)
dontTouch(base)

for (i <- 0 until vlenb) {
vslideup_vd(i) := Mux(ma, "hff".U, old_vd(i * 8 + 7, i * 8))
when(vmask_byte_strb(i).asBool) {
when(((base + i.U) >= slide_bytes) && ((base + i.U - slide_bytes) < vlenb.U)) {
vslideup_vd(i) := vs2_bytes(base + i.U - slide_bytes)
when(((base +& i.U) >= slide_bytes) && ((base +& i.U - slide_bytes) < vlmax_bytes)) {
vslideup_vd(i) := vs2_bytes((base +& i.U - slide_bytes)(vlenbWidth - 1, 0))
}.otherwise {
vslideup_vd(i) := old_vd(i * 8 + 7, i * 8)
}
Expand All @@ -452,8 +572,8 @@ class Permutation extends Module {
for (i <- 0 until vlenb) {
vslidedn_vd(i) := Mux(ma, "hff".U, old_vd(i * 8 + 7, i * 8))
when(vmask_byte_strb(i).asBool) {
when(((i.U + slide_bytes) >= base) && ((i.U + slide_bytes - base) < vlmax_bytes)) {
vslidedn_vd(i) := vs2_bytes(i.U + slide_bytes - base)
when(((i.U +& slide_bytes) >= base) && ((i.U +& slide_bytes - base) < vlmax_bytes)) {
vslidedn_vd(i) := vs2_bytes((i.U +& slide_bytes - base)(vlenbWidth - 1, 0))
}.elsewhen(first_slidedn) {
vslidedn_vd(i) := 0.U
}.otherwise {
Expand All @@ -466,7 +586,7 @@ class Permutation extends Module {
vslide1up_vd(i) := Mux(ma, "hff".U, old_vd(i * 8 + 7, i * 8))
when(vslide1up && (vmask_byte_strb(i) === 1.U)) {
when((i.U < vsew_bytes)) {
vslide1up_vd(i) := vs1_bytes(vlenb.U - vsew_bytes + i.U)
vslide1up_vd(i) := vs1_bytes((vlenb.U - vsew_bytes + i.U)(vlenbWidth - 1, 0))
}.otherwise {
vslide1up_vd(i) := vs2_bytes(i.U - vsew_bytes)
}
Expand Down Expand Up @@ -498,17 +618,12 @@ class Permutation extends Module {

val vslideup_vstart = Mux(vslideup & (slide_ele > vstart), Mux(slide_ele > VLEN.U, VLEN.U, slide_ele), vstart)
vstartRemain := vslideup_vstart
when((vcompress && (uopIdx === 3.U)) ||
((vslideup) && ((uopIdx === 1.U) || (uopIdx === 2.U))) ||
((vslidedn) && (uopIdx === 2.U)) ||
(((vrgather && !vrgather16_sew8) || vrgather_vx) && (uopIdx >= 2.U)) ||
(vrgather16_sew8 && (uopIdx >= 4.U))
) {
vstartRemain := Mux(vslideup_vstart >= ele_cnt, vslideup_vstart - ele_cnt, 0.U)
}.elsewhen(vslide1up) {
when(vslide1up) {
vstartRemain := Mux(vstart >= (uopIdx << vsew_plus1), vstart - (uopIdx << vsew_plus1), 0.U)
}.elsewhen(vslide1dn) {
vstartRemain := Mux(vstart >= (uopIdx(5, 1) << vsew_plus1), vstart - (uopIdx(5, 1) << vsew_plus1), 0.U)
}.otherwise {
vstartRemain := Mux1H(Seq.tabulate(8)(i => (vdId === i.U) -> Mux(vslideup_vstart >= (ele_cnt * i.U), vslideup_vstart - (ele_cnt * i.U), 0.U)))
}

val vd_reg = RegInit(0.U(VLEN.W))
Expand All @@ -523,9 +638,9 @@ class Permutation extends Module {
vd_reg := Cat(vslidedn_vd.reverse)
}.elsewhen(vslide1dn && fire) {
vd_reg := Cat(vslide1dn_vd.reverse)
}.elsewhen((vrgather || vrgather_vx) && !(vrgather16_sew8 && ((vlmul === 0.U) || (vlmul === 1.U))) && fire) {
}.elsewhen((vrgather || vrgather_vx) && !(vrgather16_sew8) && fire) {
vd_reg := Cat(vrgather_vd.reverse)
}.elsewhen(vrgather16_sew8 && (vlmul === 0.U) || (vlmul === 1.U) && fire) {
}.elsewhen(vrgather16_sew8 && fire) {
when(uopIdx(0)) {
vd_reg := Cat(Cat(vrgather_vd.reverse)(VLEN - 1, VLEN / 2), old_vd(VLEN / 2 - 1, 0))
}.otherwise {
Expand Down Expand Up @@ -565,10 +680,10 @@ class Permutation extends Module {
val tail_bytes = Mux((vlRemainBytes_reg >= vlenb.U), 0.U, vlenb.U - vlRemainBytes_reg)
val tail_bits = Cat(tail_bytes, 0.U(3.W))
val vmask_tail_bits = Wire(UInt(VLEN.W))
vmask_tail_bits := Mux(is_vmvnr_reg, vd_mask, vd_mask >> tail_bits)
vmask_tail_bits := vd_mask >> tail_bits
val tail_old_vd = old_vd_reg & (~vmask_tail_bits)
val tail_ones_vd = ~vmask_tail_bits
val tail_vd = Mux(is_vmvnr_reg, 0.U, Mux(ta_reg, tail_ones_vd, tail_old_vd))
val tail_vd = Mux(ta_reg, tail_ones_vd, tail_old_vd)
val perm_tail_mask_vd = Wire(UInt(VLEN.W))

val vstart_bytes = Mux(vstartRemainBytes_reg >= vlenb.U, vlenb.U, vstartRemainBytes_reg)
Expand Down
Loading