Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enum TxfmSize: Make a real enum #1240

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/cdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use crate::src::levels::BlockPartition;
use crate::src::levels::BlockSize;
use crate::src::levels::MVJoint;
use crate::src::levels::SegmentId;
use crate::src::levels::TxfmSize;
use crate::src::levels::N_COMP_INTER_PRED_MODES;
use crate::src::levels::N_INTRA_PRED_MODES;
use crate::src::levels::N_TX_SIZES;
use crate::src::levels::N_UV_INTRA_PRED_MODES;
use crate::src::tables::dav1d_partition_type_count;
use parking_lot::RwLock;
Expand Down Expand Up @@ -4999,7 +4999,7 @@ pub(crate) fn rav1d_cdf_thread_update(
);
}
update_cdf_2d!(8, 6, m.angle_delta);
for k in 0..N_TX_SIZES - 1 {
for k in 0..TxfmSize::NUM_SQUARE - 1 {
update_cdf_2d!(3, cmp::min(k + 1, 2), m.txsz[k]);
}
update_cdf_3d!(2, N_INTRA_PRED_MODES, 6, m.txtp_intra1);
Expand All @@ -5008,17 +5008,17 @@ pub(crate) fn rav1d_cdf_thread_update(
for k in 0..BlockLevel::COUNT {
update_cdf_2d!(4, dav1d_partition_type_count[k] as usize, m.partition[k]);
}
update_bit_2d!(N_TX_SIZES, 13, coef.skip);
update_bit_2d!(TxfmSize::NUM_SQUARE, 13, coef.skip);
update_cdf_3d!(2, 2, 4, coef.eob_bin_16);
update_cdf_3d!(2, 2, 5, coef.eob_bin_32);
update_cdf_3d!(2, 2, 6, coef.eob_bin_64);
update_cdf_3d!(2, 2, 7, coef.eob_bin_128);
update_cdf_3d!(2, 2, 8, coef.eob_bin_256);
update_cdf_2d!(2, 9, coef.eob_bin_512);
update_cdf_2d!(2, 10, coef.eob_bin_1024);
update_bit_3d!(N_TX_SIZES, 2, 11 /*22*/, coef.eob_hi_bit);
update_cdf_4d!(N_TX_SIZES, 2, 4, 2, coef.eob_base_tok);
update_cdf_4d!(N_TX_SIZES, 2, 41 /*42*/, 3, coef.base_tok);
update_bit_3d!(TxfmSize::NUM_SQUARE, 2, 11 /*22*/, coef.eob_hi_bit);
update_cdf_4d!(TxfmSize::NUM_SQUARE, 2, 4, 2, coef.eob_base_tok);
update_cdf_4d!(TxfmSize::NUM_SQUARE, 2, 41 /*42*/, 3, coef.base_tok);
update_bit_2d!(2, 3, coef.dc_sign);
update_cdf_4d!(4, 2, 21, 3, coef.br_tok);
update_cdf_2d!(3, SegmentId::COUNT - 1, m.seg_id);
Expand Down
59 changes: 32 additions & 27 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ use crate::src::levels::InterIntraPredMode;
use crate::src::levels::InterIntraType;
use crate::src::levels::MVJoint;
use crate::src::levels::MotionMode;
use crate::src::levels::RectTxfmSize;
use crate::src::levels::SegmentId;
use crate::src::levels::TxfmSize;
use crate::src::levels::CFL_PRED;
Expand All @@ -110,11 +109,7 @@ use crate::src::levels::NEWMV;
use crate::src::levels::NEWMV_NEWMV;
use crate::src::levels::N_COMP_INTER_PRED_MODES;
use crate::src::levels::N_INTRA_PRED_MODES;
use crate::src::levels::N_RECT_TX_SIZES;
use crate::src::levels::N_UV_INTRA_PRED_MODES;
use crate::src::levels::TX_4X4;
use crate::src::levels::TX_64X64;
use crate::src::levels::TX_8X8;
use crate::src::levels::VERT_LEFT_PRED;
use crate::src::levels::VERT_PRED;
use crate::src::lf_mask::rav1d_calc_eih;
Expand Down Expand Up @@ -285,7 +280,7 @@ fn read_tx_tree(
t: &mut Rav1dTaskContext,
f: &Rav1dFrameData,
ts_c: &mut Rav1dTileStateContext,
from: RectTxfmSize,
from: TxfmSize,
depth: c_int,
masks: &mut [u16; 2],
x_off: usize,
Expand All @@ -298,10 +293,10 @@ fn read_tx_tree(
let txh = t_dim.lh;
let is_split;

if depth < 2 && from > TX_4X4 {
let cat = 2 * (TX_64X64 as c_int - t_dim.max as c_int) - depth;
let a = (*f.a[t.a].tx.index(bx4 as usize) < txw) as c_int;
let l = (*t.l.tx.index(by4 as usize) < txh) as c_int;
if depth < 2 && from > TxfmSize::S4x4 {
let cat = 2 * (TxfmSize::S64x64 as c_int - t_dim.max as c_int) - depth;
let a = ((*f.a[t.a].tx.index(bx4 as usize) as u8) < txw) as c_int;
let l = ((*t.l.tx.index(by4 as usize) as u8) < txh) as c_int;

is_split = rav1d_msac_decode_bool_adapt(
&mut ts_c.msac,
Expand All @@ -313,9 +308,9 @@ fn read_tx_tree(
} else {
is_split = false;
}
if is_split && t_dim.max as TxfmSize > TX_8X8 {
let sub = t_dim.sub as RectTxfmSize;
let sub_t_dim = &dav1d_txfm_dimensions[usize::from(sub)]; // `from` used instead of `into` for rust-analyzer type inference
if is_split && t_dim.max > TxfmSize::S8x8 as _ {
let sub = t_dim.sub;
let sub_t_dim = &dav1d_txfm_dimensions[sub as usize];
let txsw = sub_t_dim.w as c_int;
let txsh = sub_t_dim.h as c_int;

Expand Down Expand Up @@ -377,7 +372,13 @@ fn read_tx_tree(
[t_dim.h as usize, t_dim.w as usize],
[by4 as usize, bx4 as usize],
|case, (dir, val)| {
case.set_disjoint(&dir.tx, if is_split { TX_4X4 } else { val });
let tx = if is_split {
TxfmSize::S4x4
} else {
// TODO check unwrap is optimized out
TxfmSize::from_repr(val as _).unwrap()
};
case.set_disjoint(&dir.tx, tx);
},
);
};
Expand Down Expand Up @@ -799,16 +800,17 @@ fn read_vartx_tree(
let frame_hdr = &***f.frame_hdr.as_ref().unwrap();
let txfm_mode = frame_hdr.txfm_mode;
let uvtx;
if b.skip == 0 && (frame_hdr.segmentation.lossless[b.seg_id.get()] || max_ytx == TX_4X4) {
uvtx = TX_4X4;
if b.skip == 0 && (frame_hdr.segmentation.lossless[b.seg_id.get()] || max_ytx == TxfmSize::S4x4)
{
uvtx = TxfmSize::S4x4;
max_ytx = uvtx;
if txfm_mode == Rav1dTxfmMode::Switchable {
CaseSet::<32, false>::many(
[&t.l, &f.a[t.a]],
[bh4 as usize, bw4 as usize],
[by4 as usize, bx4 as usize],
|case, dir| {
case.set_disjoint(&dir.tx, TX_4X4);
case.set_disjoint(&dir.tx, TxfmSize::S4x4);
},
);
}
Expand All @@ -819,13 +821,15 @@ fn read_vartx_tree(
[bh4 as usize, bw4 as usize],
[by4 as usize, bx4 as usize],
|case, (dir, dir_index)| {
case.set_disjoint(&dir.tx, b_dim[2 + dir_index]);
// TODO check unwrap is optimized out
let tx = TxfmSize::from_repr(b_dim[2 + dir_index] as _).unwrap();
case.set_disjoint(&dir.tx, tx);
},
);
}
uvtx = dav1d_max_txfm_size_for_bs[bs as usize][f.cur.p.layout as usize];
} else {
assert!(bw4 <= 16 || bh4 <= 16 || max_ytx == TX_64X64);
assert!(bw4 <= 16 || bh4 <= 16 || max_ytx == TxfmSize::S64x64);
let ytx = &dav1d_txfm_dimensions[max_ytx as usize];
let h = ytx.h as usize;
let w = ytx.w as usize;
Expand Down Expand Up @@ -1871,13 +1875,13 @@ fn decode_b(
let frame_hdr = f.frame_hdr();

let tx = if frame_hdr.segmentation.lossless[b.seg_id.get()] {
b.uvtx = TX_4X4;
b.uvtx = TxfmSize::S4x4;
b.uvtx
} else {
let mut tx = dav1d_max_txfm_size_for_bs[bs as usize][0];
b.uvtx = dav1d_max_txfm_size_for_bs[bs as usize][f.cur.p.layout as usize];
let mut t_dim = &dav1d_txfm_dimensions[tx as usize];
if frame_hdr.txfm_mode == Rav1dTxfmMode::Switchable && t_dim.max > TX_4X4 as u8 {
if frame_hdr.txfm_mode == Rav1dTxfmMode::Switchable && t_dim.max > TxfmSize::S4x4 as _ {
let tctx = get_tx_ctx(&f.a[t.a], &t.l, t_dim, by4, bx4);
let tx_cdf = &mut ts_c.cdf.m.txsz[(t_dim.max - 1) as usize][tctx as usize];
let depth =
Expand All @@ -1890,7 +1894,7 @@ fn decode_b(
}
}
if debug_block_info!(f, t.b) {
println!("Post-tx[{}]: r={}", tx, ts_c.msac.rng);
println!("Post-tx[{:?}]: r={}", tx, ts_c.msac.rng);
}
tx
};
Expand Down Expand Up @@ -1962,7 +1966,8 @@ fn decode_b(
[by4 as usize, bx4 as usize],
|case, (dir, lw_lh, dir_index)| {
case.set_disjoint(&dir.tx_intra, lw_lh as i8);
case.set_disjoint(&dir.tx, lw_lh);
// TODO check unwrap is optimized out
case.set_disjoint(&dir.tx, TxfmSize::from_repr(lw_lh as _).unwrap());
case.set_disjoint(&dir.mode, y_mode_nofilt);
case.set_disjoint(&dir.pal_sz, pal_sz[0]);
case.set_disjoint(&dir.seg_pred, seg_pred.into());
Expand Down Expand Up @@ -3064,8 +3069,8 @@ fn decode_b(
let mut ytx = max_ytx;
let mut uvtx = b.uvtx;
if frame_hdr.segmentation.lossless[b.seg_id.get()] {
ytx = TX_4X4;
uvtx = TX_4X4;
ytx = TxfmSize::S4x4;
uvtx = TxfmSize::S4x4;
}
let lflvl = match ts.lflvl.get() {
TileStateRef::Frame => &f.lf.lvl,
Expand Down Expand Up @@ -3790,7 +3795,7 @@ fn reset_context(ctx: &mut BlockContext, keyframe: bool, pass: c_int) {
ctx.tx_lpf_y.get_mut().0.fill(2);
ctx.tx_lpf_uv.get_mut().0.fill(1);
ctx.tx_intra.get_mut().0.fill(-1);
ctx.tx.get_mut().0.fill(TX_64X64);
ctx.tx.get_mut().0.fill(TxfmSize::S64x64);
if !keyframe {
for r#ref in &mut ctx.r#ref {
r#ref.get_mut().0.fill(-1);
Expand Down Expand Up @@ -4531,7 +4536,7 @@ pub(crate) fn rav1d_decode_frame_init(c: &Rav1dContext, fc: &Rav1dFrameContext)
// setup dequant tables
init_quant_tables(&seq_hdr, &frame_hdr, frame_hdr.quant.yac, &f.dq);
if frame_hdr.quant.qm != 0 {
for i in 0..N_RECT_TX_SIZES {
for i in 0..TxfmSize::COUNT {
f.qm[i][0] = dav1d_qm_tbl[frame_hdr.quant.qm_y as usize][0][i];
f.qm[i][1] = dav1d_qm_tbl[frame_hdr.quant.qm_u as usize][1][i];
f.qm[i][2] = dav1d_qm_tbl[frame_hdr.quant.qm_v as usize][1][i];
Expand Down
13 changes: 6 additions & 7 deletions src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ use crate::src::levels::DCT_DCT;
use crate::src::levels::H_ADST;
use crate::src::levels::H_FLIPADST;
use crate::src::levels::IDTX;
use crate::src::levels::TX_16X16;
use crate::src::levels::TX_32X32;
use crate::src::levels::V_ADST;
use crate::src::levels::V_FLIPADST;
use crate::src::refmvs::refmvs_candidate;
Expand All @@ -46,7 +44,7 @@ pub struct BlockContext {
pub filter: [DisjointMut<Align8<[Rav1dFilterMode; 32]>>; 2],

pub tx_intra: DisjointMut<Align8<[i8; 32]>>,
pub tx: DisjointMut<Align8<[u8; 32]>>,
pub tx: DisjointMut<Align8<[TxfmSize; 32]>>,
pub tx_lpf_y: DisjointMut<Align8<[u8; 32]>>,
pub tx_lpf_uv: DisjointMut<Align8<[u8; 32]>>,
pub partition: DisjointMut<Align8<[u8; 16]>>,
Expand Down Expand Up @@ -141,17 +139,18 @@ pub fn gather_top_partition_prob(r#in: &[u16; 16], bl: BlockLevel) -> u32 {

#[inline]
pub fn get_uv_inter_txtp(uvt_dim: &TxfmInfo, ytxtp: TxfmType) -> TxfmType {
if (*uvt_dim).max as TxfmSize == TX_32X32 {
if uvt_dim.max == TxfmSize::S32x32 as _ {
return if ytxtp == IDTX { IDTX } else { DCT_DCT };
}
if (*uvt_dim).min as TxfmSize == TX_16X16
&& ((1 << ytxtp) & ((1 << H_FLIPADST) | (1 << V_FLIPADST) | (1 << H_ADST) | (1 << V_ADST)))
if uvt_dim.min == TxfmSize::S16x16 as _
&& ((1 << ytxtp as u8)
& ((1 << H_FLIPADST) | (1 << V_FLIPADST) | (1 << H_ADST) | (1 << V_ADST)))
!= 0
{
return DCT_DCT;
}

return ytxtp;
ytxtp
}

#[inline]
Expand Down
Loading