Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task: various small improvements to LocalKey #4795

Merged
merged 8 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 211 additions & 84 deletions tokio/src/task/task_local.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use pin_project_lite::pin_project;
use std::cell::RefCell;
use std::error::Error;
use std::future::Future;
Expand Down Expand Up @@ -79,7 +78,7 @@ macro_rules! __task_local_inner {

/// A key for task-local data.
///
/// This type is generated by the `task_local!` macro.
/// This type is generated by the [`task_local!`] macro.
///
/// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will
/// _not_ lazily initialize the value on first access. Instead, the
Expand Down Expand Up @@ -107,7 +106,9 @@ macro_rules! __task_local_inner {
/// }).await;
/// # }
/// ```
///
/// [`std::thread::LocalKey`]: struct@std::thread::LocalKey
/// [`task_local!`]: ../macro.task_local.html
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub struct LocalKey<T: 'static> {
#[doc(hidden)]
Expand All @@ -119,6 +120,11 @@ impl<T: 'static> LocalKey<T> {
///
/// On completion of `scope`, the task-local will be dropped.
///
/// ### Panics
///
/// If you poll the returned future inside a call to [`with`] or
/// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic.
///
/// ### Examples
///
/// ```
Expand All @@ -132,14 +138,17 @@ impl<T: 'static> LocalKey<T> {
/// }).await;
/// # }
/// ```
///
/// [`with`]: fn@Self::with
/// [`try_with`]: fn@Self::try_with
pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F>
where
F: Future,
{
TaskLocalFuture {
local: self,
slot: Some(value),
future: f,
future: Some(f),
_pinned: PhantomPinned,
}
}
Expand All @@ -148,6 +157,11 @@ impl<T: 'static> LocalKey<T> {
///
/// On completion of `scope`, the task-local will be dropped.
///
/// ### Panics
///
/// This method panics if called inside a call to [`with`] or [`try_with`]
/// on the same `LocalKey`.
///
/// ### Examples
///
/// ```
Expand All @@ -161,34 +175,85 @@ impl<T: 'static> LocalKey<T> {
/// });
/// # }
/// ```
///
/// [`with`]: fn@Self::with
/// [`try_with`]: fn@Self::try_with
#[track_caller]
pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R
where
F: FnOnce() -> R,
{
let scope = TaskLocalFuture {
local: self,
slot: Some(value),
future: (),
_pinned: PhantomPinned,
};
crate::pin!(scope);
scope.with_task(|_| f())
let mut value = Some(value);
match self.scope_inner(&mut value, f) {
Ok(res) => res,
Err(ScopeInnerErr::BorrowError) => {
panic!("sync_scope called while Task Local Storage is borrowed")
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
}
Err(ScopeInnerErr::AccessError) => {
panic!("cannot access a Task Local Storage value during or after destruction")
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr>
where
F: FnOnce() -> R,
{
struct Guard<'a, T: 'static> {
local: &'static LocalKey<T>,
slot: &'a mut Option<T>,
}

impl<'a, T: 'static> Drop for Guard<'a, T> {
fn drop(&mut self) {
// This should not panic.
//
// We know that the RefCell was not borrowed before the call to
// `scope_inner`, so the only way for this to panic is if the
// closure has created but not destroyed a RefCell guard.
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
// However, we never give user-code access to the guards, so
// there's no way for user-code to forget to destroy a guard.
//
// The call to `with` also should not panic, since the
// thread-local wasn't destroyed when we first called
// `scope_inner`, and it shouldn't have gotten destroyed since
// then.
self.local.inner.with(|inner| {
let mut ref_mut = inner.borrow_mut();
std::mem::swap(self.slot, &mut *ref_mut);
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
});
}
}

self.inner.try_with(|inner| {
inner
.try_borrow_mut()
.map(|mut ref_mut| std::mem::swap(slot, &mut *ref_mut))
})??;

let guard = Guard { local: self, slot };

let res = f();

drop(guard);

Ok(res)
}

/// Accesses the current task-local and runs the provided closure.
///
/// # Panics
///
/// This function will panic if not called within the context
/// of a future containing a task-local with the corresponding key.
/// This function will panic if the task local doesn't have a value set.
#[track_caller]
pub fn with<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
self.try_with(f).expect(
"cannot access a Task Local Storage value \
without setting it via `LocalKey::set`",
)
match self.try_with(f) {
Ok(res) => res,
Err(_) => panic!("cannot access a Task Local Storage value without setting it first"),
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Accesses the current task-local and runs the provided closure.
Expand All @@ -200,19 +265,31 @@ impl<T: 'static> LocalKey<T> {
where
F: FnOnce(&T) -> R,
{
self.inner.with(|v| {
if let Some(val) = v.borrow().as_ref() {
Ok(f(val))
} else {
Err(AccessError { _private: () })
}
})
// If called after the thread-local storing the task-local is destroyed,
// then we are outside of a closure where the task-local is set.
//
// Therefore, it is correct to return an AccessError if `try_with`
// returns an error.
let try_with_res = self.inner.try_with(|v| {
// This call to `borrow` cannot panic because no user-defined code
// runs while a `borrow_mut` call is active.
v.borrow().as_ref().map(f)
});

match try_with_res {
Ok(Some(res)) => Ok(res),
Ok(None) | Err(_) => Err(AccessError { _private: () }),
}
}
}

impl<T: Copy + 'static> LocalKey<T> {
/// Returns a copy of the task-local value
/// if the task-local value implements `Copy`.
///
/// # Panics
///
/// This function will panic if the task local doesn't have a value set.
pub fn get(&'static self) -> T {
self.with(|v| *v)
}
Expand All @@ -224,76 +301,109 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
}
}

pin_project! {
/// A future that sets a value `T` of a task local for the future `F` during
/// its execution.
///
/// The value of the task-local must be `'static` and will be dropped on the
/// completion of the future.
///
/// Created by the function [`LocalKey::scope`](self::LocalKey::scope).
///
/// ### Examples
///
/// ```
/// # async fn dox() {
/// tokio::task_local! {
/// static NUMBER: u32;
/// }
///
/// NUMBER.scope(1, async move {
/// println!("task local value: {}", NUMBER.get());
/// }).await;
/// # }
/// ```
pub struct TaskLocalFuture<T, F>
where
T: 'static
{
local: &'static LocalKey<T>,
slot: Option<T>,
#[pin]
future: F,
#[pin]
_pinned: PhantomPinned,
}
/// A future that sets a value `T` of a task local for the future `F` during
/// its execution.
///
/// The value of the task-local must be `'static` and will be dropped on the
/// completion of the future.
///
/// Created by the function [`LocalKey::scope`](self::LocalKey::scope).
///
/// ### Examples
///
/// ```
/// # async fn dox() {
/// tokio::task_local! {
/// static NUMBER: u32;
/// }
///
/// NUMBER.scope(1, async move {
/// println!("task local value: {}", NUMBER.get());
/// }).await;
/// # }
/// ```
// Doesn't use pin_project due to custom Drop.
pub struct TaskLocalFuture<T, F>
where
T: 'static,
{
local: &'static LocalKey<T>,
slot: Option<T>,
future: Option<F>,
_pinned: PhantomPinned,
}

impl<T: 'static, F> TaskLocalFuture<T, F> {
fn with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R {
struct Guard<'a, T: 'static> {
local: &'static LocalKey<T>,
slot: &'a mut Option<T>,
prev: Option<T>,
}

impl<T> Drop for Guard<'_, T> {
fn drop(&mut self) {
let value = self.local.inner.with(|c| c.replace(self.prev.take()));
*self.slot = value;
}
}
impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
type Output = F::Output;

let project = self.project();
let val = project.slot.take();
#[track_caller]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// safety: The TaskLocalFuture struct is `!Unpin` so there is no way to
// move `self.future` from now on.
let this = unsafe { Pin::into_inner_unchecked(self) };
let mut future_opt = unsafe { Pin::new_unchecked(&mut this.future) };

let prev = project.local.inner.with(|c| c.replace(val));
let res =
this.local
.scope_inner(&mut this.slot, || match future_opt.as_mut().as_pin_mut() {
Some(fut) => {
let res = fut.poll(cx);
if res.is_ready() {
future_opt.set(None);
}
Some(res)
}
None => None,
});

let _guard = Guard {
prev,
slot: project.slot,
local: *project.local,
};
match res {
Ok(Some(res)) => res,
Ok(None) => panic!("TaskLocalFuture polled after completion"),
Err(ScopeInnerErr::BorrowError) => {
panic!("TaskLocalFuture::poll called while task local is borrowed")
}
Err(ScopeInnerErr::AccessError) => {
panic!("cannot access a Task Local Storage value during or after destruction")
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
}
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

f(project.future)
impl<T: 'static, F> Drop for TaskLocalFuture<T, F> {
fn drop(&mut self) {
if std::mem::needs_drop::<F>() && self.future.is_some() {
// Drop the future while the task-local is set, if possible. Otherwise
// the future is dropped normally when the `Option<F>` field drops.
let future = &mut self.future;
let _ = self.local.scope_inner(&mut self.slot, || {
*future = None;
});
}
}
}

impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
type Output = F::Output;
impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
/// Format the Option without Some.
struct TransparentOption<'a, T> {
value: &'a Option<T>,
}
impl<'a, T: fmt::Debug> fmt::Debug for TransparentOption<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.value.as_ref() {
Some(value) => value.fmt(f),
// Hitting the None branch should not be possible.
None => f.pad("<missing>"),
}
}
}

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.with_task(|f| f.poll(cx))
f.debug_struct("TaskLocalFuture")
.field("value", &TransparentOption { value: &self.slot })
.finish()
}
}

Expand All @@ -316,3 +426,20 @@ impl fmt::Display for AccessError {
}

impl Error for AccessError {}

enum ScopeInnerErr {
BorrowError,
AccessError,
}

impl From<std::cell::BorrowMutError> for ScopeInnerErr {
fn from(_: std::cell::BorrowMutError) -> Self {
Self::BorrowError
}
}

impl From<std::thread::AccessError> for ScopeInnerErr {
fn from(_: std::thread::AccessError) -> Self {
Self::AccessError
}
}
Loading