Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize equality for [T] and comparison for [u8] to use memcmp when possible #32699

Merged
merged 3 commits into from
Apr 7, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions src/libcollectionstest/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,18 +574,48 @@ fn test_slice_2() {
assert_eq!(v[1], 3);
}

macro_rules! assert_order {
(Greater, $a:expr, $b:expr) => {
assert_eq!($a.cmp($b), Greater);
assert!($a > $b);
};
(Less, $a:expr, $b:expr) => {
assert_eq!($a.cmp($b), Less);
assert!($a < $b);
};
(Equal, $a:expr, $b:expr) => {
assert_eq!($a.cmp($b), Equal);
assert_eq!($a, $b);
}
}

#[test]
fn test_total_ord_u8() {
let c = &[1u8, 2, 3];
assert_order!(Greater, &[1u8, 2, 3, 4][..], &c[..]);
let c = &[1u8, 2, 3, 4];
assert_order!(Less, &[1u8, 2, 3][..], &c[..]);
let c = &[1u8, 2, 3, 6];
assert_order!(Equal, &[1u8, 2, 3, 6][..], &c[..]);
let c = &[1u8, 2, 3, 4, 5, 6];
assert_order!(Less, &[1u8, 2, 3, 4, 5, 5, 5, 5][..], &c[..]);
let c = &[1u8, 2, 3, 4];
assert_order!(Greater, &[2u8, 2][..], &c[..]);
}


