diff --git a/src/day4.rs b/src/day4.rs index 0add765..29e5f15 100644 --- a/src/day4.rs +++ b/src/day4.rs @@ -1,34 +1,37 @@ use super::*; +static mut SCRATCH: [u8; 2048] = [0; 2048]; + #[target_feature(enable = "avx2,bmi1,bmi2,cmpxchg16b,lzcnt,movbe,popcnt")] unsafe fn inner1(s: &[u8]) -> u32 { let r = s.as_ptr_range(); let mut ptr = r.start; - let end = r.end; + let mut end = ptr.add(19264); let mut sums0 = i8x32::splat(0); let mut sums1 = i8x32::splat(0); let mut sums2 = i8x32::splat(0); let mut sums3 = i8x32::splat(0); + let mut finishing = false; + macro_rules! load { + ($x:expr, $y:expr) => { + (ptr.add($x).add($y * 141) as *const i8x32).read_unaligned() + }; + } + macro_rules! test_four { + ($sums:expr, $a:expr, $b:expr, $c:expr, $d:expr) => { + let diff0 = $d - $a; + let diff1 = $b - $c; + let abs0 = diff0.abs(); + let abs1 = diff1.abs(); + let eq0 = abs0.simd_eq(Simd::splat(b'X' - b'S').cast()); + let eq1 = abs1.simd_eq(Simd::splat(b'M' - b'A').cast()); + let sign = diff0 ^ diff1; + let eq = eq0 & eq1; + let signs_match = sign.simd_lt(Simd::splat(0)); + $sums -= (signs_match & eq).to_int(); + }; + } loop { - macro_rules! load { - ($x:expr, $y:expr) => { - (ptr.add($x).add($y * 141) as *const i8x32).read_unaligned() - }; - } - macro_rules! test_four { - ($sums:expr, $a:expr, $b:expr, $c:expr, $d:expr) => { - let diff0 = $d - $a; - let diff1 = $b - $c; - let abs0 = diff0.abs(); - let abs1 = diff1.abs(); - let eq0 = abs0.simd_eq(Simd::splat(b'X' - b'S').cast()); - let eq1 = abs1.simd_eq(Simd::splat(b'M' - b'A').cast()); - let sign = diff0 ^ diff1; - let eq = eq0 & eq1; - let signs_match = sign.simd_lt(Simd::splat(0)); - $sums -= (signs_match & eq).to_int(); - }; - } let v00 = load!(0, 0); let v10 = load!(1, 0); let v20 = load!(2, 0); @@ -47,8 +50,16 @@ unsafe fn inner1(s: &[u8]) -> u32 { test_four!(sums3, v00, v11, v22, v33); ptr = ptr.add(32); // yes we're reading hundreds of bytes past the end of the buffer. sue me - if ptr >= end { - break; + if ptr > end { + if finishing { + break; + } + finishing = true; + let scratch = &mut SCRATCH; + let remainder = r.end.offset_from(ptr) as usize; + scratch[0..remainder].copy_from_slice(std::slice::from_raw_parts(ptr, remainder)); + ptr = scratch as _; + end = ptr.add(remainder); } } let sums0: u16x16 = _mm256_maddubs_epi16(sums0.into(), u8x32::splat(1).into()).into();