Skip to content

Don't emit divide-by-zero panic paths in StepBy::len #123564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 49 additions & 31 deletions library/core/src/iter/adapters/step_by.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
intrinsics,
iter::{from_fn, TrustedLen, TrustedRandomAccess},
num::NonZeroUsize,
ops::{Range, Try},
};

Expand All @@ -22,7 +23,11 @@ pub struct StepBy<I> {
/// Additionally this type-dependent preprocessing means specialized implementations
/// cannot be used interchangeably.
iter: I,
step: usize,
/// This field is `step - 1`, aka the correct amount to pass to `nth` when iterating.
/// It MUST NOT be `usize::MAX`, as `unsafe` code depends on being able to add one
/// without the risk of overflow. (This is important so that length calculations
/// don't need to check for division-by-zero, for example.)
step_minus_one: usize,
first_take: bool,
}

Expand All @@ -31,7 +36,16 @@ impl<I> StepBy<I> {
pub(in crate::iter) fn new(iter: I, step: usize) -> StepBy<I> {
assert!(step != 0);
let iter = <I as SpecRangeSetup<I>>::setup(iter, step);
StepBy { iter, step: step - 1, first_take: true }
StepBy { iter, step_minus_one: step - 1, first_take: true }
}

/// The `step` that was originally passed to `Iterator::step_by(step)`,
/// aka `self.step_minus_one + 1`.
#[inline]
fn original_step(&self) -> NonZeroUsize {
// SAFETY: By type invariant, `step_minus_one` cannot be `MAX`, which
// means the addition cannot overflow and the result cannot be zero.
unsafe { NonZeroUsize::new_unchecked(intrinsics::unchecked_add(self.step_minus_one, 1)) }
}
}

Expand Down Expand Up @@ -81,8 +95,8 @@ where
// The zero-based index starting from the end of the iterator of the
// last element. Used in the `DoubleEndedIterator` implementation.
fn next_back_index(&self) -> usize {
let rem = self.iter.len() % (self.step + 1);
if self.first_take { if rem == 0 { self.step } else { rem - 1 } } else { rem }
let rem = self.iter.len() % self.original_step();
if self.first_take { if rem == 0 { self.step_minus_one } else { rem - 1 } } else { rem }
}
}

Expand Down Expand Up @@ -209,30 +223,30 @@ unsafe impl<I: Iterator> StepByImpl<I> for StepBy<I> {

#[inline]
default fn spec_next(&mut self) -> Option<I::Item> {
let step_size = if self.first_take { 0 } else { self.step };
let step_size = if self.first_take { 0 } else { self.step_minus_one };
self.first_take = false;
self.iter.nth(step_size)
}

#[inline]
default fn spec_size_hint(&self) -> (usize, Option<usize>) {
#[inline]
fn first_size(step: usize) -> impl Fn(usize) -> usize {
move |n| if n == 0 { 0 } else { 1 + (n - 1) / (step + 1) }
fn first_size(step: NonZeroUsize) -> impl Fn(usize) -> usize {
move |n| if n == 0 { 0 } else { 1 + (n - 1) / step }
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Div<NonZeroUsize>, for just magically doing the right thing here 🎉

}

#[inline]
fn other_size(step: usize) -> impl Fn(usize) -> usize {
move |n| n / (step + 1)
fn other_size(step: NonZeroUsize) -> impl Fn(usize) -> usize {
move |n| n / step
}

let (low, high) = self.iter.size_hint();

if self.first_take {
let f = first_size(self.step);
let f = first_size(self.original_step());
(f(low), high.map(f))
} else {
let f = other_size(self.step);
let f = other_size(self.original_step());
(f(low), high.map(f))
}
}
Expand All @@ -247,10 +261,9 @@ unsafe impl<I: Iterator> StepByImpl<I> for StepBy<I> {
}
n -= 1;
}
// n and self.step are indices, we need to add 1 to get the amount of elements
// n and self.step_minus_one are indices, we need to add 1 to get the amount of elements
// When calling `.nth`, we need to subtract 1 again to convert back to an index
// step + 1 can't overflow because `.step_by` sets `self.step` to `step - 1`
let mut step = self.step + 1;
let mut step = self.original_step().get();
// n + 1 could overflow
// thus, if n is usize::MAX, instead of adding one, we call .nth(step)
if n == usize::MAX {
Expand Down Expand Up @@ -288,8 +301,11 @@ unsafe impl<I: Iterator> StepByImpl<I> for StepBy<I> {
R: Try<Output = Acc>,
{
#[inline]
fn nth<I: Iterator>(iter: &mut I, step: usize) -> impl FnMut() -> Option<I::Item> + '_ {
move || iter.nth(step)
fn nth<I: Iterator>(
iter: &mut I,
step_minus_one: usize,
) -> impl FnMut() -> Option<I::Item> + '_ {
move || iter.nth(step_minus_one)
}

if self.first_take {
Expand All @@ -299,16 +315,19 @@ unsafe impl<I: Iterator> StepByImpl<I> for StepBy<I> {
Some(x) => acc = f(acc, x)?,
}
}
from_fn(nth(&mut self.iter, self.step)).try_fold(acc, f)
from_fn(nth(&mut self.iter, self.step_minus_one)).try_fold(acc, f)
}

default fn spec_fold<Acc, F>(mut self, mut acc: Acc, mut f: F) -> Acc
where
F: FnMut(Acc, Self::Item) -> Acc,
{
#[inline]
fn nth<I: Iterator>(iter: &mut I, step: usize) -> impl FnMut() -> Option<I::Item> + '_ {
move || iter.nth(step)
fn nth<I: Iterator>(
iter: &mut I,
step_minus_one: usize,
) -> impl FnMut() -> Option<I::Item> + '_ {
move || iter.nth(step_minus_one)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I went through I also tried to be consistent about naming things step_minus_one when they're not the original step, since it was confusing to have the mix.

}

if self.first_take {
Expand All @@ -318,7 +337,7 @@ unsafe impl<I: Iterator> StepByImpl<I> for StepBy<I> {
Some(x) => acc = f(acc, x),
}
}
from_fn(nth(&mut self.iter, self.step)).fold(acc, f)
from_fn(nth(&mut self.iter, self.step_minus_one)).fold(acc, f)
}
}

Expand All @@ -336,7 +355,7 @@ unsafe impl<I: DoubleEndedIterator + ExactSizeIterator> StepByBackImpl<I> for St
// is out of bounds because the length of `self.iter` does not exceed
// `usize::MAX` (because `I: ExactSizeIterator`) and `nth_back` is
// zero-indexed
let n = n.saturating_mul(self.step + 1).saturating_add(self.next_back_index());
let n = n.saturating_mul(self.original_step().get()).saturating_add(self.next_back_index());
self.iter.nth_back(n)
}

