Skip to content

Commit

Permalink
[heapsort] Protect against integer overflow
Browse files Browse the repository at this point in the history
(Firstly, I changed `n` to `b`, as that is less confusing. It's not a length, it's a right boundary.)

The invariant maintained is `cur < b`. In the worst case `2*cur + 1` results in a maximum of `2b`. Since `2b` is not guaranteed to be lower than `maxInt`, we have to add one overflow check to `siftDown` to make sure we avoid undefined behavior.

LLVM also seems to have a nicer time compiling this version of the function. It is about 2x faster in my tests (I think LLVM was stumped by the `child += @intFromBool` line), and adding/removing the overflow check has a negligible performance difference on my machine. Of course, we could check `2b <= maxInt` in the parent function, and dispatch to a version of the function without the overflow check in the common case, but that probably is not worth the code size just to eliminate a single instruction.
  • Loading branch information
Validark authored Jun 22, 2023
1 parent c608967 commit 7d511d6
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions lib/std/sort.zig
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pub fn insertion(
/// O(1) memory (no allocator required).
/// Sorts in ascending order with respect to the given `lessThan` function.
pub fn insertionContext(a: usize, b: usize, context: anytype) void {
assert(a <= b);

var i = a + 1;
while (i < b) : (i += 1) {
var j = i;
Expand Down Expand Up @@ -73,6 +75,7 @@ pub fn heap(
/// O(1) memory (no allocator required).
/// Sorts in ascending order with respect to the given `lessThan` function.
pub fn heapContext(a: usize, b: usize, context: anytype) void {
assert(a <= b);
// build the heap in linear time.
var i = a + (b - a) / 2;
while (i > a) {
Expand All @@ -89,22 +92,33 @@ pub fn heapContext(a: usize, b: usize, context: anytype) void {
}
}

fn siftDown(a: usize, root: usize, n: usize, context: anytype) void {
var node = root;
fn siftDown(a: usize, target: usize, b: usize, context: anytype) void {
var cur = target;
while (true) {
var child = a + 2 * (node - a) + 1;
if (child >= n) break;
// When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1
// The `+ a + 1` is safe because:
// for `a > 0` then `2a >= a + 1`.
// for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe.
var child = (math.mul(usize, cur - a, 2) catch break) + a + 1;

// stop if we overshot the boundary
if (!(child < b)) break;

// choose the greater child.
child += @intFromBool(child + 1 < n and context.lessThan(child, child + 1));
// `next_child` is at most `b`, therefore no overflow is possible
const next_child = child + 1;

// store the greater child in `child`
if (next_child < b and context.lessThan(child, next_child)) {
child = next_child;
}

// stop if the invariant holds at `node`.
if (!context.lessThan(node, child)) break;
// stop if the Heap invariant holds at `cur`.
if (context.lessThan(child, cur)) break;

// swap `node` with the greater child,
// swap `cur` with the greater child,
// move one step down, and continue sifting.
context.swap(node, child);
node = child;
context.swap(child, cur);
cur = child;
}
}

Expand Down

0 comments on commit 7d511d6

Please sign in to comment.