diff --git a/src/pcmp.rs b/src/pcmp.rs index 232e99c..be9f5ca 100644 --- a/src/pcmp.rs +++ b/src/pcmp.rs @@ -11,6 +11,7 @@ extern crate unchecked_index; extern crate memchr; use std::cmp; +use std::mem; use std::iter::Zip; use self::unchecked_index::get_unchecked; @@ -37,8 +38,29 @@ use std::arch::x86_64::*; /// PCMPESTRI xmm1, xmm2/m128, imm8 /// /// Return value: least index for start of (partial) match, (16 if no match). +/// +/// Mask: `text` can be at at any point in valid memory, as long as `text_len` +/// bytes are readable. +#[target_feature(enable = "sse4.2")] +unsafe fn pcmpestri_16_mask(text: *const u8, offset: usize, text_len: usize, + needle: __m128i, needle_len: usize) -> u32 { + //debug_assert!(text_len + offset <= text.len()); // saturates at 16 + //debug_assert!(needle_len <= 16); // saturates at 16 + let text = mask_load(text.offset(offset as _) as *const _, text_len); + _mm_cmpestri(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_ORDERED) as _ +} + +/// `pcmpestri` +/// +/// “Packed compare explicit length strings (return index)” +/// +/// PCMPESTRI xmm1, xmm2/m128, imm8 +/// +/// Return value: least index for start of (partial) match, (16 if no match). +/// +/// No mask: `text` must be at least 16 bytes from the end of a memory region. #[target_feature(enable = "sse4.2")] -unsafe fn pcmpestri_16(text: *const u8, offset: usize, text_len: usize, +unsafe fn pcmpestri_16_nomask(text: *const u8, offset: usize, text_len: usize, needle: __m128i, needle_len: usize) -> u32 { //debug_assert!(text_len + offset <= text.len()); // saturates at 16 //debug_assert!(needle_len <= 16); // saturates at 16 @@ -130,7 +152,7 @@ unsafe fn first_start_of_match_inner(text: &[u8], pat: &[u8], p: __m128i) -> Opt } while text.len() >= offset - tp_align_offset + patl { let tlen = text.len() - (offset - tp_align_offset); - let ret = pcmpestri_16(tp_aligned, offset, tlen, p, patl) as usize; + let ret = pcmpestri_16_mask(tp_aligned, offset, tlen, p, patl) as usize; if ret == 16 { offset += 16; } else { @@ -154,7 +176,7 @@ unsafe fn first_start_of_match_unaligned(text: &[u8], pat_len: usize, p: __m128i while text.len() - pat_len >= offset { let tlen = text.len() - offset; - let ret = pcmpestri_16(tp, offset, tlen, p, pat_len) as usize; + let ret = pcmpestri_16_nomask(tp, offset, tlen, p, pat_len) as usize; if ret == 16 { offset += 16; } else { @@ -491,7 +513,21 @@ fn test_find() { /// Load the first 16 bytes of `pat` into a SIMD vector. #[inline(always)] fn pat128(pat: &[u8]) -> __m128i { - unsafe { _mm_loadu_si128(pat.as_ptr() as *const _) } + unsafe { + mask_load(pat.as_ptr() as *const _, pat.len()) + } +} + +#[inline(always)] +unsafe fn mask_load(ptr: *const u8, len: usize) -> __m128i { + const REGSZ: usize = mem::size_of::<__m128i>(); + if len >= REGSZ { + return _mm_loadu_si128(ptr as _); + } + + let mut data = [0; REGSZ]; + ::std::ptr::copy_nonoverlapping(ptr, data.as_mut_ptr(), len); + return _mm_loadu_si128(data.as_mut_ptr() as _); } /// Find longest shared prefix, return its length