diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 764546fb68f9..d54d15b80f53 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -34,7 +34,7 @@ snmalloc = ["snmalloc-rs"] [dependencies] arrow = "27.0.0" -datafusion = { path = "../datafusion/core", version = "14.0.0" } +datafusion = { path = "../datafusion/core", version = "14.0.0", features = ["scheduler"] } env_logger = "0.10" futures = "0.3" mimalloc = { version = "0.1", optional = true, default-features = false } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 7be0afcfe5e9..57c9c578096f 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -49,6 +49,8 @@ use datafusion::datasource::file_format::csv::DEFAULT_CSV_EXTENSION; use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; use datafusion::datasource::listing::ListingTableUrl; use datafusion::execution::context::SessionState; +use datafusion::scheduler::Scheduler; +use futures::TryStreamExt; use serde::Serialize; use structopt::StructOpt; @@ -101,6 +103,10 @@ struct DataFusionBenchmarkOpt { /// Whether to disable collection of statistics (and cost based optimizations) or not. #[structopt(short = "S", long = "disable-statistics")] disable_statistics: bool, + + /// Enable scheduler + #[structopt(short = "e", long = "enable-scheduler")] + enable_scheduler: bool, } #[derive(Debug, StructOpt)] @@ -235,14 +241,16 @@ async fn benchmark_query( if query_id == 15 { for (n, query) in sql.iter().enumerate() { if n == 1 { - result = execute_query(&ctx, query, opt.debug).await?; + result = execute_query(&ctx, query, opt.debug, opt.enable_scheduler) + .await?; } else { - execute_query(&ctx, query, opt.debug).await?; + execute_query(&ctx, query, opt.debug, opt.enable_scheduler).await?; } } } else { for query in sql { - result = execute_query(&ctx, query, opt.debug).await?; + result = + execute_query(&ctx, query, opt.debug, opt.enable_scheduler).await?; } } @@ -317,6 +325,7 @@ async fn execute_query( ctx: &SessionContext, sql: &str, debug: bool, + enable_scheduler: bool, ) -> Result> { let plan = ctx.sql(sql).await?; let plan = plan.to_unoptimized_plan(); @@ -337,7 +346,13 @@ async fn execute_query( ); } let task_ctx = ctx.task_ctx(); - let result = collect(physical_plan.clone(), task_ctx).await?; + let result = if enable_scheduler { + let scheduler = Scheduler::new(num_cpus::get()); + let results = scheduler.schedule(physical_plan.clone(), task_ctx).unwrap(); + results.stream().try_collect().await? + } else { + collect(physical_plan.clone(), task_ctx).await? + }; if debug { println!( "=== Physical plan with metrics ===\n{}\n", @@ -813,7 +828,7 @@ mod tests { let sql = &get_query_sql(n)?; for query in sql { - execute_query(&ctx, query, false).await?; + execute_query(&ctx, query, false, false).await?; } Ok(()) @@ -841,6 +856,7 @@ mod ci { mem_table: false, output_path: None, disable_statistics: false, + enable_scheduler: false, }; register_tables(&opt, &ctx).await?; let queries = get_query_sql(query)?; @@ -1153,6 +1169,7 @@ mod ci { mem_table: false, output_path: None, disable_statistics: false, + enable_scheduler: false, }; let mut results = benchmark_datafusion(opt).await?; assert_eq!(results.len(), 1);