Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bounds checked get_unchecked, use it everywhere. #231

Merged
merged 5 commits into from
Jul 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/avx2/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::arch::x86_64::{
use std::mem;

pub use crate::error::{Error, ErrorType};
use crate::safer_unchecked::GetSaferUnchecked;
use crate::stringparse::{handle_unicode_codepoint, ESCAPE_MAP};
use crate::Deserializer;
pub use crate::Result;
Expand Down Expand Up @@ -44,7 +45,7 @@ impl<'de> Deserializer<'de> {
// This is safe since we check sub's length in the range access above and only
// create sub sliced form sub to `sub.len()`.

let src: &[u8] = unsafe { data.get_unchecked(idx..) };
let src: &[u8] = unsafe { data.get_kinda_unchecked(idx..) };
let mut src_i: usize = 0;
let mut len = src_i;
loop {
Expand Down Expand Up @@ -77,7 +78,7 @@ impl<'de> Deserializer<'de> {

len += quote_dist as usize;
unsafe {
let v = input.get_unchecked(idx..idx + len) as *const [u8] as *const str;
let v = input.get_kinda_unchecked(idx..idx + len) as *const [u8] as *const str;
return Ok(&*v);
}

Expand Down Expand Up @@ -142,10 +143,10 @@ impl<'de> Deserializer<'de> {
dst_i += quote_dist as usize;
unsafe {
input
.get_unchecked_mut(idx + len..idx + len + dst_i)
.clone_from_slice(buffer.get_unchecked(..dst_i));
.get_kinda_unchecked_mut(idx + len..idx + len + dst_i)
.clone_from_slice(buffer.get_kinda_unchecked(..dst_i));
let v =
input.get_unchecked(idx..idx + len + dst_i) as *const [u8] as *const str;
input.get_kinda_unchecked(idx..idx + len + dst_i) as *const [u8] as *const str;
return Ok(&*v);
}

Expand All @@ -155,16 +156,16 @@ impl<'de> Deserializer<'de> {
if (quote_bits.wrapping_sub(1) & bs_bits) != 0 {
// find out where the backspace is
let bs_dist: u32 = bs_bits.trailing_zeros();
let escape_char: u8 = unsafe { *src.get_unchecked(src_i + bs_dist as usize + 1) };
let escape_char: u8 = unsafe { *src.get_kinda_unchecked(src_i + bs_dist as usize + 1) };
// we encountered backslash first. Handle backslash
if escape_char == b'u' {
// move src/dst up to the start; they will be further adjusted
// within the unicode codepoint handling code.
src_i += bs_dist as usize;
dst_i += bs_dist as usize;
let (o, s) = if let Ok(r) =
handle_unicode_codepoint(unsafe { src.get_unchecked(src_i..) }, unsafe {
buffer.get_unchecked_mut(dst_i..)
handle_unicode_codepoint(unsafe { src.get_kinda_unchecked(src_i..) }, unsafe {
buffer.get_kinda_unchecked_mut(dst_i..)
}) {
r
} else {
Expand All @@ -182,12 +183,12 @@ impl<'de> Deserializer<'de> {
// note this may reach beyond the part of the buffer we've actually
// seen. I think this is ok
let escape_result: u8 =
unsafe { *ESCAPE_MAP.get_unchecked(escape_char as usize) };
unsafe { *ESCAPE_MAP.get_kinda_unchecked(escape_char as usize) };
if escape_result == 0 {
return Err(Self::raw_error(src_i, escape_char as char, InvalidEscape));
}
unsafe {
*buffer.get_unchecked_mut(dst_i + bs_dist as usize) = escape_result;
*buffer.get_kinda_unchecked_mut(dst_i + bs_dist as usize) = escape_result;
}
src_i += bs_dist as usize + 2;
dst_i += bs_dist as usize + 1;
Expand Down
34 changes: 18 additions & 16 deletions src/charutils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::safer_unchecked::GetSaferUnchecked;

const STRUCTURAL_OR_WHITESPACE_NEGATED: [u32; 256] = [
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
Expand All @@ -22,12 +24,12 @@ const STRUCTURAL_OR_WHITESPACE: [u32; 256] = [

#[cfg_attr(not(feature = "no-inline"), inline(always))]
pub fn is_not_structural_or_whitespace(c: u8) -> u32 {
unsafe { *STRUCTURAL_OR_WHITESPACE_NEGATED.get_unchecked(c as usize) }
unsafe { *STRUCTURAL_OR_WHITESPACE_NEGATED.get_kinda_unchecked(c as usize) }
}

#[cfg_attr(not(feature = "no-inline"), inline(always))]
pub fn is_structural_or_whitespace(c: u8) -> u32 {
unsafe { *STRUCTURAL_OR_WHITESPACE.get_unchecked(c as usize) }
unsafe { *STRUCTURAL_OR_WHITESPACE.get_kinda_unchecked(c as usize) }
}

const DIGITTOVAL: [i8; 256] = [
Expand Down Expand Up @@ -55,10 +57,10 @@ pub fn hex_to_u32_nocheck(src: &[u8]) -> u32 {
// invalid value. After the shifts, this will *still* result in the outcome that the high 16 bits of any
// value with any invalid char will be all 1's. We check for this in the caller.
unsafe {
let v1: i32 = i32::from(*DIGITTOVAL.get_unchecked(*src.get_unchecked(0) as usize));
let v2: i32 = i32::from(*DIGITTOVAL.get_unchecked(*src.get_unchecked(1) as usize));
let v3: i32 = i32::from(*DIGITTOVAL.get_unchecked(*src.get_unchecked(2) as usize));
let v4: i32 = i32::from(*DIGITTOVAL.get_unchecked(*src.get_unchecked(3) as usize));
let v1: i32 = i32::from(*DIGITTOVAL.get_kinda_unchecked(*src.get_kinda_unchecked(0) as usize));
let v2: i32 = i32::from(*DIGITTOVAL.get_kinda_unchecked(*src.get_kinda_unchecked(1) as usize));
let v3: i32 = i32::from(*DIGITTOVAL.get_kinda_unchecked(*src.get_kinda_unchecked(2) as usize));
let v4: i32 = i32::from(*DIGITTOVAL.get_kinda_unchecked(*src.get_kinda_unchecked(3) as usize));
(v1 << 12 | v2 << 8 | v3 << 4 | v4) as u32
}
}
Expand All @@ -80,27 +82,27 @@ pub fn hex_to_u32_nocheck(src: &[u8]) -> u32 {
pub fn codepoint_to_utf8(cp: u32, c: &mut [u8]) -> usize {
unsafe {
if cp <= 0x7F {
*c.get_unchecked_mut(0) = cp as u8;
*c.get_kinda_unchecked_mut(0) = cp as u8;
return 1; // ascii
}
if cp <= 0x7FF {
*c.get_unchecked_mut(0) = ((cp >> 6) + 192) as u8;
*c.get_unchecked_mut(1) = ((cp & 63) + 128) as u8;
*c.get_kinda_unchecked_mut(0) = ((cp >> 6) + 192) as u8;
*c.get_kinda_unchecked_mut(1) = ((cp & 63) + 128) as u8;
return 2; // universal plane
// Surrogates are treated elsewhere...
//} //else if (0xd800 <= cp && cp <= 0xdfff) {
// return 0; // surrogates // could put assert here
} else if cp <= 0xFFFF {
*c.get_unchecked_mut(0) = ((cp >> 12) + 224) as u8;
*c.get_unchecked_mut(1) = (((cp >> 6) & 63) + 128) as u8;
*c.get_unchecked_mut(2) = ((cp & 63) + 128) as u8;
*c.get_kinda_unchecked_mut(0) = ((cp >> 12) + 224) as u8;
*c.get_kinda_unchecked_mut(1) = (((cp >> 6) & 63) + 128) as u8;
*c.get_kinda_unchecked_mut(2) = ((cp & 63) + 128) as u8;
return 3;
} else if cp <= 0x0010_FFFF {
// if you know you have a valid code point, this is not needed
*c.get_unchecked_mut(0) = ((cp >> 18) + 240) as u8;
*c.get_unchecked_mut(1) = (((cp >> 12) & 63) + 128) as u8;
*c.get_unchecked_mut(2) = (((cp >> 6) & 63) + 128) as u8;
*c.get_unchecked_mut(3) = ((cp & 63) + 128) as u8;
*c.get_kinda_unchecked_mut(0) = ((cp >> 18) + 240) as u8;
*c.get_kinda_unchecked_mut(1) = (((cp >> 12) & 63) + 128) as u8;
*c.get_kinda_unchecked_mut(2) = (((cp >> 6) & 63) + 128) as u8;
*c.get_kinda_unchecked_mut(3) = ((cp & 63) + 128) as u8;
return 4;
}
}
Expand Down
33 changes: 20 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#![deny(warnings)]
#![cfg_attr(feature = "hints", feature(core_intrinsics))]
#![deny(warnings)]
#![warn(unused_extern_crates)]
#![deny(
clippy::all,
Expand Down Expand Up @@ -139,6 +138,9 @@ mod macros;
mod error;
mod numberparse;
mod stringparse;
mod safer_unchecked;

use safer_unchecked::GetSaferUnchecked;

/// Reexport of Cow
pub mod cow;
Expand Down Expand Up @@ -479,16 +481,20 @@ impl<'de> Deserializer<'de> {
}

unsafe {
input_buffer
.as_mut_slice()
.get_unchecked_mut(..len)
.clone_from_slice(input);
*(input_buffer.get_unchecked_mut(len)) = 0;
input_buffer.set_len(len);
std::ptr::copy_nonoverlapping(
input.as_ptr(),
input_buffer.as_mut_ptr(),
len,
);

let to_fill = input_buffer.capacity() - len;
std::ptr::write_bytes(input_buffer.as_mut_ptr().add(len), 0, to_fill);
Comment on lines +490 to +491
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to guard this behind the same if cfg!(debug_assertions) that is in get_kinda_unchecked? It doesn't really seem like it's needed during runtime, just during tests to validate the gets. OTOH, it's probably worth benchmarking if it makes a measurable difference and is worth guarding or fine to just leave in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because if you just write one byte, you get a use-of-uninitialized-value from memory sanitizer, and I'm sure we're reading uninit bytes into a u64 (say, read_true_atom run on [t r u e \x00 <uninit> <uninit> <uninit>]

Simply reading uninit bytes into a u64 is UB (see: rust-lang/unsafe-code-guidelines#71), and memory sanitizer complaining means there's definitely UB here (memory sanitizer is a very conservative check, only checking for branching on uninit bytes).

Upstream seems to... just ignore the problem (they run memory sanitizer, but ignore this).

We presumably only need to ensure that input.len() + SIMDJSON_PADDING bytes are zero filled, so the user could re-use buffers manually (which would remove both the allocation and the zeroing). As in, the write_bytes could only be done if we're using a buffer that isn't long enough (in len, not capacity).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They ignore it because it's not a real problem, it's a "UB by decree" not because any behavior is actually undefined, the discussion you linked says this quite clearly.

There are no possible invalid combinations of bits in a u64. Rust has the tendency to be overly careful with safety, which is great, but sometimes it's worth understanding what happens under the hood - that's why unsafe exists, after all, the compiler can't always make the right call as it misses context, like that' we're not reading a u64 but a SIMD register.

But it's not worth arguing over this, criterion says the change has no significant impact on performance, some tests are marginally faster others marginally slower - it's with run on run varriance. So there is no reason to keep potentially UB behaviour just for the sake of it, that'd be silly, if we can be compiler-agreed-safe and fast that's definitely the route to take :D

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it's not worth arguing over this

I'm not commenting here to argue with you, but to note that you should be significantly more careful with uninitialized integers since rust-lang/rust#106294


input_buffer.set_len(input_buffer.capacity());
};

let s1_result: std::result::Result<Vec<u32>, ErrorType> =
unsafe { Self::find_structural_bits(input_buffer) };
unsafe { Self::find_structural_bits(input) };

let structural_indexes = match s1_result {
Ok(i) => i,
Expand Down Expand Up @@ -519,7 +525,7 @@ impl<'de> Deserializer<'de> {
#[cfg_attr(not(feature = "no-inline"), inline(always))]
pub unsafe fn next_(&mut self) -> Node<'de> {
self.idx += 1;
*self.tape.get_unchecked(self.idx)
*self.tape.get_kinda_unchecked(self.idx)
}

//#[inline(never)]
Expand Down Expand Up @@ -569,7 +575,7 @@ impl<'de> Deserializer<'de> {
__builtin_prefetch(buf + idx + 128);
#endif
*/
let chunk = input.get_unchecked(idx..idx + 64);
let chunk = input.get_kinda_unchecked(idx..idx + 64);
utf8_validator.update_from_chunks(chunk);

let input = SimdInput::new(chunk);
Expand Down Expand Up @@ -717,15 +723,16 @@ impl AlignedBuf {
}
}

fn as_mut_ptr(&mut self) -> *mut u8 {
self.inner.as_ptr()
}

fn capacity_overflow() -> ! {
panic!("capacity overflow");
}
fn capacity(&self) -> usize {
self.capacity
}
fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.inner.as_ptr(), self.len) }
}
unsafe fn set_len(&mut self, n: usize) {
assert!(
n <= self.capacity,
Expand Down
33 changes: 17 additions & 16 deletions src/neon/deser.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::safer_unchecked::GetSaferUnchecked;
use crate::error::ErrorType;
use crate::neon::stage1::bit_mask;
use crate::stringparse::{handle_unicode_codepoint, ESCAPE_MAP};
Expand Down Expand Up @@ -64,14 +65,14 @@ impl<'de> Deserializer<'de> {
// This is safe since we check sub's length in the range access above and only
// create sub sliced form sub to `sub.len()`.

let src: &[u8] = unsafe { data.get_unchecked(idx..) };
let src: &[u8] = unsafe { data.get_kinda_unchecked(idx..) };
let mut src_i: usize = 0;
let mut len = src_i;
loop {
let (v0, v1) = unsafe {
(
vld1q_u8(src.get_unchecked(src_i..src_i + 16).as_ptr()),
vld1q_u8(src.get_unchecked(src_i + 16..src_i + 32).as_ptr()),
vld1q_u8(src.get_kinda_unchecked(src_i..src_i + 16).as_ptr()),
vld1q_u8(src.get_kinda_unchecked(src_i + 16..src_i + 32).as_ptr()),
)
};

Expand All @@ -92,7 +93,7 @@ impl<'de> Deserializer<'de> {

len += quote_dist as usize;
unsafe {
let v = input.get_unchecked(idx..idx + len) as *const [u8] as *const str;
let v = input.get_kinda_unchecked(idx..idx + len) as *const [u8] as *const str;
return Ok(&*v);
}

Expand All @@ -119,15 +120,15 @@ impl<'de> Deserializer<'de> {
loop {
let (v0, v1) = unsafe {
(
vld1q_u8(src.get_unchecked(src_i..src_i + 16).as_ptr()),
vld1q_u8(src.get_unchecked(src_i + 16..src_i + 32).as_ptr()),
vld1q_u8(src.get_kinda_unchecked(src_i..src_i + 16).as_ptr()),
vld1q_u8(src.get_kinda_unchecked(src_i + 16..src_i + 32).as_ptr()),
)
};

unsafe {
buffer
.get_unchecked_mut(dst_i..dst_i + 32)
.copy_from_slice(src.get_unchecked(src_i..src_i + 32));
.get_kinda_unchecked_mut(dst_i..dst_i + 32)
.copy_from_slice(src.get_kinda_unchecked(src_i..src_i + 32));
}

// store to dest unconditionally - we can overwrite the bits we don't like
Expand All @@ -150,10 +151,10 @@ impl<'de> Deserializer<'de> {
dst_i += quote_dist as usize;
unsafe {
input
.get_unchecked_mut(idx + len..idx + len + dst_i)
.clone_from_slice(buffer.get_unchecked(..dst_i));
.get_kinda_unchecked_mut(idx + len..idx + len + dst_i)
.clone_from_slice(buffer.get_kinda_unchecked(..dst_i));
let v =
input.get_unchecked(idx..idx + len + dst_i) as *const [u8] as *const str;
input.get_kinda_unchecked(idx..idx + len + dst_i) as *const [u8] as *const str;
return Ok(&*v);
}

Expand All @@ -163,16 +164,16 @@ impl<'de> Deserializer<'de> {
if (quote_bits.wrapping_sub(1) & bs_bits) != 0 {
// find out where the backspace is
let bs_dist: u32 = bs_bits.trailing_zeros();
let escape_char: u8 = unsafe { *src.get_unchecked(src_i + bs_dist as usize + 1) };
let escape_char: u8 = unsafe { *src.get_kinda_unchecked(src_i + bs_dist as usize + 1) };
// we encountered backslash first. Handle backslash
if escape_char == b'u' {
// move src/dst up to the start; they will be further adjusted
// within the unicode codepoint handling code.
src_i += bs_dist as usize;
dst_i += bs_dist as usize;
let (o, s) = if let Ok(r) =
handle_unicode_codepoint(unsafe { src.get_unchecked(src_i..) }, unsafe {
buffer.get_unchecked_mut(dst_i..)
handle_unicode_codepoint(unsafe { src.get_kinda_unchecked(src_i..) }, unsafe {
buffer.get_kinda_unchecked_mut(dst_i..)
}) {
r
} else {
Expand All @@ -190,12 +191,12 @@ impl<'de> Deserializer<'de> {
// note this may reach beyond the part of the buffer we've actually
// seen. I think this is ok
let escape_result: u8 =
unsafe { *ESCAPE_MAP.get_unchecked(escape_char as usize) };
unsafe { *ESCAPE_MAP.get_kinda_unchecked(escape_char as usize) };
if escape_result == 0 {
return Err(Self::raw_error(src_i, escape_char as char, InvalidEscape));
}
unsafe {
*buffer.get_unchecked_mut(dst_i + bs_dist as usize) = escape_result;
*buffer.get_kinda_unchecked_mut(dst_i + bs_dist as usize) = escape_result;
}
src_i += bs_dist as usize + 2;
dst_i += bs_dist as usize + 1;
Expand Down
Loading