Skip to content

Commit

Permalink
Simplify safety checks in inv_memchr, LOOP_SIZE-stepping loop.
Browse files Browse the repository at this point in the history
Before this commit the safety of `end_ptr.sub(loop_size)` depended on a
somewhat remote and indirect reasoning that `haystack.len() >=
LOOP_SIZE` if `loop_size == LOOP_SIZE` (that reasoning was based on
having `let loop_size = cmp::min(LOOP_SIZE, haystack.len())`).

After this commit, the safety checks are done in an `if` statement right
above where `end_ptr.sub(LOOP_SIZE)` happens.  This simplification also
allowed removing the `loop_size` binding/variable.
  • Loading branch information
anforowicz committed Nov 14, 2024
1 parent 41f8bdb commit 8f3d10a
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions src/byteset/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// the 'inverse' query of memchr, e.g. finding the first byte not in the
// provided set. This is simple for the 1-byte case.

use core::{cmp, usize};

const USIZE_BYTES: usize = core::mem::size_of::<usize>();
const ALIGN_MASK: usize = core::mem::align_of::<usize>() - 1;

Expand All @@ -22,7 +20,6 @@ fn repeat_byte(b: u8) -> usize {
pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
let vn1 = repeat_byte(n1);
let confirm = |byte| byte != n1;
let loop_size = cmp::min(LOOP_SIZE, haystack.len());
let start_ptr = haystack.as_ptr();

unsafe {
Expand All @@ -41,18 +38,25 @@ pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & ALIGN_MASK));
debug_assert!(ptr > start_ptr);
debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);

let a = *(ptr as *const usize);
let b = *(ptr.add(USIZE_BYTES) as *const usize);
let eqa = (a ^ vn1) != 0;
let eqb = (b ^ vn1) != 0;
if eqa || eqb {
break;

if haystack.len() >= LOOP_SIZE {
// The `if` condition guarantees that `end_ptr.sub(LOOP_SIZE)` (in the loop condition)
// meets the safety requrement that the result must be in bounds of the same allocated
// object.
while ptr <= end_ptr.sub(LOOP_SIZE) {
debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);

let a = *(ptr as *const usize);
let b = *(ptr.add(USIZE_BYTES) as *const usize);
let eqa = (a ^ vn1) != 0;
let eqb = (b ^ vn1) != 0;
if eqa || eqb {
break;
}
ptr = ptr.add(LOOP_SIZE);
}
ptr = ptr.add(LOOP_SIZE);
}

forward_search(start_ptr, end_ptr, ptr, confirm)
}
}
Expand All @@ -61,7 +65,6 @@ pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
let vn1 = repeat_byte(n1);
let confirm = |byte| byte != n1;
let loop_size = cmp::min(LOOP_SIZE, haystack.len());
let start_ptr = haystack.as_ptr();

unsafe {
Expand All @@ -79,17 +82,22 @@ pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {

ptr = ptr.sub(end_ptr as usize & ALIGN_MASK);
debug_assert!(start_ptr <= ptr && ptr <= end_ptr);
while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) {
debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);

let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
let eqa = (a ^ vn1) != 0;
let eqb = (b ^ vn1) != 0;
if eqa || eqb {
break;
if haystack.len() >= LOOP_SIZE {
// The `if` condition guarantees that `start_ptr.add(LOOP_SIZE)` (in the loop
// condition) meets the safety requrement that the result must be in bounds of the same
// allocated object.
while ptr >= start_ptr.add(LOOP_SIZE) {
debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);

let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
let eqa = (a ^ vn1) != 0;
let eqb = (b ^ vn1) != 0;
if eqa || eqb {
break;
}
ptr = ptr.sub(LOOP_SIZE);
}
ptr = ptr.sub(loop_size);
}
reverse_search(start_ptr, end_ptr, ptr, confirm)
}
Expand Down

0 comments on commit 8f3d10a

Please sign in to comment.