Skip to content

Commit

Permalink
time: use sharding for timer implementation (#6534)
Browse files Browse the repository at this point in the history
  • Loading branch information
wathenjiang authored May 22, 2024
1 parent e62c3e9 commit 1914e1e
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 70 deletions.
9 changes: 5 additions & 4 deletions tokio/src/runtime/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ impl Builder {
}
}

fn get_cfg(&self) -> driver::Cfg {
fn get_cfg(&self, workers: usize) -> driver::Cfg {
driver::Cfg {
enable_pause_time: match self.kind {
Kind::CurrentThread => true,
Expand All @@ -715,6 +715,7 @@ impl Builder {
enable_time: self.enable_time,
start_paused: self.start_paused,
nevents: self.nevents,
workers,
}
}

Expand Down Expand Up @@ -1095,7 +1096,7 @@ impl Builder {
use crate::runtime::scheduler::{self, CurrentThread};
use crate::runtime::{runtime::Scheduler, Config};

let (driver, driver_handle) = driver::Driver::new(self.get_cfg())?;
let (driver, driver_handle) = driver::Driver::new(self.get_cfg(1))?;

// Blocking pool
let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads);
Expand Down Expand Up @@ -1248,7 +1249,7 @@ cfg_rt_multi_thread! {

let core_threads = self.worker_threads.unwrap_or_else(num_cpus);

let (driver, driver_handle) = driver::Driver::new(self.get_cfg())?;
let (driver, driver_handle) = driver::Driver::new(self.get_cfg(core_threads))?;

// Create the blocking pool
let blocking_pool =
Expand Down Expand Up @@ -1295,7 +1296,7 @@ cfg_rt_multi_thread! {
use crate::runtime::scheduler::MultiThreadAlt;

let core_threads = self.worker_threads.unwrap_or_else(num_cpus);
let (driver, driver_handle) = driver::Driver::new(self.get_cfg())?;
let (driver, driver_handle) = driver::Driver::new(self.get_cfg(core_threads))?;

// Create the blocking pool
let blocking_pool =
Expand Down
12 changes: 8 additions & 4 deletions tokio/src/runtime/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::runtime::coop;

use std::cell::Cell;

#[cfg(any(feature = "rt", feature = "macros"))]
#[cfg(any(feature = "rt", feature = "macros", feature = "time"))]
use crate::util::rand::FastRand;

cfg_rt! {
Expand Down Expand Up @@ -57,7 +57,7 @@ struct Context {
#[cfg(feature = "rt")]
runtime: Cell<EnterRuntime>,

#[cfg(any(feature = "rt", feature = "macros"))]
#[cfg(any(feature = "rt", feature = "macros", feature = "time"))]
rng: Cell<Option<FastRand>>,

/// Tracks the amount of "work" a task may still do before yielding back to
Expand Down Expand Up @@ -100,7 +100,7 @@ tokio_thread_local! {
#[cfg(feature = "rt")]
runtime: Cell::new(EnterRuntime::NotEntered),

#[cfg(any(feature = "rt", feature = "macros"))]
#[cfg(any(feature = "rt", feature = "macros", feature = "time"))]
rng: Cell::new(None),

budget: Cell::new(coop::Budget::unconstrained()),
Expand All @@ -121,7 +121,11 @@ tokio_thread_local! {
}
}

#[cfg(any(feature = "macros", all(feature = "sync", feature = "rt")))]
#[cfg(any(
feature = "time",
feature = "macros",
all(feature = "sync", feature = "rt")
))]
pub(crate) fn thread_rng_n(n: u32) -> u32 {
CONTEXT.with(|ctx| {
let mut rng = ctx.rng.get().unwrap_or_else(FastRand::new);
Expand Down
8 changes: 6 additions & 2 deletions tokio/src/runtime/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub(crate) struct Cfg {
pub(crate) enable_pause_time: bool,
pub(crate) start_paused: bool,
pub(crate) nevents: usize,
pub(crate) workers: usize,
}

impl Driver {
Expand All @@ -48,7 +49,8 @@ impl Driver {

let clock = create_clock(cfg.enable_pause_time, cfg.start_paused);

let (time_driver, time_handle) = create_time_driver(cfg.enable_time, io_stack, &clock);
let (time_driver, time_handle) =
create_time_driver(cfg.enable_time, io_stack, &clock, cfg.workers);

Ok((
Self { inner: time_driver },
Expand Down Expand Up @@ -306,9 +308,10 @@ cfg_time! {
enable: bool,
io_stack: IoStack,
clock: &Clock,
workers: usize,
) -> (TimeDriver, TimeHandle) {
if enable {
let (driver, handle) = crate::runtime::time::Driver::new(io_stack, clock);
let (driver, handle) = crate::runtime::time::Driver::new(io_stack, clock, workers as u32);

(TimeDriver::Enabled { driver }, Some(handle))
} else {
Expand Down Expand Up @@ -361,6 +364,7 @@ cfg_not_time! {
_enable: bool,
io_stack: IoStack,
_clock: &Clock,
_workers: usize,
) -> (TimeDriver, TimeHandle) {
(io_stack, ())
}
Expand Down
5 changes: 5 additions & 0 deletions tokio/src/runtime/scheduler/multi_thread/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ impl Context {
pub(crate) fn defer(&self, waker: &Waker) {
self.defer.defer(waker);
}

#[allow(dead_code)]
pub(crate) fn get_worker_index(&self) -> usize {
self.worker.index
}
}

impl Core {
Expand Down
5 changes: 5 additions & 0 deletions tokio/src/runtime/scheduler/multi_thread_alt/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,11 @@ impl Context {
fn shared(&self) -> &Shared {
&self.handle.shared
}

#[cfg_attr(not(feature = "time"), allow(dead_code))]
pub(crate) fn get_worker_index(&self) -> usize {
self.index
}
}

impl Core {
Expand Down
37 changes: 35 additions & 2 deletions tokio/src/runtime/time/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicU64;
use crate::loom::sync::atomic::Ordering;

use crate::runtime::context;
use crate::runtime::scheduler;
use crate::sync::AtomicWaker;
use crate::time::Instant;
Expand Down Expand Up @@ -328,6 +329,8 @@ pub(super) type EntryList = crate::util::linked_list::LinkedList<TimerShared, Ti
///
/// Note that this structure is located inside the `TimerEntry` structure.
pub(crate) struct TimerShared {
/// The shard id. We should never change it.
shard_id: u32,
/// A link within the doubly-linked list of timers on a particular level and
/// slot. Valid only if state is equal to Registered.
///
Expand Down Expand Up @@ -368,8 +371,9 @@ generate_addr_of_methods! {
}

impl TimerShared {
pub(super) fn new() -> Self {
pub(super) fn new(shard_id: u32) -> Self {
Self {
shard_id,
cached_when: AtomicU64::new(0),
pointers: linked_list::Pointers::new(),
state: StateCell::default(),
Expand Down Expand Up @@ -438,6 +442,11 @@ impl TimerShared {
pub(super) fn might_be_registered(&self) -> bool {
self.state.might_be_registered()
}

/// Gets the shard id.
pub(super) fn shard_id(&self) -> u32 {
self.shard_id
}
}

unsafe impl linked_list::Link for TimerShared {
Expand Down Expand Up @@ -485,8 +494,10 @@ impl TimerEntry {
fn inner(&self) -> &TimerShared {
let inner = unsafe { &*self.inner.get() };
if inner.is_none() {
let shard_size = self.driver.driver().time().inner.get_shard_size();
let shard_id = generate_shard_id(shard_size);
unsafe {
*self.inner.get() = Some(TimerShared::new());
*self.inner.get() = Some(TimerShared::new(shard_id));
}
}
return inner.as_ref().unwrap();
Expand Down Expand Up @@ -643,3 +654,25 @@ impl Drop for TimerEntry {
unsafe { Pin::new_unchecked(self) }.as_mut().cancel();
}
}

// Generates a shard id. If current thread is a worker thread, we use its worker index as a shard id.
// Otherwise, we use a random number generator to obtain the shard id.
cfg_rt! {
fn generate_shard_id(shard_size: u32) -> u32 {
let id = context::with_scheduler(|ctx| match ctx {
Some(scheduler::Context::CurrentThread(_ctx)) => 0,
#[cfg(feature = "rt-multi-thread")]
Some(scheduler::Context::MultiThread(ctx)) => ctx.get_worker_index() as u32,
#[cfg(all(tokio_unstable, feature = "rt-multi-thread"))]
Some(scheduler::Context::MultiThreadAlt(ctx)) => ctx.get_worker_index() as u32,
None => context::thread_rng_n(shard_size),
});
id % shard_size
}
}

cfg_not_rt! {
fn generate_shard_id(shard_size: u32) -> u32 {
context::thread_rng_n(shard_size)
}
}
Loading

0 comments on commit 1914e1e

Please sign in to comment.