Skip to content

Commit

Permalink
Add wait_drained to SchedulerServer and Executor (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpurins-coralogix authored and ch-sc committed Mar 30, 2023
1 parent beff946 commit a8f1cae
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 30 deletions.
61 changes: 40 additions & 21 deletions ballista/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,8 @@ use datafusion::physical_plan::udf::ScalarUDF;
use datafusion::physical_plan::Partitioning;
use futures::future::AbortHandle;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

pub struct TasksDrainedFuture(pub Arc<Executor>);

impl Future for TasksDrainedFuture {
type Output = ();

fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0.abort_handles.len() > 0 {
Poll::Pending
} else {
Poll::Ready(())
}
}
}
use tokio::sync::watch;

type AbortHandles = Arc<DashMap<(usize, PartitionId), AbortHandle>>;

Expand Down Expand Up @@ -84,6 +68,9 @@ pub struct Executor {
/// Execution engine that the executor will delegate to
/// for executing query stages
pub(crate) execution_engine: Arc<dyn ExecutionEngine>,

drained: Arc<watch::Sender<()>>,
check_drained: watch::Receiver<()>,
}

impl Executor {
Expand All @@ -96,6 +83,7 @@ impl Executor {
concurrent_tasks: usize,
execution_engine: Option<Arc<dyn ExecutionEngine>>,
) -> Self {
let (drained, check_drained) = watch::channel(());
Self {
metadata,
work_dir: work_dir.to_owned(),
Expand All @@ -108,6 +96,8 @@ impl Executor {
abort_handles: Default::default(),
execution_engine: execution_engine
.unwrap_or_else(|| Arc::new(DefaultExecutionEngine {})),
drained: Arc::new(drained),
check_drained
}
}
}
Expand All @@ -131,9 +121,11 @@ impl Executor {
self.abort_handles
.insert((task_id, partition.clone()), abort_handle);

let partitions = task.await??;
let partitions = task.await;

self.remove_handle(task_id, partition.clone());

self.abort_handles.remove(&(task_id, partition.clone()));
let partitions = partitions??;

self.metrics_collector.record_stage(
&partition.job_id,
Expand All @@ -152,14 +144,14 @@ impl Executor {
stage_id: usize,
partition_id: usize,
) -> Result<bool, BallistaError> {
if let Some((_, handle)) = self.abort_handles.remove(&(
if let Some((_, handle)) = self.remove_handle(
task_id,
PartitionId {
job_id,
stage_id,
partition_id,
},
)) {
) {
handle.abort();
Ok(true)
} else {
Expand All @@ -174,6 +166,33 @@ impl Executor {
pub fn active_task_count(&self) -> usize {
self.abort_handles.len()
}

pub async fn wait_drained(&self) {
let mut check_drained = self.check_drained.clone();
loop {
if self.active_task_count() == 0 {
break;
}

if check_drained.changed().await.is_err() {
break;
};
}
}

fn remove_handle(
&self,
task_id: usize,
partition: PartitionId,
) -> Option<((usize, PartitionId), AbortHandle)> {
let removed = self.abort_handles.remove(&(task_id, partition));

if self.active_task_count() == 0 {
self.drained.send_replace(());
}

removed
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion ballista/executor/src/executor_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ pub async fn start_executor_process(opt: ExecutorProcessConfig) -> Result<()> {
shutdown_noti.subscribe_for_shutdown(),
)));

let tasks_drained = TasksDrainedFuture(executor);
let tasks_drained = executor.wait_drained();

// Concurrently run the service checking and listen for the `shutdown` signal and wait for the stop request coming.
// The check_services runs until an error is encountered, so under normal circumstances, this `select!` statement runs
Expand Down
4 changes: 4 additions & 0 deletions ballista/scheduler/src/scheduler_server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T
pub fn session_manager(&self) -> SessionManager {
self.state.session_manager.clone()
}

pub async fn wait_drained(&self) {
self.state.task_manager.wait_drained().await;
}
}

pub fn timestamp_secs() -> u64 {
Expand Down
42 changes: 34 additions & 8 deletions ballista/scheduler/src/state/task_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tokio::sync::{watch, RwLock};

use crate::scheduler_server::timestamp_millis;
use tracing::trace;
Expand Down Expand Up @@ -131,6 +131,8 @@ pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
// Cache for active jobs curated by this scheduler
active_job_cache: ActiveJobCache,
launcher: Arc<dyn TaskLauncher>,
drained: Arc<watch::Sender<()>>,
check_drained: watch::Receiver<()>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -165,13 +167,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
codec: BallistaCodec<T, U>,
scheduler_id: String,
) -> Self {
Self {
Self::with_launcher(
state,
codec,
scheduler_id: scheduler_id.clone(),
active_job_cache: Arc::new(DashMap::new()),
launcher: Arc::new(DefaultTaskLauncher::new(scheduler_id)),
}
scheduler_id.clone(),
Arc::new(DefaultTaskLauncher::new(scheduler_id)),
)
}

#[allow(dead_code)]
Expand All @@ -181,12 +182,16 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
scheduler_id: String,
launcher: Arc<dyn TaskLauncher>,
) -> Self {
let (drained, check_drained) = watch::channel(());

Self {
state,
codec,
scheduler_id,
active_job_cache: Arc::new(DashMap::new()),
launcher,
drained: Arc::new(drained),
check_drained,
}
}

Expand Down Expand Up @@ -701,9 +706,16 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
&self,
job_id: &str,
) -> Option<Arc<RwLock<ExecutionGraph>>> {
self.active_job_cache
let removed = self
.active_job_cache
.remove(job_id)
.map(|value| value.1.execution_graph)
.map(|value| value.1.execution_graph);

if self.get_active_job_count() == 0 {
self.drained.send_replace(());
}

removed
}

/// Generate a new random Job ID
Expand Down Expand Up @@ -732,6 +744,20 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
}
});
}

pub async fn wait_drained(&self) {
let mut check_drained = self.check_drained.clone();

loop {
if self.get_active_job_count() == 0 {
break;
}

if check_drained.changed().await.is_err() {
break;
};
}
}
}

pub struct JobOverview {
Expand Down

0 comments on commit a8f1cae

Please sign in to comment.