Skip to content

Commit 98c08f8

Browse files
committed
switch to multitask
Signed-off-by: Marc-Antoine Perennou <Marc-Antoine@Perennou.com>
1 parent a6f6d04 commit 98c08f8

File tree

5 files changed

+121
-19
lines changed

5 files changed

+121
-19
lines changed

Cargo.toml

+9-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ default = [
2929
"blocking",
3030
"kv-log-macro",
3131
"log",
32+
"multitask",
3233
"num_cpus",
3334
"pin-project-lite",
34-
"smol",
3535
]
3636
docs = ["attributes", "unstable", "default"]
3737
unstable = [
@@ -56,7 +56,7 @@ alloc = [
5656
"futures-core/alloc",
5757
"pin-project-lite",
5858
]
59-
tokio02 = ["smol/tokio02"]
59+
tokio02 = ["tokio"]
6060

6161
[dependencies]
6262
async-attributes = { version = "1.1.1", optional = true }
@@ -81,7 +81,7 @@ surf = { version = "1.0.3", optional = true }
8181
[target.'cfg(not(target_os = "unknown"))'.dependencies]
8282
async-io = { version = "0.1.2", optional = true }
8383
blocking = { version = "0.4.6", optional = true }
84-
smol = { version = "0.1.17", optional = true }
84+
multitask = { version = "0.2.0", optional = true }
8585

8686
[target.'cfg(target_arch = "wasm32")'.dependencies]
8787
futures-timer = { version = "3.0.2", optional = true, features = ["wasm-bindgen"] }
@@ -91,6 +91,12 @@ futures-channel = { version = "0.3.4", optional = true }
9191
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
9292
wasm-bindgen-test = "0.3.10"
9393

94+
[dependencies.tokio]
95+
version = "0.2"
96+
default-features = false
97+
features = ["rt-threaded"]
98+
optional = true
99+
94100
[dev-dependencies]
95101
femme = "1.3.0"
96102
rand = "0.7.3"

src/task/builder.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::task::{Context, Poll};
77
use pin_project_lite::pin_project;
88

99
use crate::io;
10-
use crate::task::{JoinHandle, Task, TaskLocalsWrapper};
10+
use crate::task::{self, JoinHandle, Task, TaskLocalsWrapper};
1111

1212
/// Task builder that configures the settings of a new task.
1313
#[derive(Debug, Default)]
@@ -61,9 +61,9 @@ impl Builder {
6161
});
6262

6363
let task = wrapped.tag.task().clone();
64-
let smol_task = smol::Task::spawn(wrapped).into();
64+
let handle = task::executor::spawn(wrapped);
6565

66-
Ok(JoinHandle::new(smol_task, task))
66+
Ok(JoinHandle::new(handle, task))
6767
}
6868

6969
/// Spawns a task locally with the configured settings.
@@ -81,9 +81,9 @@ impl Builder {
8181
});
8282

8383
let task = wrapped.tag.task().clone();
84-
let smol_task = smol::Task::local(wrapped).into();
84+
let handle = task::executor::local(wrapped);
8585

86-
Ok(JoinHandle::new(smol_task, task))
86+
Ok(JoinHandle::new(handle, task))
8787
}
8888

