Skip to content

Commit 3530bfc

Browse files
edwloefgitbot
authored and
gitbot
committed
optimize slice::ptr_rotate for compile-time-constant small rotates
1 parent c244dfb commit 3530bfc

File tree

1 file changed

+166
-161
lines changed

1 file changed

+166
-161
lines changed

core/src/slice/rotate.rs

+166-161
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,18 @@ use crate::{cmp, ptr};
1111
///
1212
/// # Algorithm
1313
///
14-
/// Algorithm 1 is used for small values of `left + right` or for large `T`. The elements are moved
15-
/// into their final positions one at a time starting at `mid - left` and advancing by `right` steps
16-
/// modulo `left + right`, such that only one temporary is needed. Eventually, we arrive back at
17-
/// `mid - left`. However, if `gcd(left + right, right)` is not 1, the above steps skipped over
18-
/// elements. For example:
14+
/// Algorithm 1 is used if `min(left, right)` is small enough to fit onto a stack buffer. The
15+
/// `min(left, right)` elements are copied onto the buffer, `memmove` is applied to the others, and
16+
/// the ones on the buffer are moved back into the hole on the opposite side of where they
17+
/// originated.
18+
///
19+
/// Algorithms that can be vectorized outperform the above once `left + right` becomes large enough.
20+
///
21+
/// Algorithm 2 is otherwise used for small values of `left + right` or for large `T`. The elements
22+
/// are moved into their final positions one at a time starting at `mid - left` and advancing by
23+
/// `right` steps modulo `left + right`, such that only one temporary is needed. Eventually, we
24+
/// arrive back at `mid - left`. However, if `gcd(left + right, right)` is not 1, the above steps
25+
/// skipped over elements. For example:
1926
/// ```text
2027
/// left = 10, right = 6
2128
/// the `^` indicates an element in its final place
@@ -39,13 +46,7 @@ use crate::{cmp, ptr};
3946
/// `gcd(left + right, right)` value). The end result is that all elements are finalized once and
4047
/// only once.
4148
///
42-
/// Algorithm 2 is used if `left + right` is large but `min(left, right)` is small enough to
43-
/// fit onto a stack buffer. The `min(left, right)` elements are copied onto the buffer, `memmove`
44-
/// is applied to the others, and the ones on the buffer are moved back into the hole on the
45-
/// opposite side of where they originated.
46-
///
47-
/// Algorithms that can be vectorized outperform the above once `left + right` becomes large enough.
48-
/// Algorithm 1 can be vectorized by chunking and performing many rounds at once, but there are too
49+
/// Algorithm 2 can be vectorized by chunking and performing many rounds at once, but there are too
4950
/// few rounds on average until `left + right` is enormous, and the worst case of a single
5051
/// round is always there. Instead, algorithm 3 utilizes repeated swapping of
5152
/// `min(left, right)` elements until a smaller rotate problem is left.
@@ -65,172 +66,176 @@ pub(super) unsafe fn ptr_rotate<T>(mut left: usize, mut mid: *mut T, mut right:
6566
if T::IS_ZST {
6667
return;
6768
}
68-
loop {
69-
// N.B. the below algorithms can fail if these cases are not checked
70-
if (right == 0) || (left == 0) {
71-
return;
69+
// N.B. the below algorithms can fail if these cases are not checked
70+
if (right == 0) || (left == 0) {
71+
return;
72+
}
73+
// `T` is not a zero-sized type, so it's okay to divide by its size.
74+
if !cfg!(feature = "optimize_for_size")
75+
&& cmp::min(left, right) <= mem::size_of::<BufType>() / mem::size_of::<T>()
76+
{
77+
// Algorithm 1
78+
// The `[T; 0]` here is to ensure this is appropriately aligned for T
79+
let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit();
80+
let buf = rawarray.as_mut_ptr() as *mut T;
81+
// SAFETY: `mid-left <= mid-left+right < mid+right`
82+
let dim = unsafe { mid.sub(left).add(right) };
83+
if left <= right {
84+
// SAFETY:
85+
//
86+
// 1) The `if` condition about the sizes ensures `[mid-left; left]` will fit in
87+
// `buf` without overflow and `buf` was created just above and so cannot be
88+
// overlapped with any value of `[mid-left; left]`
89+
// 2) [mid-left, mid+right) are all valid for reading and writing and we don't care
90+
// about overlaps here.
91+
// 3) The `if` condition about `left <= right` ensures writing `left` elements to
92+
// `dim = mid-left+right` is valid because:
93+
// - `buf` is valid and `left` elements were written in it in 1)
94+
// - `dim+left = mid-left+right+left = mid+right` and we write `[dim, dim+left)`
95+
unsafe {
96+
// 1)
97+
ptr::copy_nonoverlapping(mid.sub(left), buf, left);
98+
// 2)
99+
ptr::copy(mid, mid.sub(left), right);
100+
// 3)
101+
ptr::copy_nonoverlapping(buf, dim, left);
102+
}
103+
} else {
104+
// SAFETY: same reasoning as above but with `left` and `right` reversed
105+
unsafe {
106+
ptr::copy_nonoverlapping(mid, buf, right);
107+
ptr::copy(mid.sub(left), dim, left);
108+
ptr::copy_nonoverlapping(buf, mid.sub(left), right);
109+
}
72110
}
73-
if !cfg!(feature = "optimize_for_size")
74-
&& ((left + right < 24) || (mem::size_of::<T>() > mem::size_of::<[usize; 4]>()))
75-
{
76-
// Algorithm 1
77-
// Microbenchmarks indicate that the average performance for random shifts is better all
78-
// the way until about `left + right == 32`, but the worst case performance breaks even
79-
// around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4
80-
// `usize`s, this algorithm also outperforms other algorithms.
81-
// SAFETY: callers must ensure `mid - left` is valid for reading and writing.
82-
let x = unsafe { mid.sub(left) };
83-
// beginning of first round
84-
// SAFETY: see previous comment.
85-
let mut tmp: T = unsafe { x.read() };
86-
let mut i = right;
87-
// `gcd` can be found before hand by calculating `gcd(left + right, right)`,
88-
// but it is faster to do one loop which calculates the gcd as a side effect, then
89-
// doing the rest of the chunk
90-
let mut gcd = right;
91-
// benchmarks reveal that it is faster to swap temporaries all the way through instead
92-
// of reading one temporary once, copying backwards, and then writing that temporary at
93-
// the very end. This is possibly due to the fact that swapping or replacing temporaries
94-
// uses only one memory address in the loop instead of needing to manage two.
111+
} else if !cfg!(feature = "optimize_for_size")
112+
&& ((left + right < 24) || (mem::size_of::<T>() > mem::size_of::<[usize; 4]>()))
113+
{
114+
// Algorithm 2
115+
// Microbenchmarks indicate that the average performance for random shifts is better all
116+
// the way until about `left + right == 32`, but the worst case performance breaks even
117+
// around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4
118+
// `usize`s, this algorithm also outperforms other algorithms.
119+
// SAFETY: callers must ensure `mid - left` is valid for reading and writing.
120+
let x = unsafe { mid.sub(left) };
121+
// beginning of first round
122+
// SAFETY: see previous comment.
123+
let mut tmp: T = unsafe { x.read() };
124+
let mut i = right;
125+
// `gcd` can be found before hand by calculating `gcd(left + right, right)`,
126+
// but it is faster to do one loop which calculates the gcd as a side effect, then
127+
// doing the rest of the chunk
128+
let mut gcd = right;
129+
// benchmarks reveal that it is faster to swap temporaries all the way through instead
130+
// of reading one temporary once, copying backwards, and then writing that temporary at
131+
// the very end. This is possibly due to the fact that swapping or replacing temporaries
132+
// uses only one memory address in the loop instead of needing to manage two.
133+
loop {
134+
// [long-safety-expl]
135+
// SAFETY: callers must ensure `[left, left+mid+right)` are all valid for reading and
136+
// writing.
137+
//
138+
// - `i` start with `right` so `mid-left <= x+i = x+right = mid-left+right < mid+right`
139+
// - `i <= left+right-1` is always true
140+
// - if `i < left`, `right` is added so `i < left+right` and on the next
141+
// iteration `left` is removed from `i` so it doesn't go further
142+
// - if `i >= left`, `left` is removed immediately and so it doesn't go further.
143+
// - overflows cannot happen for `i` since the function's safety contract ask for
144+
// `mid+right-1 = x+left+right` to be valid for writing
145+
// - underflows cannot happen because `i` must be bigger or equal to `left` for
146+
// a subtraction of `left` to happen.
147+
//
148+
// So `x+i` is valid for reading and writing if the caller respected the contract
149+
tmp = unsafe { x.add(i).replace(tmp) };
150+
// instead of incrementing `i` and then checking if it is outside the bounds, we
151+
// check if `i` will go outside the bounds on the next increment. This prevents
152+
// any wrapping of pointers or `usize`.
153+
if i >= left {
154+
i -= left;
155+
if i == 0 {
156+
// end of first round
157+
// SAFETY: tmp has been read from a valid source and x is valid for writing
158+
// according to the caller.
159+
unsafe { x.write(tmp) };
160+
break;
161+
}
162+
// this conditional must be here if `left + right >= 15`
163+
if i < gcd {
164+
gcd = i;
165+
}
166+
} else {
167+
i += right;
168+
}
169+
}
170+
// finish the chunk with more rounds
171+
for start in 1..gcd {
172+
// SAFETY: `gcd` is at most equal to `right` so all values in `1..gcd` are valid for
173+
// reading and writing as per the function's safety contract, see [long-safety-expl]
174+
// above
175+
tmp = unsafe { x.add(start).read() };
176+
// [safety-expl-addition]
177+
//
178+
// Here `start < gcd` so `start < right` so `i < right+right`: `right` being the
179+
// greatest common divisor of `(left+right, right)` means that `left = right` so
180+
// `i < left+right` so `x+i = mid-left+i` is always valid for reading and writing
181+
// according to the function's safety contract.
182+
i = start + right;
95183
loop {
96-
// [long-safety-expl]
97-
// SAFETY: callers must ensure `[left, left+mid+right)` are all valid for reading and
98-
// writing.
99-
//
100-
// - `i` start with `right` so `mid-left <= x+i = x+right = mid-left+right < mid+right`
101-
// - `i <= left+right-1` is always true
102-
// - if `i < left`, `right` is added so `i < left+right` and on the next
103-
// iteration `left` is removed from `i` so it doesn't go further
104-
// - if `i >= left`, `left` is removed immediately and so it doesn't go further.
105-
// - overflows cannot happen for `i` since the function's safety contract ask for
106-
// `mid+right-1 = x+left+right` to be valid for writing
107-
// - underflows cannot happen because `i` must be bigger or equal to `left` for
108-
// a subtraction of `left` to happen.
109-
//
110-
// So `x+i` is valid for reading and writing if the caller respected the contract
184+
// SAFETY: see [long-safety-expl] and [safety-expl-addition]
111185
tmp = unsafe { x.add(i).replace(tmp) };
112-
// instead of incrementing `i` and then checking if it is outside the bounds, we
113-
// check if `i` will go outside the bounds on the next increment. This prevents
114-
// any wrapping of pointers or `usize`.
115186
if i >= left {
116187
i -= left;
117-
if i == 0 {
118-
// end of first round
119-
// SAFETY: tmp has been read from a valid source and x is valid for writing
120-
// according to the caller.
121-
unsafe { x.write(tmp) };
188+
if i == start {
189+
// SAFETY: see [long-safety-expl] and [safety-expl-addition]
190+
unsafe { x.add(start).write(tmp) };
122191
break;
123192
}
124-
// this conditional must be here if `left + right >= 15`
125-
if i < gcd {
126-
gcd = i;
127-
}
128193
} else {
129194
i += right;
130195
}
131196
}
132-
// finish the chunk with more rounds
133-
for start in 1..gcd {
134-
// SAFETY: `gcd` is at most equal to `right` so all values in `1..gcd` are valid for
135-
// reading and writing as per the function's safety contract, see [long-safety-expl]
136-
// above
137-
tmp = unsafe { x.add(start).read() };
138-
// [safety-expl-addition]
139-
//
140-
// Here `start < gcd` so `start < right` so `i < right+right`: `right` being the
141-
// greatest common divisor of `(left+right, right)` means that `left = right` so
142-
// `i < left+right` so `x+i = mid-left+i` is always valid for reading and writing
143-
// according to the function's safety contract.
144-
i = start + right;
197+
}
198+
} else {
199+
loop {
200+
if left >= right {
201+
// Algorithm 3
202+
// There is an alternate way of swapping that involves finding where the last swap
203+
// of this algorithm would be, and swapping using that last chunk instead of swapping
204+
// adjacent chunks like this algorithm is doing, but this way is still faster.
145205
loop {
146-
// SAFETY: see [long-safety-expl] and [safety-expl-addition]
147-
tmp = unsafe { x.add(i).replace(tmp) };
148-
if i >= left {
149-
i -= left;
150-
if i == start {
151-
// SAFETY: see [long-safety-expl] and [safety-expl-addition]
152-
unsafe { x.add(start).write(tmp) };
153-
break;
154-
}
155-
} else {
156-
i += right;
206+
// SAFETY:
207+
// `left >= right` so `[mid-right, mid+right)` is valid for reading and writing
208+
// Subtracting `right` from `mid` each turn is counterbalanced by the addition and
209+
// check after it.
210+
unsafe {
211+
ptr::swap_nonoverlapping(mid.sub(right), mid, right);
212+
mid = mid.sub(right);
213+
}
214+
left -= right;
215+
if left < right {
216+
break;
157217
}
158-
}
159-
}
160-
return;
161-
// `T` is not a zero-sized type, so it's okay to divide by its size.
162-
} else if !cfg!(feature = "optimize_for_size")
163-
&& cmp::min(left, right) <= mem::size_of::<BufType>() / mem::size_of::<T>()
164-
{
165-
// Algorithm 2
166-
// The `[T; 0]` here is to ensure this is appropriately aligned for T
167-
let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit();
168-
let buf = rawarray.as_mut_ptr() as *mut T;
169-
// SAFETY: `mid-left <= mid-left+right < mid+right`
170-
let dim = unsafe { mid.sub(left).add(right) };
171-
if left <= right {
172-
// SAFETY:
173-
//
174-
// 1) The `else if` condition about the sizes ensures `[mid-left; left]` will fit in
175-
// `buf` without overflow and `buf` was created just above and so cannot be
176-
// overlapped with any value of `[mid-left; left]`
177-
// 2) [mid-left, mid+right) are all valid for reading and writing and we don't care
178-
// about overlaps here.
179-
// 3) The `if` condition about `left <= right` ensures writing `left` elements to
180-
// `dim = mid-left+right` is valid because:
181-
// - `buf` is valid and `left` elements were written in it in 1)
182-
// - `dim+left = mid-left+right+left = mid+right` and we write `[dim, dim+left)`
183-
unsafe {
184-
// 1)
185-
ptr::copy_nonoverlapping(mid.sub(left), buf, left);
186-
// 2)
187-
ptr::copy(mid, mid.sub(left), right);
188-
// 3)
189-
ptr::copy_nonoverlapping(buf, dim, left);
190218
}
191219
} else {
192-
// SAFETY: same reasoning as above but with `left` and `right` reversed
193-
unsafe {
194-
ptr::copy_nonoverlapping(mid, buf, right);
195-
ptr::copy(mid.sub(left), dim, left);
196-
ptr::copy_nonoverlapping(buf, mid.sub(left), right);
197-
}
198-
}
199-
return;
200-
} else if left >= right {
201-
// Algorithm 3
202-
// There is an alternate way of swapping that involves finding where the last swap
203-
// of this algorithm would be, and swapping using that last chunk instead of swapping
204-
// adjacent chunks like this algorithm is doing, but this way is still faster.
205-
loop {
206-
// SAFETY:
207-
// `left >= right` so `[mid-right, mid+right)` is valid for reading and writing
208-
// Subtracting `right` from `mid` each turn is counterbalanced by the addition and
209-
// check after it.
210-
unsafe {
211-
ptr::swap_nonoverlapping(mid.sub(right), mid, right);
212-
mid = mid.sub(right);
213-
}
214-
left -= right;
215-
if left < right {
216-
break;
220+
// Algorithm 3, `left < right`
221+
loop {
222+
// SAFETY: `[mid-left, mid+left)` is valid for reading and writing because
223+
// `left < right` so `mid+left < mid+right`.
224+
// Adding `left` to `mid` each turn is counterbalanced by the subtraction and check
225+
// after it.
226+
unsafe {
227+
ptr::swap_nonoverlapping(mid.sub(left), mid, left);
228+
mid = mid.add(left);
229+
}
230+
right -= left;
231+
if right < left {
232+
break;
233+
}
217234
}
218235
}
219-
} else {
220-
// Algorithm 3, `left < right`
221-
loop {
222-
// SAFETY: `[mid-left, mid+left)` is valid for reading and writing because
223-
// `left < right` so `mid+left < mid+right`.
224-
// Adding `left` to `mid` each turn is counterbalanced by the subtraction and check
225-
// after it.
226-
unsafe {
227-
ptr::swap_nonoverlapping(mid.sub(left), mid, left);
228-
mid = mid.add(left);
229-
}
230-
right -= left;
231-
if right < left {
232-
break;
233-
}
236+
237+
if (right == 0) || (left == 0) {
238+
return;
234239
}
235240
}
236241
}

0 commit comments

Comments
 (0)