From 8f3d10a7caf9ef3ce90bc757bc0af60f62909d6e Mon Sep 17 00:00:00 2001
From: Lukasz Anforowicz <lukasza@chromium.org>
Date: Fri, 25 Oct 2024 15:58:19 +0000
Subject: [PATCH] Simplify safety checks in `inv_memchr`, `LOOP_SIZE`-stepping
 loop.

Before this commit the safety of `end_ptr.sub(loop_size)` depended on a
somewhat remote and indirect reasoning that `haystack.len() >=
LOOP_SIZE` if `loop_size == LOOP_SIZE` (that reasoning was based on
having `let loop_size = cmp::min(LOOP_SIZE, haystack.len())`).

After this commit, the safety checks are done in an `if` statement right
above where `end_ptr.sub(LOOP_SIZE)` happens.  This simplification also
allowed removing the `loop_size` binding/variable.
---
 src/byteset/scalar.rs | 56 ++++++++++++++++++++++++-------------------
 1 file changed, 32 insertions(+), 24 deletions(-)

diff --git a/src/byteset/scalar.rs b/src/byteset/scalar.rs
index 24dcf37..fe5ab63 100644
--- a/src/byteset/scalar.rs
+++ b/src/byteset/scalar.rs
@@ -2,8 +2,6 @@
 // the 'inverse' query of memchr, e.g. finding the first byte not in the
 // provided set. This is simple for the 1-byte case.
 
-use core::{cmp, usize};
-
 const USIZE_BYTES: usize = core::mem::size_of::<usize>();
 const ALIGN_MASK: usize = core::mem::align_of::<usize>() - 1;
 
@@ -22,7 +20,6 @@ fn repeat_byte(b: u8) -> usize {
 pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
     let vn1 = repeat_byte(n1);
     let confirm = |byte| byte != n1;
-    let loop_size = cmp::min(LOOP_SIZE, haystack.len());
     let start_ptr = haystack.as_ptr();
 
     unsafe {
@@ -41,18 +38,25 @@ pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
         ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & ALIGN_MASK));
         debug_assert!(ptr > start_ptr);
         debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
-        while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
-            debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
-
-            let a = *(ptr as *const usize);
-            let b = *(ptr.add(USIZE_BYTES) as *const usize);
-            let eqa = (a ^ vn1) != 0;
-            let eqb = (b ^ vn1) != 0;
-            if eqa || eqb {
-                break;
+
+        if haystack.len() >= LOOP_SIZE {
+            // The `if` condition guarantees that `end_ptr.sub(LOOP_SIZE)` (in the loop condition)
+            // meets the safety requrement that the result must be in bounds of the same allocated
+            // object.
+            while ptr <= end_ptr.sub(LOOP_SIZE) {
+                debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
+
+                let a = *(ptr as *const usize);
+                let b = *(ptr.add(USIZE_BYTES) as *const usize);
+                let eqa = (a ^ vn1) != 0;
+                let eqb = (b ^ vn1) != 0;
+                if eqa || eqb {
+                    break;
+                }
+                ptr = ptr.add(LOOP_SIZE);
             }
-            ptr = ptr.add(LOOP_SIZE);
         }
+
         forward_search(start_ptr, end_ptr, ptr, confirm)
     }
 }
@@ -61,7 +65,6 @@ pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
 pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
     let vn1 = repeat_byte(n1);
     let confirm = |byte| byte != n1;
-    let loop_size = cmp::min(LOOP_SIZE, haystack.len());
     let start_ptr = haystack.as_ptr();
 
     unsafe {
@@ -79,17 +82,22 @@ pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
 
         ptr = ptr.sub(end_ptr as usize & ALIGN_MASK);
         debug_assert!(start_ptr <= ptr && ptr <= end_ptr);
-        while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) {
-            debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
-
-            let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
-            let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
-            let eqa = (a ^ vn1) != 0;
-            let eqb = (b ^ vn1) != 0;
-            if eqa || eqb {
-                break;
+        if haystack.len() >= LOOP_SIZE {
+            // The `if` condition guarantees that `start_ptr.add(LOOP_SIZE)` (in the loop
+            // condition) meets the safety requrement that the result must be in bounds of the same
+            // allocated object.
+            while ptr >= start_ptr.add(LOOP_SIZE) {
+                debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
+
+                let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
+                let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
+                let eqa = (a ^ vn1) != 0;
+                let eqb = (b ^ vn1) != 0;
+                if eqa || eqb {
+                    break;
+                }
+                ptr = ptr.sub(LOOP_SIZE);
             }
-            ptr = ptr.sub(loop_size);
         }
         reverse_search(start_ptr, end_ptr, ptr, confirm)
     }