8989
/// Spawns a task locally with the configured settings.
@@ -166,8 +166,8 @@ impl Builder {
166166
unsafe {
167167
TaskLocalsWrapper::set_current(&wrapped.tag, || {
168168
let res = if should_run {
169-
// The first call should use run.
170-
smol::run(wrapped)
169+
// The first call should run the executor
170+
task::executor::run(wrapped)
171171
} else {
172172
blocking::block_on(wrapped)
173173
};

src/task/executor.rs

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
use std::cell::RefCell;
2+
use std::future::Future;
3+
use std::task::{Context, Poll};
4+
5+
static GLOBAL_EXECUTOR: once_cell::sync::Lazy<multitask::Executor> = once_cell::sync::Lazy::new(multitask::Executor::new);
6+
7+
struct Executor {
8+
local_executor: multitask::LocalExecutor,
9+
parker: async_io::parking::Parker,
10+
}
11+
12+
thread_local! {
13+
static EXECUTOR: RefCell<Executor> = RefCell::new({
14+
let (parker, unparker) = async_io::parking::pair();
15+
let local_executor = multitask::LocalExecutor::new(move || unparker.unpark());
16+
Executor { local_executor, parker }
17+
});
18+
}
19+
20+
pub(crate) fn spawn<F, T>(future: F) -> multitask::Task<T>
21+
where
22+
F: Future<Output = T> + Send + 'static,
23+
T: Send + 'static,
24+
{
25+
GLOBAL_EXECUTOR.spawn(future)
26+
}
27+
28+
#[cfg(feature = "unstable")]
29+
pub(crate) fn local<F, T>(future: F) -> multitask::Task<T>
30+
where
31+
F: Future<Output = T> + 'static,
32+
T: 'static,
33+
{
34+
EXECUTOR.with(|executor| executor.borrow().local_executor.spawn(future))
35+
}
36+
37+
pub(crate) fn run<F, T>(future: F) -> T
38+
where
39+
F: Future<Output = T>,
40+
{
41+
enter(|| EXECUTOR.with(|executor| {
42+
let executor = executor.borrow();
43+
let unparker = executor.parker.unparker();
44+
let global_ticker = GLOBAL_EXECUTOR.ticker(move || unparker.unpark());
45+
let unparker = executor.parker.unparker();
46+
let waker = async_task::waker_fn(move || unparker.unpark());
47+
let cx = &mut Context::from_waker(&waker);
48+
pin_utils::pin_mut!(future);
49+
loop {
50+
if let Poll::Ready(res) = future.as_mut().poll(cx) {
51+
return res;
52+
}
53+
if let Ok(false) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| executor.local_executor.tick() || global_ticker.tick())) {
54+
executor.parker.park();
55+
}
56+
}
57+
}))
58+
}
59+
60+
/// Enters the tokio context if the `tokio` feature is enabled.
61+
fn enter<T>(f: impl FnOnce() -> T) -> T {
62+
#[cfg(not(feature = "tokio02"))]
63+
return f();
64+
65+
#[cfg(feature = "tokio02")]
66+
{
67+
use std::cell::Cell;
68+
use tokio::runtime::Runtime;
69+
70+
thread_local! {
71+
/// The level of nested `enter` calls we are in, to ensure that the outermost always
72+
/// has a runtime spawned.
73+
static NESTING: Cell<usize> = Cell::new(0);
74+
}
75+
76+
/// The global tokio runtime.
77+
static RT: once_cell::sync::Lazy<Runtime> = once_cell::sync::Lazy::new(|| Runtime::new().expect("cannot initialize tokio"));
78+
79+
NESTING.with(|nesting| {
80+
let res = if nesting.get() == 0 {
81+
nesting.replace(1);
82+
RT.enter(f)
83+
} else {
84+
nesting.replace(nesting.get() + 1);
85+
f()
86+
};
87+
nesting.replace(nesting.get() - 1);
88+
res
89+
})
90+
}
91+
}

src/task/join_handle.rs

+12-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub struct JoinHandle<T> {
1818
}
1919

2020
#[cfg(not(target_os = "unknown"))]
21-
type InnerHandle<T> = async_task::JoinHandle<T, ()>;
21+
type InnerHandle<T> = multitask::Task<T>;
2222
#[cfg(target_arch = "wasm32")]
2323
type InnerHandle<T> = futures_channel::oneshot::Receiver<T>;
2424

@@ -54,8 +54,7 @@ impl<T> JoinHandle<T> {
5454
#[cfg(not(target_os = "unknown"))]
5555
pub async fn cancel(mut self) -> Option<T> {
5656
let handle = self.handle.take().unwrap();
57-
handle.cancel();
58-
handle.await
57+
handle.cancel().await
5958
}
6059

6160
/// Cancel this task.
@@ -67,15 +66,19 @@ impl<T> JoinHandle<T> {
6766
}
6867
}
6968

69+
#[cfg(not(target_os = "unknown"))]
70+
impl<T> Drop for JoinHandle<T> {
71+
fn drop(&mut self) {
72+
if let Some(handle) = self.handle.take() {
73+
handle.detach();
74+
}
75+
}
76+
}
77+
7078
impl<T> Future for JoinHandle<T> {
7179
type Output = T;
7280

7381
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
74-
match Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx) {
75-
Poll::Pending => Poll::Pending,
76-
Poll::Ready(output) => {
77-
Poll::Ready(output.expect("cannot await the result of a panicked task"))
78-
}
79-
}
82+
Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx)
8083
}
8184
}

src/task/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ cfg_default! {
148148
mod block_on;
149149
mod builder;
150150
mod current;
151+
#[cfg(not(target_os = "unknown"))]
152+
mod executor;
151153
mod join_handle;
152154
mod sleep;
153155
#[cfg(not(target_os = "unknown"))]

0 commit comments

Comments
 (0)