Skip to content

Commit fcd0683

Browse files
authored
fn get_lo_ctx: Optimize (#1241)
* Part of #1180. This optimizes `fn get_lo_ctx`. 9d00ae6 optimizes it enough that it is now inlined (it was already `#[inline]`), which enables other optimizations. a5c6209 makes `stride` a `u8`, which it is in the caller `fn decode_coefs`, and which tells LLVM that `_ * stride` can't overflow, and thus `2 * stride` being in bounds means that `0 * stride` is in bounds, for example. 242065b then pre-indexes `levels` so that there is only one bounds check per branch. There is still actually another since LLVM does not know that `stride != 0`. This can be optimized out with more effort in `fn decode_coefs`, as `stride` is ultimately derived from `t_dim.h`, which comes from `static`/`const` data.
2 parents dac2b82 + e42f2b8 commit fcd0683

File tree

2 files changed

+72
-106
lines changed

2 files changed

+72
-106
lines changed

src/levels.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ pub const DCT_ADST: TxfmType = 2;
145145
pub const ADST_DCT: TxfmType = 1;
146146
pub const DCT_DCT: TxfmType = 0;
147147

148-
#[derive(Clone, Copy, PartialEq, Eq)]
148+
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
149149
pub enum TxClass {
150150
TwoD,
151151
H,

src/recon.rs

+71-105
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ use crate::src::tables::dav1d_txtp_from_uvmode;
8282
use crate::src::tables::TxfmInfo;
8383
use crate::src::wedge::dav1d_ii_masks;
8484
use crate::src::wedge::dav1d_wedge_masks;
85+
use assert_matches::debug_assert_matches;
8586
use libc::intptr_t;
8687
use std::array;
8788
use std::cmp;
@@ -467,30 +468,46 @@ fn get_dc_sign_ctx(tx: TxfmSize, a: &[u8], l: &[u8]) -> c_uint {
467468
fn get_lo_ctx(
468469
levels: &[u8],
469470
tx_class: TxClass,
470-
hi_mag: &mut c_uint,
471+
hi_mag: &mut u32,
471472
ctx_offsets: Option<&[[u8; 5]; 5]>,
472-
x: usize,
473-
y: usize,
474-
stride: usize,
475-
) -> usize {
476-
let level = |y, x| levels[y * stride + x] as usize;
477-
478-
let mut mag = level(0, 1) + level(1, 0);
479-
let offset = match tx_class {
480-
TxClass::TwoD => {
473+
x: u32,
474+
y: u32,
475+
stride: u8,
476+
) -> u8 {
477+
let stride = stride as usize;
478+
let level = |y, x| levels[y * stride + x] as u32;
479+
480+
// Note that the first `mag` initialization is moved inside the `match`
481+
// so that the different bounds checks can be done inside the `match`,
482+
// as putting them outside the `match` in an identical one trips up LLVM.
483+
let mut mag;
484+
let offset;
485+
match ctx_offsets {
486+
Some(ctx_offsets) => {
487+
level(2, 1); // Bounds check all at once.
488+
mag = level(0, 1) + level(1, 0);
489+
debug_assert_matches!(tx_class, TxClass::TwoD);
481490
mag += level(1, 1);
482-
*hi_mag = mag as c_uint;
491+
*hi_mag = mag;
483492
mag += level(0, 2) + level(2, 0);
484-
ctx_offsets.unwrap()[cmp::min(y, 4)][cmp::min(x, 4)] as usize
493+
offset = ctx_offsets[cmp::min(y as usize, 4)][cmp::min(x as usize, 4)];
485494
}
486-
TxClass::H | TxClass::V => {
495+
None => {
496+
debug_assert_matches!(tx_class, TxClass::H | TxClass::V);
497+
level(1, 4); // Bounds check all at once.
498+
mag = level(0, 1) + level(1, 0);
487499
mag += level(0, 2);
488-
*hi_mag = mag as c_uint;
500+
*hi_mag = mag;
489501
mag += level(0, 3) + level(0, 4);
490-
26 + if y > 1 { 10 } else { y * 5 }
502+
offset = 26 + if y > 1 { 10 } else { y as u8 * 5 };
503+
}
504+
}
505+
offset
506+
+ if mag > 512 {
507+
4
508+
} else {
509+
((mag + 64) >> 7) as u8
491510
}
492-
};
493-
offset + if mag > 512 { 4 } else { (mag + 64) >> 7 }
494511
}
495512

496513
fn decode_coefs<BD: BitDepth>(
@@ -709,9 +726,9 @@ fn decode_coefs<BD: BitDepth>(
709726
let sh = cmp::min(t_dim.h, 8);
710727

711728
// eob
712-
let mut ctx: c_uint = 1
713-
+ (eob > sw as c_int * sh as c_int * 2) as c_uint
714-
+ (eob > sw as c_int * sh as c_int * 4) as c_uint;
729+
let mut ctx = 1
730+
+ (eob > sw as c_int * sh as c_int * 2) as u8
731+
+ (eob > sw as c_int * sh as c_int * 4) as u8;
715732
let eob_tok =
716733
rav1d_msac_decode_symbol_adapt4(&mut ts_c.msac, &mut eob_cdf[ctx as usize], 2) as c_int;
717734
let mut tok = eob_tok + 1;
@@ -727,7 +744,7 @@ fn decode_coefs<BD: BitDepth>(
727744
[nonsquare_tx.wrapping_add(tx as c_uint & nonsquare_tx) as usize],
728745
);
729746
scan = dav1d_scans[tx as usize];
730-
let stride = 4 * sh as usize;
747+
let stride = 4 * sh;
731748
let shift: c_uint = if t_dim.lh < 4 {
732749
t_dim.lh as c_uint + 2
733750
} else {
@@ -737,7 +754,7 @@ fn decode_coefs<BD: BitDepth>(
737754
let mask: c_uint = 4 * sh as c_uint - 1;
738755
// Optimizes better than `.fill(0)`,
739756
// which doesn't elide the bounds check, inline, or vectorize.
740-
for i in 0..stride * (4 * sw as usize + 2) {
757+
for i in 0..stride as usize * (4 * sw as usize + 2) {
741758
levels[i] = 0;
742759
}
743760
let mut x: c_uint;
@@ -793,7 +810,7 @@ fn decode_coefs<BD: BitDepth>(
793810
}
794811
}
795812
cf.set::<BD>(f, t_cf, rc as usize, (tok << 11).as_::<BD::Coef>());
796-
levels[x as usize * stride + y as usize] = level_tok as u8;
813+
levels[x as usize * stride as usize + y as usize] = level_tok as u8;
797814
let mut i = eob - 1;
798815
while i > 0 {
799816
// ac
@@ -816,16 +833,8 @@ fn decode_coefs<BD: BitDepth>(
816833
}
817834
}
818835
assert!(x < 32 && y < 32);
819-
let level = &mut levels[x as usize * stride + y as usize..];
820-
ctx = get_lo_ctx(
821-
level,
822-
tx_class,
823-
&mut mag,
824-
lo_ctx_offsets,
825-
x as usize,
826-
y as usize,
827-
stride,
828-
) as c_uint;
836+
let level = &mut levels[x as usize * stride as usize + y as usize..];
837+
ctx = get_lo_ctx(level, tx_class, &mut mag, lo_ctx_offsets, x, y, stride);
829838
if tx_class == TxClass::TwoD {
830839
y |= x;
831840
}
@@ -842,16 +851,11 @@ fn decode_coefs<BD: BitDepth>(
842851
}
843852
if tok == 3 {
844853
mag &= 63;
845-
ctx = ((if y > (tx_class == TxClass::TwoD) as c_uint {
854+
ctx = if y > (tx_class == TxClass::TwoD) as c_uint {
846855
14
847856
} else {
848857
7
849-
}) as c_uint)
850-
.wrapping_add(if mag > 12 {
851-
6
852-
} else {
853-
mag.wrapping_add(1) >> 1
854-
});
858+
} + if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 };
855859
tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize])
856860
as c_int;
857861
if dbg {
@@ -891,7 +895,7 @@ fn decode_coefs<BD: BitDepth>(
891895
ctx = if tx_class == TxClass::TwoD {
892896
0
893897
} else {
894-
get_lo_ctx(levels, tx_class, &mut mag, lo_ctx_offsets, 0, 0, stride) as c_uint
898+
get_lo_ctx(levels, tx_class, &mut mag, lo_ctx_offsets, 0, 0, stride)
895899
};
896900
dc_tok =
897901
rav1d_msac_decode_symbol_adapt4(&mut ts_c.msac, &mut lo_cdf[ctx as usize], 3)
@@ -904,16 +908,12 @@ fn decode_coefs<BD: BitDepth>(
904908
}
905909
if dc_tok == 3 {
906910
if tx_class == TxClass::TwoD {
907-
mag = levels[0 * stride + 1] as c_uint
908-
+ levels[1 * stride + 0] as c_uint
909-
+ levels[1 * stride + 1] as c_uint;
911+
mag = levels[0 * stride as usize + 1] as c_uint
912+
+ levels[1 * stride as usize + 0] as c_uint
913+
+ levels[1 * stride as usize + 1] as c_uint;
910914
}
911915
mag &= 63;
912-
ctx = if mag > 12 {
913-
6
914-
} else {
915-
mag.wrapping_add(1) >> 1
916-
};
916+
ctx = if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 };
917917
dc_tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize])
918918
as c_uint;
919919
if dbg {
@@ -935,7 +935,7 @@ fn decode_coefs<BD: BitDepth>(
935935
let mask: c_uint = 4 * sh as c_uint - 1;
936936
// Optimizes better than `.fill(0)`,
937937
// which doesn't elide the bounds check, inline, or vectorize.
938-
for i in 0..stride * (4 * sh as usize + 2) {
938+
for i in 0..stride as usize * (4 * sh as usize + 2) {
939939
levels[i] = 0;
940940
}
941941
let mut x: c_uint;
@@ -990,7 +990,7 @@ fn decode_coefs<BD: BitDepth>(
990990
}
991991
}
992992
cf.set::<BD>(f, t_cf, rc as usize, (tok << 11).as_::<BD::Coef>());
993-
levels[x as usize * stride + y as usize] = level_tok as u8;
993+
levels[x as usize * stride as usize + y as usize] = level_tok as u8;
994994
let mut i = eob - 1;
995995
while i > 0 {
996996
let rc_i: c_uint;
@@ -1012,16 +1012,8 @@ fn decode_coefs<BD: BitDepth>(
10121012
}
10131013
}
10141014
assert!(x < 32 && y < 32);
1015-
let level = &mut levels[x as usize * stride + y as usize..];
1016-
ctx = get_lo_ctx(
1017-
level,
1018-
tx_class,
1019-
&mut mag,
1020-
lo_ctx_offsets,
1021-
x as usize,
1022-
y as usize,
1023-
stride,
1024-
) as c_uint;
1015+
let level = &mut levels[x as usize * stride as usize + y as usize..];
1016+
ctx = get_lo_ctx(level, tx_class, &mut mag, lo_ctx_offsets, x, y, stride);
10251017
if tx_class == TxClass::TwoD {
10261018
y |= x;
10271019
}
@@ -1038,16 +1030,11 @@ fn decode_coefs<BD: BitDepth>(
10381030
}
10391031
if tok == 3 {
10401032
mag &= 63;
1041-
ctx = ((if y > (tx_class == TxClass::TwoD) as c_uint {
1033+
ctx = if y > (tx_class == TxClass::TwoD) as c_uint {
10421034
14
10431035
} else {
10441036
7
1045-
}) as c_uint)
1046-
.wrapping_add(if mag > 12 {
1047-
6
1048-
} else {
1049-
mag.wrapping_add(1) >> 1
1050-
});
1037+
} + if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 };
10511038
tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize])
10521039
as c_int;
10531040
if dbg {
@@ -1084,7 +1071,7 @@ fn decode_coefs<BD: BitDepth>(
10841071
ctx = if tx_class == TxClass::TwoD {
10851072
0
10861073
} else {
1087-
get_lo_ctx(levels, tx_class, &mut mag, lo_ctx_offsets, 0, 0, stride) as c_uint
1074+
get_lo_ctx(levels, tx_class, &mut mag, lo_ctx_offsets, 0, 0, stride)
10881075
};
10891076
dc_tok =
10901077
rav1d_msac_decode_symbol_adapt4(&mut ts_c.msac, &mut lo_cdf[ctx as usize], 3)
@@ -1097,16 +1084,12 @@ fn decode_coefs<BD: BitDepth>(
10971084
}
10981085
if dc_tok == 3 {
10991086
if tx_class == TxClass::TwoD {
1100-
mag = levels[0 * stride + 1] as c_uint
1101-
+ levels[1 * stride + 0] as c_uint
1102-
+ levels[1 * stride + 1] as c_uint;
1087+
mag = levels[0 * stride as usize + 1] as c_uint
1088+
+ levels[1 * stride as usize + 0] as c_uint
1089+
+ levels[1 * stride as usize + 1] as c_uint;
11031090
}
11041091
mag &= 63;
1105-
ctx = if mag > 12 {
1106-
6
1107-
} else {
1108-
mag.wrapping_add(1) >> 1
1109-
};
1092+
ctx = if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 };
11101093
dc_tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize])
11111094
as c_uint;
11121095
if dbg {
@@ -1128,7 +1111,7 @@ fn decode_coefs<BD: BitDepth>(
11281111
let mask: c_uint = 4 * sw as c_uint - 1;
11291112
// Optimizes better than `.fill(0)`,
11301113
// which doesn't elide the bounds check, inline, or vectorize.
1131-
for i in 0..stride * (4 * sw as usize + 2) {
1114+
for i in 0..stride as usize * (4 * sw as usize + 2) {
11321115
levels[i] = 0;
11331116
}
11341117
let mut x: c_uint;
@@ -1183,7 +1166,7 @@ fn decode_coefs<BD: BitDepth>(
11831166
}
11841167
}
11851168
cf.set::<BD>(f, t_cf, rc as usize, (tok << 11).as_::<BD::Coef>());
1186-
levels[x as usize * stride + y as usize] = level_tok as u8;
1169+
levels[x as usize * stride as usize + y as usize] = level_tok as u8;
11871170
let mut i = eob - 1;
11881171
while i > 0 {
11891172
let rc_i: c_uint;
@@ -1205,16 +1188,8 @@ fn decode_coefs<BD: BitDepth>(
12051188
}
12061189
}
12071190
assert!(x < 32 && y < 32);
1208-
let level = &mut levels[x as usize * stride + y as usize..];
1209-
ctx = get_lo_ctx(
1210-
level,
1211-
tx_class,
1212-
&mut mag,
1213-
lo_ctx_offsets,
1214-
x as usize,
1215-
y as usize,
1216-
stride,
1217-
) as c_uint;
1191+
let level = &mut levels[x as usize * stride as usize + y as usize..];
1192+
ctx = get_lo_ctx(level, tx_class, &mut mag, lo_ctx_offsets, x, y, stride);
12181193
if tx_class == TxClass::TwoD {
12191194
y |= x;
12201195
}
@@ -1231,16 +1206,11 @@ fn decode_coefs<BD: BitDepth>(
12311206
}
12321207
if tok == 3 {
12331208
mag &= 63;
1234-
ctx = ((if y > (tx_class == TxClass::TwoD) as c_uint {
1209+
ctx = if y > (tx_class == TxClass::TwoD) as c_uint {
12351210
14
12361211
} else {
12371212
7
1238-
}) as c_uint)
1239-
.wrapping_add(if mag > 12 {
1240-
6
1241-
} else {
1242-
mag.wrapping_add(1) >> 1
1243-
});
1213+
} + if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 };
12441214
tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize])
12451215
as c_int;
12461216
if dbg {
@@ -1277,7 +1247,7 @@ fn decode_coefs<BD: BitDepth>(
12771247
ctx = if tx_class == TxClass::TwoD {
12781248
0
12791249
} else {
1280-
get_lo_ctx(levels, tx_class, &mut mag, lo_ctx_offsets, 0, 0, stride) as c_uint
1250+
get_lo_ctx(levels, tx_class, &mut mag, lo_ctx_offsets, 0, 0, stride)
12811251
};
12821252
dc_tok =
12831253
rav1d_msac_decode_symbol_adapt4(&mut ts_c.msac, &mut lo_cdf[ctx as usize], 3)
@@ -1290,16 +1260,12 @@ fn decode_coefs<BD: BitDepth>(
12901260
}
12911261
if dc_tok == 3 {
12921262
if tx_class == TxClass::TwoD {
1293-
mag = levels[0 * stride + 1] as c_uint
1294-
+ levels[1 * stride + 0] as c_uint
1295-
+ levels[1 * stride + 1] as c_uint;
1263+
mag = levels[0 * stride as usize + 1] as c_uint
1264+
+ levels[1 * stride as usize + 0] as c_uint
1265+
+ levels[1 * stride as usize + 1] as c_uint;
12961266
}
12971267
mag &= 63;
1298-
ctx = if mag > 12 {
1299-
6
1300-
} else {
1301-
mag.wrapping_add(1) >> 1
1302-
};
1268+
ctx = if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 };
13031269
dc_tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize])
13041270
as c_uint;
13051271
if dbg {

0 commit comments

Comments
 (0)