Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move OsStr::slice_encoded_bytes validation to platform modules #118569

Merged
merged 1 commit into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions library/std/src/ffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@
//! trait, which provides a [`from_wide`] method to convert a native Windows
//! string (without the terminating nul character) to an [`OsString`].
//!
//! ## Other platforms
//!
//! Many other platforms provide their own extension traits in a
//! `std::os::*::ffi` module.
//!
//! ## On all platforms
//!
//! On all platforms, [`OsStr`] consists of a sequence of bytes that is encoded as a superset of
Expand All @@ -135,6 +140,8 @@
//! For limited, inexpensive conversions from and to bytes, see [`OsStr::as_encoded_bytes`] and
//! [`OsStr::from_encoded_bytes_unchecked`].
//!
//! For basic string processing, see [`OsStr::slice_encoded_bytes`].
//!
//! [Unicode scalar value]: https://www.unicode.org/glossary/#unicode_scalar_value
//! [Unicode code point]: https://www.unicode.org/glossary/#code_point
//! [`env::set_var()`]: crate::env::set_var "env::set_var"
Expand Down
43 changes: 8 additions & 35 deletions library/std/src/ffi/os_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::hash::{Hash, Hasher};
use crate::ops::{self, Range};
use crate::rc::Rc;
use crate::slice;
use crate::str::{from_utf8 as str_from_utf8, FromStr};
use crate::str::FromStr;
use crate::sync::Arc;

