Skip to content

Commit

Permalink
tokio: reduce LLVM code generation (#5859)
Browse files Browse the repository at this point in the history
  • Loading branch information
dullbananas committed Jul 15, 2023
1 parent 91ad76c commit 304d140
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 58 deletions.
4 changes: 1 addition & 3 deletions tokio/src/runtime/scheduler/multi_thread/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ impl Handle {
{
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id);

if let Some(notified) = notified {
me.schedule_task(notified, false);
}
me.schedule_option_task_without_yield(notified);

handle
}
Expand Down
6 changes: 6 additions & 0 deletions tokio/src/runtime/scheduler/multi_thread/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,12 @@ impl Handle {
})
}

pub(super) fn schedule_option_task_without_yield(&self, task: Option<Notified>) {
if let Some(task) = task {
self.schedule_task(task, false);
}
}

fn schedule_local(&self, core: &mut Core, task: Notified, is_yield: bool) {
core.stats.inc_local_schedule_count();

Expand Down
71 changes: 50 additions & 21 deletions tokio/src/runtime/task/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,44 +211,66 @@ impl<T: Future, S: Schedule> Cell<T, S> {
/// Allocates a new task cell, containing the header, trailer, and core
/// structures.
pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> {
// Separated into a non-generic function to reduce LLVM codegen
fn new_header(
state: State,
vtable: &'static Vtable,
#[cfg(all(tokio_unstable, feature = "tracing"))] tracing_id: Option<tracing::Id>,
) -> Header {
Header {
state,
queue_next: UnsafeCell::new(None),
vtable,
owner_id: UnsafeCell::new(0),
#[cfg(all(tokio_unstable, feature = "tracing"))]
tracing_id,
}
}

#[cfg(all(tokio_unstable, feature = "tracing"))]
let tracing_id = future.id();
let vtable = raw::vtable::<T, S>();
let result = Box::new(Cell {
header: Header {
header: new_header(
state,
queue_next: UnsafeCell::new(None),
vtable: raw::vtable::<T, S>(),
owner_id: UnsafeCell::new(0),
vtable,
#[cfg(all(tokio_unstable, feature = "tracing"))]
tracing_id,
},
),
core: Core {
scheduler,
stage: CoreStage {
stage: UnsafeCell::new(Stage::Running(future)),
},
task_id,
},
trailer: Trailer {
waker: UnsafeCell::new(None),
owned: linked_list::Pointers::new(),
},
trailer: Trailer::new(),
});

#[cfg(debug_assertions)]
{
let trailer_addr = (&result.trailer) as *const Trailer as usize;
let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(&result.header)) };
assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);

let scheduler_addr = (&result.core.scheduler) as *const S as usize;
let scheduler_ptr =
unsafe { Header::get_scheduler::<S>(NonNull::from(&result.header)) };
assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);

let id_addr = (&result.core.task_id) as *const Id as usize;
let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(&result.header)) };
assert_eq!(id_addr, id_ptr.as_ptr() as usize);
// Using a separate function for this code avoids instantiating it separately for every `T`.
unsafe fn check<S>(header: &Header, trailer: &Trailer, scheduler: &S, task_id: &Id) {
let trailer_addr = trailer as *const Trailer as usize;
let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(header)) };
assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);

let scheduler_addr = scheduler as *const S as usize;
let scheduler_ptr = unsafe { Header::get_scheduler::<S>(NonNull::from(header)) };
assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);

let id_addr = task_id as *const Id as usize;
let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(header)) };
assert_eq!(id_addr, id_ptr.as_ptr() as usize);
}
unsafe {
check(
&result.header,
&result.trailer,
&result.core.scheduler,
&result.core.task_id,
);
}
}

result
Expand Down Expand Up @@ -442,6 +464,13 @@ impl Header {
}