Expand All @@ -348,16 +367,16 @@ unsafe impl<I: DoubleEndedIterator + ExactSizeIterator> StepByBackImpl<I> for St
#[inline]
fn nth_back<I: DoubleEndedIterator>(
iter: &mut I,
step: usize,
step_minus_one: usize,
) -> impl FnMut() -> Option<I::Item> + '_ {
move || iter.nth_back(step)
move || iter.nth_back(step_minus_one)
}

match self.next_back() {
None => try { init },
Some(x) => {
let acc = f(init, x)?;
from_fn(nth_back(&mut self.iter, self.step)).try_fold(acc, f)
from_fn(nth_back(&mut self.iter, self.step_minus_one)).try_fold(acc, f)
}
}
}
Expand All @@ -371,16 +390,16 @@ unsafe impl<I: DoubleEndedIterator + ExactSizeIterator> StepByBackImpl<I> for St
#[inline]
fn nth_back<I: DoubleEndedIterator>(
iter: &mut I,
step: usize,
step_minus_one: usize,
) -> impl FnMut() -> Option<I::Item> + '_ {
move || iter.nth_back(step)
move || iter.nth_back(step_minus_one)
}

match self.next_back() {
None => init,
Some(x) => {
let acc = f(init, x);
from_fn(nth_back(&mut self.iter, self.step)).fold(acc, f)
from_fn(nth_back(&mut self.iter, self.step_minus_one)).fold(acc, f)
}
}
}
Expand Down Expand Up @@ -424,8 +443,7 @@ macro_rules! spec_int_ranges {
fn spec_next(&mut self) -> Option<$t> {
// if a step size larger than the type has been specified fall back to
// t::MAX, in which case remaining will be at most 1.
// The `+ 1` can't overflow since the constructor substracted 1 from the original value.
let step = <$t>::try_from(self.step + 1).unwrap_or(<$t>::MAX);
let step = <$t>::try_from(self.original_step().get()).unwrap_or(<$t>::MAX);
let remaining = self.iter.end;
if remaining > 0 {
let val = self.iter.start;
Expand Down Expand Up @@ -474,7 +492,7 @@ macro_rules! spec_int_ranges {
{
// if a step size larger than the type has been specified fall back to
// t::MAX, in which case remaining will be at most 1.
let step = <$t>::try_from(self.step + 1).unwrap_or(<$t>::MAX);
let step = <$t>::try_from(self.original_step().get()).unwrap_or(<$t>::MAX);
let remaining = self.iter.end;
let mut acc = init;
let mut val = self.iter.start;
Expand All @@ -500,7 +518,7 @@ macro_rules! spec_int_ranges_r {
fn spec_next_back(&mut self) -> Option<Self::Item>
where Range<$t>: DoubleEndedIterator + ExactSizeIterator,
{
let step = (self.step + 1) as $t;
let step = self.original_step().get() as $t;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like how here the two steps in the previous code meant very different things, but since the let step is the width (not the "what to pass to nth") it's still called step.

let remaining = self.iter.end;
if remaining > 0 {
let start = self.iter.start;
Expand Down
26 changes: 26 additions & 0 deletions tests/codegen/step_by-overflow-checks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//@ compile-flags: -O

#![crate_type = "lib"]

use std::iter::StepBy;
use std::slice::Iter;

// The constructor for `StepBy` ensures we can never end up needing to do zero
// checks on denominators, so check that the code isn't emitting panic paths.

// CHECK-LABEL: @step_by_len_std
#[no_mangle]
pub fn step_by_len_std(x: &StepBy<Iter<i32>>) -> usize {
// CHECK-NOT: div_by_zero
// CHECK: udiv
// CHECK-NOT: div_by_zero
x.len()
}

// CHECK-LABEL: @step_by_len_naive
#[no_mangle]
pub fn step_by_len_naive(x: Iter<i32>, step_minus_one: usize) -> usize {
// CHECK: udiv
// CHECK: call{{.+}}div_by_zero
x.len() / (step_minus_one + 1)
}
Loading