Skip to content

Commit b9205aa

Browse files
authored
Add tag (awslabs#98)
* Add tag * Tag: i64 -> u64 * Address comments from @jorajeev and @kraglb * Add spawn_and_join test * Add NOTE to get_current_task * Address comments from @jorajeev * Address Clippy too_many_arguments lint * Address comments from @jorajeev
1 parent 259bb95 commit b9205aa

File tree

5 files changed

+218
-5
lines changed

5 files changed

+218
-5
lines changed

src/current.rs

+18-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
88
use crate::runtime::execution::ExecutionState;
99
use crate::runtime::task::clock::VectorClock;
10-
use crate::runtime::task::TaskId;
10+
pub use crate::runtime::task::{Tag, TaskId};
1111

1212
/// The number of context switches that happened so far in the current Shuttle execution.
1313
///
@@ -22,7 +22,7 @@ pub fn context_switches() -> usize {
2222

2323
/// Get the current thread's vector clock
2424
pub fn clock() -> VectorClock {
25-
crate::runtime::execution::ExecutionState::with(|state| {
25+
ExecutionState::with(|state| {
2626
let me = state.current();
2727
state.get_clock(me.id()).clone()
2828
})
@@ -32,3 +32,19 @@ pub fn clock() -> VectorClock {
3232
pub fn clock_for(task_id: TaskId) -> VectorClock {
3333
ExecutionState::with(|state| state.get_clock(task_id).clone())
3434
}
35+
36+
/// Sets the `tag` field of the current task.
37+
/// Returns the `tag` which was there previously.
38+
pub fn set_tag_for_current_task(tag: Tag) -> Tag {
39+
ExecutionState::set_tag_for_current_task(tag)
40+
}
41+
42+
/// Gets the `tag` field of the current task.
43+
pub fn get_tag_for_current_task() -> Tag {
44+
ExecutionState::get_tag_for_current_task()
45+
}
46+
47+
/// Gets the `TaskId` of the current task, or `None` if there is no current task.
48+
pub fn get_current_task() -> Option<TaskId> {
49+
ExecutionState::with(|s| Some(s.try_current()?.id()))
50+
}

src/runtime/execution.rs

+26-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::runtime::failure::{init_panic_hook, persist_failure, persist_task_failure};
22
use crate::runtime::storage::{StorageKey, StorageMap};
33
use crate::runtime::task::clock::VectorClock;
4+
use crate::runtime::task::Tag;
45
use crate::runtime::task::{Task, TaskId, DEFAULT_INLINE_TASKS};
56
use crate::runtime::thread::continuation::PooledContinuation;
67
use crate::scheduler::{Schedule, Scheduler};
@@ -283,6 +284,7 @@ impl ExecutionState {
283284
.map(|t| t.span.clone())
284285
.unwrap_or_else(tracing::Span::current);
285286
let task_id = TaskId(state.tasks.len());
287+
let tag = state.get_tag_or_default_for_current_task();
286288
let clock = state.increment_clock_mut(); // Increment the parent's clock
287289
clock.extend(task_id); // and extend it with an entry for the new task
288290

@@ -294,6 +296,7 @@ impl ExecutionState {
294296
clock.clone(),
295297
parent_span,
296298
schedule_len,
299+
tag,
297300
);
298301

299302
state.tasks.push(task);
@@ -312,7 +315,7 @@ impl ExecutionState {
312315
{
313316
Self::with(|state| {
314317
let task_id = TaskId(state.tasks.len());
315-
318+
let tag = state.get_tag_or_default_for_current_task();
316319
let clock = if let Some(ref mut clock) = initial_clock {
317320
clock
318321
} else {
@@ -328,7 +331,7 @@ impl ExecutionState {
328331
.try_current()
329332
.map(|t| t.span.clone())
330333
.unwrap_or_else(tracing::Span::current);
331-
let task = Task::from_closure(f, stack_size, task_id, name, clock, parent_span, schedule_len);
334+
let task = Task::from_closure(f, stack_size, task_id, name, clock, parent_span, schedule_len, tag);
332335
state.tasks.push(task);
333336
task_id
334337
})
@@ -605,6 +608,27 @@ impl ExecutionState {
605608
}
606609
});
607610
}
611+
612+
// Sets the `tag` field of the current task.
613+
// Returns the `tag` which was there previously.
614+
fn set_tag_for_current_task_internal(&mut self, tag: Tag) -> Tag {
615+
self.current_mut().set_tag(tag)
616+
}
617+
618+
pub(crate) fn set_tag_for_current_task(tag: Tag) -> Tag {
619+
ExecutionState::with(|s| s.set_tag_for_current_task_internal(tag))
620+
}
621+
622+
fn get_tag_or_default_for_current_task(&self) -> Tag {
623+
match self.try_current() {
624+
Some(current) => current.get_tag(),
625+
None => Tag::default(),
626+
}
627+
}
628+
629+
pub(crate) fn get_tag_for_current_task() -> Tag {
630+
ExecutionState::with(|s| s.get_tag_or_default_for_current_task())
631+
}
608632
}
609633

610634
#[cfg(debug_assertions)]

src/runtime/task/mod.rs

+38-1
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,14 @@ pub(crate) struct Task {
5959

6060
local_storage: StorageMap,
6161
pub(super) span: tracing::Span,
62+
63+
// Arbitrarily settable tag which is inherited from the parent.
64+
tag: Tag,
6265
}
6366

6467
impl Task {
6568
/// Create a task from a continuation
69+
#[allow(clippy::too_many_arguments)]
6670
fn new<F>(
6771
f: F,
6872
stack_size: usize,
@@ -71,6 +75,7 @@ impl Task {
7175
clock: VectorClock,
7276
parent_span: tracing::Span,
7377
schedule_len: usize,
78+
tag: Tag,
7479
) -> Self
7580
where
7681
F: FnOnce() + Send + 'static,
@@ -100,9 +105,11 @@ impl Task {
100105
name,
101106
span,
102107
local_storage: StorageMap::new(),
108+
tag,
103109
}
104110
}
105111

112+
#[allow(clippy::too_many_arguments)]
106113
pub(crate) fn from_closure<F>(
107114
f: F,
108115
stack_size: usize,
@@ -111,13 +118,15 @@ impl Task {
111118
clock: VectorClock,
112119
parent_span: tracing::Span,
113120
schedule_len: usize,
121+
tag: Tag,
114122
) -> Self
115123
where
116124
F: FnOnce() + Send + 'static,
117125
{
118-
Self::new(f, stack_size, id, name, clock, parent_span, schedule_len)
126+
Self::new(f, stack_size, id, name, clock, parent_span, schedule_len, tag)
119127
}
120128

129+
#[allow(clippy::too_many_arguments)]
121130
pub(crate) fn from_future<F>(
122131
future: F,
123132
stack_size: usize,
@@ -126,6 +135,7 @@ impl Task {
126135
clock: VectorClock,
127136
parent_span: tracing::Span,
128137
schedule_len: usize,
138+
tag: Tag,
129139
) -> Self
130140
where
131141
F: Future<Output = ()> + Send + 'static,
@@ -147,6 +157,7 @@ impl Task {
147157
clock,
148158
parent_span,
149159
schedule_len,
160+
tag,
150161
)
151162
}
152163

@@ -335,6 +346,32 @@ impl Task {
335346
self.park_state.token_available = true;
336347
}
337348
}
349+
350+
pub(crate) fn get_tag(&self) -> Tag {
351+
self.tag
352+
}
353+
354+
/// Sets the `tag` field of the current task.
355+
/// Returns the `tag` which was there previously.
356+
pub(crate) fn set_tag(&mut self, tag: Tag) -> Tag {
357+
std::mem::replace(&mut self.tag, tag)
358+
}
359+
}
360+
361+
/// A `Tag` is an arbitrarily settable value for each task.
362+
#[derive(PartialEq, Eq, Clone, Copy, Debug, Default, Hash, PartialOrd, Ord)]
363+
pub struct Tag(u64);
364+
365+
impl From<u64> for Tag {
366+
fn from(tag: u64) -> Self {
367+
Tag(tag)
368+
}
369+
}
370+
371+
impl From<Tag> for u64 {
372+
fn from(tag: Tag) -> u64 {
373+
tag.0
374+
}
338375
}
339376

340377
#[derive(PartialEq, Eq, Clone, Copy, Debug)]

tests/basic/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ mod portfolio;
1616
mod replay;
1717
mod rwlock;
1818
mod shrink;
19+
mod tag;
1920
mod thread;
2021
mod timeout;
2122
mod uncontrolled_nondeterminism;

tests/basic/tag.rs

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
use futures::future::join_all;
2+
use shuttle::{
3+
check_dfs, check_random,
4+
current::{get_tag_for_current_task, set_tag_for_current_task, Tag},
5+
future::block_on,
6+
thread,
7+
thread::JoinHandle,
8+
};
9+
use test_log::test;
10+
11+
fn spawn_some_futures_and_set_tag<F: (Fn(Tag, u64) -> Tag) + Send + Sync>(
12+
tag_on_entry: Tag,
13+
f: &'static F,
14+
num_threads: u64,
15+
) {
16+
let threads: Vec<_> = (0..num_threads)
17+
.map(|i| {
18+
shuttle::future::spawn(async move {
19+
assert_eq!(get_tag_for_current_task(), tag_on_entry);
20+
let new_tag = f(tag_on_entry, i);
21+
set_tag_for_current_task(new_tag);
22+
assert_eq!(get_tag_for_current_task(), new_tag);
23+
})
24+
})
25+
.collect();
26+
27+
block_on(join_all(threads));
28+
29+
assert_eq!(get_tag_for_current_task(), tag_on_entry);
30+
}
31+
32+
fn spawn_some_threads_and_set_tag<F: (Fn(Tag, u64) -> Tag) + Send + Sync>(
33+
tag_on_entry: Tag,
34+
f: &'static F,
35+
num_threads: u64,
36+
) {
37+
let threads: Vec<_> = (0..num_threads)
38+
.map(|i| {
39+
thread::spawn(move || {
40+
assert_eq!(get_tag_for_current_task(), tag_on_entry);
41+
let new_tag = f(tag_on_entry, i);
42+
set_tag_for_current_task(new_tag);
43+
assert_eq!(get_tag_for_current_task(), new_tag);
44+
})
45+
})
46+
.collect();
47+
48+
threads.into_iter().for_each(|t| t.join().expect("Failed"));
49+
50+
assert_eq!(get_tag_for_current_task(), tag_on_entry);
51+
}
52+
53+
fn spawn_threads_which_spawn_more_threads(
54+
tag_on_entry: Tag,
55+
num_threads_first_block: u64,
56+
num_threads_second_block: u64,
57+
) {
58+
let mut threads: Vec<_> = (0..num_threads_first_block)
59+
.map(|i| {
60+
thread::spawn(move || {
61+
assert_eq!(get_tag_for_current_task(), tag_on_entry);
62+
let new_tag = i.into();
63+
set_tag_for_current_task(new_tag);
64+
assert_eq!(get_tag_for_current_task(), new_tag);
65+
spawn_some_threads_and_set_tag(new_tag, &|_, _| 123.into(), 13);
66+
assert_eq!(get_tag_for_current_task(), new_tag);
67+
spawn_some_threads_and_set_tag(new_tag, &|_, x| (x * 13).into(), 7);
68+
assert_eq!(get_tag_for_current_task(), new_tag);
69+
spawn_some_threads_and_set_tag(new_tag, &|p, x| ((u64::from(p) << 4) + x).into(), 19);
70+
assert_eq!(get_tag_for_current_task(), new_tag);
71+
spawn_some_futures_and_set_tag(new_tag, &|p, x| ((u64::from(p) << 4) & x).into(), 17);
72+
assert_eq!(get_tag_for_current_task(), new_tag);
73+
})
74+
})
75+
.collect();
76+
77+
assert_eq!(get_tag_for_current_task(), tag_on_entry);
78+
79+
let new_tag_main_thread = 987654321.into();
80+
set_tag_for_current_task(new_tag_main_thread);
81+
assert_eq!(get_tag_for_current_task(), new_tag_main_thread);
82+
83+
threads.extend(
84+
(0..num_threads_second_block)
85+
.map(|i| {
86+
thread::spawn(move || {
87+
assert_eq!(get_tag_for_current_task(), new_tag_main_thread);
88+
let new_tag = i.into();
89+
set_tag_for_current_task(new_tag);
90+
assert_eq!(get_tag_for_current_task(), new_tag);
91+
spawn_some_threads_and_set_tag(new_tag, &|_, _| 123.into(), 13);
92+
assert_eq!(get_tag_for_current_task(), new_tag);
93+
spawn_some_threads_and_set_tag(new_tag, &|_, x| (x * 13).into(), 7);
94+
assert_eq!(get_tag_for_current_task(), new_tag);
95+
spawn_some_threads_and_set_tag(new_tag, &|p, x| ((u64::from(p) << 4) + x).into(), 19);
96+
assert_eq!(get_tag_for_current_task(), new_tag);
97+
spawn_some_futures_and_set_tag(new_tag, &|p, x| ((u64::from(p) << 4) & x).into(), 17);
98+
assert_eq!(get_tag_for_current_task(), new_tag);
99+
})
100+
})
101+
.collect::<Vec<_>>(),
102+
);
103+
104+
threads.into_iter().for_each(|t| t.join().expect("Failed"));
105+
106+
assert_eq!(get_tag_for_current_task(), new_tag_main_thread);
107+
}
108+
109+
#[test]
110+
fn threads_which_spawn_threads_which_spawn_threads() {
111+
check_random(|| spawn_threads_which_spawn_more_threads(Tag::default(), 3, 2), 10)
112+
}
113+
114+
fn spawn_thread_and_set_tag(tag_on_entry: Tag, new_tag: Tag) -> JoinHandle<u64> {
115+
thread::spawn(move || {
116+
assert_eq!(get_tag_for_current_task(), tag_on_entry);
117+
assert_eq!(set_tag_for_current_task(new_tag), tag_on_entry); // NOTE: Assertion with side effect
118+
assert_eq!(get_tag_for_current_task(), new_tag);
119+
new_tag.into()
120+
})
121+
}
122+
123+
fn spawn_and_join() {
124+
set_tag_for_current_task(42.into());
125+
let h1 = spawn_thread_and_set_tag(42.into(), 84.into());
126+
set_tag_for_current_task(50.into());
127+
let h2 = spawn_thread_and_set_tag(50.into(), 100.into());
128+
let results = [h1.join().unwrap(), h2.join().unwrap()];
129+
assert_eq!(results, [84, 100]);
130+
}
131+
132+
#[test]
133+
fn test_spawn_and_join() {
134+
check_dfs(spawn_and_join, None);
135+
}

0 commit comments

Comments
 (0)