Skip to content

Commit

Permalink
fix(batch): Clean batch task execution when job finished. (#7791)
Browse files Browse the repository at this point in the history
As title, when job finished, we should clean batch task execution.

Approved-By: BowenXiao1999
  • Loading branch information
liurenjie1024 authored Feb 9, 2023
1 parent 67ef1fc commit 037f51a
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 85 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ description = "Run all streaming e2e tests"
[tasks.slt-batch]
category = "RiseDev - SQLLogicTest"
extend = "slt"
args = ["${@}", "./e2e_test/batch/**/*.slt"]
args = ["${@}", "./e2e_test/batch/*.slt"]
description = "Run all batch e2e tests"

[tasks.slt-generated]
Expand Down
12 changes: 11 additions & 1 deletion grafana/risingwave-dashboard.dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,7 @@ def section_batch_exchange(outer_panels):
panels = outer_panels.sub_panel()
return [
outer_panels.row_collapsed(
"Batch Exchange",
"Batch Metrics",
[
panels.timeseries_row(
"Exchange Recv Row Number",
Expand All @@ -1473,6 +1473,16 @@ def section_batch_exchange(outer_panels):
),
],
),
panels.timeseries_row(
"Batch Mpp Task Number",
"",
[
panels.target(
f"{metric('batch_task_num')}",
"",
),
],
),
],
),
]
Expand Down
2 changes: 1 addition & 1 deletion grafana/risingwave-dashboard.json

Large diffs are not rendered by default.

21 changes: 20 additions & 1 deletion src/batch/src/executor/monitor/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;
use prometheus::core::{AtomicF64, AtomicU64, Collector, Desc, GenericCounterVec, GenericGaugeVec};
use prometheus::{
exponential_buckets, opts, proto, GaugeVec, HistogramOpts, HistogramVec, IntCounterVec,
Registry,
IntGauge, Registry,
};

use crate::task::TaskId;
Expand Down Expand Up @@ -207,3 +207,22 @@ impl BatchTaskMetricsWithTaskLabels {
self.task_labels.iter().map(AsRef::as_ref).collect()
}
}

#[derive(Clone)]
pub struct BatchManagerMetrics {
pub task_num: IntGauge,
}