#[test]
fn test_total_ord() {
fn test_total_ord_i32() {
let c = &[1, 2, 3];
[1, 2, 3, 4][..].cmp(c) == Greater;
assert_order!(Greater, &[1, 2, 3, 4][..], &c[..]);
let c = &[1, 2, 3, 4];
[1, 2, 3][..].cmp(c) == Less;
assert_order!(Less, &[1, 2, 3][..], &c[..]);
let c = &[1, 2, 3, 6];
[1, 2, 3, 4][..].cmp(c) == Equal;
assert_order!(Equal, &[1, 2, 3, 6][..], &c[..]);
let c = &[1, 2, 3, 4, 5, 6];
[1, 2, 3, 4, 5, 5, 5, 5][..].cmp(c) == Less;
assert_order!(Less, &[1, 2, 3, 4, 5, 5, 5, 5][..], &c[..]);
let c = &[1, 2, 3, 4];
[2, 2][..].cmp(c) == Greater;
assert_order!(Greater, &[2, 2][..], &c[..]);
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions src/libcore/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#![feature(unwind_attributes)]
#![feature(repr_simd, platform_intrinsics)]
#![feature(rustc_attrs)]
#![feature(specialization)]
#![feature(staged_api)]
#![feature(unboxed_closures)]
#![feature(question_mark)]
Expand Down
155 changes: 139 additions & 16 deletions src/libcore/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1630,12 +1630,60 @@ pub unsafe fn from_raw_parts_mut<'a, T>(p: *mut T, len: usize) -> &'a mut [T] {
}

//
// Boilerplate traits
// Comparison traits
//

extern {
/// Call implementation provided memcmp
///
/// Interprets the data as u8.
///
/// Return 0 for equal, < 0 for less than and > 0 for greater
/// than.
// FIXME(#32610): Return type should be c_int
fn memcmp(s1: *const u8, s2: *const u8, n: usize) -> i32;
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
fn eq(&self, other: &[B]) -> bool {
SlicePartialEq::equal(self, other)
}

fn ne(&self, other: &[B]) -> bool {
SlicePartialEq::not_equal(self, other)
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Eq> Eq for [T] {}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Ord> Ord for [T] {
fn cmp(&self, other: &[T]) -> Ordering {
SliceOrd::compare(self, other)
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: PartialOrd> PartialOrd for [T] {
fn partial_cmp(&self, other: &[T]) -> Option<Ordering> {
SlicePartialOrd::partial_compare(self, other)
}
}

#[doc(hidden)]
// intermediate trait for specialization of slice's PartialEq
trait SlicePartialEq<B> {
fn equal(&self, other: &[B]) -> bool;
fn not_equal(&self, other: &[B]) -> bool;
}

// Generic slice equality
impl<A, B> SlicePartialEq<B> for [A]
where A: PartialEq<B>
{
default fn equal(&self, other: &[B]) -> bool {
if self.len() != other.len() {
return false;
}
Expand All @@ -1648,7 +1696,8 @@ impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {

true
}
fn ne(&self, other: &[B]) -> bool {

default fn not_equal(&self, other: &[B]) -> bool {
if self.len() != other.len() {
return true;
}
Expand All @@ -1663,12 +1712,36 @@ impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Eq> Eq for [T] {}
// Use memcmp for bytewise equality when the types allow
impl<A> SlicePartialEq<A> for [A]
where A: PartialEq<A> + BytewiseEquality
{
fn equal(&self, other: &[A]) -> bool {
if self.len() != other.len() {
return false;
}
unsafe {
let size = mem::size_of_val(self);
memcmp(self.as_ptr() as *const u8,
other.as_ptr() as *const u8, size) == 0
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Ord> Ord for [T] {
fn cmp(&self, other: &[T]) -> Ordering {
fn not_equal(&self, other: &[A]) -> bool {
!self.equal(other)
}
}

#[doc(hidden)]
// intermediate trait for specialization of slice's PartialOrd
trait SlicePartialOrd<B> {
fn partial_compare(&self, other: &[B]) -> Option<Ordering>;
}

impl<A> SlicePartialOrd<A> for [A]
where A: PartialOrd
{
default fn partial_compare(&self, other: &[A]) -> Option<Ordering> {
let l = cmp::min(self.len(), other.len());

// Slice to the loop iteration range to enable bound check
Expand All @@ -1677,19 +1750,33 @@ impl<T: Ord> Ord for [T] {
let rhs = &other[..l];

for i in 0..l {
match lhs[i].cmp(&rhs[i]) {
Ordering::Equal => (),
match lhs[i].partial_cmp(&rhs[i]) {
Some(Ordering::Equal) => (),
non_eq => return non_eq,
}
}

self.len().cmp(&other.len())
self.len().partial_cmp(&other.len())
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: PartialOrd> PartialOrd for [T] {
fn partial_cmp(&self, other: &[T]) -> Option<Ordering> {
impl SlicePartialOrd<u8> for [u8] {
#[inline]
fn partial_compare(&self, other: &[u8]) -> Option<Ordering> {
Some(SliceOrd::compare(self, other))
}
}

#[doc(hidden)]
// intermediate trait for specialization of slice's Ord
trait SliceOrd<B> {
fn compare(&self, other: &[B]) -> Ordering;
}

impl<A> SliceOrd<A> for [A]
where A: Ord
{
default fn compare(&self, other: &[A]) -> Ordering {
let l = cmp::min(self.len(), other.len());

// Slice to the loop iteration range to enable bound check
Expand All @@ -1698,12 +1785,48 @@ impl<T: PartialOrd> PartialOrd for [T] {
let rhs = &other[..l];

for i in 0..l {
match lhs[i].partial_cmp(&rhs[i]) {
Some(Ordering::Equal) => (),
match lhs[i].cmp(&rhs[i]) {
Ordering::Equal => (),
non_eq => return non_eq,
}
}

self.len().partial_cmp(&other.len())
self.len().cmp(&other.len())
}
}

// memcmp compares a sequence of unsigned bytes lexicographically.
// this matches the order we want for [u8], but no others (not even [i8]).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we implement this for [i8]? We may not be able to implement it for larger types due to endianness (unless you're on a big-endian platform), but I'm drawing a blank on why to not implement it for i8. The semantics for comparing [u8] are the same with this and the above impl, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[-1] should compare less than [1], but when interpreted as unsigned bytes ([255] and [1]) it compares greater than.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, right!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

impl SliceOrd<u8> for [u8] {
#[inline]
fn compare(&self, other: &[u8]) -> Ordering {
let order = unsafe {
memcmp(self.as_ptr(), other.as_ptr(),
cmp::min(self.len(), other.len()))
};
if order == 0 {
self.len().cmp(&other.len())
} else if order < 0 {
Less
} else {
Greater
}
}
}

#[doc(hidden)]
/// Trait implemented for types that can be compared for equality using
/// their bytewise representation
trait BytewiseEquality { }

macro_rules! impl_marker_for {
($traitname:ident, $($ty:ty)*) => {
$(
impl $traitname for $ty { }
)*
}
}

impl_marker_for!(BytewiseEquality,
u8 i8 u16 i16 u32 i32 u64 i64 usize isize char bool);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is room to eek out more performance here if we specialize for [u16], [u32] and etc. I did some experimentation with trying to speed up memcmp a year or so ago, and most implementations just check a byte at a time until the slices align on a usize. For these larger types, we could just skip some of these alignment checks since we already know they're aligned.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! That sounds like exactly the motivation needed to make memcmp(*const u8, *const u8, usize) into an llvm intrinsic, so that it uses the alignment information from the pointer (like memcpy already does). http://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bluss: Exactly. There was a few questions a number of years ago (1, 2) to add one, but it didn't get any traction.

PS: I think I found when I was talking about this on #rust-internals with you, @bluss, back in 2015 :) I can't say that this would be a big win, it might just shave off a few conditionals, which may or may not really matter in real code. I also found my old benchmarks, which I've uploaded to https://github.com/erickt/rust-memcmp-benches.


25 changes: 3 additions & 22 deletions src/libcore/str/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1150,16 +1150,7 @@ Section: Comparing strings
#[lang = "str_eq"]
#[inline]
fn eq_slice(a: &str, b: &str) -> bool {
a.len() == b.len() && unsafe { cmp_slice(a, b, a.len()) == 0 }
}

/// Bytewise slice comparison.
/// NOTE: This uses the system's memcmp, which is currently dramatically
/// faster than comparing each byte in a loop.
#[inline]
unsafe fn cmp_slice(a: &str, b: &str, len: usize) -> i32 {
extern { fn memcmp(s1: *const i8, s2: *const i8, n: usize) -> i32; }
memcmp(a.as_ptr() as *const i8, b.as_ptr() as *const i8, len)
a.as_bytes() == b.as_bytes()
}

/*
Expand Down Expand Up @@ -1328,8 +1319,7 @@ Section: Trait implementations
*/

mod traits {
use cmp::{self, Ordering, Ord, PartialEq, PartialOrd, Eq};
use cmp::Ordering::{Less, Greater};
use cmp::{Ord, Ordering, PartialEq, PartialOrd, Eq};
use iter::Iterator;
use option::Option;
use option::Option::Some;
Expand All @@ -1340,16 +1330,7 @@ mod traits {
impl Ord for str {
#[inline]
fn cmp(&self, other: &str) -> Ordering {
let cmp = unsafe {
super::cmp_slice(self, other, cmp::min(self.len(), other.len()))
};
if cmp == 0 {
self.len().cmp(&other.len())
} else if cmp < 0 {
Less
} else {
Greater
}
self.as_bytes().cmp(other.as_bytes())
}
}

Expand Down