Skip to content

Commit 7631259

Browse files
committed
refactor(turbo-tasks): Tighten up id factory overflow checks, tweak API to make construction easier
1 parent 27d2fa2 commit 7631259

File tree

8 files changed

+175
-68
lines changed

8 files changed

+175
-68
lines changed

turbopack/crates/turbo-tasks-backend/src/backend/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,12 @@ impl<B: BackingStorage> TurboTasksBackendInner<B> {
218218
start_time: Instant::now(),
219219
session_id: backing_storage.next_session_id(),
220220
persisted_task_id_factory: IdFactoryWithReuse::new(
221-
*backing_storage.next_free_task_id() as u64,
222-
(TRANSIENT_TASK_BIT - 1) as u64,
221+
backing_storage.next_free_task_id(),
222+
TaskId::try_from(TRANSIENT_TASK_BIT - 1).unwrap(),
223223
),
224224
transient_task_id_factory: IdFactoryWithReuse::new(
225-
TRANSIENT_TASK_BIT as u64,
226-
u32::MAX as u64,
225+
TaskId::try_from(TRANSIENT_TASK_BIT).unwrap(),
226+
TaskId::MAX,
227227
),
228228
persisted_task_cache_log: need_log.then(|| Sharded::new(shard_amount)),
229229
task_cache: BiMap::new(),

turbopack/crates/turbo-tasks-backend/src/kv_backing_storage.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,13 @@ impl<T: KeyValueDatabase + Send + Sync + 'static> BackingStorage
128128
}
129129

130130
fn next_free_task_id(&self) -> TaskId {
131-
TaskId::from(get_infra_u32(&self.database, META_KEY_NEXT_FREE_TASK_ID).unwrap_or(1))
131+
TaskId::try_from(get_infra_u32(&self.database, META_KEY_NEXT_FREE_TASK_ID).unwrap_or(1))
132+
.unwrap()
132133
}
133134

134135
fn next_session_id(&self) -> SessionId {
135-
SessionId::from(get_infra_u32(&self.database, META_KEY_SESSION_ID).unwrap_or(0) + 1)
136+
SessionId::try_from(get_infra_u32(&self.database, META_KEY_SESSION_ID).unwrap_or(0) + 1)
137+
.unwrap()
136138
}
137139

138140
fn uncompleted_operations(&self) -> Vec<AnyOperation> {
@@ -367,7 +369,7 @@ impl<T: KeyValueDatabase + Send + Sync + 'static> BackingStorage
367369
return Ok(None);
368370
};
369371
let bytes = bytes.borrow().try_into()?;
370-
let id = TaskId::from(u32::from_le_bytes(bytes));
372+
let id = TaskId::try_from(u32::from_le_bytes(bytes)).unwrap();
371373
Ok(Some(id))
372374
}
373375
if self.database.is_empty() {

turbopack/crates/turbo-tasks-memory/src/memory_backend.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ impl MemoryBackend {
7979
persistent_tasks: NoMoveVec::new(),
8080
transient_tasks: NoMoveVec::new(),
8181
backend_jobs: NoMoveVec::new(),
82-
backend_job_id_factory: IdFactoryWithReuse::new(1, u32::MAX as u64),
82+
backend_job_id_factory: IdFactoryWithReuse::new(BackendJobId::MIN, BackendJobId::MAX),
8383
task_cache: DashMap::with_hasher_and_shard_amount(Default::default(), shard_amount),
8484
transient_task_cache: DashMap::with_hasher_and_shard_amount(
8585
Default::default(),

turbopack/crates/turbo-tasks-testing/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ impl VcStorage {
5656
})));
5757
i
5858
};
59-
let task_id = TaskId::from(i as u32 + 1);
59+
let task_id = TaskId::try_from(u32::try_from(i).unwrap() + 1).unwrap();
6060
handle.spawn(with_turbo_tasks_for_testing(
6161
this.clone(),
6262
task_id,
@@ -321,7 +321,7 @@ impl VcStorage {
321321
this: weak.clone(),
322322
..Default::default()
323323
}),
324-
TaskId::from(u32::MAX),
324+
TaskId::MAX,
325325
f,
326326
)
327327
}

turbopack/crates/turbo-tasks/src/id.rs

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ macro_rules! define_id {
2929
}
3030

