Skip to content

Commit e7d15a0

Browse files
Voultaphergitbot
authored and
gitbot
committed
Use more explicit and reliable ptr select in sort impls
Using if ... with the intent to avoid branches can be surprising to readers and carries the risk of turning into jumps/branches generated by some future compiler version, breaking crucial optimizations. This commit replaces their usage with the explicit and IR annotated `bool::select_unpredictable`.
1 parent e78b619 commit e7d15a0

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

core/src/slice/sort/shared/smallsort.rs

+13-18
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ unsafe fn swap_if_less<T, F>(v_base: *mut T, a_pos: usize, b_pos: usize, is_less
387387
where
388388
F: FnMut(&T, &T) -> bool,
389389
{
390-
// SAFETY: the caller must guarantee that `a` and `b` each added to `v_base` yield valid
390+
// SAFETY: the caller must guarantee that `a_pos` and `b_pos` each added to `v_base` yield valid
391391
// pointers into `v_base`, and are properly aligned, and part of the same allocation.
392392
unsafe {
393393
let v_a = v_base.add(a_pos);
@@ -404,16 +404,16 @@ where
404404
// The equivalent code with a branch would be:
405405
//
406406
// if should_swap {
407-
// ptr::swap(left, right, 1);
407+
// ptr::swap(v_a, v_b, 1);
408408
// }
409409

410410
// The goal is to generate cmov instructions here.
411-
let left_swap = if should_swap { v_b } else { v_a };
412-
let right_swap = if should_swap { v_a } else { v_b };
411+
let v_a_swap = should_swap.select_unpredictable(v_b, v_a);
412+
let v_b_swap = should_swap.select_unpredictable(v_a, v_b);
413413

414-
let right_swap_tmp = ManuallyDrop::new(ptr::read(right_swap));
415-
ptr::copy(left_swap, v_a, 1);
416-
ptr::copy_nonoverlapping(&*right_swap_tmp, v_b, 1);
414+
let v_b_swap_tmp = ManuallyDrop::new(ptr::read(v_b_swap));
415+
ptr::copy(v_a_swap, v_a, 1);
416+
ptr::copy_nonoverlapping(&*v_b_swap_tmp, v_b, 1);
417417
}
418418
}
419419

@@ -640,26 +640,21 @@ pub unsafe fn sort4_stable<T, F: FnMut(&T, &T) -> bool>(
640640
// 1, 1 | c b a d
641641
let c3 = is_less(&*c, &*a);
642642
let c4 = is_less(&*d, &*b);
643-
let min = select(c3, c, a);
644-
let max = select(c4, b, d);
645-
let unknown_left = select(c3, a, select(c4, c, b));
646-
let unknown_right = select(c4, d, select(c3, b, c));
643+
let min = c3.select_unpredictable(c, a);
644+
let max = c4.select_unpredictable(b, d);
645+
let unknown_left = c3.select_unpredictable(a, c4.select_unpredictable(c, b));
646+
let unknown_right = c4.select_unpredictable(d, c3.select_unpredictable(b, c));
647647

648648
// Sort the last two unknown elements.
649649
let c5 = is_less(&*unknown_right, &*unknown_left);
650-
let lo = select(c5, unknown_right, unknown_left);
651-
let hi = select(c5, unknown_left, unknown_right);
650+
let lo = c5.select_unpredictable(unknown_right, unknown_left);
651+
let hi = c5.select_unpredictable(unknown_left, unknown_right);
652652

653653
ptr::copy_nonoverlapping(min, dst, 1);
654654
ptr::copy_nonoverlapping(lo, dst.add(1), 1);
655655
ptr::copy_nonoverlapping(hi, dst.add(2), 1);
656656
ptr::copy_nonoverlapping(max, dst.add(3), 1);
657657
}
658-
659-
#[inline(always)]
660-
fn select<T>(cond: bool, if_true: *const T, if_false: *const T) -> *const T {
661-
if cond { if_true } else { if_false }
662-
}
663658
}
664659

665660
/// SAFETY: The caller MUST guarantee that `v_base` is valid for 8 reads and

0 commit comments

Comments
 (0)