Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rt: do not trace tasks while locking OwnedTasks #6036

Merged
merged 9 commits into from
Oct 6, 2023
8 changes: 8 additions & 0 deletions tokio/src/runtime/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,14 @@ cfg_taskdump! {
scheduler::Handle::MultiThreadAlt(_) => panic!("task dump not implemented for this runtime flavor"),
}
}

/// Produces `true` if the current task is being traced for a dump;
/// otherwise false. This function is only public for integration
/// testing purposes. Do not rely on it.
#[doc(hidden)]
pub fn is_tracing() -> bool {
super::task::trace::Context::is_tracing()
}
}

cfg_rt_multi_thread! {
Expand Down
18 changes: 18 additions & 0 deletions tokio/src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,14 @@ impl<S: 'static> Task<S> {
fn header_ptr(&self) -> NonNull<Header> {
self.raw.header_ptr()
}

cfg_taskdump! {
pub(super) fn notify_for_tracing(&self) -> Notified<S> {
self.as_raw().state().transition_to_notified_for_tracing();
// SAFETY: `transition_to_notified_for_tracing` increments the refcount.
unsafe { Notified(Task::new(self.raw)) }
}
}
}

impl<S: 'static> Notified<S> {
Expand Down Expand Up @@ -444,6 +452,16 @@ impl<S: Schedule> UnownedTask<S> {
}
}

impl<S: 'static> Clone for Task<S> {
fn clone(&self) -> Task<S> {
// SAFETY: We increment the ref count.
unsafe {
self.raw.ref_inc();
Task::new(self.raw)
}
}
}
jswrenn marked this conversation as resolved.
Show resolved Hide resolved

impl<S: 'static> Drop for Task<S> {
fn drop(&mut self) {
// Decrement the ref count
Expand Down
3 changes: 2 additions & 1 deletion tokio/src/runtime/task/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ impl RawTask {

/// Increment the task's reference count.
///
/// Currently, this is used only when creating an `AbortHandle`.
/// Currently, this is used only when creating an `AbortHandle`,
/// and when cloning a `Task`.
jswrenn marked this conversation as resolved.
Show resolved Hide resolved
pub(super) fn ref_inc(self) {
self.header().state.ref_inc();
}
Expand Down
73 changes: 41 additions & 32 deletions tokio/src/runtime/task/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mod tree;
use symbol::Symbol;
use tree::Tree;

use super::{Notified, OwnedTasks};
use super::{Notified, OwnedTasks, Schedule};

type Backtrace = Vec<BacktraceFrame>;
type SymbolTrace = Vec<Symbol>;
Expand Down Expand Up @@ -100,6 +100,16 @@ impl Context {
Self::try_with_current(|context| f(&context.collector)).expect(FAIL_NO_THREAD_LOCAL)
}
}

/// Produces `true` if the current task is being traced; otherwise false.
pub(crate) fn is_tracing() -> bool {
Self::with_current_collector(|maybe_collector| {
let collector = maybe_collector.take();
let result = collector.is_some();
maybe_collector.set(collector);
result
})
}
}

impl Trace {
Expand Down Expand Up @@ -268,22 +278,8 @@ pub(in crate::runtime) fn trace_current_thread(
drop(task);
}

// notify each task
let mut tasks = vec![];
owned.for_each(|task| {
// set the notified bit
task.as_raw().state().transition_to_notified_for_tracing();
// store the raw tasks into a vec
tasks.push(task.as_raw());
});

tasks
.into_iter()
.map(|task| {
let ((), trace) = Trace::capture(|| task.poll());
trace
})
.collect()
// precondition: We have drained the tasks from the injection queue.
trace_owned(owned)
}

cfg_rt_multi_thread! {
Expand Down Expand Up @@ -316,21 +312,34 @@ cfg_rt_multi_thread! {

drop(synced);

// notify each task
let mut traces = vec![];
owned.for_each(|task| {
// set the notified bit
task.as_raw().state().transition_to_notified_for_tracing();

// trace the task
let ((), trace) = Trace::capture(|| task.as_raw().poll());
traces.push(trace);
// precondition: we have drained the tasks from the local and injection
// queues.
trace_owned(owned)
}
}

// reschedule the task
let _ = task.as_raw().state().transition_to_notified_by_ref();
task.as_raw().schedule();
});
/// Trace the `OwnedTasks`.
///
/// # Preconditions
///
/// This helper presumes exclusive access to each task. The tasks must not exist
/// in any other queue.
fn trace_owned<S: Schedule>(owned: &OwnedTasks<S>) -> Vec<Trace> {
// notify each task
let mut tasks = vec![];
owned.for_each(|task| {
// notify the task (and thus make it poll-able) and stash it
tasks.push(task.notify_for_tracing());
// we do not poll it here since we hold a lock on `owned` and the task
// may complete and need to remove itself from `owned`.
});

traces
}
tasks
.into_iter()
.map(|task| {
let local_notified = owned.assert_owner(task);
let ((), trace) = Trace::capture(|| local_notified.run());
trace
})
.collect()
}
57 changes: 57 additions & 0 deletions tokio/tests/dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,60 @@ fn multi_thread() {
);
});
}

/// Regression tests for #6035.
///
/// These tests ensure that dumping will not deadlock if a future completes
/// during a trace.
mod future_completes_during_trace {
use super::*;

use core::future::{poll_fn, Future};

/// A future that completes only during a trace.
fn complete_during_trace() -> impl Future<Output = ()> + Send {
use std::task::Poll;
poll_fn(|cx| {
if Handle::is_tracing() {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
})
}

#[test]
fn current_thread() {
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();

async fn dump() {
let handle = Handle::current();
let _dump = handle.dump().await;
}

rt.block_on(async {
let _ = tokio::join!(tokio::spawn(complete_during_trace()), dump());
});
}

#[test]
fn multi_thread() {
let rt = runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();

async fn dump() {
let handle = Handle::current();
let _dump = handle.dump().await;
}

rt.block_on(async {
let _ = tokio::join!(tokio::spawn(complete_during_trace()), dump());
});
}
}
Loading