use crate::sys::os_str::{Buf, Slice};
Expand Down Expand Up @@ -997,42 +997,15 @@ impl OsStr {
/// ```
#[unstable(feature = "os_str_slice", issue = "118485")]
pub fn slice_encoded_bytes<R: ops::RangeBounds<usize>>(&self, range: R) -> &Self {
#[track_caller]
fn check_valid_boundary(bytes: &[u8], index: usize) {
if index == 0 || index == bytes.len() {
return;
}

// Fast path
if bytes[index - 1].is_ascii() || bytes[index].is_ascii() {
return;
}

let (before, after) = bytes.split_at(index);

// UTF-8 takes at most 4 bytes per codepoint, so we don't
// need to check more than that.
let after = after.get(..4).unwrap_or(after);
match str_from_utf8(after) {
Ok(_) => return,
Err(err) if err.valid_up_to() != 0 => return,
Err(_) => (),
}

for len in 2..=4.min(index) {
let before = &before[index - len..];
if str_from_utf8(before).is_ok() {
return;
}
}

panic!("byte index {index} is not an OsStr boundary");
}

let encoded_bytes = self.as_encoded_bytes();
let Range { start, end } = slice::range(range, ..encoded_bytes.len());
check_valid_boundary(encoded_bytes, start);
check_valid_boundary(encoded_bytes, end);

// `check_public_boundary` should panic if the index does not lie on an
// `OsStr` boundary as described above. It's possible to do this in an
// encoding-agnostic way, but details of the internal encoding might
// permit a more efficient implementation.
self.inner.check_public_boundary(start);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes this a public boundary? Is there a private boundary?

I think it would be good to add some docs here if we're adding an extension point -- perhaps a couple lines of common describing what the function should (roughly) do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The public boundaries are where we let users split without panicking. The private boundaries would depend on the safety invariants for the internal encoding. For example if you split in the middle of a WTF-8 codepoint you can cause out-of-bounds reads, so that's neither a public nor a private boundary. But if you split between surrogate codepoints then that's fine as far as the implementation is concerned, we just don't allow users to do that, so that's a private boundary but not a public boundary.

I've added a comment, good call.

self.inner.check_public_boundary(end);

// SAFETY: `slice::range` ensures that `start` and `end` are valid
let slice = unsafe { encoded_bytes.get_unchecked(start..end) };
Expand Down
68 changes: 61 additions & 7 deletions library/std/src/ffi/os_str/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,65 @@ fn slice_encoded_bytes() {
}

#[test]
#[should_panic(expected = "byte index 2 is not an OsStr boundary")]
#[should_panic]
fn slice_out_of_bounds() {
let crab = OsStr::new("🦀");
let _ = crab.slice_encoded_bytes(..5);
}

#[test]
#[should_panic]
fn slice_mid_char() {
let crab = OsStr::new("🦀");
let _ = crab.slice_encoded_bytes(..2);
}

#[cfg(unix)]
#[test]
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
fn slice_invalid_data() {
use crate::os::unix::ffi::OsStrExt;

let os_string = OsStr::from_bytes(b"\xFF\xFF");
let _ = os_string.slice_encoded_bytes(1..);
}

#[cfg(unix)]
#[test]
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
fn slice_partial_utf8() {
use crate::os::unix::ffi::{OsStrExt, OsStringExt};

let part_crab = OsStr::from_bytes(&"🦀".as_bytes()[..3]);
let mut os_string = OsString::from_vec(vec![0xFF]);
os_string.push(part_crab);
let _ = os_string.slice_encoded_bytes(1..);
}

#[cfg(unix)]
#[test]
fn slice_invalid_edge() {
use crate::os::unix::ffi::{OsStrExt, OsStringExt};

let os_string = OsStr::from_bytes(b"a\xFFa");
assert_eq!(os_string.slice_encoded_bytes(..1), "a");
assert_eq!(os_string.slice_encoded_bytes(1..), OsStr::from_bytes(b"\xFFa"));
assert_eq!(os_string.slice_encoded_bytes(..2), OsStr::from_bytes(b"a\xFF"));
assert_eq!(os_string.slice_encoded_bytes(2..), "a");

let os_string = OsStr::from_bytes(&"abc🦀".as_bytes()[..6]);
assert_eq!(os_string.slice_encoded_bytes(..3), "abc");
assert_eq!(os_string.slice_encoded_bytes(3..), OsStr::from_bytes(b"\xF0\x9F\xA6"));

let mut os_string = OsString::from_vec(vec![0xFF]);
os_string.push("🦀");
assert_eq!(os_string.slice_encoded_bytes(..1), OsStr::from_bytes(b"\xFF"));
assert_eq!(os_string.slice_encoded_bytes(1..), "🦀");
}

#[cfg(windows)]
#[test]
#[should_panic(expected = "byte index 3 is not an OsStr boundary")]
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
fn slice_between_surrogates() {
use crate::os::windows::ffi::OsStringExt;

Expand All @@ -216,10 +266,14 @@ fn slice_between_surrogates() {
fn slice_surrogate_edge() {
use crate::os::windows::ffi::OsStringExt;

let os_string = OsString::from_wide(&[0xD800]);
let mut with_crab = os_string.clone();
with_crab.push("🦀");
let surrogate = OsString::from_wide(&[0xD800]);
let mut pre_crab = surrogate.clone();
pre_crab.push("🦀");
assert_eq!(pre_crab.slice_encoded_bytes(..3), surrogate);
assert_eq!(pre_crab.slice_encoded_bytes(3..), "🦀");

assert_eq!(with_crab.slice_encoded_bytes(..3), os_string);
assert_eq!(with_crab.slice_encoded_bytes(3..), "🦀");
let mut post_crab = OsString::from("🦀");
post_crab.push(&surrogate);
assert_eq!(post_crab.slice_encoded_bytes(..4), "🦀");
assert_eq!(post_crab.slice_encoded_bytes(4..), surrogate);
}
43 changes: 43 additions & 0 deletions library/std/src/sys/os_str/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,49 @@ impl Slice {
unsafe { mem::transmute(s) }
}

#[track_caller]
#[inline]
pub fn check_public_boundary(&self, index: usize) {
if index == 0 || index == self.inner.len() {
return;
}
if index < self.inner.len()
&& (self.inner[index - 1].is_ascii() || self.inner[index].is_ascii())
{
return;
}

slow_path(&self.inner, index);

/// We're betting that typical splits will involve an ASCII character.
///
/// Putting the expensive checks in a separate function generates notably
/// better assembly.
#[track_caller]
#[inline(never)]
fn slow_path(bytes: &[u8], index: usize) {
let (before, after) = bytes.split_at(index);

// UTF-8 takes at most 4 bytes per codepoint, so we don't
// need to check more than that.
let after = after.get(..4).unwrap_or(after);
match str::from_utf8(after) {
Ok(_) => return,
Err(err) if err.valid_up_to() != 0 => return,
Err(_) => (),
}

for len in 2..=4.min(index) {
let before = &before[index - len..];
if str::from_utf8(before).is_ok() {
return;
}
}

panic!("byte index {index} is not an OsStr boundary");
}
}

#[inline]
pub fn from_str(s: &str) -> &Slice {
unsafe { Slice::from_encoded_bytes_unchecked(s.as_bytes()) }
Expand Down
7 changes: 6 additions & 1 deletion library/std/src/sys/os_str/wtf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::fmt;
use crate::mem;
use crate::rc::Rc;
use crate::sync::Arc;
use crate::sys_common::wtf8::{Wtf8, Wtf8Buf};
use crate::sys_common::wtf8::{check_utf8_boundary, Wtf8, Wtf8Buf};
use crate::sys_common::{AsInner, FromInner, IntoInner};

#[derive(Clone, Hash)]
Expand Down Expand Up @@ -171,6 +171,11 @@ impl Slice {
mem::transmute(Wtf8::from_bytes_unchecked(s))
}

#[track_caller]
pub fn check_public_boundary(&self, index: usize) {
check_utf8_boundary(&self.inner, index);
}

#[inline]
pub fn from_str(s: &str) -> &Slice {
unsafe { mem::transmute(Wtf8::from_str(s)) }
Expand Down
36 changes: 32 additions & 4 deletions library/std/src/sys_common/wtf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,15 +885,43 @@ fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
unsafe { char::from_u32_unchecked(code_point) }
}

/// Copied from core::str::StrPrelude::is_char_boundary
/// Copied from str::is_char_boundary
#[inline]
pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {
if index == slice.len() {
if index == 0 {
return true;
}
match slice.bytes.get(index) {
None => false,
Some(&b) => b < 128 || b >= 192,
None => index == slice.len(),
Some(&b) => (b as i8) >= -0x40,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this implementation change?

(In particular it seems like the behavior is no longer an unconditional true for index = 0, and also doesn't correspond with the str::is_char_boundary impl?)

The (b as i8) >= -0x40 is probably clearer as b.is_utf8_char_boundary().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is identical to the str::is_char_boundary impl. The point was to bring it in line with that and to be consistent with the new function.

I can't use b.is_utf8_char_boundary() because it's private to core. (Unless there are workarounds for that?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so it is. We should think about exposing that as a public API, it seems consistent with the is_ascii_ functions we already expose that have similar bit-twiddling internally.

}
}

/// Verify that `index` is at the edge of either a valid UTF-8 codepoint
/// (i.e. a codepoint that's not a surrogate) or of the whole string.
///
/// These are the cases currently permitted by `OsStr::slice_encoded_bytes`.
/// Splitting between surrogates is valid as far as WTF-8 is concerned, but
/// we do not permit it in the public API because WTF-8 is considered an
/// implementation detail.
#[track_caller]
#[inline]
pub fn check_utf8_boundary(slice: &Wtf8, index: usize) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused how this is different from the is_code_point_boundary (from just the method name/comments)?

str::is_char_boundary is documented as "is the first byte in a UTF-8 code point sequence or the end of the string", which sounds very similar to "at the edge of either a valid UTF-8 codepoint or of the whole string". Is this needed separately due to WTF-8 details perhaps?

I'm worried that it'll be easy to use the wrong function so I think some detail on when we should use each in comments would be good. It's also a bit worrying to me that we want a new function since that feels like it implies we're changing behavior rather than just optimizing here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm posting a longer explanation below, but the gist of it is that there are some WTF-8 codepoints that are not UTF-8 codepoints. If the string is pure UTF-8 then the boundaries are the same. I've tweaked the comment a little.

It's not a behavioral change, the old function wasn't used for this functionality to begin with. The cases in which this implementation panics should be the same as those in which the old one does.

if index == 0 {
return;
}
match slice.bytes.get(index) {
Some(0xED) => (), // Might be a surrogate
Some(&b) if (b as i8) >= -0x40 => return,
Some(_) => panic!("byte index {index} is not a codepoint boundary"),
None if index == slice.len() => return,
None => panic!("byte index {index} is out of bounds"),
}
if slice.bytes[index + 1] >= 0xA0 {
// There's a surrogate after index. Now check before index.
if index >= 3 && slice.bytes[index - 3] == 0xED && slice.bytes[index - 2] >= 0xA0 {
panic!("byte index {index} lies between surrogate codepoints");
}
}
}

Expand Down
62 changes: 62 additions & 0 deletions library/std/src/sys_common/wtf8/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,65 @@ fn wtf8_to_owned() {
assert_eq!(string.bytes, b"\xED\xA0\x80");
assert!(!string.is_known_utf8);
}

#[test]
fn wtf8_valid_utf8_boundaries() {
let mut string = Wtf8Buf::from_str("aé 💩");
string.push(CodePoint::from_u32(0xD800).unwrap());
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 0);
check_utf8_boundary(&string, 1);
check_utf8_boundary(&string, 3);
check_utf8_boundary(&string, 4);
check_utf8_boundary(&string, 8);
check_utf8_boundary(&string, 14);
assert_eq!(string.len(), 14);

string.push_char('a');
check_utf8_boundary(&string, 14);
check_utf8_boundary(&string, 15);

let mut string = Wtf8Buf::from_str("a");
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 1);

let mut string = Wtf8Buf::from_str("\u{D7FF}");
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 3);

let mut string = Wtf8Buf::new();
string.push(CodePoint::from_u32(0xD800).unwrap());
string.push_char('\u{D7FF}');
check_utf8_boundary(&string, 3);
}

#[test]
#[should_panic(expected = "byte index 4 is out of bounds")]
fn wtf8_utf8_boundary_out_of_bounds() {
let string = Wtf8::from_str("aé");
check_utf8_boundary(&string, 4);
}

#[test]
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
fn wtf8_utf8_boundary_inside_codepoint() {
let string = Wtf8::from_str("é");
check_utf8_boundary(&string, 1);
}

#[test]
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
fn wtf8_utf8_boundary_inside_surrogate() {
let mut string = Wtf8Buf::new();
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 1);
}

#[test]
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
fn wtf8_utf8_boundary_between_surrogates() {
let mut string = Wtf8Buf::new();
string.push(CodePoint::from_u32(0xD800).unwrap());
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 3);
}
Loading