Skip to content

Commit e37bd30

Browse files
authored
enum TxfmSize: Make a real enum (#1240)
* Part of #1180. This makes `TxfmSize` and `RectTxfmSize` into a real `enum`, which eliminates many bounds checks in hot `fn`s like `fn decode_ceofs`. It also lets us further simplify the `itx` macros and eliminate the distinction between the `R`ect ones. Note that `TxfmSize` defined the square sizes; these ones are named with an `S` prefix, while the `R`-prefixed `RectTxfmSize` stay `R`-prefixed. There are a few new places where we have to create `TxfmSize`s using `from_repr` and thus do a bounds check. These should be pretty cheap, but there are optimizable out, as they ultimately come from `static`/`const` data whos bounds are known but LLVM is not able to deduce them. This is why I added 7f0cdc2, starting to try to fix this, though it's a little complicated, so I didn't want to finish it in this PR.
2 parents b3451b8 + 47326c0 commit e37bd30

11 files changed

+618
-774
lines changed

src/cdf.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use crate::src::levels::BlockPartition;
1212
use crate::src::levels::BlockSize;
1313
use crate::src::levels::MVJoint;
1414
use crate::src::levels::SegmentId;
15+
use crate::src::levels::TxfmSize;
1516
use crate::src::levels::N_COMP_INTER_PRED_MODES;
1617
use crate::src::levels::N_INTRA_PRED_MODES;
17-
use crate::src::levels::N_TX_SIZES;
1818
use crate::src::levels::N_UV_INTRA_PRED_MODES;
1919
use crate::src::tables::dav1d_partition_type_count;
2020
use parking_lot::RwLock;
@@ -4999,7 +4999,7 @@ pub(crate) fn rav1d_cdf_thread_update(
49994999
);
50005000
}
50015001
update_cdf_2d!(8, 6, m.angle_delta);
5002-
for k in 0..N_TX_SIZES - 1 {
5002+
for k in 0..TxfmSize::NUM_SQUARE - 1 {
50035003
update_cdf_2d!(3, cmp::min(k + 1, 2), m.txsz[k]);
50045004
}
50055005
update_cdf_3d!(2, N_INTRA_PRED_MODES, 6, m.txtp_intra1);
@@ -5008,17 +5008,17 @@ pub(crate) fn rav1d_cdf_thread_update(
50085008
for k in 0..BlockLevel::COUNT {
50095009
update_cdf_2d!(4, dav1d_partition_type_count[k] as usize, m.partition[k]);
50105010
}
5011-
update_bit_2d!(N_TX_SIZES, 13, coef.skip);
5011+
update_bit_2d!(TxfmSize::NUM_SQUARE, 13, coef.skip);
50125012
update_cdf_3d!(2, 2, 4, coef.eob_bin_16);
50135013
update_cdf_3d!(2, 2, 5, coef.eob_bin_32);
50145014
update_cdf_3d!(2, 2, 6, coef.eob_bin_64);
50155015
update_cdf_3d!(2, 2, 7, coef.eob_bin_128);
50165016
update_cdf_3d!(2, 2, 8, coef.eob_bin_256);
50175017
update_cdf_2d!(2, 9, coef.eob_bin_512);
50185018
update_cdf_2d!(2, 10, coef.eob_bin_1024);
5019-
update_bit_3d!(N_TX_SIZES, 2, 11 /*22*/, coef.eob_hi_bit);
5020-
update_cdf_4d!(N_TX_SIZES, 2, 4, 2, coef.eob_base_tok);
5021-
update_cdf_4d!(N_TX_SIZES, 2, 41 /*42*/, 3, coef.base_tok);
5019+
update_bit_3d!(TxfmSize::NUM_SQUARE, 2, 11 /*22*/, coef.eob_hi_bit);
5020+
update_cdf_4d!(TxfmSize::NUM_SQUARE, 2, 4, 2, coef.eob_base_tok);
5021+
update_cdf_4d!(TxfmSize::NUM_SQUARE, 2, 41 /*42*/, 3, coef.base_tok);
50225022
update_bit_2d!(2, 3, coef.dc_sign);
50235023
update_cdf_4d!(4, 2, 21, 3, coef.br_tok);
50245024
update_cdf_2d!(3, SegmentId::COUNT - 1, m.seg_id);

src/decode.rs

+38-34
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ use crate::src::levels::InterIntraPredMode;
9595
use crate::src::levels::InterIntraType;
9696
use crate::src::levels::MVJoint;
9797
use crate::src::levels::MotionMode;
98-
use crate::src::levels::RectTxfmSize;
9998
use crate::src::levels::SegmentId;
10099
use crate::src::levels::TxfmSize;
101100
use crate::src::levels::CFL_PRED;
@@ -110,11 +109,7 @@ use crate::src::levels::NEWMV;
110109
use crate::src::levels::NEWMV_NEWMV;
111110
use crate::src::levels::N_COMP_INTER_PRED_MODES;
112111
use crate::src::levels::N_INTRA_PRED_MODES;
113-
use crate::src::levels::N_RECT_TX_SIZES;
114112
use crate::src::levels::N_UV_INTRA_PRED_MODES;
115-
use crate::src::levels::TX_4X4;
116-
use crate::src::levels::TX_64X64;
117-
use crate::src::levels::TX_8X8;
118113
use crate::src::levels::VERT_LEFT_PRED;
119114
use crate::src::levels::VERT_PRED;
120115
use crate::src::lf_mask::rav1d_calc_eih;
@@ -149,7 +144,6 @@ use crate::src::refmvs::RefMvsFrame;
149144
use crate::src::relaxed_atomic::RelaxedAtomic;
150145
use crate::src::tables::cfl_allowed_mask;
151146
use crate::src::tables::dav1d_al_part_ctx;
152-
use crate::src::tables::dav1d_block_dimensions;
153147
use crate::src::tables::dav1d_block_sizes;
154148
use crate::src::tables::dav1d_comp_inter_pred_modes;
155149
use crate::src::tables::dav1d_filter_2d;
@@ -285,7 +279,7 @@ fn read_tx_tree(
285279
t: &mut Rav1dTaskContext,
286280
f: &Rav1dFrameData,
287281
ts_c: &mut Rav1dTileStateContext,
288-
from: RectTxfmSize,
282+
from: TxfmSize,
289283
depth: c_int,
290284
masks: &mut [u16; 2],
291285
x_off: usize,
@@ -298,10 +292,10 @@ fn read_tx_tree(
298292
let txh = t_dim.lh;
299293
let is_split;
300294

301-
if depth < 2 && from > TX_4X4 {
302-
let cat = 2 * (TX_64X64 as c_int - t_dim.max as c_int) - depth;
303-
let a = (*f.a[t.a].tx.index(bx4 as usize) < txw) as c_int;
304-
let l = (*t.l.tx.index(by4 as usize) < txh) as c_int;
295+
if depth < 2 && from > TxfmSize::S4x4 {
296+
let cat = 2 * (TxfmSize::S64x64 as c_int - t_dim.max as c_int) - depth;
297+
let a = ((*f.a[t.a].tx.index(bx4 as usize) as u8) < txw) as c_int;
298+
let l = ((*t.l.tx.index(by4 as usize) as u8) < txh) as c_int;
305299

306300
is_split = rav1d_msac_decode_bool_adapt(
307301
&mut ts_c.msac,
@@ -313,9 +307,9 @@ fn read_tx_tree(
313307
} else {
314308
is_split = false;
315309
}
316-
if is_split && t_dim.max as TxfmSize > TX_8X8 {
317-
let sub = t_dim.sub as RectTxfmSize;
318-
let sub_t_dim = &dav1d_txfm_dimensions[usize::from(sub)]; // `from` used instead of `into` for rust-analyzer type inference
310+
if is_split && t_dim.max > TxfmSize::S8x8 as _ {
311+
let sub = t_dim.sub;
312+
let sub_t_dim = &dav1d_txfm_dimensions[sub as usize];
319313
let txsw = sub_t_dim.w as c_int;
320314
let txsh = sub_t_dim.h as c_int;
321315

@@ -377,7 +371,13 @@ fn read_tx_tree(
377371
[t_dim.h as usize, t_dim.w as usize],
378372
[by4 as usize, bx4 as usize],
379373
|case, (dir, val)| {
380-
case.set_disjoint(&dir.tx, if is_split { TX_4X4 } else { val });
374+
let tx = if is_split {
375+
TxfmSize::S4x4
376+
} else {
377+
// TODO check unwrap is optimized out
378+
TxfmSize::from_repr(val as _).unwrap()
379+
};
380+
case.set_disjoint(&dir.tx, tx);
381381
},
382382
);
383383
};
@@ -436,7 +436,7 @@ fn find_matching_ref(
436436
&& t.b.x + bw4 < ts.tiling.col_end
437437
&& intra_edge_flags.contains(EdgeFlags::I444_TOP_HAS_RIGHT);
438438

439-
let bs = |rp: refmvs_block| dav1d_block_dimensions[rp.bs as usize];
439+
let bs = |rp: refmvs_block| rp.bs.dimensions();
440440
let matches = |rp: refmvs_block| rp.r#ref.r#ref[0] == r#ref + 1 && rp.r#ref.r#ref[1] == -1;
441441

442442
if have_top {
@@ -541,7 +541,7 @@ fn derive_warpmv(
541541
*r.index(t.rt.r[(offset as isize + i as isize) as usize] + j as usize)
542542
};
543543

544-
let bs = |rp: refmvs_block| dav1d_block_dimensions[rp.bs as usize];
544+
let bs = |rp: refmvs_block| rp.bs.dimensions();
545545

546546
let mut add_sample = |np: usize, dx: i32, dy: i32, sx: i32, sy: i32, rp: refmvs_block| {
547547
pts[np][0][0] = 16 * (2 * dx + sx * bs(rp)[0] as i32) - 8;
@@ -789,7 +789,7 @@ fn read_vartx_tree(
789789
bx4: c_int,
790790
by4: c_int,
791791
) -> VarTx {
792-
let b_dim = &dav1d_block_dimensions[bs as usize];
792+
let b_dim = bs.dimensions();
793793
let bw4 = b_dim[0] as usize;
794794
let bh4 = b_dim[1] as usize;
795795

@@ -799,16 +799,17 @@ fn read_vartx_tree(
799799
let frame_hdr = &***f.frame_hdr.as_ref().unwrap();
800800
let txfm_mode = frame_hdr.txfm_mode;
801801
let uvtx;
802-
if b.skip == 0 && (frame_hdr.segmentation.lossless[b.seg_id.get()] || max_ytx == TX_4X4) {
803-
uvtx = TX_4X4;
802+
if b.skip == 0 && (frame_hdr.segmentation.lossless[b.seg_id.get()] || max_ytx == TxfmSize::S4x4)
803+
{
804+
uvtx = TxfmSize::S4x4;
804805
max_ytx = uvtx;
805806
if txfm_mode == Rav1dTxfmMode::Switchable {
806807
CaseSet::<32, false>::many(
807808
[&t.l, &f.a[t.a]],
808809
[bh4 as usize, bw4 as usize],
809810
[by4 as usize, bx4 as usize],
810811
|case, dir| {
811-
case.set_disjoint(&dir.tx, TX_4X4);
812+
case.set_disjoint(&dir.tx, TxfmSize::S4x4);
812813
},
813814
);
814815
}
@@ -819,13 +820,15 @@ fn read_vartx_tree(
819820
[bh4 as usize, bw4 as usize],
820821
[by4 as usize, bx4 as usize],
821822
|case, (dir, dir_index)| {
822-
case.set_disjoint(&dir.tx, b_dim[2 + dir_index]);
823+
// TODO check unwrap is optimized out
824+
let tx = TxfmSize::from_repr(b_dim[2 + dir_index] as _).unwrap();
825+
case.set_disjoint(&dir.tx, tx);
823826
},
824827
);
825828
}
826829
uvtx = dav1d_max_txfm_size_for_bs[bs as usize][f.cur.p.layout as usize];
827830
} else {
828-
assert!(bw4 <= 16 || bh4 <= 16 || max_ytx == TX_64X64);
831+
assert!(bw4 <= 16 || bh4 <= 16 || max_ytx == TxfmSize::S64x64);
829832
let ytx = &dav1d_txfm_dimensions[max_ytx as usize];
830833
let h = ytx.h as usize;
831834
let w = ytx.w as usize;
@@ -1092,7 +1095,7 @@ fn obmc_lowest_px(
10921095
let mut x = 0;
10931096
while x < w4 && i < cmp::min(b_dim[2] as c_int, 4) {
10941097
let a_r = *r.index(ri[0] + t.b.x as usize + x as usize + 1);
1095-
let a_b_dim = &dav1d_block_dimensions[a_r.bs as usize];
1098+
let a_b_dim = a_r.bs.dimensions();
10961099
if a_r.r#ref.r#ref[0] as c_int > 0 {
10971100
let oh4 = cmp::min(b_dim[1] as c_int, 16) >> 1;
10981101
mc_lowest_px(
@@ -1113,7 +1116,7 @@ fn obmc_lowest_px(
11131116
let mut y = 0;
11141117
while y < h4 && i < cmp::min(b_dim[3] as c_int, 4) {
11151118
let l_r = *r.index(ri[y as usize + 1 + 1] + t.b.x as usize - 1);
1116-
let l_b_dim = &dav1d_block_dimensions[l_r.bs as usize];
1119+
let l_b_dim = l_r.bs.dimensions();
11171120
if l_r.r#ref.r#ref[0] as c_int > 0 {
11181121
let oh4 = iclip(l_b_dim[1] as c_int, 2, b_dim[1] as c_int);
11191122
mc_lowest_px(
@@ -1164,7 +1167,7 @@ fn decode_b(
11641167

11651168
let ts = &f.ts[t.ts];
11661169
let bd_fn = f.bd_fn();
1167-
let b_dim = &dav1d_block_dimensions[bs as usize];
1170+
let b_dim = bs.dimensions();
11681171
let bx4 = t.b.x & 31;
11691172
let by4 = t.b.y & 31;
11701173
let ss_ver = (f.cur.p.layout == Rav1dPixelLayout::I420) as c_int;
@@ -1871,13 +1874,13 @@ fn decode_b(
18711874
let frame_hdr = f.frame_hdr();
18721875

18731876
let tx = if frame_hdr.segmentation.lossless[b.seg_id.get()] {
1874-
b.uvtx = TX_4X4;
1877+
b.uvtx = TxfmSize::S4x4;
18751878
b.uvtx
18761879
} else {
18771880
let mut tx = dav1d_max_txfm_size_for_bs[bs as usize][0];
18781881
b.uvtx = dav1d_max_txfm_size_for_bs[bs as usize][f.cur.p.layout as usize];
18791882
let mut t_dim = &dav1d_txfm_dimensions[tx as usize];
1880-
if frame_hdr.txfm_mode == Rav1dTxfmMode::Switchable && t_dim.max > TX_4X4 as u8 {
1883+
if frame_hdr.txfm_mode == Rav1dTxfmMode::Switchable && t_dim.max > TxfmSize::S4x4 as _ {
18811884
let tctx = get_tx_ctx(&f.a[t.a], &t.l, t_dim, by4, bx4);
18821885
let tx_cdf = &mut ts_c.cdf.m.txsz[(t_dim.max - 1) as usize][tctx as usize];
18831886
let depth =
@@ -1890,7 +1893,7 @@ fn decode_b(
18901893
}
18911894
}
18921895
if debug_block_info!(f, t.b) {
1893-
println!("Post-tx[{}]: r={}", tx, ts_c.msac.rng);
1896+
println!("Post-tx[{:?}]: r={}", tx, ts_c.msac.rng);
18941897
}
18951898
tx
18961899
};
@@ -1962,7 +1965,8 @@ fn decode_b(
19621965
[by4 as usize, bx4 as usize],
19631966
|case, (dir, lw_lh, dir_index)| {
19641967
case.set_disjoint(&dir.tx_intra, lw_lh as i8);
1965-
case.set_disjoint(&dir.tx, lw_lh);
1968+
// TODO check unwrap is optimized out
1969+
case.set_disjoint(&dir.tx, TxfmSize::from_repr(lw_lh as _).unwrap());
19661970
case.set_disjoint(&dir.mode, y_mode_nofilt);
19671971
case.set_disjoint(&dir.pal_sz, pal_sz[0]);
19681972
case.set_disjoint(&dir.seg_pred, seg_pred.into());
@@ -3064,8 +3068,8 @@ fn decode_b(
30643068
let mut ytx = max_ytx;
30653069
let mut uvtx = b.uvtx;
30663070
if frame_hdr.segmentation.lossless[b.seg_id.get()] {
3067-
ytx = TX_4X4;
3068-
uvtx = TX_4X4;
3071+
ytx = TxfmSize::S4x4;
3072+
uvtx = TxfmSize::S4x4;
30693073
}
30703074
let lflvl = match ts.lflvl.get() {
30713075
TileStateRef::Frame => &f.lf.lvl,
@@ -3790,7 +3794,7 @@ fn reset_context(ctx: &mut BlockContext, keyframe: bool, pass: c_int) {
37903794
ctx.tx_lpf_y.get_mut().0.fill(2);
37913795
ctx.tx_lpf_uv.get_mut().0.fill(1);
37923796
ctx.tx_intra.get_mut().0.fill(-1);
3793-
ctx.tx.get_mut().0.fill(TX_64X64);
3797+
ctx.tx.get_mut().0.fill(TxfmSize::S64x64);
37943798
if !keyframe {
37953799
for r#ref in &mut ctx.r#ref {
37963800
r#ref.get_mut().0.fill(-1);
@@ -4531,7 +4535,7 @@ pub(crate) fn rav1d_decode_frame_init(c: &Rav1dContext, fc: &Rav1dFrameContext)
45314535
// setup dequant tables
45324536
init_quant_tables(&seq_hdr, &frame_hdr, frame_hdr.quant.yac, &f.dq);
45334537
if frame_hdr.quant.qm != 0 {
4534-
for i in 0..N_RECT_TX_SIZES {
4538+
for i in 0..TxfmSize::COUNT {
45354539
f.qm[i][0] = dav1d_qm_tbl[frame_hdr.quant.qm_y as usize][0][i];
45364540
f.qm[i][1] = dav1d_qm_tbl[frame_hdr.quant.qm_u as usize][1][i];
45374541
f.qm[i][2] = dav1d_qm_tbl[frame_hdr.quant.qm_v as usize][1][i];

src/env.rs

+6-7
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ use crate::src::levels::DCT_DCT;
1818
use crate::src::levels::H_ADST;
1919
use crate::src::levels::H_FLIPADST;
2020
use crate::src::levels::IDTX;
21-
use crate::src::levels::TX_16X16;
22-
use crate::src::levels::TX_32X32;
2321
use crate::src::levels::V_ADST;
2422
use crate::src::levels::V_FLIPADST;
2523
use crate::src::refmvs::refmvs_candidate;
@@ -46,7 +44,7 @@ pub struct BlockContext {
4644
pub filter: [DisjointMut<Align8<[Rav1dFilterMode; 32]>>; 2],
4745

4846
pub tx_intra: DisjointMut<Align8<[i8; 32]>>,
49-
pub tx: DisjointMut<Align8<[u8; 32]>>,
47+
pub tx: DisjointMut<Align8<[TxfmSize; 32]>>,
5048
pub tx_lpf_y: DisjointMut<Align8<[u8; 32]>>,
5149
pub tx_lpf_uv: DisjointMut<Align8<[u8; 32]>>,
5250
pub partition: DisjointMut<Align8<[u8; 16]>>,
@@ -141,17 +139,18 @@ pub fn gather_top_partition_prob(r#in: &[u16; 16], bl: BlockLevel) -> u32 {
141139

142140
#[inline]
143141
pub fn get_uv_inter_txtp(uvt_dim: &TxfmInfo, ytxtp: TxfmType) -> TxfmType {
144-
if (*uvt_dim).max as TxfmSize == TX_32X32 {
142+
if uvt_dim.max == TxfmSize::S32x32 as _ {
145143
return if ytxtp == IDTX { IDTX } else { DCT_DCT };
146144
}
147-
if (*uvt_dim).min as TxfmSize == TX_16X16
148-
&& ((1 << ytxtp) & ((1 << H_FLIPADST) | (1 << V_FLIPADST) | (1 << H_ADST) | (1 << V_ADST)))
145+
if uvt_dim.min == TxfmSize::S16x16 as _
146+
&& ((1 << ytxtp as u8)
147+
& ((1 << H_FLIPADST) | (1 << V_FLIPADST) | (1 << H_ADST) | (1 << V_ADST)))
149148
!= 0
150149
{
151150
return DCT_DCT;
152151
}
153152

154-
return ytxtp;
153+
ytxtp
155154
}
156155

157156
#[inline]

0 commit comments

Comments
 (0)