Skip to content

Commit b75f65a

Browse files
committed
Auto merge of #59078 - ssomers:btreeset_intersection_revisited, r=<try>
improve worst-case performance of BTreeSet intersection Alternative to [pull request #58577](#58577): back out of attempts to optimize using ranges, more elegant code (I think). The stable public type Intersection changes from struct to enum. If that matters, then perhaps changing the fields like in the other proposal also mattered.
2 parents f8860f2 + a379504 commit b75f65a

File tree

3 files changed

+175
-72
lines changed

3 files changed

+175
-72
lines changed

src/liballoc/benches/btree/set.rs

+104-46
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,35 @@
11
use std::collections::BTreeSet;
2+
use std::collections::btree_set::Intersection;
23

34
use rand::{thread_rng, Rng};
45
use test::{black_box, Bencher};
56

6-
fn random(n1: u32, n2: u32) -> [BTreeSet<usize>; 2] {
7+
fn random(n1: usize, n2: usize) -> [BTreeSet<usize>; 2] {
78
let mut rng = thread_rng();
8-
let mut set1 = BTreeSet::new();
9-
let mut set2 = BTreeSet::new();
10-
for _ in 0..n1 {
11-
let i = rng.gen::<usize>();
12-
set1.insert(i);
13-
}
14-
for _ in 0..n2 {
15-
let i = rng.gen::<usize>();
16-
set2.insert(i);
9+
let mut sets = [BTreeSet::new(), BTreeSet::new()];
10+
for i in 0..2 {
11+
while sets[i].len() < [n1, n2][i] {
12+
sets[i].insert(rng.gen());
13+
}
1714
}
18-
[set1, set2]
15+
assert_eq!(sets[0].len(), n1);
16+
assert_eq!(sets[1].len(), n2);
17+
sets
1918
}
2019

21-
fn staggered(n1: u32, n2: u32) -> [BTreeSet<u32>; 2] {
22-
let mut even = BTreeSet::new();
23-
let mut odd = BTreeSet::new();
24-
for i in 0..n1 {
25-
even.insert(i * 2);
26-
}
27-
for i in 0..n2 {
28-
odd.insert(i * 2 + 1);
20+
fn stagger(n1: usize, factor: usize) -> [BTreeSet<u32>; 2] {
21+
let n2 = n1 * factor;
22+
let mut sets = [BTreeSet::new(), BTreeSet::new()];
23+
for i in 0..(n1 + n2) {
24+
let b = i % (factor + 1) != 0;
25+
sets[b as usize].insert(i as u32);
2926
}
30-
[even, odd]
27+
assert_eq!(sets[0].len(), n1);
28+
assert_eq!(sets[1].len(), n2);
29+
sets
3130
}
3231

33-
fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
32+
fn neg_vs_pos(n1: usize, n2: usize) -> [BTreeSet<i32>; 2] {
3433
let mut neg = BTreeSet::new();
3534
let mut pos = BTreeSet::new();
3635
for i in -(n1 as i32)..=-1 {
@@ -39,22 +38,38 @@ fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
3938
for i in 1..=(n2 as i32) {
4039
pos.insert(i);
4140
}
41+
assert_eq!(neg.len(), n1);
42+
assert_eq!(pos.len(), n2);
4243
[neg, pos]
4344
}
4445

45-
fn pos_vs_neg(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
46-
let mut neg = BTreeSet::new();
47-
let mut pos = BTreeSet::new();
48-
for i in -(n1 as i32)..=-1 {
49-
neg.insert(i);
46+
fn pos_vs_neg(n1: usize, n2: usize) -> [BTreeSet<i32>; 2] {
47+
let mut sets = neg_vs_pos(n2, n1);
48+
sets.reverse();
49+
assert_eq!(sets[0].len(), n1);
50+
assert_eq!(sets[1].len(), n2);
51+
sets
52+
}
53+
54+
fn intersection_search<T>(sets: &[BTreeSet<T>; 2]) -> Intersection<T>
55+
where T: std::cmp::Ord
56+
{
57+
Intersection::Search {
58+
a_iter: sets[0].iter(),
59+
b_set: &sets[1],
5060
}
51-
for i in 1..=(n2 as i32) {
52-
pos.insert(i);
61+
}
62+
63+
fn intersection_stitch<T>(sets: &[BTreeSet<T>; 2]) -> Intersection<T>
64+
where T: std::cmp::Ord
65+
{
66+
Intersection::Stitch {
67+
a_iter: sets[0].iter(),
68+
b_iter: sets[1].iter(),
5369
}
54-
[pos, neg]
5570
}
5671

57-
macro_rules! set_intersection_bench {
72+
macro_rules! intersection_bench {
5873
($name: ident, $sets: expr) => {
5974
#[bench]
6075
pub fn $name(b: &mut Bencher) {
@@ -68,21 +83,64 @@ macro_rules! set_intersection_bench {
6883
})
6984
}
7085
};
86+
($name: ident, $sets: expr, $intersection_kind: ident) => {
87+
#[bench]
88+
pub fn $name(b: &mut Bencher) {
89+
// setup
90+
let sets = $sets;
91+
assert!(sets[0].len() >= 1);
92+
assert!(sets[1].len() >= sets[0].len());
93+
94+
// measure
95+
b.iter(|| {
96+
let x = $intersection_kind(&sets).count();
97+
black_box(x);
98+
})
99+
}
100+
};
71101
}
72102

73-
set_intersection_bench! {intersect_random_100, random(100, 100)}
74-
set_intersection_bench! {intersect_random_10k, random(10_000, 10_000)}
75-
set_intersection_bench! {intersect_random_10_vs_10k, random(10, 10_000)}
76-
set_intersection_bench! {intersect_random_10k_vs_10, random(10_000, 10)}
77-
set_intersection_bench! {intersect_staggered_100, staggered(100, 100)}
78-
set_intersection_bench! {intersect_staggered_10k, staggered(10_000, 10_000)}
79-
set_intersection_bench! {intersect_staggered_10_vs_10k, staggered(10, 10_000)}
80-
set_intersection_bench! {intersect_staggered_10k_vs_10, staggered(10_000, 10)}
81-
set_intersection_bench! {intersect_neg_vs_pos_100, neg_vs_pos(100, 100)}
82-
set_intersection_bench! {intersect_neg_vs_pos_10k, neg_vs_pos(10_000, 10_000)}
83-
set_intersection_bench! {intersect_neg_vs_pos_10_vs_10k,neg_vs_pos(10, 10_000)}
84-
set_intersection_bench! {intersect_neg_vs_pos_10k_vs_10,neg_vs_pos(10_000, 10)}
85-
set_intersection_bench! {intersect_pos_vs_neg_100, pos_vs_neg(100, 100)}
86-
set_intersection_bench! {intersect_pos_vs_neg_10k, pos_vs_neg(10_000, 10_000)}
87-
set_intersection_bench! {intersect_pos_vs_neg_10_vs_10k,pos_vs_neg(10, 10_000)}
88-
set_intersection_bench! {intersect_pos_vs_neg_10k_vs_10,pos_vs_neg(10_000, 10)}
103+
intersection_bench! {intersect_100_neg_vs_100_pos, neg_vs_pos(100, 100)}
104+
intersection_bench! {intersect_100_neg_vs_10k_pos, neg_vs_pos(100, 10_000)}
105+
intersection_bench! {intersect_100_pos_vs_100_neg, pos_vs_neg(100, 100)}
106+
intersection_bench! {intersect_100_pos_vs_10k_neg, pos_vs_neg(100, 10_000)}
107+
intersection_bench! {intersect_10k_neg_vs_100_pos, neg_vs_pos(10_000, 100)}
108+
intersection_bench! {intersect_10k_neg_vs_10k_pos, neg_vs_pos(10_000, 10_000)}
109+
intersection_bench! {intersect_10k_pos_vs_100_neg, pos_vs_neg(10_000, 100)}
110+
intersection_bench! {intersect_10k_pos_vs_10k_neg, pos_vs_neg(10_000, 10_000)}
111+
intersection_bench! {intersect_random_100_vs_100_actual,random(100, 100)}
112+
intersection_bench! {intersect_random_100_vs_100_search,random(100, 100), intersection_search}
113+
intersection_bench! {intersect_random_100_vs_100_stitch,random(100, 100), intersection_stitch}
114+
intersection_bench! {intersect_random_100_vs_10k_actual,random(100, 10_000)}
115+
intersection_bench! {intersect_random_100_vs_10k_search,random(100, 10_000), intersection_search}
116+
intersection_bench! {intersect_random_100_vs_10k_stitch,random(100, 10_000), intersection_stitch}
117+
intersection_bench! {intersect_random_10k_vs_10k_actual,random(10_000, 10_000)}
118+
intersection_bench! {intersect_random_10k_vs_10k_search,random(10_000, 10_000), intersection_search}
119+
intersection_bench! {intersect_random_10k_vs_10k_stitch,random(10_000, 10_000), intersection_stitch}
120+
intersection_bench! {intersect_stagger_100_actual, stagger(100, 1)}
121+
intersection_bench! {intersect_stagger_100_search, stagger(100, 1), intersection_search}
122+
intersection_bench! {intersect_stagger_100_stitch, stagger(100, 1), intersection_stitch}
123+
intersection_bench! {intersect_stagger_10k_actual, stagger(10_000, 1)}
124+
intersection_bench! {intersect_stagger_10k_search, stagger(10_000, 1), intersection_search}
125+
intersection_bench! {intersect_stagger_10k_stitch, stagger(10_000, 1), intersection_stitch}
126+
intersection_bench! {intersect_stagger_1_actual, stagger(1, 1)}
127+
intersection_bench! {intersect_stagger_1_search, stagger(1, 1), intersection_search}
128+
intersection_bench! {intersect_stagger_1_stitch, stagger(1, 1), intersection_stitch}
129+
intersection_bench! {intersect_stagger_diff1_actual, stagger(100, 1 << 1)}
130+
intersection_bench! {intersect_stagger_diff1_search, stagger(100, 1 << 1), intersection_search}
131+
intersection_bench! {intersect_stagger_diff1_stitch, stagger(100, 1 << 1), intersection_stitch}
132+
intersection_bench! {intersect_stagger_diff2_actual, stagger(100, 1 << 2)}
133+
intersection_bench! {intersect_stagger_diff2_search, stagger(100, 1 << 2), intersection_search}
134+
intersection_bench! {intersect_stagger_diff2_stitch, stagger(100, 1 << 2), intersection_stitch}
135+
intersection_bench! {intersect_stagger_diff3_actual, stagger(100, 1 << 3)}
136+
intersection_bench! {intersect_stagger_diff3_search, stagger(100, 1 << 3), intersection_search}
137+
intersection_bench! {intersect_stagger_diff3_stitch, stagger(100, 1 << 3), intersection_stitch}
138+
intersection_bench! {intersect_stagger_diff4_actual, stagger(100, 1 << 4)}
139+
intersection_bench! {intersect_stagger_diff4_search, stagger(100, 1 << 4), intersection_search}
140+
intersection_bench! {intersect_stagger_diff4_stitch, stagger(100, 1 << 4), intersection_stitch}
141+
intersection_bench! {intersect_stagger_diff5_actual, stagger(100, 1 << 5)}
142+
intersection_bench! {intersect_stagger_diff5_search, stagger(100, 1 << 5), intersection_search}
143+
intersection_bench! {intersect_stagger_diff5_stitch, stagger(100, 1 << 5), intersection_stitch}
144+
intersection_bench! {intersect_stagger_diff6_actual, stagger(100, 1 << 6)}
145+
intersection_bench! {intersect_stagger_diff6_search, stagger(100, 1 << 6), intersection_search}
146+
intersection_bench! {intersect_stagger_diff6_stitch, stagger(100, 1 << 6), intersection_stitch}

src/liballoc/benches/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![feature(repr_simd)]
22
#![feature(test)]
3+
#![feature(benches_btree_set)]
34

45
extern crate test;
56

src/liballoc/collections/btree/set.rs

+70-26
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
use core::borrow::Borrow;
55
use core::cmp::Ordering::{self, Less, Greater, Equal};
6-
use core::cmp::{min, max};
6+
use core::cmp::max;
77
use core::fmt::{self, Debug};
88
use core::iter::{Peekable, FromIterator, FusedIterator};
99
use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds};
@@ -163,18 +163,34 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
163163
/// [`BTreeSet`]: struct.BTreeSet.html
164164
/// [`intersection`]: struct.BTreeSet.html#method.intersection
165165
#[stable(feature = "rust1", since = "1.0.0")]
166-
pub struct Intersection<'a, T: 'a> {
167-
a: Peekable<Iter<'a, T>>,
168-
b: Peekable<Iter<'a, T>>,
166+
pub enum Intersection<'a, T: 'a> {
167+
#[doc(hidden)]
168+
#[unstable(feature = "benches_btree_set", reason = "benchmarks for pull #58577", issue = "0")]
169+
Stitch {
170+
a_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets
171+
b_iter: Iter<'a, T>,
172+
},
173+
#[doc(hidden)]
174+
#[unstable(feature = "benches_btree_set", reason = "benchmarks for pull #58577", issue = "0")]
175+
Search {
176+
a_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets
177+
b_set: &'a BTreeSet<T>,
178+
},
169179
}
170180

171181
#[stable(feature = "collection_debug", since = "1.17.0")]
172182
impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
173183
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174-
f.debug_tuple("Intersection")
175-
.field(&self.a)
176-
.field(&self.b)
177-
.finish()
184+
match self {
185+
Intersection::Stitch { a_iter, b_iter } => f
186+
.debug_tuple("Intersection")
187+
.field(&a_iter)
188+
.field(&b_iter)
189+
.finish(),
190+
Intersection::Search { a_iter, b_set: _ } => {
191+
f.debug_tuple("Intersection").field(&a_iter).finish()
192+
}
193+
}
178194
}
179195
}
180196

@@ -326,9 +342,22 @@ impl<T: Ord> BTreeSet<T> {
326342
/// ```
327343
#[stable(feature = "rust1", since = "1.0.0")]
328344
pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T> {
329-
Intersection {
330-
a: self.iter().peekable(),
331-
b: other.iter().peekable(),
345+
let (a_set, b_set) = if self.len() <= other.len() {
346+
(self, other)
347+
} else {
348+
(other, self)
349+
};
350+
if a_set.len() > b_set.len() / 16 {
351+
Intersection::Stitch {
352+
a_iter: a_set.iter(),
353+
b_iter: b_set.iter(),
354+
}
355+
} else {
356+
// Iterate small set only and find matches in large set.
357+
Intersection::Search {
358+
a_iter: a_set.iter(),
359+
b_set,
360+
}
332361
}
333362
}
334363

@@ -1072,9 +1101,15 @@ impl<T: Ord> FusedIterator for SymmetricDifference<'_, T> {}
10721101
#[stable(feature = "rust1", since = "1.0.0")]
10731102
impl<T> Clone for Intersection<'_, T> {
10741103
fn clone(&self) -> Self {
1075-
Intersection {
1076-
a: self.a.clone(),
1077-
b: self.b.clone(),
1104+
match self {
1105+
Intersection::Stitch { a_iter, b_iter } => Intersection::Stitch {
1106+
a_iter: a_iter.clone(),
1107+
b_iter: b_iter.clone(),
1108+
},
1109+
Intersection::Search { a_iter, b_set } => Intersection::Search {
1110+
a_iter: a_iter.clone(),
1111+
b_set,
1112+
},
10781113
}
10791114
}
10801115
}
@@ -1083,24 +1118,33 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
10831118
type Item = &'a T;
10841119

10851120
fn next(&mut self) -> Option<&'a T> {
1086-
loop {
1087-
match Ord::cmp(self.a.peek()?, self.b.peek()?) {
1088-
Less => {
1089-
self.a.next();
1090-
}
1091-
Equal => {
1092-
self.b.next();
1093-
return self.a.next();
1094-
}
1095-
Greater => {
1096-
self.b.next();
1121+
match self {
1122+
Intersection::Stitch { a_iter, b_iter } => {
1123+
let mut a_next = a_iter.next()?;
1124+
let mut b_next = b_iter.next()?;
1125+
loop {
1126+
match Ord::cmp(a_next, b_next) {
1127+
Less => a_next = a_iter.next()?,
1128+
Greater => b_next = b_iter.next()?,
1129+
Equal => return Some(a_next),
1130+
}
10971131
}
10981132
}
1133+
Intersection::Search { a_iter, b_set } => loop {
1134+
let a_next = a_iter.next()?;
1135+
if b_set.contains(&a_next) {
1136+
return Some(a_next);
1137+
}
1138+
},
10991139
}
11001140
}
11011141

11021142
fn size_hint(&self) -> (usize, Option<usize>) {
1103-
(0, Some(min(self.a.len(), self.b.len())))
1143+
let max_size = match self {
1144+
Intersection::Stitch { a_iter, .. } => a_iter.len(),
1145+
Intersection::Search { a_iter, .. } => a_iter.len(),
1146+
};
1147+
(0, Some(max_size))
11041148
}
11051149
}
11061150

0 commit comments

Comments
 (0)