Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit dd5970f

Browse files
committedMar 7, 2024
Auto merge of #120682 - the8472:indexed-access, r=<try>
[WIP] rewrite TrustedRandomAccess into two directional variants r? `@ghost`
2 parents 1c580bc + 2981354 commit dd5970f

File tree

25 files changed

+1096
-122
lines changed

25 files changed

+1096
-122
lines changed
 

‎library/alloc/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@
160160
#![feature(str_internals)]
161161
#![feature(strict_provenance)]
162162
#![feature(trusted_fused)]
163+
#![feature(trusted_indexed_access)]
163164
#![feature(trusted_len)]
164165
#![feature(trusted_random_access)]
165166
#![feature(try_trait_v2)]

‎library/alloc/src/vec/in_place_collect.rs

+175-96
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
//! # O(1) collect
8181
//!
8282
//! The main iteration itself is further specialized when the iterator implements
83-
//! [`TrustedRandomAccessNoCoerce`] to let the optimizer see that it is a counted loop with a single
83+
//! [`UncheckedIndexedIterator`] to let the optimizer see that it is a counted loop with a single
8484
//! [induction variable]. This can turn some iterators into a noop, i.e. it reduces them from O(n) to
8585
//! O(1). This particular optimization is quite fickle and doesn't always work, see [#79308]
8686
//!
@@ -157,8 +157,10 @@
157157
use crate::alloc::{handle_alloc_error, Global};
158158
use core::alloc::Allocator;
159159
use core::alloc::Layout;
160-
use core::iter::{InPlaceIterable, SourceIter, TrustedRandomAccessNoCoerce};
160+
use core::iter::UncheckedIndexedIterator;
161+
use core::iter::{InPlaceIterable, SourceIter};
161162
use core::marker::PhantomData;
163+
use core::mem::needs_drop;
162164
use core::mem::{self, ManuallyDrop, SizedTypeProperties};
163165
use core::num::NonZero;
164166
use core::ptr::{self, NonNull};
@@ -229,96 +231,105 @@ where
229231
I: Iterator<Item = T> + InPlaceCollect,
230232
<I as SourceIter>::Source: AsVecIntoIter,
231233
{
232-
default fn from_iter(mut iterator: I) -> Self {
234+
default fn from_iter(iterator: I) -> Self {
233235
// See "Layout constraints" section in the module documentation. We rely on const
234236
// optimization here since these conditions currently cannot be expressed as trait bounds
235-
if const { !in_place_collectible::<T, I::Src>(I::MERGE_BY, I::EXPAND_BY) } {
236-
// fallback to more generic implementations
237-
return SpecFromIterNested::from_iter(iterator);
238-
}
239-
240-
let (src_buf, src_ptr, src_cap, mut dst_buf, dst_end, dst_cap) = unsafe {
241-
let inner = iterator.as_inner().as_into_iter();
242-
(
243-
inner.buf.as_ptr(),
244-
inner.ptr,
245-
inner.cap,
246-
inner.buf.as_ptr() as *mut T,
247-
inner.end as *const T,
248-
inner.cap * mem::size_of::<I::Src>() / mem::size_of::<T>(),
249-
)
237+
let fun = const {
238+
if !in_place_collectible::<T, I::Src>(I::MERGE_BY, I::EXPAND_BY) {
239+
SpecFromIterNested::<T, I>::from_iter
240+
} else {
241+
from_iter
242+
}
250243
};
251244

252-
// SAFETY: `dst_buf` and `dst_end` are the start and end of the buffer.
253-
let len = unsafe { SpecInPlaceCollect::collect_in_place(&mut iterator, dst_buf, dst_end) };
254-
255-
let src = unsafe { iterator.as_inner().as_into_iter() };
256-
// check if SourceIter contract was upheld
257-
// caveat: if they weren't we might not even make it to this point
258-
debug_assert_eq!(src_buf, src.buf.as_ptr());
259-
// check InPlaceIterable contract. This is only possible if the iterator advanced the
260-
// source pointer at all. If it uses unchecked access via TrustedRandomAccess
261-
// then the source pointer will stay in its initial position and we can't use it as reference
262-
if src.ptr != src_ptr {
263-
debug_assert!(
264-
unsafe { dst_buf.add(len) as *const _ } <= src.ptr.as_ptr(),
265-
"InPlaceIterable contract violation, write pointer advanced beyond read pointer"
266-
);
267-
}
245+
fun(iterator)
246+
}
247+
}
268248

269-
// The ownership of the source allocation and the new `T` values is temporarily moved into `dst_guard`.
270-
// This is safe because
271-
// * `forget_allocation_drop_remaining` immediately forgets the allocation
272-
// before any panic can occur in order to avoid any double free, and then proceeds to drop
273-
// any remaining values at the tail of the source.
274-
// * the shrink either panics without invalidating the allocation, aborts or
275-
// succeeds. In the last case we disarm the guard.
276-
//
277-
// Note: This access to the source wouldn't be allowed by the TrustedRandomIteratorNoCoerce
278-
// contract (used by SpecInPlaceCollect below). But see the "O(1) collect" section in the
279-
// module documentation why this is ok anyway.
280-
let dst_guard =
281-
InPlaceDstDataSrcBufDrop { ptr: dst_buf, len, src_cap, src: PhantomData::<I::Src> };
282-
src.forget_allocation_drop_remaining();
283-
284-
// Adjust the allocation if the source had a capacity in bytes that wasn't a multiple
285-
// of the destination type size.
286-
// Since the discrepancy should generally be small this should only result in some
287-
// bookkeeping updates and no memmove.
288-
if needs_realloc::<I::Src, T>(src_cap, dst_cap) {
289-
let alloc = Global;
290-
debug_assert_ne!(src_cap, 0);
291-
debug_assert_ne!(dst_cap, 0);
292-
unsafe {
293-
// The old allocation exists, therefore it must have a valid layout.
294-
let src_align = mem::align_of::<I::Src>();
295-
let src_size = mem::size_of::<I::Src>().unchecked_mul(src_cap);
296-
let old_layout = Layout::from_size_align_unchecked(src_size, src_align);
297-
298-
// The allocation must be equal or smaller for in-place iteration to be possible
299-
// therefore the new layout must be ≤ the old one and therefore valid.
300-
let dst_align = mem::align_of::<T>();
301-
let dst_size = mem::size_of::<T>().unchecked_mul(dst_cap);
302-
let new_layout = Layout::from_size_align_unchecked(dst_size, dst_align);
303-
304-
let result = alloc.shrink(
305-
NonNull::new_unchecked(dst_buf as *mut u8),
306-
old_layout,
307-
new_layout,
308-
);
309-
let Ok(reallocated) = result else { handle_alloc_error(new_layout) };
310-
dst_buf = reallocated.as_ptr() as *mut T;
311-
}
312-
} else {
313-
debug_assert_eq!(src_cap * mem::size_of::<I::Src>(), dst_cap * mem::size_of::<T>());
314-
}
249+
fn from_iter<I, T>(mut iterator: I) -> Vec<T>
250+
where
251+
I: Iterator<Item = T> + InPlaceCollect,
252+
<I as SourceIter>::Source: AsVecIntoIter,
253+
{
254+
let (src_buf, src_ptr, src_cap, mut dst_buf, dst_end, dst_cap) = unsafe {
255+
let inner = iterator.as_inner().as_into_iter();
256+
(
257+
inner.buf.as_ptr(),
258+
inner.ptr,
259+
inner.cap,
260+
inner.buf.as_ptr() as *mut T,
261+
inner.end as *const T,
262+
inner.cap * mem::size_of::<I::Src>() / mem::size_of::<T>(),
263+
)
264+
};
265+
266+
// SAFETY: `dst_buf` and `dst_end` are the start and end of the buffer.
267+
let len = unsafe { SpecInPlaceCollect::collect_in_place(&mut iterator, dst_buf, dst_end) };
315268

316-
mem::forget(dst_guard);
269+
let src = unsafe { iterator.as_inner().as_into_iter() };
270+
// check if SourceIter contract was upheld
271+
// caveat: if they weren't we might not even make it to this point
272+
debug_assert_eq!(src_buf, src.buf.as_ptr());
273+
// check InPlaceIterable contract. This is only possible if the iterator advanced the
274+
// source pointer at all. If it uses unchecked access via UncheckedIndexedIterator
275+
// and doesn't perform cleanup then the source pointer will stay in its initial position
276+
// and we can't use it as reference.
277+
if src.ptr != src_ptr {
278+
debug_assert!(
279+
unsafe { dst_buf.add(len) as *const _ } <= src.ptr.as_ptr(),
280+
"InPlaceIterable contract violation, write pointer advanced beyond read pointer"
281+
);
282+
}
283+
284+
// The ownership of the source allocation and the new `T` values is temporarily moved into `dst_guard`.
285+
// This is safe because
286+
// * `forget_allocation_drop_remaining` immediately forgets the allocation
287+
// before any panic can occur in order to avoid any double free, and then proceeds to drop
288+
// any remaining values at the tail of the source.
289+
// * the shrink either panics without invalidating the allocation, aborts or
290+
// succeeds. In the last case we disarm the guard.
291+
//
292+
// Note: This access to the source wouldn't be allowed by the TrustedRandomIteratorNoCoerce
293+
// contract (used by SpecInPlaceCollect below). But see the "O(1) collect" section in the
294+
// module documentation why this is ok anyway.
295+
let dst_guard =
296+
InPlaceDstDataSrcBufDrop { ptr: dst_buf, len, src_cap, src: PhantomData::<I::Src> };
297+
src.forget_allocation_drop_remaining();
298+
299+
// Adjust the allocation if the source had a capacity in bytes that wasn't a multiple
300+
// of the destination type size.
301+
// Since the discrepancy should generally be small this should only result in some
302+
// bookkeeping updates and no memmove.
303+
if needs_realloc::<I::Src, T>(src_cap, dst_cap) {
304+
let alloc = Global;
305+
debug_assert_ne!(src_cap, 0);
306+
debug_assert_ne!(dst_cap, 0);
307+
unsafe {
308+
// The old allocation exists, therefore it must have a valid layout.
309+
let src_align = mem::align_of::<I::Src>();
310+
let src_size = mem::size_of::<I::Src>().unchecked_mul(src_cap);
311+
let old_layout = Layout::from_size_align_unchecked(src_size, src_align);
317312

318-
let vec = unsafe { Vec::from_raw_parts(dst_buf, len, dst_cap) };
313+
// The allocation must be equal or smaller for in-place iteration to be possible
314+
// therefore the new layout must be ≤ the old one and therefore valid.
315+
let dst_align = mem::align_of::<T>();
316+
let dst_size = mem::size_of::<T>().unchecked_mul(dst_cap);
317+
let new_layout = Layout::from_size_align_unchecked(dst_size, dst_align);
319318

320-
vec
319+
let result =
320+
alloc.shrink(NonNull::new_unchecked(dst_buf as *mut u8), old_layout, new_layout);
321+
let Ok(reallocated) = result else { handle_alloc_error(new_layout) };
322+
dst_buf = reallocated.as_ptr() as *mut T;
323+
}
324+
} else {
325+
debug_assert_eq!(src_cap * mem::size_of::<I::Src>(), dst_cap * mem::size_of::<T>());
321326
}
327+
328+
mem::forget(dst_guard);
329+
330+
let vec = unsafe { Vec::from_raw_parts(dst_buf, len, dst_cap) };
331+
332+
vec
322333
}
323334

324335
fn write_in_place_with_drop<T>(
@@ -369,28 +380,96 @@ where
369380
}
370381
}
371382

383+
// impl<T, I> SpecInPlaceCollect<T, I> for I
384+
// where
385+
// I: Iterator<Item = T> + TrustedRandomAccessNoCoerce,
386+
// {
387+
// #[inline]
388+
// unsafe fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
389+
// let len = self.size();
390+
// let mut drop_guard = InPlaceDrop { inner: dst_buf, dst: dst_buf };
391+
// for i in 0..len {
392+
// // Safety: InplaceIterable contract guarantees that for every element we read
393+
// // one slot in the underlying storage will have been freed up and we can immediately
394+
// // write back the result.
395+
// unsafe {
396+
// let dst = dst_buf.add(i);
397+
// debug_assert!(dst as *const _ <= end, "InPlaceIterable contract violation");
398+
// ptr::write(dst, self.__iterator_get_unchecked(i));
399+
// // Since this executes user code which can panic we have to bump the pointer
400+
// // after each step.
401+
// drop_guard.dst = dst.add(1);
402+
// }
403+
// }
404+
// mem::forget(drop_guard);
405+
// len
406+
// }
407+
// }
408+
372409
impl<T, I> SpecInPlaceCollect<T, I> for I
373410
where
374-
I: Iterator<Item = T> + TrustedRandomAccessNoCoerce,
411+
I: Iterator<Item = T> + UncheckedIndexedIterator,
375412
{
376413
#[inline]
377414
unsafe fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
378-
let len = self.size();
379-
let mut drop_guard = InPlaceDrop { inner: dst_buf, dst: dst_buf };
380-
for i in 0..len {
381-
// Safety: InplaceIterable contract guarantees that for every element we read
382-
// one slot in the underlying storage will have been freed up and we can immediately
383-
// write back the result.
415+
let len = self.size_hint().0;
416+
417+
if len == 0 {
418+
return 0;
419+
}
420+
421+
struct LoopGuard<'a, I>
422+
where
423+
I: Iterator + UncheckedIndexedIterator,
424+
{
425+
it: &'a mut I,
426+
len: usize,
427+
idx: usize,
428+
dst_buf: *mut I::Item,
429+
}
430+
431+
impl<I> Drop for LoopGuard<'_, I>
432+
where
433+
I: Iterator + UncheckedIndexedIterator,
434+
{
435+
#[inline]
436+
fn drop(&mut self) {
437+
unsafe {
438+
let new_len = self.len - self.idx;
439+
if I::CLEANUP_ON_DROP {
440+
self.it.set_front_index_from_end_unchecked(new_len, self.len);
441+
}
442+
if needs_drop::<I::Item>() && self.idx != self.len {
443+
let raw_slice =
444+
ptr::slice_from_raw_parts_mut::<I::Item>(self.dst_buf, self.idx);
445+
ptr::drop_in_place(raw_slice);
446+
}
447+
}
448+
}
449+
}
450+
451+
let mut state = LoopGuard { it: self, len, idx: 0, dst_buf };
452+
453+
loop {
384454
unsafe {
385-
let dst = dst_buf.add(i);
455+
let idx = state.idx;
456+
state.idx = idx.unchecked_add(1);
457+
let dst = state.dst_buf.add(idx);
386458
debug_assert!(dst as *const _ <= end, "InPlaceIterable contract violation");
387-
ptr::write(dst, self.__iterator_get_unchecked(i));
388-
// Since this executes user code which can panic we have to bump the pointer
389-
// after each step.
390-
drop_guard.dst = dst.add(1);
459+
dst.write(state.it.index_from_end_unchecked(len - idx));
391460
}
461+
if state.idx == len {
462+
break;
463+
}
464+
}
465+
466+
// disarm guard, we don't want the front elements to get dropped
467+
mem::forget(state);
468+
// since the guard was disarmed, update the iterator state
469+
if Self::CLEANUP_ON_DROP {
470+
unsafe { self.set_front_index_from_end_unchecked(0, len) };
392471
}
393-
mem::forget(drop_guard);
472+
394473
len
395474
}
396475
}

0 commit comments

Comments
 (0)
Please sign in to comment.