impl Trailer {
fn new() -> Self {
Trailer {
waker: UnsafeCell::new(None),
owned: linked_list::Pointers::new(),
}
}

pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) {
self.waker.with_mut(|ptr| {
*ptr = waker;
Expand Down
58 changes: 37 additions & 21 deletions tokio/src/runtime/task/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use crate::future::Future;
use crate::runtime::task::core::{Cell, Core, Header, Trailer};
use crate::runtime::task::state::{Snapshot, State};
use crate::runtime::task::waker::waker_ref;
use crate::runtime::task::{JoinError, Notified, RawTask, Schedule, Task};
use crate::runtime::task::{Id, JoinError, Notified, RawTask, Schedule, Task};

use std::any::Any;
use std::mem;
use std::mem::ManuallyDrop;
use std::panic;
Expand Down Expand Up @@ -192,6 +193,15 @@ where

match self.state().transition_to_running() {
TransitionToRunning::Success => {
// Separated to reduce LLVM codegen
fn transition_result_to_poll_future(result: TransitionToIdle) -> PollFuture {
match result {
TransitionToIdle::Ok => PollFuture::Done,
TransitionToIdle::OkNotified => PollFuture::Notified,
TransitionToIdle::OkDealloc => PollFuture::Dealloc,
TransitionToIdle::Cancelled => PollFuture::Complete,
}
}
let header_ptr = self.header_ptr();
let waker_ref = waker_ref::<T, S>(&header_ptr);
let cx = Context::from_waker(&waker_ref);
Expand All @@ -202,17 +212,13 @@ where
return PollFuture::Complete;
}

match self.state().transition_to_idle() {
TransitionToIdle::Ok => PollFuture::Done,
TransitionToIdle::OkNotified => PollFuture::Notified,
TransitionToIdle::OkDealloc => PollFuture::Dealloc,
TransitionToIdle::Cancelled => {
// The transition to idle failed because the task was
// cancelled during the poll.
cancel_task(self.core());
PollFuture::Complete
}
let transition_res = self.state().transition_to_idle();
if let TransitionToIdle::Cancelled = transition_res {
// The transition to idle failed because the task was
// cancelled during the poll.
cancel_task(self.core());
}
transition_result_to_poll_future(transition_res)
}
TransitionToRunning::Cancelled => {
cancel_task(self.core());
Expand Down Expand Up @@ -447,13 +453,16 @@ fn cancel_task<T: Future, S: Schedule>(core: &Core<T, S>) {
core.drop_future_or_output();
}));

core.store_output(Err(panic_result_to_join_error(core.task_id, res)));
}

fn panic_result_to_join_error(
task_id: Id,
res: Result<(), Box<dyn Any + Send + 'static>>,
) -> JoinError {
match res {
Ok(()) => {
core.store_output(Err(JoinError::cancelled(core.task_id)));
}
Err(panic) => {
core.store_output(Err(JoinError::panic(core.task_id, panic)));
}
Ok(()) => JoinError::cancelled(task_id),
Err(panic) => JoinError::panic(task_id, panic),
}
}

Expand Down Expand Up @@ -482,10 +491,7 @@ fn poll_future<T: Future, S: Schedule>(core: &Core<T, S>, cx: Context<'_>) -> Po
let output = match output {
Ok(Poll::Pending) => return Poll::Pending,
Ok(Poll::Ready(output)) => Ok(output),
Err(panic) => {
core.scheduler.unhandled_panic();
Err(JoinError::panic(core.task_id, panic))
}
Err(panic) => Err(panic_to_error(&core.scheduler, core.task_id, panic)),
};

// Catch and ignore panics if the future panics on drop.
Expand All @@ -499,3 +505,13 @@ fn poll_future<T: Future, S: Schedule>(core: &Core<T, S>, cx: Context<'_>) -> Po

Poll::Ready(())
}

#[cold]
fn panic_to_error<S: Schedule>(
scheduler: &S,
task_id: Id,
panic: Box<dyn Any + Send + 'static>,
) -> JoinError {
scheduler.unhandled_panic();
JoinError::panic(task_id, panic)
}
12 changes: 10 additions & 2 deletions tokio/src/runtime/task/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@ impl<S: 'static> OwnedTasks<S> {
T::Output: Send + 'static,
{
let (task, notified, join) = super::new_task(task, scheduler, id);
let notified = unsafe { self.bind_inner(task, notified) };
(join, notified)
}

/// The part of `bind` that's the same for every type of future.
unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
where
S: Schedule,
{
unsafe {
// safety: We just created the task, so we have exclusive access
// to the field.
Expand All @@ -108,10 +116,10 @@ impl<S: 'static> OwnedTasks<S> {
drop(lock);
drop(notified);
task.shutdown();
(join, None)
None
} else {
lock.list.push_front(task);
(join, Some(notified))
Some(notified)
}
}

Expand Down
26 changes: 15 additions & 11 deletions tokio/src/util/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ cfg_trace! {
#[inline]
#[track_caller]
pub(crate) fn task<F>(task: F, kind: &'static str, name: Option<&str>, id: u64) -> Instrumented<F> {
#[track_caller]
fn get_span(kind: &'static str, name: Option<&str>, id: u64) -> tracing::Span {
let location = std::panic::Location::caller();
tracing::trace_span!(
target: "tokio::task",
"runtime.spawn",
%kind,
task.name = %name.unwrap_or_default(),
task.id = id,
loc.file = location.file(),
loc.line = location.line(),
loc.col = location.column(),
)
}
use tracing::instrument::Instrument;
let location = std::panic::Location::caller();
let span = tracing::trace_span!(
target: "tokio::task",
"runtime.spawn",
%kind,
task.name = %name.unwrap_or_default(),
task.id = id,
loc.file = location.file(),
loc.line = location.line(),
loc.col = location.column(),
);
let span = get_span(kind, name, id);
task.instrument(span)
}

Expand Down

0 comments on commit 304d140

Please sign in to comment.