Skip to content

Commit

Permalink
Switch to a single threaded tokio executor, remove some unsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
pkolaczk committed Aug 17, 2024
1 parent 14da8a0 commit 6899b68
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 56 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ status-line = "0.2.0"
strum = { version = "0.26", features = ["derive"] }
time = "0.3"
thiserror = "1.0.26"
tokio = { version = "1", features = ["rt", "rt-multi-thread", "time", "parking_lot", "signal"] }
tokio = { version = "1", features = ["rt", "time", "parking_lot", "signal"] }
tokio-stream = "0.1"
tracing = "0.1"
tracing-appender = "0.2"
Expand Down
63 changes: 35 additions & 28 deletions src/exec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::runtime::Builder;
use tokio::signal::ctrl_c;
use tokio::task::LocalSet;
use tokio::time::MissedTickBehavior;
use tokio_stream::wrappers::IntervalStream;

Expand Down Expand Up @@ -107,35 +109,40 @@ fn spawn_stream(
) -> Receiver<Result<WorkloadStats>> {
let (tx, rx) = channel(1);

tokio::spawn(async move {
match rate {
Some(rate) => {
let stream = interval_stream(rate);
run_stream(
stream,
workload,
iter_counter,
concurrency,
sampling,
progress,
tx,
)
.await
let rt = Builder::new_current_thread().enable_all().build().unwrap();
std::thread::spawn(move || {
let local = LocalSet::new();
local.spawn_local(async move {
match rate {
Some(rate) => {
let stream = interval_stream(rate);
run_stream(
stream,
workload,
iter_counter,
concurrency,
sampling,
progress,
tx,
)
.await
}
None => {
let stream = futures::stream::repeat_with(|| ());
run_stream(
stream,
workload,
iter_counter,
concurrency,
sampling,
progress,
tx,
)
.await
}
}
None => {
let stream = futures::stream::repeat_with(|| ());
run_stream(
stream,
workload,
iter_counter,
concurrency,
sampling,
progress,
tx,
)
.await
}
}
});
rt.block_on(local);
});
rx
}
Expand Down
9 changes: 3 additions & 6 deletions src/exec/workload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,14 @@ impl Program {
/// If execution fails, emits diagnostic messages, e.g. stacktrace to standard error stream.
/// Also signals an error if the function execution succeeds, but the function returns
/// an error value.
pub async fn async_call(
&self,
fun: &FnRef,
args: impl Args + Send,
) -> Result<Value, LatteError> {
pub async fn async_call(&self, fun: &FnRef, args: impl Args) -> Result<Value, LatteError> {
let handle_err = |e: VmError| {
let mut out = StandardStream::stderr(ColorChoice::Auto);
let _ = e.emit(&mut out, &self.sources);
LatteError::ScriptExecError(fun.name.to_string(), e)
};
let execution = self.vm().send_execute(fun.hash, args).map_err(handle_err)?;
let mut vm = self.vm();
let mut execution = vm.execute(fun.hash, args).map_err(handle_err)?;
let result = execution
.async_complete()
.await
Expand Down
18 changes: 3 additions & 15 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,15 +507,8 @@ fn edit_workload(workload: PathBuf) -> Result<()> {
Ok(())
}

fn init_runtime(thread_count: usize) -> std::io::Result<Runtime> {
if thread_count == 1 {
Builder::new_current_thread().enable_all().build()
} else {
Builder::new_multi_thread()
.worker_threads(thread_count)
.enable_all()
.build()
}
fn init_runtime() -> std::io::Result<Runtime> {
Builder::new_current_thread().enable_all().build()
}

fn setup_logging(run_id: &str, config: &AppConfig) -> Result<WorkerGuard> {
Expand Down Expand Up @@ -560,12 +553,7 @@ fn main() {
};

let command = config.command;
let thread_count = match &command {
Command::Run(cmd) => cmd.threads.get(),
Command::Load(cmd) => cmd.threads.get(),
_ => 1,
};
let runtime = init_runtime(thread_count);
let runtime = init_runtime();
if let Err(e) = runtime.unwrap().block_on(async_main(run_id, command)) {
eprintln!("error: {e}");
exit(128);
Expand Down
6 changes: 1 addition & 5 deletions src/scripting/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ pub struct GlobalContext {
// Needed, because Rune `Value` is !Send, as it may contain some internal pointers.
// Therefore, it is not safe to pass a `Value` to another thread by cloning it, because
// both objects could accidentally share some unprotected, `!Sync` data.
// To make it safe, the same `Context` is never used by more than one thread at once, and
// we make sure in `clone` to make a deep copy of the `data` field by serializing
// To make it safe, we make sure in `clone` to make a deep copy of the `data` field by serializing
// and deserializing it, so no pointers could get through.
unsafe impl Send for LocalContext {}
unsafe impl Sync for LocalContext {}
unsafe impl Send for GlobalContext {}
unsafe impl Sync for GlobalContext {}

impl GlobalContext {
pub fn new(session: scylla::Session, retry_strategy: RetryStrategy) -> Self {
Expand Down

0 comments on commit 6899b68

Please sign in to comment.