Skip to content

Commit

Permalink
Move OsStr::slice_encoded_bytes validation to platform modules
Browse files Browse the repository at this point in the history
On Unix this opens the possibility of removing the checks later.

On other platforms this improves performance and error messaging.
  • Loading branch information
blyxxyz committed Dec 3, 2023
1 parent 463b3bc commit 694a789
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 47 deletions.
7 changes: 7 additions & 0 deletions library/std/src/ffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@
//! trait, which provides a [`from_wide`] method to convert a native Windows
//! string (without the terminating nul character) to an [`OsString`].
//!
//! ## Other platforms
//!
//! Many other platforms provide their own extension traits in a
//! `std::os::*::ffi` module.
//!
//! ## On all platforms
//!
//! On all platforms, [`OsStr`] consists of a sequence of bytes that is encoded as a superset of
Expand All @@ -135,6 +140,8 @@
//! For limited, inexpensive conversions from and to bytes, see [`OsStr::as_encoded_bytes`] and
//! [`OsStr::from_encoded_bytes_unchecked`].
//!
//! For basic string processing, see [`OsStr::slice_encoded_bytes`].
//!
//! [Unicode scalar value]: https://www.unicode.org/glossary/#unicode_scalar_value
//! [Unicode code point]: https://www.unicode.org/glossary/#code_point
//! [`env::set_var()`]: crate::env::set_var "env::set_var"
Expand Down
39 changes: 4 additions & 35 deletions library/std/src/ffi/os_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::hash::{Hash, Hasher};
use crate::ops::{self, Range};
use crate::rc::Rc;
use crate::slice;
use crate::str::{from_utf8 as str_from_utf8, FromStr};
use crate::str::FromStr;
use crate::sync::Arc;

