Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enum InterPredMode and enum CompInterPredMode: make real enums #927

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/cdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use crate::src::internal::Rav1dContext;
use crate::src::levels::BlockLevel;
use crate::src::levels::BlockPartition;
use crate::src::levels::BlockSize;
use crate::src::levels::CompInterPredMode;
use crate::src::levels::MVJoint;
use crate::src::levels::N_COMP_INTER_PRED_MODES;
use crate::src::levels::N_INTRA_PRED_MODES;
use crate::src::levels::N_TX_SIZES;
use crate::src::levels::N_UV_INTRA_PRED_MODES;
Expand Down Expand Up @@ -89,7 +89,7 @@ pub struct CdfModeContext {
pub cfl_sign: Align16<[u16; 8]>,
pub angle_delta: Align16<[[u16; 8]; 8]>,
pub filter_intra: Align16<[u16; 8]>,
pub comp_inter_mode: Align16<[[u16; N_COMP_INTER_PRED_MODES]; 8]>,
pub comp_inter_mode: Align16<[[u16; CompInterPredMode::COUNT]; 8]>,
pub seg_id: Align16<[[u16; RAV1D_MAX_SEGMENTS as usize]; 3]>,
pub pal_sz: Align16<[[[u16; 8]; 7]; 2]>,
pub color_map: Align16<[[[[u16; 8]; 5]; 7]; 2]>,
Expand Down Expand Up @@ -5048,7 +5048,7 @@ pub(crate) fn rav1d_cdf_thread_update(
update_bit_1d!(2, m.globalmv_mode);
update_bit_1d!(6, m.refmv_mode);
update_bit_1d!(3, m.drl_bit);
update_cdf_2d!(8, N_COMP_INTER_PRED_MODES - 1, m.comp_inter_mode);
update_cdf_2d!(8, CompInterPredMode::COUNT - 1, m.comp_inter_mode);
update_bit_1d!(4, m.intra);
update_bit_1d!(5, m.comp);
update_bit_1d!(5, m.comp_dir);
Expand Down
82 changes: 44 additions & 38 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,20 @@ use crate::src::levels::Av1Block;
use crate::src::levels::BlockLevel;
use crate::src::levels::BlockPartition;
use crate::src::levels::BlockSize;
use crate::src::levels::CompInterPredMode;
use crate::src::levels::CompInterType;
use crate::src::levels::DrlProximity;
use crate::src::levels::Filter2d;
use crate::src::levels::InterIntraPredMode;
use crate::src::levels::InterIntraType;
use crate::src::levels::InterPredMode;
use crate::src::levels::MVJoint;
use crate::src::levels::MotionMode;
use crate::src::levels::RectTxfmSize;
use crate::src::levels::TxfmSize;
use crate::src::levels::CFL_PRED;
use crate::src::levels::DC_PRED;
use crate::src::levels::FILTER_PRED;
use crate::src::levels::GLOBALMV;
use crate::src::levels::GLOBALMV_GLOBALMV;
use crate::src::levels::NEARESTMV;
use crate::src::levels::NEARESTMV_NEARESTMV;
use crate::src::levels::NEARMV;
use crate::src::levels::NEWMV;
use crate::src::levels::NEWMV_NEWMV;
use crate::src::levels::N_COMP_INTER_PRED_MODES;
use crate::src::levels::N_INTRA_PRED_MODES;
use crate::src::levels::N_RECT_TX_SIZES;
use crate::src::levels::N_UV_INTRA_PRED_MODES;
Expand Down Expand Up @@ -910,7 +904,8 @@ unsafe fn splat_oneref_mv(
],
},
bs,
mf: (mode == GLOBALMV && cmp::min(bw4, bh4) >= 2) as u8 | (mode == NEWMV) as u8 * 2,
mf: (mode == InterPredMode::Global.into() && cmp::min(bw4, bh4) >= 2) as u8
| (mode == InterPredMode::New.into()) as u8 * 2,
}));
c.refmvs_dsp.splat_mv(
&mut t.rt.r[((t.b.y & 31) + 5) as usize..],
Expand Down Expand Up @@ -964,7 +959,8 @@ unsafe fn splat_tworef_mv(
r#ref: [b.r#ref()[0] + 1, b.r#ref()[1] + 1],
},
bs,
mf: (mode == GLOBALMV_GLOBALMV) as u8 | (1 << mode & 0xbc != 0) as u8 * 2,
mf: (mode == CompInterPredMode::GlobalGlobal) as u8
| (1 << mode as u8 & 0xbc != 0) as u8 * 2,
}));
c.refmvs_dsp.splat_mv(
&mut t.rt.r[((t.b.y & 31) + 5) as usize..],
Expand Down Expand Up @@ -2164,7 +2160,7 @@ unsafe fn decode_b_inner(
frame_hdr.skip_mode.refs[1] as i8,
];
*b.comp_type_mut() = Some(CompInterType::Avg);
*b.inter_mode_mut() = NEARESTMV_NEARESTMV;
*b.inter_mode_mut() = CompInterPredMode::NearestNearest;
*b.drl_idx_mut() = DrlProximity::Nearest;
has_subpel_filter = false;

Expand Down Expand Up @@ -2290,14 +2286,15 @@ unsafe fn decode_b_inner(
frame_hdr,
);

*b.inter_mode_mut() = rav1d_msac_decode_symbol_adapt8(
*b.inter_mode_mut() = CompInterPredMode::from_repr(rav1d_msac_decode_symbol_adapt8(
&mut ts.msac,
&mut ts.cdf.m.comp_inter_mode[ctx as usize],
N_COMP_INTER_PRED_MODES as usize - 1,
) as u8;
CompInterPredMode::COUNT as usize - 1,
) as usize)
.expect("valid variant");
if debug_block_info!(f, t.b) {
println!(
"Post-compintermode[{},ctx={},n_mvs={}]: r={}",
"Post-compintermode[{:?},ctx={},n_mvs={}]: r={}",
b.inter_mode(),
ctx,
n_mvs,
Expand All @@ -2307,7 +2304,7 @@ unsafe fn decode_b_inner(

let im = &dav1d_comp_inter_pred_modes[b.inter_mode() as usize];
*b.drl_idx_mut() = DrlProximity::Nearest;
if b.inter_mode() == NEWMV_NEWMV {
if b.inter_mode() == CompInterPredMode::NewNew {
if n_mvs > 1 {
// `Nearer` or `Near`
let drl_ctx_v1 = get_drl_context(&mvstack, 0);
Expand Down Expand Up @@ -2336,7 +2333,7 @@ unsafe fn decode_b_inner(
);
}
}
} else if im[0] == NEARMV || im[1] == NEARMV {
} else if im[0] == InterPredMode::Near || im[1] == InterPredMode::Near {
*b.drl_idx_mut() = DrlProximity::Nearer;
if n_mvs > 2 {
// `Near` or `Nearish`
Expand Down Expand Up @@ -2368,13 +2365,14 @@ unsafe fn decode_b_inner(
}
}

has_subpel_filter = cmp::min(bw4, bh4) == 1 || b.inter_mode() != GLOBALMV_GLOBALMV;
has_subpel_filter =
cmp::min(bw4, bh4) == 1 || b.inter_mode() != CompInterPredMode::GlobalGlobal;
let mut assign_comp_mv = |idx: usize| match im[idx] {
NEARMV | NEARESTMV => {
InterPredMode::Near | InterPredMode::Nearest => {
b.mv_mut()[idx] = mvstack[b.drl_idx() as usize].mv.mv[idx];
fix_mv_precision(frame_hdr, &mut b.mv_mut()[idx]);
}
GLOBALMV => {
InterPredMode::Global => {
has_subpel_filter |= frame_hdr.gmv[b.r#ref()[idx] as usize].r#type
== Rav1dWarpedMotionType::Translation;
b.mv_mut()[idx] = get_gmv_2d(
Expand All @@ -2386,7 +2384,7 @@ unsafe fn decode_b_inner(
frame_hdr,
);
}
NEWMV => {
InterPredMode::New => {
b.mv_mut()[idx] = mvstack[b.drl_idx() as usize].mv.mv[idx];
read_mv_residual(
t,
Expand All @@ -2396,7 +2394,6 @@ unsafe fn decode_b_inner(
!frame_hdr.force_integer_mv,
);
}
_ => {}
};
assign_comp_mv(0);
assign_comp_mv(1);
Expand Down Expand Up @@ -2593,7 +2590,7 @@ unsafe fn decode_b_inner(
&mut ts.cdf.m.globalmv_mode[(ctx >> 3 & 1) as usize],
)
{
*b.inter_mode_mut() = GLOBALMV;
*b.inter_mode_mut() = InterPredMode::Global.into();
b.mv_mut()[0] = get_gmv_2d(
&frame_hdr.gmv[b.r#ref()[0] as usize],
t.b.x,
Expand All @@ -2612,7 +2609,7 @@ unsafe fn decode_b_inner(
&mut ts.cdf.m.refmv_mode[(ctx >> 4 & 15) as usize],
) {
// `Nearer`, `Near` or `Nearish`
*b.inter_mode_mut() = NEARMV;
*b.inter_mode_mut() = InterPredMode::Near.into();
*b.drl_idx_mut() = DrlProximity::Nearer;
if n_mvs > 2 {
// `Nearer`, `Near` or `Nearish`
Expand All @@ -2636,7 +2633,7 @@ unsafe fn decode_b_inner(
}
}
} else {
*b.inter_mode_mut() = NEARESTMV as u8;
*b.inter_mode_mut() = InterPredMode::Nearest.into();
*b.drl_idx_mut() = DrlProximity::Nearest;
}
b.mv_mut()[0] = mvstack[b.drl_idx() as usize].mv.mv[0];
Expand All @@ -2647,7 +2644,7 @@ unsafe fn decode_b_inner(

if debug_block_info!(f, t.b) {
println!(
"Post-intermode[{},drl={:?},mv=y:{},x:{},n_mvs={}]: r={}",
"Post-intermode[{:?},drl={:?},mv=y:{},x:{},n_mvs={}]: r={}",
b.inter_mode(),
b.drl_idx(),
b.mv()[0].y,
Expand All @@ -2658,7 +2655,7 @@ unsafe fn decode_b_inner(
}
} else {
has_subpel_filter = true;
*b.inter_mode_mut() = NEWMV;
*b.inter_mode_mut() = InterPredMode::New.into();
*b.drl_idx_mut() = DrlProximity::Nearest;
if n_mvs > 1 {
// `Nearer`, `Near` or `Nearish`
Expand Down Expand Up @@ -2690,7 +2687,7 @@ unsafe fn decode_b_inner(
}
if debug_block_info!(f, t.b) {
println!(
"Post-intermode[{},drl={:?}]: r={}",
"Post-intermode[{:?},drl={:?}]: r={}",
b.inter_mode(),
b.drl_idx(),
ts.msac.rng,
Expand Down Expand Up @@ -2768,7 +2765,7 @@ unsafe fn decode_b_inner(
&& cmp::min(bw4, bh4) >= 2
// is not warped global motion
&& !(!frame_hdr.force_integer_mv
&& b.inter_mode() == GLOBALMV
&& b.inter_mode() == InterPredMode::Global.into()
&& frame_hdr.gmv[b.r#ref()[0] as usize].r#type > Rav1dWarpedMotionType::Translation)
// has overlappable neighbours
&& (have_left && findoddzero(&t.l.intra.0[by4 as usize..][..h4 as usize])
Expand Down Expand Up @@ -2911,8 +2908,12 @@ unsafe fn decode_b_inner(

let frame_hdr = f.frame_hdr();
if frame_hdr.loopfilter.level_y != [0, 0] {
let is_globalmv =
(b.inter_mode() == if is_comp { GLOBALMV_GLOBALMV } else { GLOBALMV }) as c_int;
let is_globalmv = (b.inter_mode()
== if is_comp {
CompInterPredMode::GlobalGlobal
} else {
InterPredMode::Global.into()
}) as c_int;
let tx_split = [b.tx_split0() as u16, b.tx_split1()];
let mut ytx = b.max_ytx() as RectTxfmSize;
let mut uvtx = b.uvtx as RectTxfmSize;
Expand Down Expand Up @@ -2981,7 +2982,7 @@ unsafe fn decode_b_inner(
case.set(&mut dir.comp_type.0, b.comp_type());
case.set(&mut dir.filter.0[0], filter[0]);
case.set(&mut dir.filter.0[1], filter[1]);
case.set(&mut dir.mode.0, b.inter_mode());
case.set(&mut dir.mode.0, b.inter_mode() as u8);
case.set(&mut dir.r#ref.0[0], b.r#ref()[0]);
case.set(&mut dir.r#ref.0[1], b.r#ref()[1]);
},
Expand Down Expand Up @@ -3041,7 +3042,8 @@ unsafe fn decode_b_inner(
if b.comp_type().is_none() {
// y
if cmp::min(bw4, bh4) > 1
&& (b.inter_mode() == GLOBALMV && f.gmv_warp_allowed[b.r#ref()[0] as usize] != 0
&& (b.inter_mode() == InterPredMode::Global.into()
&& f.gmv_warp_allowed[b.r#ref()[0] as usize] != 0
|| b.motion_mode() == MotionMode::Warp
&& t.warpmv.r#type > Rav1dWarpedMotionType::Translation)
{
Expand Down Expand Up @@ -3154,7 +3156,7 @@ unsafe fn decode_b_inner(
&f.svc[b.r#ref()[0] as usize][1],
);
} else if cmp::min(cbw4, cbh4) > 1
&& (b.inter_mode() == GLOBALMV
&& (b.inter_mode() == InterPredMode::Global.into()
&& f.gmv_warp_allowed[b.r#ref()[0] as usize] != 0
|| b.motion_mode() == MotionMode::Warp
&& t.warpmv.r#type > Rav1dWarpedMotionType::Translation)
Expand Down Expand Up @@ -3201,7 +3203,9 @@ unsafe fn decode_b_inner(
let refmvs =
|| std::iter::zip(b.r#ref(), b.mv()).map(|(r#ref, mv)| (r#ref as usize, mv));
for (r#ref, mv) in refmvs() {
if b.inter_mode() == GLOBALMV_GLOBALMV && f.gmv_warp_allowed[r#ref] != 0 {
if b.inter_mode() == CompInterPredMode::GlobalGlobal
&& f.gmv_warp_allowed[r#ref] != 0
{
affine_lowest_px_luma(
t,
&mut lowest_px[r#ref][0],
Expand All @@ -3220,7 +3224,9 @@ unsafe fn decode_b_inner(
}
}
for (r#ref, mv) in refmvs() {
if b.inter_mode() == GLOBALMV_GLOBALMV && f.gmv_warp_allowed[r#ref] != 0 {
if b.inter_mode() == CompInterPredMode::GlobalGlobal
&& f.gmv_warp_allowed[r#ref] != 0
{
affine_lowest_px_luma(
t,
&mut lowest_px[r#ref][0],
Expand All @@ -3242,7 +3248,7 @@ unsafe fn decode_b_inner(
// uv
if has_chroma {
for (r#ref, mv) in refmvs() {
if b.inter_mode() == GLOBALMV_GLOBALMV
if b.inter_mode() == CompInterPredMode::GlobalGlobal
&& cmp::min(cbw4, cbh4) > 1
&& f.gmv_warp_allowed[r#ref] != 0
{
Expand Down Expand Up @@ -3631,7 +3637,7 @@ fn reset_context(ctx: &mut BlockContext, keyframe: bool, pass: c_int) {
r#ref.fill(-1);
}
ctx.comp_type.0.fill(None);
ctx.mode.0.fill(NEARESTMV);
ctx.mode.0.fill(InterPredMode::Nearest as u8);
}
ctx.lcoef.0.fill(0x40);
for ccoef in &mut ctx.ccoef.0 {
Expand Down
53 changes: 33 additions & 20 deletions src/levels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,6 @@ pub enum MVJoint {
HV = 3,
}

pub type InterPredMode = u8;
pub const _N_INTER_PRED_MODES: usize = 4;
pub const NEWMV: InterPredMode = 3;
pub const GLOBALMV: InterPredMode = 2;
pub const NEARMV: InterPredMode = 1;
pub const NEARESTMV: InterPredMode = 0;

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum DrlProximity {
Nearest,
Expand All @@ -204,16 +197,36 @@ pub enum DrlProximity {
Nearish,
}

pub type CompInterPredMode = u8;
pub const N_COMP_INTER_PRED_MODES: usize = 8;
pub const NEWMV_NEWMV: CompInterPredMode = 7;
pub const GLOBALMV_GLOBALMV: CompInterPredMode = 6;
pub const NEWMV_NEARMV: CompInterPredMode = 5;
pub const NEARMV_NEWMV: CompInterPredMode = 4;
pub const NEWMV_NEARESTMV: CompInterPredMode = 3;
pub const NEARESTMV_NEWMV: CompInterPredMode = 2;
pub const NEARMV_NEARMV: CompInterPredMode = 1;
pub const NEARESTMV_NEARESTMV: CompInterPredMode = 0;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InterPredMode {
Nearest = 0,
Near = 1,
Global = 2,
New = 3,
Comment on lines +202 to +205
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've kept the number here for now because the ordering is important for the mapping to CompInterPredMode

}

#[derive(Debug, Clone, Copy, PartialEq, Eq, FromRepr, EnumCount)]
pub enum CompInterPredMode {
NearestNearest = 0,
NearNear = 1,
NearestNew = 2,
NewNearest = 3,
NearNew = 4,
NewNear = 5,
GlobalGlobal = 6,
NewNew = 7,
}

impl From<InterPredMode> for CompInterPredMode {
fn from(value: InterPredMode) -> Self {
match value {
InterPredMode::Nearest => CompInterPredMode::NearestNearest,
InterPredMode::Near => CompInterPredMode::NearNear,
InterPredMode::Global => CompInterPredMode::NearestNew,
InterPredMode::New => CompInterPredMode::NewNearest,
}
}
}
Comment on lines +220 to +229
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this conversion happens in practice, though I don't really understand why it makes sense. Maybe the CompInterPredMode field is just implicitly a InterPredMode in certain circumstances? or there is something more subtle going on.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but e.g. Global mapping to NearestNew is weird when GlobalGlobal exists.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is very weird. I have no idea why that's done. It looks like a correct translation of the C, though. @fbossen, have any idea?

Maybe the same fields are being used for the two enums, but never at the same time? I'm curious if the tests would still pass if we changed the conversion to what would be expected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is this one expression that mixes the two types

(b.inter_mode() == if is_comp { GLOBALMV_GLOBALMV } else { GLOBALMV }) as c_int;

so, I guess, if is_comp is true then CompInterPredMode is used, and otherwise just InterPredMode. That at least intuitively makes sense here, and also seems to be true in this file.

Not sure what to do with that though, is_comp is a runtime value so it's hard encode that into the type. Some enum { Comp(_), NotComp(_) } enum also doesn't seem great (but would provide some extra type safety).


#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum CompInterType {
Expand Down Expand Up @@ -319,7 +332,7 @@ pub union Av1Block_inter_nd {
pub struct Av1Block_inter {
pub c2rust_unnamed: Av1Block_inter_nd,
pub comp_type: Option<CompInterType>,
pub inter_mode: u8,
pub inter_mode: CompInterPredMode,
pub motion_mode: MotionMode,
pub drl_idx: DrlProximity,
pub r#ref: [i8; 2],
Expand Down Expand Up @@ -383,11 +396,11 @@ impl Av1Block {
&mut self.c2rust_unnamed.c2rust_unnamed_0.drl_idx
}

pub unsafe fn inter_mode(&self) -> u8 {
pub unsafe fn inter_mode(&self) -> CompInterPredMode {
self.c2rust_unnamed.c2rust_unnamed_0.inter_mode
}

pub unsafe fn inter_mode_mut(&mut self) -> &mut u8 {
pub unsafe fn inter_mode_mut(&mut self) -> &mut CompInterPredMode {
&mut self.c2rust_unnamed.c2rust_unnamed_0.inter_mode
}

Expand Down
Loading
Loading