3131
impl $name {
32+
pub const MIN: Self = Self { id: NonZero::<$primitive>::MIN };
33+
pub const MAX: Self = Self { id: NonZero::<$primitive>::MAX };
34+
3235
/// Constructs a wrapper type from the numeric identifier.
3336
///
3437
/// # Safety
@@ -37,6 +40,15 @@ macro_rules! define_id {
3740
pub const unsafe fn new_unchecked(id: $primitive) -> Self {
3841
Self { id: unsafe { NonZero::<$primitive>::new_unchecked(id) } }
3942
}
43+
44+
/// Allows `const` conversion to a [`NonZeroU64`], useful with
45+
/// [`crate::id_factory::IdFactory::new_const`].
46+
pub const fn to_non_zero_u64(self) -> NonZeroU64 {
47+
const {
48+
assert!(<$primitive>::BITS <= u64::BITS);
49+
}
50+
unsafe { NonZeroU64::new_unchecked(self.id.get() as u64) }
51+
}
4052
}
4153

4254
impl Display for $name {
@@ -53,30 +65,51 @@ macro_rules! define_id {
5365
}
5466
}
5567

56-
/// Converts a numeric identifier to the wrapper type.
57-
///
58-
/// Panics if the given id value is zero.
59-
impl From<$primitive> for $name {
60-
fn from(id: $primitive) -> Self {
68+
define_id!(@impl_try_from_primitive_conversion $name $primitive);
69+
70+
impl From<NonZero<$primitive>> for $name {
71+
fn from(id: NonZero::<$primitive>) -> Self {
6172
Self {
62-
id: NonZero::<$primitive>::new(id)
63-
.expect("Ids can only be created from non zero values")
73+
id,
6474
}
6575
}
6676
}
6777

68-
/// Converts a numeric identifier to the wrapper type.
78+
impl From<$name> for NonZeroU64 {
79+
fn from(id: $name) -> Self {
80+
id.to_non_zero_u64()
81+
}
82+
}
83+
84+
impl TraceRawVcs for $name {
85+
fn trace_raw_vcs(&self, _trace_context: &mut TraceRawVcsContext) {}
86+
}
87+
};
88+
(
89+
@impl_try_from_primitive_conversion $name:ident u64
90+
) => {
91+
// we get a `TryFrom` blanket impl for free via the `From` impl
92+
};
93+
(
94+
@impl_try_from_primitive_conversion $name:ident $primitive:ty
95+
) => {
96+
impl TryFrom<$primitive> for $name {
97+
type Error = TryFromIntError;
98+
99+
fn try_from(id: $primitive) -> Result<Self, Self::Error> {
100+
Ok(Self {
101+
id: NonZero::try_from(id)?
102+
})
103+
}
104+
}
105+
69106
impl TryFrom<NonZeroU64> for $name {
70107
type Error = TryFromIntError;
71108

72109
fn try_from(id: NonZeroU64) -> Result<Self, Self::Error> {
73110
Ok(Self { id: NonZero::try_from(id)? })
74111
}
75112
}
76-
77-
impl TraceRawVcs for $name {
78-
fn trace_raw_vcs(&self, _trace_context: &mut TraceRawVcsContext) {}
79-
}
80113
};
81114
}
82115

turbopack/crates/turbo-tasks/src/id_factory.rs

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,36 @@ use concurrent_queue::ConcurrentQueue;
1111
///
1212
/// For ids that may be re-used, see [`IdFactoryWithReuse`].
1313
pub struct IdFactory<T> {
14-
next_id: AtomicU64,
15-
max_id: u64,
14+
/// A value starting at 0 and incremented each time a new id is allocated. Regardless of the
15+
/// underlying type, a u64 is used to cheaply detect overflows.
16+
counter: AtomicU64,
17+
/// We've overflowed if the `counter > max_count`.
18+
max_count: u64,
19+
id_offset: u64, // added to the value received from `counter`
1620
_phantom_data: PhantomData<T>,
1721
}
1822

1923
impl<T> IdFactory<T> {
20-
pub const fn new(start: u64, max: u64) -> Self {
24+
/// Create a factory for ids in the range `start..=max`.
25+
pub fn new(start: T, max: T) -> Self
26+
where
27+
T: Into<NonZeroU64> + Ord,
28+
{
29+
Self::new_const(start.into(), max.into())
30+
}
31+
32+
/// Create a factory for ids in the range `start..=max`.
33+
///
34+
/// Provides a less convenient API than [`IdFactory::new`], but skips a type conversion that
35+
/// would make the function non-const.
36+
pub const fn new_const(start: NonZeroU64, max: NonZeroU64) -> Self {
37+
assert!(start.get() < max.get());
2138
Self {
22-
next_id: AtomicU64::new(start),
23-
max_id: max,
39+
// Always start `counter` at 0, don't use the value of `start` because `start` could be
40+
// close to `u64::MAX`.
41+
counter: AtomicU64::new(0),
42+
max_count: max.get() - start.get(),
43+
id_offset: start.get(),
2444
_phantom_data: PhantomData,
2545
}
2646
}
@@ -32,47 +52,80 @@ where
3252
{
3353
/// Return a unique new id.
3454
///
35-
/// Panics (best-effort) if the id type overflows.
55+
/// Panics if the id type overflows.
3656
pub fn get(&self) -> T {
37-
let new_id = self.next_id.fetch_add(1, Ordering::Relaxed);
57+
let count = self.counter.fetch_add(1, Ordering::Relaxed);
58+
59+
#[cfg(debug_assertions)]
60+
{
61+
if count == u64::MAX {
62+
// u64 counter is about to overflow -- this should never happen! A `u64` counter
63+
// starting at 0 should take decades to overflow on a single machine.
64+
//
65+
// This is unrecoverable because other threads may have already read the overflowed
66+
// value, so abort the entire process.
67+
std::process::abort()
68+
}
69+
}
3870

39-
if new_id > self.max_id {
71+
// `max_count` might be something like `u32::MAX`. The extra bits of `u64` are useful to
72+
// detect overflows in that case. We assume the u64 counter is large enough to never
73+
// overflow.
74+
if count > self.max_count {
4075
panic!(
41-
"Max id limit hit while attempting to generate a unique {}",
76+
"Max id limit (overflow) hit while attempting to generate a unique {}",
4277
type_name::<T>(),
4378
)
4479
}
4580

46-
// Safety: u64 will not overflow. This is *very* unlikely to happen (would take
47-
// decades).
48-
let new_id = unsafe { NonZeroU64::new_unchecked(new_id) };
81+
let new_id_u64 = count + self.id_offset;
82+
// Safety:
83+
// - `count` is assumed not to overflow.
84+
// - `id_offset` is a non-zero value.
85+
// - `id_offset + count < u64::MAX`.
86+
let new_id = unsafe { NonZeroU64::new_unchecked(new_id_u64) };
4987

50-
// Use the extra bits of the AtomicU64 as cheap overflow detection when the
51-
// value is less than 64 bits.
5288
match new_id.try_into() {
5389
Ok(id) => id,
90+
// With any sane implementation of `TryFrom`, this shouldn't happen, as we've already
91+
// checked the `max_count` bound. (Could happen with the `new_const` constructor)
5492
Err(_) => panic!(
55-
"Overflow detected while attempting to generate a unique {}",
56-
type_name::<T>(),
93+
"Failed to convert NonZeroU64 value of {} into {}",
94+
new_id,
95+
type_name::<T>()
5796
),
5897
}
5998
}
6099
}
61100

62-
/// An [`IdFactory`], but extended with a free list to allow for id reuse for
63-
/// ids such as [`BackendJobId`][crate::backend::BackendJobId].
101+
/// An [`IdFactory`], but extended with a free list to allow for id reuse for ids such as
102+
/// [`BackendJobId`][crate::backend::BackendJobId].
64103
pub struct IdFactoryWithReuse<T> {
65104
factory: IdFactory<T>,
66105
free_ids: ConcurrentQueue<T>,
67106
}
68107

69-
impl<T> IdFactoryWithReuse<T> {
70-
pub const fn new(start: u64, max: u64) -> Self {
108+
impl<T> IdFactoryWithReuse<T>
109+
where
110+
T: Into<NonZeroU64> + Ord,
111+
{
112+
/// Create a factory for ids in the range `start..=max`.
113+
pub fn new(start: T, max: T) -> Self {
71114
Self {
72115
factory: IdFactory::new(start, max),
73116
free_ids: ConcurrentQueue::unbounded(),
74117
}
75118
}
119+
120+
/// Create a factory for ids in the range `start..=max`. Provides a less convenient API than
121+
/// [`IdFactoryWithReuse::new`], but skips a type conversion that would make the function
122+
/// non-const.
123+
pub const fn new_const(start: NonZeroU64, max: NonZeroU64) -> Self {
124+
Self {
125+
factory: IdFactory::new_const(start, max),
126+
free_ids: ConcurrentQueue::unbounded(),
127+
}
128+
}
76129
}
77130

78131
impl<T> IdFactoryWithReuse<T>
@@ -81,18 +134,18 @@ where
81134
{
82135
/// Return a new or potentially reused id.
83136
///
84-
/// Panics (best-effort) if the id type overflows.
137+
/// Panics if the id type overflows.
85138
pub fn get(&self) -> T {
86139
self.free_ids.pop().unwrap_or_else(|_| self.factory.get())
87140
}
88141

89-
/// Add an id to the free list, allowing it to be re-used on a subsequent
90-
/// call to [`IdFactoryWithReuse::get`].
142+
/// Add an id to the free list, allowing it to be re-used on a subsequent call to
143+
/// [`IdFactoryWithReuse::get`].
91144
///
92145
/// # Safety
93146
///
94-
/// It must be ensured that the id is no longer used. Id must be a valid id
95-
/// that was previously returned by `get`.
147+
/// The id must no longer be used. Must be a valid id that was previously returned by
148+
/// [`IdFactoryWithReuse::get`].
96149
pub unsafe fn reuse(&self, id: T) {
97150
let _ = self.free_ids.push(id);
98151
}
@@ -105,12 +158,21 @@ mod tests {
105158
use super::*;
106159

107160
#[test]
108-
#[should_panic(expected = "Overflow detected")]
109-
fn test_overflow() {
110-
let factory = IdFactory::<NonZeroU8>::new(1, u16::MAX as u64);
161+
#[should_panic(expected = "Max id limit (overflow)")]
162+
fn test_overflow_detection() {
163+
let factory = IdFactory::new(NonZeroU8::MIN, NonZeroU8::MAX);
111164
assert_eq!(factory.get(), NonZeroU8::new(1).unwrap());
112165
assert_eq!(factory.get(), NonZeroU8::new(2).unwrap());
113-
for _i in 2..256 {
166+
for _ in 2..256 {
167+
factory.get();
168+
}
169+
}
170+
171+
#[test]
172+
#[should_panic(expected = "Max id limit (overflow)")]
173+
fn test_overflow_detection_near_u64_max() {
174+
let factory = IdFactory::new(NonZeroU64::try_from(u64::MAX - 5).unwrap(), NonZeroU64::MAX);
175+
for _ in 0..=6 {
114176
factory.get();
115177
}
116178
}

turbopack/crates/turbo-tasks/src/manager.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,9 @@ impl CurrentTaskState {
419419

420420
fn create_local_task(&mut self, local_task: LocalTask) -> LocalTaskId {
421421
self.local_tasks.push(local_task);
422-
// generate a one-indexed id
422+
// generate a one-indexed id from len() -- we just pushed so len() is >= 1
423423
if cfg!(debug_assertions) {
424-
LocalTaskId::from(u32::try_from(self.local_tasks.len()).unwrap())
424+
LocalTaskId::try_from(u32::try_from(self.local_tasks.len()).unwrap()).unwrap()
425425
} else {
426426
unsafe { LocalTaskId::new_unchecked(self.local_tasks.len() as u32) }
427427
}
@@ -452,9 +452,12 @@ impl<B: Backend + 'static> TurboTasks<B> {
452452
// so we probably want to make sure that all tasks are joined
453453
// when trying to drop turbo tasks
454454
pub fn new(backend: B) -> Arc<Self> {
455-
let task_id_factory = IdFactoryWithReuse::new(1, (TRANSIENT_TASK_BIT - 1) as u64);
455+
let task_id_factory = IdFactoryWithReuse::new(
456+
TaskId::MIN,
457+
TaskId::try_from(TRANSIENT_TASK_BIT - 1).unwrap(),
458+
);
456459
let transient_task_id_factory =
457-
IdFactoryWithReuse::new(TRANSIENT_TASK_BIT as u64, u32::MAX as u64);
460+
IdFactoryWithReuse::new(TaskId::try_from(TRANSIENT_TASK_BIT).unwrap(), TaskId::MAX);
458461
let this = Arc::new_cyclic(|this| Self {
459462
this: this.clone(),
460463
backend,

0 commit comments

Comments
 (0)