Skip to content

Commit

Permalink
vfcvt: fix timing
Browse files Browse the repository at this point in the history
  • Loading branch information
sinceforYy committed Mar 14, 2024
1 parent cc7d3c4 commit b2a8aec
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 142 deletions.
168 changes: 84 additions & 84 deletions src/main/scala/yunsuan/vector/VectorConvert/CVT32.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import yunsuan.util._


class CVT32(width: Int = 32) extends CVT(width){
/** cycle0 | cycle1 | cycle2
* fp2int in(32) raw_in(33) left,right ShiftRightJam(25) | RoundingUnit(11) adder |
* int2fp in(32) in_abs(32) lzc in_shift exp_raw | RoundingUnit(10) adder | -> result & fflags
* vfr in(32) lzc exp_nor sig_nor clz out_exp adder | Table |
/** cycle0 | cycle1 | cycle2
* fp2int in(32) raw_in(33) left,right | ShiftRightJam(25) RoundingUnit(11) adder |
* int2fp in(32) in_abs(32) lzc | in_shift exp_raw RoundingUnit(10) adder | -> result & fflags
* vfr in(32) lzc | exp_nor sig_nor clz out_exp adder Table |
* fp2fp
*/
// control path
Expand Down Expand Up @@ -55,14 +55,10 @@ class CVT32(width: Int = 32) extends CVT(width){
val fflags0 = WireInit(Cat(NV, DZ, OF, UF, NX))
val fflags0_reg1 = RegEnable(fflags0, fireReg)

val round_in = Wire(UInt(24.W))
val round_roundIn = Wire(Bool())
val round_stickyIn = Wire(Bool())
val round_signIn = Wire(Bool())
val round_in_reg0 = RegEnable(round_in, 0.U(24.W), fire)
val round_roundIn_reg0 = RegEnable(round_roundIn, false.B, fire)
val round_stickyIn_reg0 = RegEnable(round_stickyIn, false.B, fire)
val round_signIn_reg0 = RegEnable(round_signIn, false.B, fire)
val round_in_reg0 = Wire(UInt(24.W))
val round_roundIn_reg0 = Wire(Bool())
val round_stickyIn_reg0 = Wire(Bool())
val round_signIn_reg0 = Wire(Bool())
val rm_reg0 = RegEnable(io.rm, fire)

val is_normal = Wire(Bool())
Expand Down Expand Up @@ -333,10 +329,15 @@ class CVT32(width: Int = 32) extends CVT(width){
(VectorFloat.expBias(f16.expWidth) + 8 - 1).U)
) - exp

val rpath_sig_shifted0 = Wire(UInt(25.W))
val rpath_sig_shifted_reg0 = RegEnable(rpath_sig_shifted0, 0.U(25.W), fire)
val (rpath_sig_shifted, rpath_sticky) = ShiftRightJam(Cat(sig, 0.U), rpath_shamt)
rpath_sig_shifted0 := rpath_sig_shifted
// cycle1
val sig_cat0 = Cat(sig, 0.U)
val sig_cat0_reg = RegEnable(sig_cat0, fire)
val rpath_shamt_reg0 = RegEnable(rpath_shamt, fire)
val rpath_sig_shifted_reg0 = Wire(UInt(25.W))
val rpath_sticky_reg0 = Wire(Bool())
val (rpath_sig_shifted, rpath_sticky) = ShiftRightJam(sig_cat0_reg, rpath_shamt_reg0)
rpath_sig_shifted_reg0 := rpath_sig_shifted
rpath_sticky_reg0 := rpath_sticky

// int2fp
val in_abs = Mux1H(
Expand All @@ -345,6 +346,7 @@ class CVT32(width: Int = 32) extends CVT(width){
Seq((~in).asUInt + 1.U,
in)
)
val in_abs_reg0 = RegEnable(in_abs, fire)

val int2fp_clz = Mux1H(
Seq(out_is_fp32 || exp_of,
Expand All @@ -353,82 +355,82 @@ class CVT32(width: Int = 32) extends CVT(width){
Seq(CLZ(in_abs(31, 0)),
CLZ(in_abs(15, 0)))
)
val int2fp_clz_reg0 = RegEnable(int2fp_clz, fire)

val in_shift = Mux1H(
Seq(out_is_fp32 || exp_of,
int32tofp16 && !in_abs.tail(1).head(16).orR || int16tofp16 || int8tofp16
val in_shift_reg0 = Mux1H(
Seq(out_is_fp32_reg0 || exp_of_reg0,
int32tofp16_reg0 && !in_abs_reg0.tail(1).head(16).orR || int16tofp16_reg0 || int8tofp16_reg0
),
Seq((in_abs.tail(1) << int2fp_clz)(30, 0),
Cat((in_abs.tail(1) << int2fp_clz)(14, 0), 0.U(16.W)))
Seq((in_abs_reg0.tail(1) << int2fp_clz_reg0)(30, 0),
Cat((in_abs_reg0.tail(1) << int2fp_clz_reg0)(14, 0), 0.U(16.W)))
)

val exp_raw = Wire(UInt(8.W))
val exp_raw_reg0 = RegEnable(exp_raw, 0.U(8.W), fire)
exp_raw := Mux1H(
Seq(is_int2fp && out_is_fp32,
is_int2fp && out_is_fp16),
Seq(VectorFloat.expBias(f32.expWidth).asUInt +& 31.U - int2fp_clz,
VectorFloat.expBias(f16.expWidth).asUInt +& 15.U - int2fp_clz)
val exp_raw_reg0 = Wire(UInt(8.W))
exp_raw_reg0 := Mux1H(
Seq(is_int2fp_reg0 && out_is_fp32_reg0,
is_int2fp_reg0 && out_is_fp16_reg0),
Seq(VectorFloat.expBias(f32.expWidth).asUInt +& 31.U - int2fp_clz_reg0,
VectorFloat.expBias(f16.expWidth).asUInt +& 15.U - int2fp_clz_reg0)
)

// share RoundingUnit
round_in := Mux1H(
Seq(fp32toint32,
fp32toint16,
fp16toint16 || fp16toint32,
fp16toint8,
int16tofp32,
int32tofp32,
int8tofp16,
int16tofp16 || int32tofp16),
Seq(rpath_sig_shifted.head(f32.precision),
rpath_sig_shifted.head(16),
rpath_sig_shifted.head(f16.precision),
rpath_sig_shifted.head(8),
in_shift.head(16),
in_shift.head(23),
in_shift.head(8),
in_shift.head(10))
)
round_roundIn := Mux1H(
Seq(fp32toint32,
fp32toint16,
fp16toint16 || fp16toint32,
fp16toint8,
int16tofp32,
int32tofp32,
int8tofp16,
int16tofp16 || int32tofp16),
Seq(rpath_sig_shifted.tail(f32.precision).head(1),
rpath_sig_shifted.tail(16).head(1),
rpath_sig_shifted.tail(f16.precision).head(1),
rpath_sig_shifted.tail(8).head(1),
in_shift.tail(16).head(1),
in_shift.tail(23).head(1),
in_shift.tail(8).head(1),
in_shift.tail(10).head(1)
round_in_reg0 := Mux1H(
Seq(fp32toint32_reg0,
fp32toint16_reg0,
fp16toint16_reg0 || fp16toint32_reg0,
fp16toint8_reg0,
int16tofp32_reg0,
int32tofp32_reg0,
int8tofp16_reg0,
int16tofp16_reg0 || int32tofp16_reg0),
Seq(rpath_sig_shifted_reg0.head(f32.precision),
rpath_sig_shifted_reg0.head(16),
rpath_sig_shifted_reg0.head(f16.precision),
rpath_sig_shifted_reg0.head(8),
in_shift_reg0.head(16),
in_shift_reg0.head(23),
in_shift_reg0.head(8),
in_shift_reg0.head(10))
)
round_roundIn_reg0 := Mux1H(
Seq(fp32toint32_reg0,
fp32toint16_reg0,
fp16toint16_reg0 || fp16toint32_reg0,
fp16toint8_reg0,
int16tofp32_reg0,
int32tofp32_reg0,
int8tofp16_reg0,
int16tofp16_reg0 || int32tofp16_reg0),
Seq(rpath_sig_shifted_reg0.tail(f32.precision).head(1),
rpath_sig_shifted_reg0.tail(16).head(1),
rpath_sig_shifted_reg0.tail(f16.precision).head(1),
rpath_sig_shifted_reg0.tail(8).head(1),
in_shift_reg0.tail(16).head(1),
in_shift_reg0.tail(23).head(1),
in_shift_reg0.tail(8).head(1),
in_shift_reg0.tail(10).head(1)
)
)
round_stickyIn := Mux1H(
Seq(fp32toint32,
fp32toint16,
fp16toint16 || fp16toint32,
fp16toint8,
int16tofp32,
int32tofp32,
int8tofp16,
int16tofp16 || int32tofp16),
Seq(rpath_sticky,
rpath_sticky || rpath_sig_shifted.tail(17).orR,
rpath_sticky || rpath_sig_shifted.tail(12).orR,
rpath_sticky || rpath_sig_shifted.tail(9).orR,
in_shift.tail(16).orR,
in_shift.tail(f32.precision).orR,
in_shift.tail(8).orR,
in_shift.tail(f16.precision).orR
round_stickyIn_reg0 := Mux1H(
Seq(fp32toint32_reg0,
fp32toint16_reg0,
fp16toint16_reg0 || fp16toint32_reg0,
fp16toint8_reg0,
int16tofp32_reg0,
int32tofp32_reg0,
int8tofp16_reg0,
int16tofp16_reg0 || int32tofp16_reg0),
Seq(rpath_sticky_reg0,
rpath_sticky_reg0 || rpath_sig_shifted_reg0.tail(17).orR,
rpath_sticky_reg0 || rpath_sig_shifted_reg0.tail(12).orR,
rpath_sticky_reg0 || rpath_sig_shifted_reg0.tail(9).orR,
in_shift_reg0.tail(16).orR,
in_shift_reg0.tail(f32.precision).orR,
in_shift_reg0.tail(8).orR,
in_shift_reg0.tail(f16.precision).orR
)
)
round_signIn := (is_fp2int || is_int2fp) && sign
round_signIn_reg0 := (is_fp2int_reg0 || is_int2fp_reg0) && sign_reg0

val sel_lpath = Wire(Bool())
val sel_lpath_reg0 = RegEnable(sel_lpath, fire)
Expand Down Expand Up @@ -1166,6 +1168,4 @@ class CVT32(width: Int = 32) extends CVT(width){
io.result := result
io.fflags := fflags



}
}
Loading

0 comments on commit b2a8aec

Please sign in to comment.