impl BatchManagerMetrics {
pub fn new(registry: Registry) -> Self {
let task_num = IntGauge::new("batch_task_num", "Number of batch task in memory").unwrap();

registry.register(Box::new(task_num.clone())).unwrap();
Self { task_num }
}

#[cfg(test)]
pub fn for_test() -> Self {
Self::new(Registry::new())
}
}
7 changes: 6 additions & 1 deletion src/batch/src/task/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,13 @@ impl BatchEnvironment {
use risingwave_source::dml_manager::DmlManager;
use risingwave_storage::monitor::MonitoredStorageMetrics;

use crate::executor::monitor::BatchManagerMetrics;

BatchEnvironment {
task_manager: Arc::new(BatchManager::new(BatchConfig::default())),
task_manager: Arc::new(BatchManager::new(
BatchConfig::default(),
BatchManagerMetrics::for_test(),
)),
server_addr: "127.0.0.1:5688".parse().unwrap(),
config: Arc::new(BatchConfig::default()),
worker_id: WorkerNodeId::default(),
Expand Down
42 changes: 17 additions & 25 deletions src/batch/src/task/task_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use tokio::runtime::Runtime;
use tokio::sync::mpsc::Sender;
use tonic::Status;

use crate::executor::BatchManagerMetrics;
use crate::rpc::service::exchange::GrpcExchangeWriter;
use crate::rpc::service::task_service::TaskInfoResponseResult;
use crate::task::{
Expand All @@ -52,10 +53,13 @@ pub struct BatchManager {
/// When each task context report their own usage, it will apply the diff into this total mem
/// value for all tasks.
total_mem_val: Arc<TrAdder<i64>>,

/// Metrics for batch manager.
metrics: BatchManagerMetrics,
}

impl BatchManager {
pub fn new(config: BatchConfig) -> Self {
pub fn new(config: BatchConfig, metrics: BatchManagerMetrics) -> Self {
let runtime = {
let mut builder = tokio::runtime::Builder::new_multi_thread();
if let Some(worker_threads_num) = config.worker_threads_num {
Expand All @@ -75,6 +79,7 @@ impl BatchManager {
runtime: Box::leak(Box::new(runtime)),
config,
total_mem_val: TrAdder::new().into(),
metrics,
}
}

Expand All @@ -95,6 +100,7 @@ impl BatchManager {
// it's possible do not found parent task id in theory.
let ret = if let hash_map::Entry::Vacant(e) = self.tasks.lock().entry(task_id.clone()) {
e.insert(task.clone());
self.metrics.task_num.inc();
Ok(())
} else {
Err(ErrorCode::InternalError(format!(
Expand Down Expand Up @@ -144,25 +150,17 @@ impl BatchManager {

pub fn abort_task(&self, sid: &ProstTaskId) {
let sid = TaskId::from(sid);
match self.tasks.lock().get(&sid) {
Some(task) => task.abort_task(),
match self.tasks.lock().remove(&sid) {
Some(task) => {
task.abort_task();
self.metrics.task_num.dec()
}
None => {
warn!("Task id not found for abort task")
}
};
}

pub fn remove_task(
&self,
sid: &ProstTaskId,
) -> Result<Option<Arc<BatchTaskExecution<ComputeNodeContext>>>> {
let task_id = TaskId::from(sid);
match self.tasks.lock().remove(&task_id) {
Some(t) => Ok(Some(t)),
None => Err(TaskNotFound.into()),
}
}

/// Returns error if task is not running.
pub fn check_if_task_running(&self, task_id: &TaskId) -> Result<()> {
match self.tasks.lock().get(task_id) {
Expand Down Expand Up @@ -253,12 +251,6 @@ impl BatchManager {
}
}

impl Default for BatchManager {
fn default() -> Self {
BatchManager::new(BatchConfig::default())
}
}

#[cfg(test)]
mod tests {
use risingwave_common::config::BatchConfig;
Expand All @@ -275,12 +267,13 @@ mod tests {
use risingwave_pb::expr::TableFunction;
use tonic::Code;

use crate::executor::BatchManagerMetrics;
use crate::task::{BatchManager, ComputeNodeContext, StateReporter, TaskId};

#[test]
fn test_task_not_found() {
use tonic::Status;
let manager = BatchManager::new(BatchConfig::default());
let manager = BatchManager::new(BatchConfig::default(), BatchManagerMetrics::for_test());
let task_id = TaskId {
task_id: 0,
stage_id: 0,
Expand Down Expand Up @@ -308,7 +301,7 @@ mod tests {

#[tokio::test]
async fn test_task_id_conflict() {
let manager = BatchManager::new(BatchConfig::default());
let manager = BatchManager::new(BatchConfig::default(), BatchManagerMetrics::for_test());
let plan = PlanFragment {
root: Some(PlanNode {
children: vec![],
Expand Down Expand Up @@ -356,7 +349,7 @@ mod tests {

#[tokio::test]
async fn test_task_aborted() {
let manager = BatchManager::new(BatchConfig::default());
let manager = BatchManager::new(BatchConfig::default(), BatchManagerMetrics::for_test());
let plan = PlanFragment {
root: Some(PlanNode {
children: vec![],
Expand Down Expand Up @@ -398,7 +391,6 @@ mod tests {
.unwrap();
manager.abort_task(&task_id);
let task_id = TaskId::from(&task_id);
let res = manager.wait_until_task_aborted(&task_id).await;
assert_eq!(res, Ok(()));
assert!(!manager.tasks.lock().contains_key(&task_id));
}
}
8 changes: 6 additions & 2 deletions src/compute/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::time::Duration;

use async_stack_trace::StackTraceManager;
use pretty_bytes::converter::convert;
use risingwave_batch::executor::BatchTaskMetrics;
use risingwave_batch::executor::{BatchManagerMetrics, BatchTaskMetrics};
use risingwave_batch::rpc::service::task_service::BatchServiceImpl;
use risingwave_batch::task::{BatchEnvironment, BatchManager};
use risingwave_common::config::{
Expand Down Expand Up @@ -121,6 +121,7 @@ pub async fn compute_node_serve(
let hummock_metrics = Arc::new(HummockMetrics::new(registry.clone()));
let streaming_metrics = Arc::new(StreamingMetrics::new(registry.clone()));
let batch_task_metrics = Arc::new(BatchTaskMetrics::new(registry.clone()));
let batch_manager_metrics = BatchManagerMetrics::new(registry.clone());
let exchange_srv_metrics = Arc::new(ExchangeServiceMetrics::new(registry.clone()));

// Initialize state store.
Expand Down Expand Up @@ -215,7 +216,10 @@ pub async fn compute_node_serve(
};

// Initialize the managers.
let batch_mgr = Arc::new(BatchManager::new(config.batch.clone()));
let batch_mgr = Arc::new(BatchManager::new(
config.batch.clone(),
batch_manager_metrics,
));
let stream_mgr = Arc::new(LocalStreamManager::new(
advertise_addr.clone(),
state_store.clone(),
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ num-traits = "0.2"
parking_lot = "0.12"
parse-display = "0.6"
paste = "1"
petgraph = "0.6"
pgwire = { path = "../utils/pgwire" }
pin-project-lite = "0.2"
postgres-types = { version = "0.2.4" }
Expand Down
80 changes: 57 additions & 23 deletions src/frontend/src/scheduler/distributed/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ impl QueryRunner {
let has_lookup_join_stage = self.query.has_lookup_join_stage();
// To convince the compiler that `pinned_snapshot` will only be dropped once.
let mut pinned_snapshot_to_drop = Some(pinned_snapshot);

let mut finished_stage_cnt = 0usize;
while let Some(msg_inner) = self.msg_receiver.recv().await {
match msg_inner {
Stage(Scheduled(stage_id)) => {
Expand Down Expand Up @@ -280,20 +282,50 @@ impl QueryRunner {
self.query.query_id, id, reason
);

self.handle_cancel_or_failed_stage(reason).await;
self.clean_all_stages(Some(reason)).await;
// One stage failed, not necessary to execute schedule stages.
break;
}
Stage(StageEvent::Completed(_)) => {
finished_stage_cnt += 1;
assert!(finished_stage_cnt <= self.stage_executions.len());
if finished_stage_cnt == self.stage_executions.len() {
// Now all stages completed, we should remove all
self.clean_all_stages(None).await;
break;
}
}
QueryMessage::CancelQuery => {
self.handle_cancel_or_failed_stage(SchedulerError::QueryCancelError)
self.clean_all_stages(Some(SchedulerError::QueryCancelError))
.await;
// One stage failed, not necessary to execute schedule stages.
break;
}
rest => {
unimplemented!("unsupported message \"{:?}\" for QueryRunner.run", rest);
}
}

// {
// let mut graph = Graph::<String, String>::new();
// let mut stage_id_to_node_id = HashMap::new();
// for stage in &self.stage_executions {
// let node_id = graph.add_node(format!("{} {}", stage.0,
// stage.1.state().await)); stage_id_to_node_id.insert(stage.0,
// node_id); }
//
// for stage in &self.stage_executions {
// let stage_id = stage.0;
// if let Some(child_stages) = self.query.stage_graph.get_child_stages(stage_id)
// { for child_stage in child_stages {
// graph.add_edge(
// stage_id_to_node_id.get(stage_id).unwrap().clone(),
// stage_id_to_node_id.get(child_stage).unwrap().clone(),
// "".to_string(),
// );
// }
// }
// }
// println!("Printing query execution {} states:", self.query.query_id());
// println!("{}", Dot::with_config(&graph, &[Config::EdgeNoLabel]));
// }
}
}

Expand Down Expand Up @@ -346,30 +378,32 @@ impl QueryRunner {

/// Handle ctrl-c query or failed execution. Should stop all executions and send error to query
/// result fetcher.
async fn handle_cancel_or_failed_stage(mut self, reason: SchedulerError) {
let err_str = reason.to_string();
// Consume sender here and send error to root stage.
let root_stage_sender = mem::take(&mut self.root_stage_sender);
// It's possible we receive stage failed event message multi times and the
// sender has been consumed in first failed event.
if let Some(sender) = root_stage_sender {
if let Err(e) = sender.send(Err(reason)) {
warn!("Query execution dropped: {:?}", e);
} else {
debug!(
"Root stage failure event for {:?} sent.",
self.query.query_id
);
async fn clean_all_stages(mut self, error: Option<SchedulerError>) {
let error_msg = error.as_ref().map(|e| e.to_string());
if let Some(reason) = error {
// Consume sender here and send error to root stage.
let root_stage_sender = mem::take(&mut self.root_stage_sender);
// It's possible we receive stage failed event message multi times and the
// sender has been consumed in first failed event.
if let Some(sender) = root_stage_sender {
if let Err(e) = sender.send(Err(reason)) {
warn!("Query execution dropped: {:?}", e);
} else {
debug!(
"Root stage failure event for {:?} sent.",
self.query.query_id
);
}
}
}

// If root stage has been taken (None), then root stage is responsible for send error to
// Query Result Fetcher.
// If root stage has been taken (None), then root stage is responsible for send error to
// Query Result Fetcher.
}

// Stop all running stages.
for stage_execution in self.stage_executions.values() {
// The stop is return immediately so no need to spawn tasks.
stage_execution.stop(err_str.clone()).await;
stage_execution.stop(error_msg.clone()).await;
}
}
}
Expand Down
Loading

0 comments on commit 037f51a

Please sign in to comment.