-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
performance regression of binary_search #115271
Comments
If your use-case permits it you can try partition_point instead which doesn't check the equality case. Its limitation is that it doesn't distinguish the presence/absence of the element at the returned index. |
I haven't run your benchmark yet, but #117722 was merged which might reclaim some performance. Using |
@okaneco Hi, thank you very much for your contribution. Here are the new benchmarks. (MacBook Pro M2). It looks like
lib.rs#![feature(core_intrinsics)]
use std::cmp::Ord;
use std::cmp::Ordering::{self, Equal, Greater, Less};
pub fn old_binary_search<T>(s: &[T], x: &T) -> Result<usize, usize>
where
T: Ord,
{
old_binary_search_by(s, |p| p.cmp(x))
}
#[inline(always)]
pub fn old_binary_search_by<'a, T, F>(s: &'a [T], mut f: F) -> Result<usize, usize>
where
F: FnMut(&'a T) -> Ordering,
{
let mut size = s.len();
if size == 0 {
return Err(0);
}
let mut base = 0usize;
while size > 1 {
let half = size / 2;
let mid = base + half;
// mid is always in [0, size), that means mid is >= 0 and < size.
// mid >= 0: by definition
// mid < size: mid = size / 2 + size / 4 + size / 8 ...
let cmp = f(unsafe { s.get_unchecked(mid) });
base = if cmp == Greater { base } else { mid };
size -= half;
}
// base is always in [0, size) because base <= mid.
let cmp = f(unsafe { s.get_unchecked(base) });
if cmp == Equal {
Ok(base)
} else {
Err(base + (cmp == Less) as usize)
}
}
pub fn new_binary_search<T>(s: &[T], x: &T) -> Result<usize, usize>
where
T: Ord,
{
new_binary_search_by(s, |p| p.cmp(x))
}
#[inline(always)]
pub fn new_binary_search_by<'a, T, F>(s: &'a [T], mut f: F) -> Result<usize, usize>
where
F: FnMut(&'a T) -> Ordering,
{
// INVARIANTS:
// - 0 <= left <= left + size = right <= self.len()
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = s.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = left + size / 2;
// SAFETY: the while condition means `size` is strictly positive, so
// `size/2 < size`. Thus `left + size/2 < left + size`, which
// coupled with the `left + size <= self.len()` invariant means
// we have `left + size/2 < self.len()`, and this is in-bounds.
let cmp = f(unsafe { s.get_unchecked(mid) });
// The reason why we use if/else control flow rather than match
// is because match reorders comparison operations, which is perf sensitive.
// This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra.
if cmp == Less {
left = mid + 1;
} else if cmp == Greater {
right = mid;
} else {
return Ok(mid);
}
size = right - left;
}
Err(left)
}
pub fn binary_search_117722<T>(s: &[T], x: &T) -> Result<usize, usize>
where
T: Ord,
{
binary_search_by_117722(s, |p| p.cmp(x))
}
#[inline(always)]
pub fn binary_search_by_117722<'a, F, T>(s: &'a [T], mut f: F) -> Result<usize, usize>
where
F: FnMut(&'a T) -> Ordering,
{
// INVARIANTS:
// - 0 <= left <= left + size = right <= self.len()
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = s.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = left + size / 2;
// SAFETY: the while condition means `size` is strictly positive, so
// `size/2 < size`. Thus `left + size/2 < left + size`, which
// coupled with the `left + size <= self.len()` invariant means
// we have `left + size/2 < self.len()`, and this is in-bounds.
let cmp = f(unsafe { s.get_unchecked(mid) });
// This control flow produces conditional moves, which results in
// fewer branches and instructions than if/else or matching on
// cmp::Ordering.
// This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx.
left = if cmp == Less { mid + 1 } else { left };
right = if cmp == Greater { mid } else { right };
if cmp == Equal {
// SAFETY: same as the `get_unchecked` above
unsafe { std::intrinsics::assume(mid < s.len()) };
return Ok(mid);
}
size = right - left;
}
// SAFETY: directly true from the overall invariant.
// Note that this is `<=`, unlike the assume in the `Ok` path.
unsafe { std::intrinsics::assume(left <= s.len()) };
Err(left)
} bench.rs#![feature(test)]
extern crate test;
use test::black_box;
use test::Bencher;
use binary_search_bench::*;
enum Cache {
L1,
L2,
L3,
}
fn old_bench_binary_search<F>(b: &mut Bencher, cache: Cache, mapper: F)
where
F: Fn(usize) -> usize,
{
let size = match cache {
Cache::L1 => 1000, // 8kb
Cache::L2 => 10_000, // 80kb
Cache::L3 => 1_000_000, // 8Mb
};
let v = (0..size).map(&mapper).collect::<Vec<_>>();
let mut r = 0usize;
b.iter(move || {
// LCG constants from https://en.wikipedia.org/wiki/Numerical_Recipes.
r = r.wrapping_mul(1664525).wrapping_add(1013904223);
// Lookup the whole range to get 50% hits and 50% misses.
let i = mapper(r % size);
black_box(old_binary_search(&v, &i).is_ok());
})
}
fn old_bench_binary_search_worst_case(b: &mut Bencher, cache: Cache) {
let size = match cache {
Cache::L1 => 1000, // 8kb
Cache::L2 => 10_000, // 80kb
Cache::L3 => 1_000_000, // 8Mb
};
let mut v = vec![0; size];
let i = 1;
v[size - 1] = i;
b.iter(move || {
black_box(old_binary_search(&v, &i).is_ok());
})
}
#[bench]
fn old_binary_search_l1(b: &mut Bencher) {
old_bench_binary_search(b, Cache::L1, |i| i * 2);
}
#[bench]
fn old_binary_search_l2(b: &mut Bencher) {
old_bench_binary_search(b, Cache::L2, |i| i * 2);
}
#[bench]
fn old_binary_search_l3(b: &mut Bencher) {
old_bench_binary_search(b, Cache::L3, |i| i * 2);
}
#[bench]
fn old_binary_search_l1_with_dups(b: &mut Bencher) {
old_bench_binary_search(b, Cache::L1, |i| i / 16 * 16);
}
#[bench]
fn old_binary_search_l2_with_dups(b: &mut Bencher) {
old_bench_binary_search(b, Cache::L2, |i| i / 16 * 16);
}
#[bench]
fn old_binary_search_l3_with_dups(b: &mut Bencher) {
old_bench_binary_search(b, Cache::L3, |i| i / 16 * 16);
}
#[bench]
fn old_binary_search_l1_worst_case(b: &mut Bencher) {
old_bench_binary_search_worst_case(b, Cache::L1);
}
#[bench]
fn old_binary_search_l2_worst_case(b: &mut Bencher) {
old_bench_binary_search_worst_case(b, Cache::L2);
}
#[bench]
fn old_binary_search_l3_worst_case(b: &mut Bencher) {
old_bench_binary_search_worst_case(b, Cache::L3);
}
fn new_bench_binary_search<F>(b: &mut Bencher, cache: Cache, mapper: F)
where
F: Fn(usize) -> usize,
{
let size = match cache {
Cache::L1 => 1000, // 8kb
Cache::L2 => 10_000, // 80kb
Cache::L3 => 1_000_000, // 8Mb
};
let v = (0..size).map(&mapper).collect::<Vec<_>>();
let mut r = 0usize;
b.iter(move || {
// LCG constants from https://en.wikipedia.org/wiki/Numerical_Recipes.
r = r.wrapping_mul(1664525).wrapping_add(1013904223);
// Lookup the whole range to get 50% hits and 50% misses.
let i = mapper(r % size);
black_box(new_binary_search(&v, &i).is_ok());
})
}
fn new_bench_binary_search_worst_case(b: &mut Bencher, cache: Cache) {
let size = match cache {
Cache::L1 => 1000, // 8kb
Cache::L2 => 10_000, // 80kb
Cache::L3 => 1_000_000, // 8Mb
};
let mut v = vec![0; size];
let i = 1;
v[size - 1] = i;
b.iter(move || {
black_box(new_binary_search(&v, &i).is_ok());
})
}
#[bench]
fn new_binary_search_l1(b: &mut Bencher) {
new_bench_binary_search(b, Cache::L1, |i| i * 2);
}
#[bench]
fn new_binary_search_l2(b: &mut Bencher) {
new_bench_binary_search(b, Cache::L2, |i| i * 2);
}
#[bench]
fn new_binary_search_l3(b: &mut Bencher) {
new_bench_binary_search(b, Cache::L3, |i| i * 2);
}
#[bench]
fn new_binary_search_l1_with_dups(b: &mut Bencher) {
new_bench_binary_search(b, Cache::L1, |i| i / 16 * 16);
}
#[bench]
fn new_binary_search_l2_with_dups(b: &mut Bencher) {
new_bench_binary_search(b, Cache::L2, |i| i / 16 * 16);
}
#[bench]
fn new_binary_search_l3_with_dups(b: &mut Bencher) {
new_bench_binary_search(b, Cache::L3, |i| i / 16 * 16);
}
#[bench]
fn new_binary_search_l1_worst_case(b: &mut Bencher) {
new_bench_binary_search_worst_case(b, Cache::L1);
}
#[bench]
fn new_binary_search_l2_worst_case(b: &mut Bencher) {
new_bench_binary_search_worst_case(b, Cache::L2);
}
#[bench]
fn new_binary_search_l3_worst_case(b: &mut Bencher) {
new_bench_binary_search_worst_case(b, Cache::L3);
}
fn bench_binary_search_117722<F>(b: &mut Bencher, cache: Cache, mapper: F)
where
F: Fn(usize) -> usize,
{
let size = match cache {
Cache::L1 => 1000, // 8kb
Cache::L2 => 10_000, // 80kb
Cache::L3 => 1_000_000, // 8Mb
};
let v = (0..size).map(&mapper).collect::<Vec<_>>();
let mut r = 0usize;
b.iter(move || {
// LCG constants from https://en.wikipedia.org/wiki/Numerical_Recipes.
r = r.wrapping_mul(1664525).wrapping_add(1013904223);
// Lookup the whole range to get 50% hits and 50% misses.
let i = mapper(r % size);
black_box(binary_search_117722(&v, &i).is_ok());
})
}
fn bench_binary_search_117722_worst_case(b: &mut Bencher, cache: Cache) {
let size = match cache {
Cache::L1 => 1000, // 8kb
Cache::L2 => 10_000, // 80kb
Cache::L3 => 1_000_000, // 8Mb
};
let mut v = vec![0; size];
let i = 1;
v[size - 1] = i;
b.iter(move || {
black_box(binary_search_117722(&v, &i).is_ok());
})
}
#[bench]
fn binary_search_117722_l1(b: &mut Bencher) {
bench_binary_search_117722(b, Cache::L1, |i| i * 2);
}
#[bench]
fn binary_search_117722_l2(b: &mut Bencher) {
bench_binary_search_117722(b, Cache::L2, |i| i * 2);
}
#[bench]
fn binary_search_117722_l3(b: &mut Bencher) {
bench_binary_search_117722(b, Cache::L3, |i| i * 2);
}
#[bench]
fn binary_search_117722_l1_with_dups(b: &mut Bencher) {
bench_binary_search_117722(b, Cache::L1, |i| i / 16 * 16);
}
#[bench]
fn binary_search_117722_l2_with_dups(b: &mut Bencher) {
bench_binary_search_117722(b, Cache::L2, |i| i / 16 * 16);
}
#[bench]
fn binary_search_117722_l3_with_dups(b: &mut Bencher) {
bench_binary_search_117722(b, Cache::L3, |i| i / 16 * 16);
}
#[bench]
fn binary_search_117722_l1_worst_case(b: &mut Bencher) {
bench_binary_search_117722_worst_case(b, Cache::L1);
}
#[bench]
fn binary_search_117722_l2_worst_case(b: &mut Bencher) {
bench_binary_search_117722_worst_case(b, Cache::L2);
}
#[bench]
fn binary_search_117722_l3_worst_case(b: &mut Bencher) {
bench_binary_search_117722_worst_case(b, Cache::L3);
} Cargo.toml[package]
name = "binary_search_bench"
version = "0.1.0"
edition = "2021"
[profile.bench]
opt-level = 3
lto = true
codegen-units = 1
panic = "abort"
strip = true
debug = false |
See #128254 for a potential fix. |
Rewrite binary search implementation This PR builds on top of rust-lang#128250, which should be merged first. This restores the original binary search implementation from rust-lang#45333 which has the nice property of having a loop count that only depends on the size of the slice. This, along with explicit conditional moves from rust-lang#128250, means that the entire binary search loop can be perfectly predicted by the branch predictor. Additionally, LLVM is able to unroll the loop when the slice length is known at compile-time. This results in a very compact code sequence of 3-4 instructions per binary search step and zero branches. Fixes rust-lang#53823 Fixes rust-lang#115271
#128254 (MacBook M3) Just tested it locally, and it is undoubtedly better than all binary_search ever
|
This affects my real project.
[usize]
seems to cause a huge performance regression.After investigation, it is caused by the new binary_search introduced by this pr #74024 .
Why is the new
binary_search
notbinary_search_unstable
?macOS Pro M2
lib.rs
bench.rs
The text was updated successfully, but these errors were encountered: