Skip to content

[WIP] rewrite TrustedRandomAccess into two directional variants #120682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions library/alloc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
#![feature(str_internals)]
#![feature(strict_provenance)]
#![feature(trusted_fused)]
#![feature(trusted_indexed_access)]
#![feature(trusted_len)]
#![feature(trusted_random_access)]
#![feature(try_trait_v2)]
Expand Down
271 changes: 175 additions & 96 deletions library/alloc/src/vec/in_place_collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
//! # O(1) collect
//!
//! The main iteration itself is further specialized when the iterator implements
//! [`TrustedRandomAccessNoCoerce`] to let the optimizer see that it is a counted loop with a single
//! [`UncheckedIndexedIterator`] to let the optimizer see that it is a counted loop with a single
//! [induction variable]. This can turn some iterators into a noop, i.e. it reduces them from O(n) to
//! O(1). This particular optimization is quite fickle and doesn't always work, see [#79308]
//!
Expand Down Expand Up @@ -157,8 +157,10 @@
use crate::alloc::{handle_alloc_error, Global};
use core::alloc::Allocator;
use core::alloc::Layout;
use core::iter::{InPlaceIterable, SourceIter, TrustedRandomAccessNoCoerce};
use core::iter::UncheckedIndexedIterator;
use core::iter::{InPlaceIterable, SourceIter};
use core::marker::PhantomData;
use core::mem::needs_drop;
use core::mem::{self, ManuallyDrop, SizedTypeProperties};
use core::num::NonZero;
use core::ptr::{self, NonNull};
Expand Down Expand Up @@ -229,96 +231,105 @@ where
I: Iterator<Item = T> + InPlaceCollect,
<I as SourceIter>::Source: AsVecIntoIter,
{
default fn from_iter(mut iterator: I) -> Self {
default fn from_iter(iterator: I) -> Self {
// See "Layout constraints" section in the module documentation. We rely on const
// optimization here since these conditions currently cannot be expressed as trait bounds
if const { !in_place_collectible::<T, I::Src>(I::MERGE_BY, I::EXPAND_BY) } {
// fallback to more generic implementations
return SpecFromIterNested::from_iter(iterator);
}

let (src_buf, src_ptr, src_cap, mut dst_buf, dst_end, dst_cap) = unsafe {
let inner = iterator.as_inner().as_into_iter();
(
inner.buf.as_ptr(),
inner.ptr,
inner.cap,
inner.buf.as_ptr() as *mut T,
inner.end as *const T,
inner.cap * mem::size_of::<I::Src>() / mem::size_of::<T>(),
)
let fun = const {
if !in_place_collectible::<T, I::Src>(I::MERGE_BY, I::EXPAND_BY) {
SpecFromIterNested::<T, I>::from_iter
} else {
from_iter
}
};

// SAFETY: `dst_buf` and `dst_end` are the start and end of the buffer.
let len = unsafe { SpecInPlaceCollect::collect_in_place(&mut iterator, dst_buf, dst_end) };

let src = unsafe { iterator.as_inner().as_into_iter() };
// check if SourceIter contract was upheld
// caveat: if they weren't we might not even make it to this point
debug_assert_eq!(src_buf, src.buf.as_ptr());
// check InPlaceIterable contract. This is only possible if the iterator advanced the
// source pointer at all. If it uses unchecked access via TrustedRandomAccess
// then the source pointer will stay in its initial position and we can't use it as reference
if src.ptr != src_ptr {
debug_assert!(
unsafe { dst_buf.add(len) as *const _ } <= src.ptr.as_ptr(),
"InPlaceIterable contract violation, write pointer advanced beyond read pointer"
);
}
fun(iterator)
}
}

// The ownership of the source allocation and the new `T` values is temporarily moved into `dst_guard`.
// This is safe because
// * `forget_allocation_drop_remaining` immediately forgets the allocation
// before any panic can occur in order to avoid any double free, and then proceeds to drop
// any remaining values at the tail of the source.
// * the shrink either panics without invalidating the allocation, aborts or
// succeeds. In the last case we disarm the guard.
//
// Note: This access to the source wouldn't be allowed by the TrustedRandomIteratorNoCoerce
// contract (used by SpecInPlaceCollect below). But see the "O(1) collect" section in the
// module documentation why this is ok anyway.
let dst_guard =
InPlaceDstDataSrcBufDrop { ptr: dst_buf, len, src_cap, src: PhantomData::<I::Src> };
src.forget_allocation_drop_remaining();

// Adjust the allocation if the source had a capacity in bytes that wasn't a multiple
// of the destination type size.
// Since the discrepancy should generally be small this should only result in some
// bookkeeping updates and no memmove.
if needs_realloc::<I::Src, T>(src_cap, dst_cap) {
let alloc = Global;
debug_assert_ne!(src_cap, 0);
debug_assert_ne!(dst_cap, 0);
unsafe {
// The old allocation exists, therefore it must have a valid layout.
let src_align = mem::align_of::<I::Src>();
let src_size = mem::size_of::<I::Src>().unchecked_mul(src_cap);
let old_layout = Layout::from_size_align_unchecked(src_size, src_align);

// The allocation must be equal or smaller for in-place iteration to be possible
// therefore the new layout must be ≤ the old one and therefore valid.
let dst_align = mem::align_of::<T>();
let dst_size = mem::size_of::<T>().unchecked_mul(dst_cap);
let new_layout = Layout::from_size_align_unchecked(dst_size, dst_align);

let result = alloc.shrink(
NonNull::new_unchecked(dst_buf as *mut u8),
old_layout,
new_layout,
);
let Ok(reallocated) = result else { handle_alloc_error(new_layout) };
dst_buf = reallocated.as_ptr() as *mut T;
}
} else {
debug_assert_eq!(src_cap * mem::size_of::<I::Src>(), dst_cap * mem::size_of::<T>());
}
fn from_iter<I, T>(mut iterator: I) -> Vec<T>
where
I: Iterator<Item = T> + InPlaceCollect,
<I as SourceIter>::Source: AsVecIntoIter,
{
let (src_buf, src_ptr, src_cap, mut dst_buf, dst_end, dst_cap) = unsafe {
let inner = iterator.as_inner().as_into_iter();
(
inner.buf.as_ptr(),
inner.ptr,
inner.cap,
inner.buf.as_ptr() as *mut T,
inner.end as *const T,
inner.cap * mem::size_of::<I::Src>() / mem::size_of::<T>(),
)
};

// SAFETY: `dst_buf` and `dst_end` are the start and end of the buffer.
let len = unsafe { SpecInPlaceCollect::collect_in_place(&mut iterator, dst_buf, dst_end) };

mem::forget(dst_guard);
let src = unsafe { iterator.as_inner().as_into_iter() };
// check if SourceIter contract was upheld
// caveat: if they weren't we might not even make it to this point
debug_assert_eq!(src_buf, src.buf.as_ptr());
// check InPlaceIterable contract. This is only possible if the iterator advanced the
// source pointer at all. If it uses unchecked access via UncheckedIndexedIterator
// and doesn't perform cleanup then the source pointer will stay in its initial position
// and we can't use it as reference.
if src.ptr != src_ptr {
debug_assert!(
unsafe { dst_buf.add(len) as *const _ } <= src.ptr.as_ptr(),
"InPlaceIterable contract violation, write pointer advanced beyond read pointer"
);
}

// The ownership of the source allocation and the new `T` values is temporarily moved into `dst_guard`.
// This is safe because
// * `forget_allocation_drop_remaining` immediately forgets the allocation
// before any panic can occur in order to avoid any double free, and then proceeds to drop
// any remaining values at the tail of the source.
// * the shrink either panics without invalidating the allocation, aborts or
// succeeds. In the last case we disarm the guard.
//
// Note: This access to the source wouldn't be allowed by the TrustedRandomIteratorNoCoerce
// contract (used by SpecInPlaceCollect below). But see the "O(1) collect" section in the
// module documentation why this is ok anyway.
let dst_guard =
InPlaceDstDataSrcBufDrop { ptr: dst_buf, len, src_cap, src: PhantomData::<I::Src> };
src.forget_allocation_drop_remaining();

// Adjust the allocation if the source had a capacity in bytes that wasn't a multiple
// of the destination type size.
// Since the discrepancy should generally be small this should only result in some
// bookkeeping updates and no memmove.
if needs_realloc::<I::Src, T>(src_cap, dst_cap) {
let alloc = Global;
debug_assert_ne!(src_cap, 0);
debug_assert_ne!(dst_cap, 0);
unsafe {
// The old allocation exists, therefore it must have a valid layout.
let src_align = mem::align_of::<I::Src>();
let src_size = mem::size_of::<I::Src>().unchecked_mul(src_cap);
let old_layout = Layout::from_size_align_unchecked(src_size, src_align);

let vec = unsafe { Vec::from_raw_parts(dst_buf, len, dst_cap) };
// The allocation must be equal or smaller for in-place iteration to be possible
// therefore the new layout must be ≤ the old one and therefore valid.
let dst_align = mem::align_of::<T>();
let dst_size = mem::size_of::<T>().unchecked_mul(dst_cap);
let new_layout = Layout::from_size_align_unchecked(dst_size, dst_align);

vec
let result =
alloc.shrink(NonNull::new_unchecked(dst_buf as *mut u8), old_layout, new_layout);
let Ok(reallocated) = result else { handle_alloc_error(new_layout) };
dst_buf = reallocated.as_ptr() as *mut T;
}
} else {
debug_assert_eq!(src_cap * mem::size_of::<I::Src>(), dst_cap * mem::size_of::<T>());
}

mem::forget(dst_guard);

let vec = unsafe { Vec::from_raw_parts(dst_buf, len, dst_cap) };

vec
}

fn write_in_place_with_drop<T>(
Expand Down Expand Up @@ -369,28 +380,96 @@ where
}
}

// impl<T, I> SpecInPlaceCollect<T, I> for I
// where
// I: Iterator<Item = T> + TrustedRandomAccessNoCoerce,
// {
// #[inline]
// unsafe fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
// let len = self.size();
// let mut drop_guard = InPlaceDrop { inner: dst_buf, dst: dst_buf };
// for i in 0..len {
// // Safety: InplaceIterable contract guarantees that for every element we read
// // one slot in the underlying storage will have been freed up and we can immediately
// // write back the result.
// unsafe {
// let dst = dst_buf.add(i);
// debug_assert!(dst as *const _ <= end, "InPlaceIterable contract violation");
// ptr::write(dst, self.__iterator_get_unchecked(i));
// // Since this executes user code which can panic we have to bump the pointer
// // after each step.
// drop_guard.dst = dst.add(1);
// }
// }
// mem::forget(drop_guard);
// len
// }
// }

impl<T, I> SpecInPlaceCollect<T, I> for I
where
I: Iterator<Item = T> + TrustedRandomAccessNoCoerce,
I: Iterator<Item = T> + UncheckedIndexedIterator,
{
#[inline]
unsafe fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
let len = self.size();
let mut drop_guard = InPlaceDrop { inner: dst_buf, dst: dst_buf };
for i in 0..len {
// Safety: InplaceIterable contract guarantees that for every element we read
// one slot in the underlying storage will have been freed up and we can immediately
// write back the result.
let len = self.size_hint().0;

if len == 0 {
return 0;
}

struct LoopGuard<'a, I>
where
I: Iterator + UncheckedIndexedIterator,
{
it: &'a mut I,
len: usize,
idx: usize,
dst_buf: *mut I::Item,
}

impl<I> Drop for LoopGuard<'_, I>
where
I: Iterator + UncheckedIndexedIterator,
{
#[inline]
fn drop(&mut self) {
unsafe {
let new_len = self.len - self.idx;
if I::CLEANUP_ON_DROP {
self.it.set_front_index_from_end_unchecked(new_len, self.len);
}
if needs_drop::<I::Item>() && self.idx != self.len {
let raw_slice =
ptr::slice_from_raw_parts_mut::<I::Item>(self.dst_buf, self.idx);
ptr::drop_in_place(raw_slice);
}
}
}
}

let mut state = LoopGuard { it: self, len, idx: 0, dst_buf };

loop {
unsafe {
let dst = dst_buf.add(i);
let idx = state.idx;
state.idx = idx.unchecked_add(1);
let dst = state.dst_buf.add(idx);
debug_assert!(dst as *const _ <= end, "InPlaceIterable contract violation");
ptr::write(dst, self.__iterator_get_unchecked(i));
// Since this executes user code which can panic we have to bump the pointer
// after each step.
drop_guard.dst = dst.add(1);
dst.write(state.it.index_from_end_unchecked(len - idx));
}
if state.idx == len {
break;
}
}

// disarm guard, we don't want the front elements to get dropped
mem::forget(state);
// since the guard was disarmed, update the iterator state
if Self::CLEANUP_ON_DROP {
unsafe { self.set_front_index_from_end_unchecked(0, len) };
}
mem::forget(drop_guard);

len
}
}
Expand Down
Loading