use crate::sys::os_str::{Buf, Slice};
Expand Down Expand Up @@ -997,42 +997,11 @@ impl OsStr {
/// ```
#[unstable(feature = "os_str_slice", issue = "118485")]
pub fn slice_encoded_bytes<R: ops::RangeBounds<usize>>(&self, range: R) -> &Self {
#[track_caller]
fn check_valid_boundary(bytes: &[u8], index: usize) {
if index == 0 || index == bytes.len() {
return;
}

// Fast path
if bytes[index - 1].is_ascii() || bytes[index].is_ascii() {
return;
}

let (before, after) = bytes.split_at(index);

// UTF-8 takes at most 4 bytes per codepoint, so we don't
// need to check more than that.
let after = after.get(..4).unwrap_or(after);
match str_from_utf8(after) {
Ok(_) => return,
Err(err) if err.valid_up_to() != 0 => return,
Err(_) => (),
}

for len in 2..=4.min(index) {
let before = &before[index - len..];
if str_from_utf8(before).is_ok() {
return;
}
}

panic!("byte index {index} is not an OsStr boundary");
}

let encoded_bytes = self.as_encoded_bytes();
let Range { start, end } = slice::range(range, ..encoded_bytes.len());
check_valid_boundary(encoded_bytes, start);
check_valid_boundary(encoded_bytes, end);

self.inner.check_public_boundary(start);
self.inner.check_public_boundary(end);

// SAFETY: `slice::range` ensures that `start` and `end` are valid
let slice = unsafe { encoded_bytes.get_unchecked(start..end) };
Expand Down
68 changes: 61 additions & 7 deletions library/std/src/ffi/os_str/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,65 @@ fn slice_encoded_bytes() {
}

#[test]
#[should_panic(expected = "byte index 2 is not an OsStr boundary")]
#[should_panic]
fn slice_out_of_bounds() {
let crab = OsStr::new("🦀");
let _ = crab.slice_encoded_bytes(..5);
}

#[test]
#[should_panic]
fn slice_mid_char() {
let crab = OsStr::new("🦀");
let _ = crab.slice_encoded_bytes(..2);
}

#[cfg(unix)]
#[test]
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
fn slice_invalid_data() {
use crate::os::unix::ffi::OsStrExt;

let os_string = OsStr::from_bytes(b"\xFF\xFF");
let _ = os_string.slice_encoded_bytes(1..);
}

#[cfg(unix)]
#[test]
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
fn slice_partial_utf8() {
use crate::os::unix::ffi::{OsStrExt, OsStringExt};

let part_crab = OsStr::from_bytes(&"🦀".as_bytes()[..3]);
let mut os_string = OsString::from_vec(vec![0xFF]);
os_string.push(part_crab);
let _ = os_string.slice_encoded_bytes(1..);
}

#[cfg(unix)]
#[test]
fn slice_invalid_edge() {
use crate::os::unix::ffi::{OsStrExt, OsStringExt};

let os_string = OsStr::from_bytes(b"a\xFFa");
assert_eq!(os_string.slice_encoded_bytes(..1), "a");
assert_eq!(os_string.slice_encoded_bytes(1..), OsStr::from_bytes(b"\xFFa"));
assert_eq!(os_string.slice_encoded_bytes(..2), OsStr::from_bytes(b"a\xFF"));
assert_eq!(os_string.slice_encoded_bytes(2..), "a");

let os_string = OsStr::from_bytes(&"abc🦀".as_bytes()[..6]);
assert_eq!(os_string.slice_encoded_bytes(..3), "abc");
assert_eq!(os_string.slice_encoded_bytes(3..), OsStr::from_bytes(b"\xF0\x9F\xA6"));

let mut os_string = OsString::from_vec(vec![0xFF]);
os_string.push("🦀");
assert_eq!(os_string.slice_encoded_bytes(..1), OsStr::from_bytes(b"\xFF"));
assert_eq!(os_string.slice_encoded_bytes(1..), "🦀");
}

#[cfg(windows)]
#[test]
#[should_panic(expected = "byte index 3 is not an OsStr boundary")]
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
fn slice_between_surrogates() {
use crate::os::windows::ffi::OsStringExt;

Expand All @@ -220,10 +270,14 @@ fn slice_between_surrogates() {
fn slice_surrogate_edge() {
use crate::os::windows::ffi::OsStringExt;

let os_string = OsString::from_wide(&[0xD800]);
let mut with_crab = os_string.clone();
with_crab.push("🦀");
let surrogate = OsString::from_wide(&[0xD800]);
let mut pre_crab = surrogate.clone();
pre_crab.push("🦀");
assert_eq!(pre_crab.slice_encoded_bytes(..3), surrogate);
assert_eq!(pre_crab.slice_encoded_bytes(3..), "🦀");

assert_eq!(with_crab.slice_encoded_bytes(..3), os_string);
assert_eq!(with_crab.slice_encoded_bytes(3..), "🦀");
let mut post_crab = OsString::from("🦀");
post_crab.push(&surrogate);
assert_eq!(post_crab.slice_encoded_bytes(..4), "🦀");
assert_eq!(post_crab.slice_encoded_bytes(4..), surrogate);
}
43 changes: 43 additions & 0 deletions library/std/src/sys/unix/os_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,49 @@ impl Slice {
unsafe { mem::transmute(s) }
}

#[track_caller]
#[inline]
pub fn check_public_boundary(&self, index: usize) {
if index == 0 || index == self.inner.len() {
return;
}
if index < self.inner.len()
&& (self.inner[index - 1].is_ascii() || self.inner[index].is_ascii())
{
return;
}

slow_path(&self.inner, index);

/// We're betting that typical splits will involve an ASCII character.
///
/// Putting the expensive checks in a separate function generates notably
/// better assembly.
#[track_caller]
#[inline(never)]
fn slow_path(bytes: &[u8], index: usize) {
let (before, after) = bytes.split_at(index);

// UTF-8 takes at most 4 bytes per codepoint, so we don't
// need to check more than that.
let after = after.get(..4).unwrap_or(after);
match str::from_utf8(after) {
Ok(_) => return,
Err(err) if err.valid_up_to() != 0 => return,
Err(_) => (),
}

for len in 2..=4.min(index) {
let before = &before[index - len..];
if str::from_utf8(before).is_ok() {
return;
}
}

panic!("byte index {index} is not an OsStr boundary");
}
}

#[inline]
pub fn from_str(s: &str) -> &Slice {
unsafe { Slice::from_encoded_bytes_unchecked(s.as_bytes()) }
Expand Down
10 changes: 10 additions & 0 deletions library/std/src/sys/unsupported/os_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ impl Slice {
unsafe { mem::transmute(s) }
}

#[inline]
pub fn check_public_boundary(&self, index: usize) {
// We need to check that self.inner.is_char_boundary(index).
// If we delegate that to the Index impl then we'll get a nice panic
// message courtesy of slice_error_fail.
// (Ideally we'd use slice_error_fail directly with #[track_caller]
// but it isn't exported.)
let _ = &self.inner[..index];
}

#[inline]
pub fn from_str(s: &str) -> &Slice {
unsafe { mem::transmute(s) }
Expand Down
7 changes: 6 additions & 1 deletion library/std/src/sys/windows/os_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::fmt;
use crate::mem;
use crate::rc::Rc;
use crate::sync::Arc;
use crate::sys_common::wtf8::{Wtf8, Wtf8Buf};
use crate::sys_common::wtf8::{check_utf8_boundary, Wtf8, Wtf8Buf};
use crate::sys_common::{AsInner, FromInner, IntoInner};

#[derive(Clone, Hash)]
Expand Down Expand Up @@ -171,6 +171,11 @@ impl Slice {
mem::transmute(Wtf8::from_bytes_unchecked(s))
}

#[track_caller]
pub fn check_public_boundary(&self, index: usize) {
check_utf8_boundary(&self.inner, index);
}

#[inline]
pub fn from_str(s: &str) -> &Slice {
unsafe { mem::transmute(Wtf8::from_str(s)) }
Expand Down
36 changes: 32 additions & 4 deletions library/std/src/sys_common/wtf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,15 +885,43 @@ fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
unsafe { char::from_u32_unchecked(code_point) }
}

/// Copied from core::str::StrPrelude::is_char_boundary
/// Copied from str::is_char_boundary
#[inline]
pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {
if index == slice.len() {
if index == 0 {
return true;
}
match slice.bytes.get(index) {
None => false,
Some(&b) => b < 128 || b >= 192,
None => index == slice.len(),
Some(&b) => (b as i8) >= -0x40,
}
}

/// Verify that `index` is at the edge of either a valid UTF-8 codepoint
/// or of the whole string.
///
/// These are the cases currently permitted by `OsStr::slice_encoded_bytes`.
/// Splitting between surrogates is valid as far as WTF-8 is concerned, but
/// we do not permit it in the public API because WTF-8 is considered an
/// implementation detail.
#[track_caller]
#[inline]
pub fn check_utf8_boundary(slice: &Wtf8, index: usize) {
if index == 0 {
return;
}
match slice.bytes.get(index) {
Some(0xED) => (), // Might be a surrogate
Some(&b) if (b as i8) >= -0x40 => return,
Some(_) => panic!("byte index {index} is not a codepoint boundary"),
None if index == slice.len() => return,
None => panic!("byte index {index} is out of bounds"),
}
if slice.bytes[index + 1] >= 0xA0 {
// There's a surrogate after index. Now check before index.
if index >= 3 && slice.bytes[index - 3] == 0xED && slice.bytes[index - 2] >= 0xA0 {
panic!("byte index {index} lies between surrogate codepoints");
}
}
}

Expand Down
62 changes: 62 additions & 0 deletions library/std/src/sys_common/wtf8/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,65 @@ fn wtf8_to_owned() {
assert_eq!(string.bytes, b"\xED\xA0\x80");
assert!(!string.is_known_utf8);
}

#[test]
fn wtf8_valid_utf8_boundaries() {
let mut string = Wtf8Buf::from_str("aé 💩");
string.push(CodePoint::from_u32(0xD800).unwrap());
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 0);
check_utf8_boundary(&string, 1);
check_utf8_boundary(&string, 3);
check_utf8_boundary(&string, 4);
check_utf8_boundary(&string, 8);
check_utf8_boundary(&string, 14);
assert_eq!(string.len(), 14);

string.push_char('a');
check_utf8_boundary(&string, 14);
check_utf8_boundary(&string, 15);

let mut string = Wtf8Buf::from_str("a");
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 1);

let mut string = Wtf8Buf::from_str("\u{D7FF}");
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 3);

let mut string = Wtf8Buf::new();
string.push(CodePoint::from_u32(0xD800).unwrap());
string.push_char('\u{D7FF}');
check_utf8_boundary(&string, 3);
}

#[test]
#[should_panic(expected = "byte index 4 is out of bounds")]
fn wtf8_utf8_boundary_out_of_bounds() {
let string = Wtf8::from_str("aé");
check_utf8_boundary(&string, 4);
}

#[test]
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
fn wtf8_utf8_boundary_inside_codepoint() {
let string = Wtf8::from_str("é");
check_utf8_boundary(&string, 1);
}

#[test]
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
fn wtf8_utf8_boundary_inside_surrogate() {
let mut string = Wtf8Buf::new();
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 1);
}

#[test]
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
fn wtf8_utf8_boundary_between_surrogates() {
let mut string = Wtf8Buf::new();
string.push(CodePoint::from_u32(0xD800).unwrap());
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 3);
}

0 comments on commit 694a789

Please sign in to comment.