diff --git a/library/core/src/str/pattern.rs b/library/core/src/str/pattern.rs index def11ca45c05e..c5be32861f9a5 100644 --- a/library/core/src/str/pattern.rs +++ b/library/core/src/str/pattern.rs @@ -956,15 +956,20 @@ impl<'a, 'b> Pattern<'a> for &'b str { match self.len().cmp(&haystack.len()) { Ordering::Less => { + if self.len() == 1 { + return haystack.as_bytes().contains(&self.as_bytes()[0]); + } + #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] - if self.as_bytes().len() <= 8 { - return simd_contains(self, haystack); + if self.len() <= 32 { + if let Some(result) = simd_contains(self, haystack) { + return result; + } } self.into_searcher(haystack).next_match().is_some() } - Ordering::Equal => self == haystack, - Ordering::Greater => false, + _ => self == haystack, } } @@ -1707,82 +1712,207 @@ impl TwoWayStrategy for RejectAndMatch { } } +/// SIMD search for short needles based on +/// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0] +/// +/// It skips ahead by the vector width on each iteration (rather than the needle length as two-way +/// does) by probing the first and last byte of the needle for the whole vector width +/// and only doing full needle comparisons when the vectorized probe indicated potential matches. +/// +/// Since the x86_64 baseline only offers SSE2 we only use u8x16 here. +/// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors +/// should be evaluated. +/// +/// For haystacks smaller than vector-size + needle length it falls back to +/// a naive O(n*m) search so this implementation should not be called on larger needles. +/// +/// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2 #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] #[inline] -fn simd_contains(needle: &str, haystack: &str) -> bool { +fn simd_contains(needle: &str, haystack: &str) -> Option { let needle = needle.as_bytes(); let haystack = haystack.as_bytes(); - if needle.len() == 1 { - return haystack.contains(&needle[0]); - } - - const CHUNK: usize = 16; + debug_assert!(needle.len() > 1); + + use crate::ops::BitAnd; + use crate::simd::mask8x16 as Mask; + use crate::simd::u8x16 as Block; + use crate::simd::{SimdPartialEq, ToBitMask}; + + let first_probe = needle[0]; + + // the offset used for the 2nd vector + let second_probe_offset = if needle.len() == 2 { + // never bail out on len=2 needles because the probes will fully cover them and have + // no degenerate cases. + 1 + } else { + // try a few bytes in case first and last byte of the needle are the same + let Some(second_probe_offset) = (needle.len().saturating_sub(4)..needle.len()).rfind(|&idx| needle[idx] != first_probe) else { + // fall back to other search methods if we can't find any different bytes + // since we could otherwise hit some degenerate cases + return None; + }; + second_probe_offset + }; - // do a naive search if if the haystack is too small to fit - if haystack.len() < CHUNK + needle.len() - 1 { - return haystack.windows(needle.len()).any(|c| c == needle); + // do a naive search if the haystack is too small to fit + if haystack.len() < Block::LANES + second_probe_offset { + return Some(haystack.windows(needle.len()).any(|c| c == needle)); } - use crate::arch::x86_64::{ - __m128i, _mm_and_si128, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8, - }; - - // SAFETY: no preconditions other than sse2 being available - let first: __m128i = unsafe { _mm_set1_epi8(needle[0] as i8) }; - // SAFETY: no preconditions other than sse2 being available - let last: __m128i = unsafe { _mm_set1_epi8(*needle.last().unwrap() as i8) }; + let first_probe: Block = Block::splat(first_probe); + let second_probe: Block = Block::splat(needle[second_probe_offset]); + // first byte are already checked by the outer loop. to verify a match only the + // remainder has to be compared. + let trimmed_needle = &needle[1..]; + // this #[cold] is load-bearing, benchmark before removing it... let check_mask = #[cold] - |idx, mut mask: u32| -> bool { + |idx, mask: u16, skip: bool| -> bool { + if skip { + return false; + } + + // and so is this. optimizations are weird. + let mut mask = mask; + while mask != 0 { let trailing = mask.trailing_zeros(); let offset = idx + trailing as usize + 1; - let sub = &haystack[offset..][..needle.len() - 2]; - let trimmed_needle = &needle[1..needle.len() - 1]; - - if sub == trimmed_needle { - return true; + // SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared + // and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop + unsafe { + let sub = haystack.get_unchecked(offset..).get_unchecked(..trimmed_needle.len()); + if small_slice_eq(sub, trimmed_needle) { + return true; + } } mask &= !(1 << trailing); } return false; }; - let test_chunk = |i| -> bool { - // SAFETY: this requires at least CHUNK bytes being readable at offset i + let test_chunk = |idx| -> u16 { + // SAFETY: this requires at least LANES bytes being readable at idx // that is ensured by the loop ranges (see comments below) - let a: __m128i = unsafe { _mm_loadu_si128(haystack.as_ptr().add(i) as *const _) }; - let b: __m128i = - // SAFETY: this requires CHUNK + needle.len() - 1 bytes being readable at offset i - unsafe { _mm_loadu_si128(haystack.as_ptr().add(i + needle.len() - 1) as *const _) }; - - // SAFETY: no preconditions other than sse2 being available - let eq_first: __m128i = unsafe { _mm_cmpeq_epi8(first, a) }; - // SAFETY: no preconditions other than sse2 being available - let eq_last: __m128i = unsafe { _mm_cmpeq_epi8(last, b) }; - - // SAFETY: no preconditions other than sse2 being available - let mask: u32 = unsafe { _mm_movemask_epi8(_mm_and_si128(eq_first, eq_last)) } as u32; + let a: Block = unsafe { haystack.as_ptr().add(idx).cast::().read_unaligned() }; + // SAFETY: this requires LANES + block_offset bytes being readable at idx + let b: Block = unsafe { + haystack.as_ptr().add(idx).add(second_probe_offset).cast::().read_unaligned() + }; + let eq_first: Mask = a.simd_eq(first_probe); + let eq_last: Mask = b.simd_eq(second_probe); + let both = eq_first.bitand(eq_last); + let mask = both.to_bitmask(); - if mask != 0 { - return check_mask(i, mask); - } - return false; + return mask; }; let mut i = 0; let mut result = false; - while !result && i + CHUNK + needle.len() <= haystack.len() { - result |= test_chunk(i); - i += CHUNK; + // The loop condition must ensure that there's enough headroom to read LANE bytes, + // and not only at the current index but also at the index shifted by block_offset + const UNROLL: usize = 4; + while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result { + let mut masks = [0u16; UNROLL]; + for j in 0..UNROLL { + masks[j] = test_chunk(i + j * Block::LANES); + } + for j in 0..UNROLL { + let mask = masks[j]; + if mask != 0 { + result |= check_mask(i + j * Block::LANES, mask, result); + } + } + i += UNROLL * Block::LANES; + } + while i + second_probe_offset + Block::LANES < haystack.len() && !result { + let mask = test_chunk(i); + if mask != 0 { + result |= check_mask(i, mask, result); + } + i += Block::LANES; } - // process the tail that didn't fit into CHUNK-sized steps - // this simply repeats the same procedure but as right-aligned chunk instead + // Process the tail that didn't fit into LANES-sized steps. + // This simply repeats the same procedure but as right-aligned chunk instead // of a left-aligned one. The last byte must be exactly flush with the string end so // we don't miss a single byte or read out of bounds. - result |= test_chunk(haystack.len() + 1 - needle.len() - CHUNK); + let i = haystack.len() - second_probe_offset - Block::LANES; + let mask = test_chunk(i); + if mask != 0 { + result |= check_mask(i, mask, result); + } + + Some(result) +} + +/// Compares short slices for equality. +/// +/// It avoids a call to libc's memcmp which is faster on long slices +/// due to SIMD optimizations but it incurs a function call overhead. +/// +/// # Safety +/// +/// Both slices must have the same length. +#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86 +#[inline] +unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool { + // This function is adapted from + // https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32 - return result; + // If we don't have enough bytes to do 4-byte at a time loads, then + // fall back to the naive slow version. + // + // Potential alternative: We could do a copy_nonoverlapping combined with a mask instead + // of a loop. Benchmark it. + if x.len() < 4 { + for (&b1, &b2) in x.iter().zip(y) { + if b1 != b2 { + return false; + } + } + return true; + } + // When we have 4 or more bytes to compare, then proceed in chunks of 4 at + // a time using unaligned loads. + // + // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is + // that this particular version of memcmp is likely to be called with tiny + // needles. That means that if we do 8 byte loads, then a higher proportion + // of memcmp calls will use the slower variant above. With that said, this + // is a hypothesis and is only loosely supported by benchmarks. There's + // likely some improvement that could be made here. The main thing here + // though is to optimize for latency, not throughput. + + // SAFETY: Via the conditional above, we know that both `px` and `py` + // have the same length, so `px < pxend` implies that `py < pyend`. + // Thus, derefencing both `px` and `py` in the loop below is safe. + // + // Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual + // end of of `px` and `py`. Thus, the final dereference outside of the + // loop is guaranteed to be valid. (The final comparison will overlap with + // the last comparison done in the loop for lengths that aren't multiples + // of four.) + // + // Finally, we needn't worry about alignment here, since we do unaligned + // loads. + unsafe { + let (mut px, mut py) = (x.as_ptr(), y.as_ptr()); + let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4)); + while px < pxend { + let vx = (px as *const u32).read_unaligned(); + let vy = (py as *const u32).read_unaligned(); + if vx != vy { + return false; + } + px = px.add(4); + py = py.add(4); + } + let vx = (pxend as *const u32).read_unaligned(); + let vy = (pyend as *const u32).read_unaligned(); + vx == vy + } }