diff --git a/Cargo.toml b/Cargo.toml index 5761935a..5186db81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,6 +81,9 @@ wasm-bindgen-test = "0.3" [build-dependencies] rustversion = "1.0" +[patch.crates-io] +tokio = {git = "https://github.com/tokio-rs/tokio", branch = "v0.2.x"} + [profile.bench] codegen-units = 1 debug = 2 diff --git a/amadeus-serde/src/csv.rs b/amadeus-serde/src/csv.rs index b74995af..34df5283 100644 --- a/amadeus-serde/src/csv.rs +++ b/amadeus-serde/src/csv.rs @@ -1,5 +1,3 @@ -#![allow(clippy::unsafe_derive_deserialize)] // https://github.com/rust-lang/rust-clippy/issues/5789 - use csv::Error as InternalCsvError; use educe::Educe; use futures::{pin_mut, stream, AsyncReadExt, FutureExt, Stream, StreamExt}; diff --git a/amadeus-serde/src/json.rs b/amadeus-serde/src/json.rs index 161176c2..9b72eae1 100644 --- a/amadeus-serde/src/json.rs +++ b/amadeus-serde/src/json.rs @@ -1,5 +1,3 @@ -#![allow(clippy::unsafe_derive_deserialize)] // https://github.com/rust-lang/rust-clippy/issues/5789 - use educe::Educe; use futures::{pin_mut, stream, AsyncReadExt, FutureExt, Stream, StreamExt}; use serde::{Deserialize, Serialize}; diff --git a/benches/in_memory.rs b/benches/in_memory.rs index ecdf28f1..bc13d27b 100644 --- a/benches/in_memory.rs +++ b/benches/in_memory.rs @@ -64,6 +64,7 @@ fn run(b: &mut Bencher, bytes: u64, mut task: impl FnMut() -> F) where F: Future, { + let _ = rayon::ThreadPoolBuilder::new().build_global(); RT.enter(|| { let _ = Lazy::force(&POOL); b.bytes = bytes; diff --git a/src/pool/thread.rs b/src/pool/thread.rs index fc945491..e6f85f00 100644 --- a/src/pool/thread.rs +++ b/src/pool/thread.rs @@ -79,13 +79,12 @@ impl ThreadPool { T: Send + 'a, { #[cfg(not(target_arch = "wasm32"))] - return Guard::new( - self.0 - .pool - .spawn_pinned_unchecked(task) - .map_err(JoinError::into_panic) - .map_err(Panicked::from), - ); + return self + .0 + .pool + .spawn_pinned_unchecked(task) + .map_err(JoinError::into_panic) + .map_err(Panicked::from); #[cfg(target_arch = "wasm32")] { let _self = self; @@ -104,10 +103,10 @@ impl ThreadPool { .map_err(Into::into) .remote_handle(); wasm_bindgen_futures::spawn_local(remote); - Guard::new(remote_handle.map_ok(|t| { + remote_handle.map_ok(|t| { let t: *mut dyn Send = Box::into_raw(t); *Box::from_raw(t as *mut T) - })) + }) } } } @@ -125,39 +124,6 @@ impl Clone for ThreadPool { impl UnwindSafe for ThreadPool {} impl RefUnwindSafe for ThreadPool {} -#[pin_project(PinnedDrop)] -struct Guard(#[pin] Option); -impl Guard { - fn new(f: F) -> Self { - Self(Some(f)) - } -} -impl Future for Guard -where - F: Future, -{ - type Output = F::Output; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - match self.as_mut().project().0.as_pin_mut() { - Some(fut) => { - let output = ready!(fut.poll(cx)); - self.project().0.set(None); - Poll::Ready(output) - } - None => Poll::Pending, - } - } -} -#[pinned_drop] -impl PinnedDrop for Guard { - fn drop(self: Pin<&mut Self>) { - if self.project().0.is_some() { - panic!("dropped before finished polling!"); - } - } -} - fn _assert() { let _ = assert_sync_and_send::; } @@ -166,10 +132,12 @@ fn _assert() { #[cfg(not(target_arch = "wasm32"))] mod pool { use async_channel::{bounded, Sender}; - use futures::{future::RemoteHandle, FutureExt}; + use futures::{ + future::{join_all, RemoteHandle}, FutureExt + }; use std::{any::Any, future::Future, mem, panic::AssertUnwindSafe, pin::Pin}; use tokio::{ - runtime::Handle, task::{JoinError, LocalSet} + runtime::Handle, task, task::{JoinError, JoinHandle, LocalSet} }; type Request = Box Box> + Send>; @@ -177,30 +145,34 @@ mod pool { #[derive(Debug)] pub(super) struct Pool { - sender: Sender<(Request, Sender>)>, + sender: Option>)>>, + threads: Vec>, } impl Pool { pub(super) fn new(threads: usize) -> Self { let handle = Handle::current(); let handle1 = handle.clone(); let (sender, receiver) = bounded::<(Request, Sender>)>(1); - for _ in 0..threads { - let receiver = receiver.clone(); - let handle = handle.clone(); - let _ = handle1.spawn_blocking(move || { - let local = LocalSet::new(); - handle.block_on(local.run_until(async { - while let Ok((task, sender)) = receiver.recv().await { - let _ = local.spawn_local(async move { - let (remote, remote_handle) = Pin::from(task()).remote_handle(); - let _ = sender.send(remote_handle).await; - remote.await; - }); - } - })) - }); - } - Self { sender } + let threads = (0..threads) + .map(|_| { + let receiver = receiver.clone(); + let handle = handle.clone(); + handle1.spawn_blocking(move || { + let local = LocalSet::new(); + handle.block_on(local.run_until(async { + while let Ok((task, sender)) = receiver.recv().await { + let _ = local.spawn_local(async move { + let (remote, remote_handle) = Pin::from(task()).remote_handle(); + let _ = sender.send(remote_handle).await; + remote.await; + }); + } + })) + }) + }) + .collect(); + let sender = Some(sender); + Self { sender, threads } } pub(super) fn spawn_pinned( &self, task: F, @@ -210,7 +182,7 @@ mod pool { Fut: Future + 'static, T: Send + 'static, { - let sender = self.sender.clone(); + let sender = self.sender.as_ref().unwrap().clone(); async move { let task: Request = Box::new(|| { Box::new( @@ -236,7 +208,7 @@ mod pool { Fut: Future + 'a, T: Send + 'a, { - let sender = self.sender.clone(); + let sender = self.sender.as_ref().unwrap().clone(); async move { let task: Box Box> + Send> = Box::new(|| { @@ -264,6 +236,18 @@ mod pool { } } } + impl Drop for Pool { + fn drop(&mut self) { + let _ = self.sender.take().unwrap(); + task::block_in_place(|| { + let handle = Handle::current(); + handle.block_on(join_all(mem::take(&mut self.threads))) + }) + .into_iter() + .collect::>() + .unwrap(); + } + } #[cfg(test)] mod tests { @@ -274,8 +258,8 @@ mod pool { atomic::{AtomicUsize, Ordering}, Arc }; - #[tokio::test] - async fn spawn_pinned_() { + #[tokio::test(threaded_scheduler)] + async fn spawn_pinned() { const TASKS: usize = 1000; const ITERS: usize = 1000; const THREADS: usize = 4; diff --git a/src/source.rs b/src/source.rs index fec23364..d7ba9f25 100644 --- a/src/source.rs +++ b/src/source.rs @@ -1,5 +1,3 @@ -#![allow(clippy::unsafe_derive_deserialize)] - use ::serde::{Deserialize, Serialize}; use derive_new::new; use futures::Stream; @@ -206,6 +204,7 @@ where } } +#[allow(clippy::unsafe_derive_deserialize)] #[pin_project] #[derive(Serialize, Deserialize)] #[serde(transparent)] diff --git a/tests/cloudfront.rs b/tests/cloudfront.rs index 804d39bd..986b5d4e 100644 --- a/tests/cloudfront.rs +++ b/tests/cloudfront.rs @@ -3,7 +3,7 @@ use amadeus::prelude::*; use std::time::SystemTime; -#[tokio::test] +#[tokio::test(threaded_scheduler)] async fn cloudfront() { let pool = &ThreadPool::new(None).unwrap(); diff --git a/tests/commoncrawl.rs b/tests/commoncrawl.rs index 4a87a4b2..5cf95000 100644 --- a/tests/commoncrawl.rs +++ b/tests/commoncrawl.rs @@ -15,7 +15,7 @@ use std::time::{Duration, SystemTime}; use amadeus::{data::Webpage, prelude::*}; -#[tokio::test] +#[tokio::test(threaded_scheduler)] async fn commoncrawl() { let start = SystemTime::now(); diff --git a/tests/csv.rs b/tests/csv.rs index b1b56a40..319518e9 100644 --- a/tests/csv.rs +++ b/tests/csv.rs @@ -4,7 +4,7 @@ use std::{path::PathBuf, time::SystemTime}; use amadeus::prelude::*; -#[tokio::test] +#[tokio::test(threaded_scheduler)] async fn csv() { let start = SystemTime::now(); diff --git a/tests/into_par_stream.rs b/tests/into_par_stream.rs index d20dbbfb..319cfbc2 100644 --- a/tests/into_par_stream.rs +++ b/tests/into_par_stream.rs @@ -2,7 +2,7 @@ use either::Either; use amadeus::prelude::*; -#[tokio::test] +#[tokio::test(threaded_scheduler)] async fn into_par_stream() { let pool = &ThreadPool::new(None).unwrap(); diff --git a/tests/json.rs b/tests/json.rs index 93a26716..2a28164e 100644 --- a/tests/json.rs +++ b/tests/json.rs @@ -4,7 +4,7 @@ use std::{path::PathBuf, time::SystemTime}; use amadeus::prelude::*; -#[tokio::test] +#[tokio::test(threaded_scheduler)] async fn json() { let start = SystemTime::now(); diff --git a/tests/panic.rs b/tests/panic.rs index 622dcca6..a4fe2e29 100644 --- a/tests/panic.rs +++ b/tests/panic.rs @@ -3,7 +3,7 @@ use std::{panic, panic::AssertUnwindSafe, time::SystemTime}; use amadeus::prelude::*; -#[tokio::test] +#[tokio::test(threaded_scheduler)] async fn panic() { let start = SystemTime::now(); diff --git a/tests/parquet.rs b/tests/parquet.rs index c48da012..b3a9ee9b 100644 --- a/tests/parquet.rs +++ b/tests/parquet.rs @@ -9,7 +9,7 @@ use std::{collections::HashMap, path::PathBuf, time::SystemTime}; use amadeus::prelude::*; -#[tokio::test] +#[tokio::test(threaded_scheduler)] async fn parquet() { let start = SystemTime::now(); diff --git a/tests/parquet_dist.rs b/tests/parquet_dist.rs index 5ef4a762..7d68d6c5 100644 --- a/tests/parquet_dist.rs +++ b/tests/parquet_dist.rs @@ -1,4 +1,4 @@ -#![type_length_limit = "1572864"] +#![type_length_limit = "2073124"] #![allow(clippy::cognitive_complexity, clippy::type_complexity)] #[cfg(feature = "constellation")] diff --git a/tests/postgres.rs b/tests/postgres.rs index 18af1e32..83c4a0fa 100644 --- a/tests/postgres.rs +++ b/tests/postgres.rs @@ -4,7 +4,7 @@ use std::time::SystemTime; use amadeus::prelude::*; -#[tokio::test] +#[tokio::test(threaded_scheduler)] async fn postgres() { let start = SystemTime::now(); diff --git a/tests/threads.rs b/tests/threads.rs index f38f2582..af4949b0 100644 --- a/tests/threads.rs +++ b/tests/threads.rs @@ -7,7 +7,7 @@ use tokio::time::delay_for as sleep; use amadeus::dist::prelude::*; -#[tokio::test] +#[tokio::test(threaded_scheduler)] async fn threads() { let start = SystemTime::now();