Skip to content

Commit

Permalink
Fix pattern and text loads so that they stay in bounds
Browse files Browse the repository at this point in the history
The simd loads require that we can read all 16 bytes of memory from the
pointer.

Add function mask_load to read from shorter slices of memory if
required. Follow the code, wherever we have the variable "safetext", we
have already ensured we're only reading up until before the last 16
bytes of the input text, so here it is safe to use simd loads.

This way, the code passes fuzz testing without errors, and keeps the
benchmark results.
  • Loading branch information
bluss committed Oct 30, 2018
1 parent 9574174 commit d4e91ea
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions src/pcmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ extern crate unchecked_index;
extern crate memchr;

use std::cmp;
use std::mem;
use std::iter::Zip;

use self::unchecked_index::get_unchecked;
Expand All @@ -37,8 +38,29 @@ use std::arch::x86_64::*;
/// PCMPESTRI xmm1, xmm2/m128, imm8
///
/// Return value: least index for start of (partial) match, (16 if no match).
///
/// Mask: `text` can be at at any point in valid memory, as long as `text_len`
/// bytes are readable.
#[target_feature(enable = "sse4.2")]
unsafe fn pcmpestri_16_mask(text: *const u8, offset: usize, text_len: usize,
needle: __m128i, needle_len: usize) -> u32 {
//debug_assert!(text_len + offset <= text.len()); // saturates at 16
//debug_assert!(needle_len <= 16); // saturates at 16
let text = mask_load(text.offset(offset as _) as *const _, text_len);
_mm_cmpestri(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_ORDERED) as _
}

/// `pcmpestri`
///
/// “Packed compare explicit length strings (return index)”
///
/// PCMPESTRI xmm1, xmm2/m128, imm8
///
/// Return value: least index for start of (partial) match, (16 if no match).
///
/// No mask: `text` must be at least 16 bytes from the end of a memory region.
#[target_feature(enable = "sse4.2")]
unsafe fn pcmpestri_16(text: *const u8, offset: usize, text_len: usize,
unsafe fn pcmpestri_16_nomask(text: *const u8, offset: usize, text_len: usize,
needle: __m128i, needle_len: usize) -> u32 {
//debug_assert!(text_len + offset <= text.len()); // saturates at 16
//debug_assert!(needle_len <= 16); // saturates at 16
Expand Down Expand Up @@ -130,7 +152,7 @@ unsafe fn first_start_of_match_inner(text: &[u8], pat: &[u8], p: __m128i) -> Opt
}
while text.len() >= offset - tp_align_offset + patl {
let tlen = text.len() - (offset - tp_align_offset);
let ret = pcmpestri_16(tp_aligned, offset, tlen, p, patl) as usize;
let ret = pcmpestri_16_mask(tp_aligned, offset, tlen, p, patl) as usize;
if ret == 16 {
offset += 16;
} else {
Expand All @@ -154,7 +176,7 @@ unsafe fn first_start_of_match_unaligned(text: &[u8], pat_len: usize, p: __m128i

while text.len() - pat_len >= offset {
let tlen = text.len() - offset;
let ret = pcmpestri_16(tp, offset, tlen, p, pat_len) as usize;
let ret = pcmpestri_16_nomask(tp, offset, tlen, p, pat_len) as usize;
if ret == 16 {
offset += 16;
} else {
Expand Down Expand Up @@ -491,7 +513,21 @@ fn test_find() {
/// Load the first 16 bytes of `pat` into a SIMD vector.
#[inline(always)]
fn pat128(pat: &[u8]) -> __m128i {
unsafe { _mm_loadu_si128(pat.as_ptr() as *const _) }
unsafe {
mask_load(pat.as_ptr() as *const _, pat.len())
}
}

#[inline(always)]
unsafe fn mask_load(ptr: *const u8, len: usize) -> __m128i {
const REGSZ: usize = mem::size_of::<__m128i>();
if len >= REGSZ {
return _mm_loadu_si128(ptr as _);
}

let mut data = [0; REGSZ];
::std::ptr::copy_nonoverlapping(ptr, data.as_mut_ptr(), len);
return _mm_loadu_si128(data.as_mut_ptr() as _);
}

/// Find longest shared prefix, return its length
Expand Down

0 comments on commit d4e91ea

Please sign in to comment.