diff --git a/Cargo.toml b/Cargo.toml index d030266955d3..5af182e873db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,5 @@ lto = true codegen-units = 1 [patch.crates-io] -arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", branch = "main" } -#arrow2 = { git = "https://github.com/blaze-init/arrow2.git", branch = "shuffle_ipc" } -#parquet2 = { git = "https://github.com/blaze-init/parquet2.git", branch = "meta_new" } +#arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", branch = "main" } +#parquet2 = { git = "https://github.com/jorgecarleitao/parquet2.git", branch = "main" } diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml index d5f7d65d83ef..063ef8ae4831 100644 --- a/ballista-examples/Cargo.toml +++ b/ballista-examples/Cargo.toml @@ -33,6 +33,6 @@ datafusion = { path = "../datafusion" } ballista = { path = "../ballista/rust/client", version = "0.6.0"} prost = "0.9" tonic = "0.6" -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } futures = "0.3" num_cpus = "1.13.0" diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index aa8297f8d06d..4ec1abe77654 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -35,6 +35,7 @@ log = "0.4" tokio = "1.0" tempfile = "3" sqlparser = "0.13" +parking_lot = "0.11" datafusion = { path = "../../../datafusion", version = "6.0.0" } diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 3fb347bddbce..4cd5a219461e 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -17,11 +17,12 @@ //! Distributed execution context. +use parking_lot::Mutex; use sqlparser::ast::Statement; use std::collections::HashMap; use std::fs; use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use ballista_core::config::BallistaConfig; use ballista_core::utils::create_df_ctx_with_ballista_query_planner; @@ -142,7 +143,7 @@ impl BallistaContext { // use local DataFusion context for now but later this might call the scheduler let mut ctx = { - let guard = self.state.lock().unwrap(); + let guard = self.state.lock(); create_df_ctx_with_ballista_query_planner( &guard.scheduler_host, guard.scheduler_port, @@ -162,7 +163,7 @@ impl BallistaContext { // use local DataFusion context for now but later this might call the scheduler let mut ctx = { - let guard = self.state.lock().unwrap(); + let guard = self.state.lock(); create_df_ctx_with_ballista_query_planner( &guard.scheduler_host, guard.scheduler_port, @@ -186,7 +187,7 @@ impl BallistaContext { // use local DataFusion context for now but later this might call the scheduler let mut ctx = { - let guard = self.state.lock().unwrap(); + let guard = self.state.lock(); create_df_ctx_with_ballista_query_planner( &guard.scheduler_host, guard.scheduler_port, @@ -203,7 +204,7 @@ impl BallistaContext { name: &str, table: Arc, ) -> Result<()> { - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock(); state.tables.insert(name.to_owned(), table); Ok(()) } @@ -280,7 +281,7 @@ impl BallistaContext { /// might require the schema to be inferred. pub async fn sql(&self, sql: &str) -> Result> { let mut ctx = { - let state = self.state.lock().unwrap(); + let state = self.state.lock(); create_df_ctx_with_ballista_query_planner( &state.scheduler_host, state.scheduler_port, @@ -291,7 +292,7 @@ impl BallistaContext { let is_show = self.is_show_statement(sql).await?; // the show tables、 show columns sql can not run at scheduler because the tables is store at client if is_show { - let state = self.state.lock().unwrap(); + let state = self.state.lock(); ctx = ExecutionContext::with_config( ExecutionConfig::new().with_information_schema( state.config.default_with_information_schema(), @@ -301,7 +302,7 @@ impl BallistaContext { // register tables with DataFusion context { - let state = self.state.lock().unwrap(); + let state = self.state.lock(); for (name, prov) in &state.tables { ctx.register_table( TableReference::Bare { table: name }, @@ -483,7 +484,7 @@ mod tests { .unwrap(); { - let mut guard = context.state.lock().unwrap(); + let mut guard = context.state.lock(); let csv_table = guard.tables.get("single_nan"); if let Some(table_provide) = csv_table { diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index caa9ca84f12d..cdbbbf064371 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -50,6 +50,8 @@ arrow = { package = "arrow2", version="0.9", features = ["io_ipc", "io_flight"] datafusion = { path = "../../../datafusion", version = "6.0.0" } +parking_lot = "0.11" + [dev-dependencies] tempfile = "3" diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index ea0d15f9e8ef..b70aa5357de4 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -176,11 +176,12 @@ enum AggregateFunction { STDDEV=11; STDDEV_POP=12; CORRELATION=13; + APPROX_PERCENTILE_CONT = 14; } message AggregateExprNode { AggregateFunction aggr_function = 1; - LogicalExprNode expr = 2; + repeated LogicalExprNode expr = 2; } enum BuiltInWindowFunction { diff --git a/ballista/rust/core/src/client.rs b/ballista/rust/core/src/client.rs index 6adaa8c0ac92..9e00b08ef661 100644 --- a/ballista/rust/core/src/client.rs +++ b/ballista/rust/core/src/client.rs @@ -19,7 +19,8 @@ use arrow::io::flight::deserialize_schemas; use arrow::io::ipc::IpcSchema; -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::sync::Arc; use std::{collections::HashMap, pin::Pin}; use std::{ convert::{TryFrom, TryInto}, @@ -164,7 +165,7 @@ impl Stream for FlightDataStream { self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let mut stream = self.stream.lock().expect("mutex is bad"); + let mut stream = self.stream.lock(); stream.poll_next_unpin(cx).map(|x| match x { Some(flight_data_chunk_result) => { let converted_chunk = flight_data_chunk_result diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 2c4b2401b4f3..a108f8adf3d2 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -20,10 +20,11 @@ //! partition is re-partitioned and streamed to disk in Arrow IPC format. Future stages of the query //! will use the ShuffleReaderExec to read these results. +use parking_lot::Mutex; use std::fs::File; use std::iter::{FromIterator, Iterator}; use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Instant; use std::{any::Any, pin::Pin}; diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 32ed6f1c1a4f..6f2dc7508c50 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1066,7 +1066,11 @@ impl TryInto for &protobuf::LogicalExprNode { Ok(Expr::AggregateFunction { fun, - args: vec![parse_required_expr(&expr.expr)?], + args: expr + .expr + .iter() + .map(|e| e.try_into()) + .collect::, _>>()?, distinct: false, //TODO }) } diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 74cf7091faf9..f732f8837710 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -25,17 +25,15 @@ mod roundtrip_tests { use crate::error::BallistaError; use arrow::datatypes::IntegerType; use core::panic; - use datafusion::arrow::datatypes::UnionMode; use datafusion::field_util::SchemaExt; - use datafusion::logical_plan::Repartition; use datafusion::{ - arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}, + arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode}, datasource::object_store::local::LocalFileSystem, logical_plan::{ col, CreateExternalTable, Expr, LogicalPlan, LogicalPlanBuilder, - Partitioning, ToDFSchema, + Partitioning, Repartition, ToDFSchema, }, - physical_plan::functions::BuiltinScalarFunction::Sqrt, + physical_plan::{aggregates, functions::BuiltinScalarFunction::Sqrt}, prelude::*, scalar::ScalarValue, sql::parser::FileType, @@ -1009,4 +1007,17 @@ mod roundtrip_tests { Ok(()) } + + #[test] + fn roundtrip_approx_percentile_cont() -> Result<()> { + let test_expr = Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxPercentileCont, + args: vec![col("bananas"), lit(0.42)], + distinct: false, + }; + + roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr); + + Ok(()) + } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 304d2db1cd83..c76ddf3e53a2 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1130,6 +1130,9 @@ impl TryInto for &Expr { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont + } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, @@ -1155,11 +1158,13 @@ impl TryInto for &Expr { } }; - let arg = &args[0]; - let aggregate_expr = Box::new(protobuf::AggregateExprNode { + let aggregate_expr = protobuf::AggregateExprNode { aggr_function: aggr_function.into(), - expr: Some(Box::new(arg.try_into()?)), - }); + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + }; Ok(protobuf::LogicalExprNode { expr_type: Some(ExprType::AggregateExpr(aggregate_expr)), }) @@ -1390,6 +1395,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, + AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index b2f3db2a6d52..cac91029f3f3 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -130,6 +130,9 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation, + protobuf::AggregateFunction::ApproxPercentileCont => { + AggregateFunction::ApproxPercentileCont + } } } } diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 1986d8114a87..520767a477ff 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -626,7 +626,6 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { let ctx_state = ExecutionContextState { catalog_list, scalar_functions: Default::default(), - var_provider: Default::default(), aggregate_functions: Default::default(), config: ExecutionConfig::new(), execution_props: ExecutionProps::new(), @@ -636,7 +635,7 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { let fun_expr = functions::create_physical_fun( &(&scalar_function).into(), - &ctx_state, + &ctx_state.execution_props, )?; Arc::new(ScalarFunctionExpr::new( diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index a30f1a25d02f..310affdc01f8 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -41,11 +41,12 @@ futures = "0.3" log = "0.4" snmalloc-rs = {version = "0.2", features= ["cache-friendly"], optional = true} tempfile = "3" -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } tokio-stream = { version = "0.1", features = ["net"] } tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } hyper = "0.14.4" +parking_lot = "0.11" [dev-dependencies] diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index 10b3723712da..fdeb7e726d57 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -53,6 +53,7 @@ tokio-stream = { version = "0.1", features = ["net"], optional = true } tonic = "0.6" tower = { version = "0.4" } warp = "0.3" +parking_lot = "0.11" [dev-dependencies] ballista-core = { path = "../core", version = "0.6.0" } diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index f9a8504c7a75..2657bf8d58b8 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -36,7 +36,7 @@ arrow = { package = "arrow2", version="0.9", features = ["io_csv", "io_json", "i datafusion = { path = "../datafusion" } ballista = { path = "../ballista/rust/client" } structopt = { version = "0.3", default-features = false } -tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread"] } +tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } futures = "0.3" env_logger = "0.9" mimalloc = { version = "0.1", optional = true, default-features = false } diff --git a/benchmarks/src/bin/nyctaxi.rs b/benchmarks/src/bin/nyctaxi.rs index b2f18c7c4bb1..0da5f89c5352 100644 --- a/benchmarks/src/bin/nyctaxi.rs +++ b/benchmarks/src/bin/nyctaxi.rs @@ -119,7 +119,7 @@ async fn datafusion_sql_benchmarks( } async fn execute_sql(ctx: &mut ExecutionContext, sql: &str, debug: bool) -> Result<()> { - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let plan = ctx.create_logical_plan(sql)?; let plan = ctx.optimize(&plan)?; if debug { diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 4494bb77977c..71f8f90e8258 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -54,6 +54,8 @@ use arrow::io::parquet::write::{Compression, Version, WriteOptions}; use ballista::prelude::{ BallistaConfig, BallistaContext, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, }; +use datafusion::datasource::file_format::csv::DEFAULT_CSV_EXTENSION; +use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; use datafusion::field_util::SchemaExt; use structopt::StructOpt; @@ -264,7 +266,7 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result { let path = format!("{}/{}", path, table); let format = ParquetFormat::default().with_enable_pruning(true); - (Arc::new(format), path, ".parquet") + (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION) } other => { unimplemented!("Invalid file format '{}'", other); diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 285c8388be36..09df15b57bc6 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -29,7 +29,7 @@ rust-version = "1.58" [dependencies] clap = { version = "3", features = ["derive", "cargo"] } rustyline = "9.0" -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } datafusion = { path = "../datafusion", version = "6.0.0" } arrow = { package = "arrow2", version="0.9", features = ["io_print"] } ballista = { path = "../ballista/rust/client", version = "0.6.0" } diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index b6724ae173f0..a53be0c786e6 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -39,6 +39,6 @@ arrow = { package = "arrow2", version="0.9", features = ["io_ipc", "io_flight"] datafusion = { path = "../datafusion" } prost = "0.9" tonic = "0.6" -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } futures = "0.3" num_cpus = "1.13.0" diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs index 50edc03df85a..a8c9b64650ff 100644 --- a/datafusion-examples/examples/parquet_sql_multiple_files.rs +++ b/datafusion-examples/examples/parquet_sql_multiple_files.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::parquet::{ + ParquetFormat, DEFAULT_PARQUET_EXTENSION, +}; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; @@ -33,7 +35,7 @@ async fn main() -> Result<()> { // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions { - file_extension: ".parquet".to_owned(), + file_extension: DEFAULT_PARQUET_EXTENSION.to_owned(), format: Arc::new(file_format), table_partition_cols: vec![], collect_stat: true, diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index d80bc090a5ec..c37c204005dd 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -63,7 +63,7 @@ chrono = { version = "0.4", default-features = false, features = ["clock"] } async-trait = "0.1.41" futures = "0.3" pin-project-lite= "^0.2.7" -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs"] } +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" log = "^0.4" md-5 = { version = "^0.10.0", optional = true } @@ -79,6 +79,7 @@ rand = "0.8" num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.15", optional = true } tempfile = "3" +parking_lot = "0.11" avro-schema = { version = "0.2", optional = true } # used to print arrow arrays in a nice columnar format @@ -92,6 +93,7 @@ features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc [dev-dependencies] criterion = "0.3" doc-comment = "0.3" +fuzz-utils = { path = "fuzz-utils" } parquet-format-async-temp = "0" [[bench]] diff --git a/datafusion/benches/aggregate_query_sql.rs b/datafusion/benches/aggregate_query_sql.rs index e580f4a63507..2aa2d16c7717 100644 --- a/datafusion/benches/aggregate_query_sql.rs +++ b/datafusion/benches/aggregate_query_sql.rs @@ -23,12 +23,13 @@ use crate::criterion::Criterion; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::sync::Arc; use tokio::runtime::Runtime; fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); - let df = rt.block_on(ctx.lock().unwrap().sql(sql)).unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); criterion::black_box(rt.block_on(df.collect()).unwrap()); } diff --git a/datafusion/benches/math_query_sql.rs b/datafusion/benches/math_query_sql.rs index 0f6a697a808d..b2b62dc03f91 100644 --- a/datafusion/benches/math_query_sql.rs +++ b/datafusion/benches/math_query_sql.rs @@ -19,7 +19,8 @@ extern crate criterion; use criterion::Criterion; -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::sync::Arc; use tokio::runtime::Runtime; @@ -38,7 +39,7 @@ fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query - let df = rt.block_on(ctx.lock().unwrap().sql(sql)).unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); rt.block_on(df.collect()).unwrap(); } diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs index a6bf75e4760c..7fe8e7c1f340 100644 --- a/datafusion/benches/sort_limit_query_sql.rs +++ b/datafusion/benches/sort_limit_query_sql.rs @@ -22,7 +22,8 @@ use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::listing::{ListingOptions, ListingTable}; use datafusion::datasource::object_store::local::LocalFileSystem; -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; @@ -36,7 +37,7 @@ fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query - let df = rt.block_on(ctx.lock().unwrap().sql(sql)).unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); rt.block_on(df.collect()).unwrap(); } @@ -79,18 +80,18 @@ fn create_context() -> Arc> { rt.block_on(async { // create local execution context let mut ctx = ExecutionContext::new(); - ctx.state.lock().unwrap().config.target_partitions = 1; - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + ctx.state.lock().config.target_partitions = 1; + let runtime = ctx.state.lock().runtime_env.clone(); let mem_table = MemTable::load(Arc::new(csv), Some(partitions), runtime) .await .unwrap(); ctx.register_table("aggregate_test_100", Arc::new(mem_table)) .unwrap(); - ctx_holder.lock().unwrap().push(Arc::new(Mutex::new(ctx))) + ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().unwrap().get(0).unwrap().clone(); + let ctx = ctx_holder.lock().get(0).unwrap().clone(); ctx } diff --git a/datafusion/benches/window_query_sql.rs b/datafusion/benches/window_query_sql.rs index bca4a38360fe..dad838eb7f62 100644 --- a/datafusion/benches/window_query_sql.rs +++ b/datafusion/benches/window_query_sql.rs @@ -25,12 +25,13 @@ use crate::criterion::Criterion; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::sync::Arc; use tokio::runtime::Runtime; fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); - let df = rt.block_on(ctx.lock().unwrap().sql(sql)).unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); criterion::black_box(rt.block_on(df.collect()).unwrap()); } diff --git a/datafusion/fuzz-utils/Cargo.toml b/datafusion/fuzz-utils/Cargo.toml new file mode 100644 index 000000000000..cb1e2e942a9e --- /dev/null +++ b/datafusion/fuzz-utils/Cargo.toml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "fuzz-utils" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { package = "arrow2", version="0.9", features = ["io_print"] } +datafusion = { path = ".." } +rand = "0.8" +env_logger = "0.9.0" diff --git a/datafusion/fuzz-utils/src/lib.rs b/datafusion/fuzz-utils/src/lib.rs new file mode 100644 index 000000000000..81da4801f423 --- /dev/null +++ b/datafusion/fuzz-utils/src/lib.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utils for fuzz tests +use arrow::array::Int32Array; +use rand::prelude::StdRng; +use rand::Rng; + +use datafusion::record_batch::RecordBatch; +pub use env_logger; + +/// Extracts the i32 values from the set of batches and returns them as a single Vec +pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { + batches + .iter() + .map(|batch| { + assert_eq!(batch.num_columns(), 1); + batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.copied()) + }) + .flatten() + .collect() +} + +/// extract values from batches and sort them +pub fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec> { + let mut values: Vec<_> = partitions + .iter() + .map(|batches| batches_to_vec(batches).into_iter()) + .flatten() + .collect(); + + values.sort_unstable(); + values +} + +/// Adds a random number of empty record batches into the stream +pub fn add_empty_batches( + batches: Vec, + rng: &mut StdRng, +) -> Vec { + let schema = batches[0].schema().clone(); + + batches + .into_iter() + .map(|batch| { + // insert 0, or 1 empty batches before and after the current batch + let empty_batch = RecordBatch::new_empty(schema.clone()); + std::iter::repeat(empty_batch.clone()) + .take(rng.gen_range(0..2)) + .chain(std::iter::once(batch)) + .chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2))) + }) + .flatten() + .collect() +} diff --git a/datafusion/src/catalog/catalog.rs b/datafusion/src/catalog/catalog.rs index 7dbfa5a80c3e..d5f509f62bcc 100644 --- a/datafusion/src/catalog/catalog.rs +++ b/datafusion/src/catalog/catalog.rs @@ -19,9 +19,10 @@ //! representing collections of named schemas. use crate::catalog::schema::SchemaProvider; +use parking_lot::RwLock; use std::any::Any; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; /// Represent a list of named catalogs pub trait CatalogList: Sync + Send { @@ -75,17 +76,17 @@ impl CatalogList for MemoryCatalogList { name: String, catalog: Arc, ) -> Option> { - let mut catalogs = self.catalogs.write().unwrap(); + let mut catalogs = self.catalogs.write(); catalogs.insert(name, catalog) } fn catalog_names(&self) -> Vec { - let catalogs = self.catalogs.read().unwrap(); + let catalogs = self.catalogs.read(); catalogs.keys().map(|s| s.to_string()).collect() } fn catalog(&self, name: &str) -> Option> { - let catalogs = self.catalogs.read().unwrap(); + let catalogs = self.catalogs.read(); catalogs.get(name).cloned() } } @@ -129,7 +130,7 @@ impl MemoryCatalogProvider { name: impl Into, schema: Arc, ) -> Option> { - let mut schemas = self.schemas.write().unwrap(); + let mut schemas = self.schemas.write(); schemas.insert(name.into(), schema) } } @@ -140,12 +141,12 @@ impl CatalogProvider for MemoryCatalogProvider { } fn schema_names(&self) -> Vec { - let schemas = self.schemas.read().unwrap(); + let schemas = self.schemas.read(); schemas.keys().cloned().collect() } fn schema(&self, name: &str) -> Option> { - let schemas = self.schemas.read().unwrap(); + let schemas = self.schemas.read(); schemas.get(name).cloned() } } diff --git a/datafusion/src/catalog/schema.rs b/datafusion/src/catalog/schema.rs index 1379eb1894eb..877ff8466e36 100644 --- a/datafusion/src/catalog/schema.rs +++ b/datafusion/src/catalog/schema.rs @@ -18,9 +18,10 @@ //! Describes the interface and built-in implementations of schemas, //! representing collections of named tables. +use parking_lot::RwLock; use std::any::Any; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; @@ -91,12 +92,12 @@ impl SchemaProvider for MemorySchemaProvider { } fn table_names(&self) -> Vec { - let tables = self.tables.read().unwrap(); + let tables = self.tables.read(); tables.keys().cloned().collect() } fn table(&self, name: &str) -> Option> { - let tables = self.tables.read().unwrap(); + let tables = self.tables.read(); tables.get(name).cloned() } @@ -111,17 +112,17 @@ impl SchemaProvider for MemorySchemaProvider { name ))); } - let mut tables = self.tables.write().unwrap(); + let mut tables = self.tables.write(); Ok(tables.insert(name, table)) } fn deregister_table(&self, name: &str) -> Result>> { - let mut tables = self.tables.write().unwrap(); + let mut tables = self.tables.write(); Ok(tables.remove(name)) } fn table_exist(&self, name: &str) -> bool { - let tables = self.tables.read().unwrap(); + let tables = self.tables.read(); tables.contains_key(name) } } diff --git a/datafusion/src/datasource/file_format/avro.rs b/datafusion/src/datasource/file_format/avro.rs index bd83e75c4f74..0924ced74cfa 100644 --- a/datafusion/src/datasource/file_format/avro.rs +++ b/datafusion/src/datasource/file_format/avro.rs @@ -33,6 +33,8 @@ use crate::physical_plan::file_format::{AvroExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; +/// The default file extension of avro files +pub const DEFAULT_AVRO_EXTENSION: &str = ".avro"; /// Avro `FileFormat` implementation. #[derive(Default, Debug)] pub struct AvroFormat; diff --git a/datafusion/src/datasource/file_format/csv.rs b/datafusion/src/datasource/file_format/csv.rs index c8897c2f011e..360754223af4 100644 --- a/datafusion/src/datasource/file_format/csv.rs +++ b/datafusion/src/datasource/file_format/csv.rs @@ -35,6 +35,8 @@ use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; +/// The default file extension of csv files +pub const DEFAULT_CSV_EXTENSION: &str = ".csv"; /// Character Separated Value `FileFormat` implementation. #[derive(Debug)] pub struct CsvFormat { diff --git a/datafusion/src/datasource/file_format/json.rs b/datafusion/src/datasource/file_format/json.rs index 5220e6f30fe7..86b28861d9ff 100644 --- a/datafusion/src/datasource/file_format/json.rs +++ b/datafusion/src/datasource/file_format/json.rs @@ -36,6 +36,8 @@ use crate::physical_plan::file_format::NdJsonExec; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; +/// The default file extension of json files +pub const DEFAULT_JSON_EXTENSION: &str = ".json"; /// New line delimited JSON `FileFormat` implementation. #[derive(Debug, Default)] pub struct JsonFormat { diff --git a/datafusion/src/datasource/file_format/parquet.rs b/datafusion/src/datasource/file_format/parquet.rs index 9af9e607dc31..c32f7b2aa9ba 100644 --- a/datafusion/src/datasource/file_format/parquet.rs +++ b/datafusion/src/datasource/file_format/parquet.rs @@ -141,7 +141,7 @@ fn summarize_min_max( if let Some(max_value) = &mut max_values[i] { if let Some(v) = stats.max_value { match max_value.update_batch(&[Arc::new( - arrow::array::$ARRAY_TYPE::from_slice(vec![v]), + arrow::array::$ARRAY_TYPE::from_slice(&[v]), )]) { Ok(_) => {} Err(_) => { @@ -153,7 +153,7 @@ fn summarize_min_max( if let Some(min_value) = &mut min_values[i] { if let Some(v) = stats.min_value { match min_value.update_batch(&[Arc::new( - arrow::array::$ARRAY_TYPE::from_slice(vec![v]), + arrow::array::$ARRAY_TYPE::from_slice(&[v]), )]) { Ok(_) => {} Err(_) => { @@ -180,7 +180,7 @@ fn summarize_min_max( if let Some(max_value) = &mut max_values[i] { if let Some(v) = stats.max_value { match max_value - .update_batch(&[Arc::new(BooleanArray::from_slice(vec![v]))]) + .update_batch(&[Arc::new(BooleanArray::from_slice(&[v]))]) { Ok(_) => {} Err(_) => { @@ -192,7 +192,7 @@ fn summarize_min_max( if let Some(min_value) = &mut min_values[i] { if let Some(v) = stats.min_value { match min_value - .update_batch(&[Arc::new(BooleanArray::from_slice(vec![v]))]) + .update_batch(&[Arc::new(BooleanArray::from_slice(&[v]))]) { Ok(_) => {} Err(_) => { @@ -262,7 +262,7 @@ fn summarize_min_max( } /// Read and parse the schema of the Parquet file at location `path` -fn fetch_schema(object_reader: Arc) -> Result { +pub fn fetch_schema(object_reader: Arc) -> Result { let mut reader = object_reader.sync_reader()?; let meta_data = read_metadata(&mut reader)?; let schema = get_schema(&meta_data)?; diff --git a/datafusion/src/datasource/listing/table.rs b/datafusion/src/datasource/listing/table.rs index b3a7122cf1ae..bda5ec996c4f 100644 --- a/datafusion/src/datasource/listing/table.rs +++ b/datafusion/src/datasource/listing/table.rs @@ -267,6 +267,8 @@ impl ListingTable { mod tests { use arrow::datatypes::DataType; + use crate::datasource::file_format::avro::DEFAULT_AVRO_EXTENSION; + use crate::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; use crate::{ datasource::{ file_format::{avro::AvroFormat, parquet::ParquetFormat}, @@ -319,7 +321,7 @@ mod tests { let store = TestObjectStore::new_arc(&[("table/p1=v1/file.avro", 100)]); let opt = ListingOptions { - file_extension: ".avro".to_owned(), + file_extension: DEFAULT_AVRO_EXTENSION.to_owned(), format: Arc::new(AvroFormat {}), table_partition_cols: vec![String::from("p1")], target_partitions: 4, @@ -420,7 +422,7 @@ mod tests { let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, name); let opt = ListingOptions { - file_extension: "parquet".to_owned(), + file_extension: DEFAULT_PARQUET_EXTENSION.to_owned(), format: Arc::new(ParquetFormat::default()), table_partition_cols: vec![], target_partitions: 2, diff --git a/datafusion/src/datasource/object_store/mod.rs b/datafusion/src/datasource/object_store/mod.rs index c581b171a57b..65f22e74c0e1 100644 --- a/datafusion/src/datasource/object_store/mod.rs +++ b/datafusion/src/datasource/object_store/mod.rs @@ -19,11 +19,12 @@ pub mod local; +use parking_lot::RwLock; use std::collections::HashMap; use std::fmt::{self, Debug}; use std::io::{Read, Seek}; use std::pin::Pin; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use async_trait::async_trait; use chrono::{DateTime, Utc}; @@ -178,12 +179,7 @@ impl fmt::Debug for ObjectStoreRegistry { f.debug_struct("ObjectStoreRegistry") .field( "schemes", - &self - .object_stores - .read() - .unwrap() - .keys() - .collect::>(), + &self.object_stores.read().keys().collect::>(), ) .finish() } @@ -214,13 +210,13 @@ impl ObjectStoreRegistry { scheme: String, store: Arc, ) -> Option> { - let mut stores = self.object_stores.write().unwrap(); + let mut stores = self.object_stores.write(); stores.insert(scheme, store) } /// Get the store registered for scheme pub fn get(&self, scheme: &str) -> Option> { - let stores = self.object_stores.read().unwrap(); + let stores = self.object_stores.read(); stores.get(scheme).cloned() } @@ -234,7 +230,7 @@ impl ObjectStoreRegistry { uri: &'a str, ) -> Result<(Arc, &'a str)> { if let Some((scheme, path)) = uri.split_once("://") { - let stores = self.object_stores.read().unwrap(); + let stores = self.object_stores.read(); let store = stores .get(&*scheme.to_lowercase()) .map(Clone::clone) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 2e70962c8360..9aa2b476bc9f 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -24,8 +24,8 @@ use crate::{ datasource::listing::{ListingOptions, ListingTable}, datasource::{ file_format::{ - avro::AvroFormat, - csv::CsvFormat, + avro::{AvroFormat, DEFAULT_AVRO_EXTENSION}, + csv::{CsvFormat, DEFAULT_CSV_EXTENSION}, parquet::{ParquetFormat, DEFAULT_PARQUET_EXTENSION}, FileFormat, }, @@ -39,13 +39,11 @@ use crate::{ }, }; use log::debug; +use parking_lot::Mutex; +use std::collections::{HashMap, HashSet}; use std::path::Path; use std::string::String; use std::sync::Arc; -use std::{ - collections::{HashMap, HashSet}, - sync::Mutex, -}; use std::{fs, path::PathBuf}; use futures::{StreamExt, TryStreamExt}; @@ -195,7 +193,6 @@ impl ExecutionContext { state: Arc::new(Mutex::new(ExecutionContextState { catalog_list, scalar_functions: HashMap::new(), - var_provider: HashMap::new(), aggregate_functions: HashMap::new(), config, execution_props: ExecutionProps::new(), @@ -207,7 +204,7 @@ impl ExecutionContext { /// Return the [RuntimeEnv] used to run queries with this [ExecutionContext] pub fn runtime_env(&self) -> Arc { - self.state.lock().unwrap().runtime_env.clone() + self.state.lock().runtime_env.clone() } /// Creates a dataframe that will execute a SQL query. @@ -224,17 +221,20 @@ impl ExecutionContext { ref file_type, ref has_header, }) => { - let file_format = match file_type { - FileType::CSV => { - Ok(Arc::new(CsvFormat::default().with_has_header(*has_header)) - as Arc) - } - FileType::Parquet => { - Ok(Arc::new(ParquetFormat::default()) as Arc) - } - FileType::Avro => { - Ok(Arc::new(AvroFormat::default()) as Arc) - } + let (file_format, file_extension) = match file_type { + FileType::CSV => Ok(( + Arc::new(CsvFormat::default().with_has_header(*has_header)) + as Arc, + DEFAULT_CSV_EXTENSION, + )), + FileType::Parquet => Ok(( + Arc::new(ParquetFormat::default()) as Arc, + DEFAULT_PARQUET_EXTENSION, + )), + FileType::Avro => Ok(( + Arc::new(AvroFormat::default()) as Arc, + DEFAULT_AVRO_EXTENSION, + )), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported file type {:?}.", file_type @@ -244,13 +244,8 @@ impl ExecutionContext { let options = ListingOptions { format: file_format, collect_stat: false, - file_extension: String::new(), - target_partitions: self - .state - .lock() - .unwrap() - .config - .target_partitions, + file_extension: file_extension.to_owned(), + target_partitions: self.state.lock().config.target_partitions, table_partition_cols: vec![], }; @@ -315,7 +310,7 @@ impl ExecutionContext { } // create a query planner - let state = self.state.lock().unwrap().clone(); + let state = self.state.lock().clone(); let query_planner = SqlToRel::new(&state); query_planner.statement_to_plan(&statements[0]) } @@ -328,9 +323,8 @@ impl ExecutionContext { ) { self.state .lock() - .unwrap() - .var_provider - .insert(variable_type, provider); + .execution_props + .add_var_provider(variable_type, provider); } /// Registers a scalar UDF within this context. @@ -343,7 +337,6 @@ impl ExecutionContext { pub fn register_udf(&mut self, f: ScalarUDF) { self.state .lock() - .unwrap() .scalar_functions .insert(f.name.clone(), Arc::new(f)); } @@ -358,7 +351,6 @@ impl ExecutionContext { pub fn register_udaf(&mut self, f: AggregateUDF) { self.state .lock() - .unwrap() .aggregate_functions .insert(f.name.clone(), Arc::new(f)); } @@ -372,7 +364,7 @@ impl ExecutionContext { ) -> Result> { let uri: String = uri.into(); let (object_store, path) = self.object_store(&uri)?; - let target_partitions = self.state.lock().unwrap().config.target_partitions; + let target_partitions = self.state.lock().config.target_partitions; Ok(Arc::new(DataFrameImpl::new( self.state.clone(), &LogicalPlanBuilder::scan_avro( @@ -403,7 +395,7 @@ impl ExecutionContext { ) -> Result> { let uri: String = uri.into(); let (object_store, path) = self.object_store(&uri)?; - let target_partitions = self.state.lock().unwrap().config.target_partitions; + let target_partitions = self.state.lock().config.target_partitions; Ok(Arc::new(DataFrameImpl::new( self.state.clone(), &LogicalPlanBuilder::scan_csv( @@ -425,7 +417,7 @@ impl ExecutionContext { ) -> Result> { let uri: String = uri.into(); let (object_store, path) = self.object_store(&uri)?; - let target_partitions = self.state.lock().unwrap().config.target_partitions; + let target_partitions = self.state.lock().config.target_partitions; let logical_plan = LogicalPlanBuilder::scan_parquet(object_store, path, None, target_partitions) .await? @@ -480,8 +472,8 @@ impl ExecutionContext { uri: &str, options: CsvReadOptions<'_>, ) -> Result<()> { - let listing_options = options - .to_listing_options(self.state.lock().unwrap().config.target_partitions); + let listing_options = + options.to_listing_options(self.state.lock().config.target_partitions); self.register_listing_table( name, @@ -498,7 +490,7 @@ impl ExecutionContext { /// executed against this context. pub async fn register_parquet(&mut self, name: &str, uri: &str) -> Result<()> { let (target_partitions, enable_pruning) = { - let m = self.state.lock().unwrap(); + let m = self.state.lock(); (m.config.target_partitions, m.config.parquet_pruning) }; let file_format = ParquetFormat::default().with_enable_pruning(enable_pruning); @@ -524,8 +516,8 @@ impl ExecutionContext { uri: &str, options: AvroReadOptions<'_>, ) -> Result<()> { - let listing_options = options - .to_listing_options(self.state.lock().unwrap().config.target_partitions); + let listing_options = + options.to_listing_options(self.state.lock().config.target_partitions); self.register_listing_table(name, uri, listing_options, options.schema) .await?; @@ -545,7 +537,7 @@ impl ExecutionContext { ) -> Option> { let name = name.into(); - let state = self.state.lock().unwrap(); + let state = self.state.lock(); let catalog = if state.config.information_schema { Arc::new(CatalogWithInformationSchema::new( Arc::downgrade(&state.catalog_list), @@ -560,7 +552,7 @@ impl ExecutionContext { /// Retrieves a `CatalogProvider` instance by name pub fn catalog(&self, name: &str) -> Option> { - self.state.lock().unwrap().catalog_list.catalog(name) + self.state.lock().catalog_list.catalog(name) } /// Registers a object store with scheme using a custom `ObjectStore` so that @@ -576,7 +568,6 @@ impl ExecutionContext { self.state .lock() - .unwrap() .object_store_registry .register_store(scheme, object_store) } @@ -588,7 +579,6 @@ impl ExecutionContext { ) -> Result<(Arc, &'a str)> { self.state .lock() - .unwrap() .object_store_registry .get_by_uri(uri) .map_err(DataFusionError::from) @@ -608,7 +598,6 @@ impl ExecutionContext { let table_ref = table_ref.into(); self.state .lock() - .unwrap() .schema_for_ref(table_ref)? .register_table(table_ref.table().to_owned(), provider) } @@ -623,7 +612,6 @@ impl ExecutionContext { let table_ref = table_ref.into(); self.state .lock() - .unwrap() .schema_for_ref(table_ref)? .deregister_table(table_ref.table()) } @@ -637,7 +625,7 @@ impl ExecutionContext { table_ref: impl Into>, ) -> Result> { let table_ref = table_ref.into(); - let schema = self.state.lock().unwrap().schema_for_ref(table_ref)?; + let schema = self.state.lock().schema_for_ref(table_ref)?; match schema.table(table_ref.table()) { Some(ref provider) => { let plan = LogicalPlanBuilder::scan( @@ -667,7 +655,6 @@ impl ExecutionContext { Ok(self .state .lock() - .unwrap() // a bare reference will always resolve to the default catalog and schema .schema_for_ref(TableReference::Bare { table: "" })? .table_names() @@ -706,7 +693,7 @@ impl ExecutionContext { logical_plan: &LogicalPlan, ) -> Result> { let (state, planner) = { - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock(); state.execution_props.start_execution(); // We need to clone `state` to release the lock that is not `Send`. We could @@ -877,7 +864,7 @@ impl ExecutionContext { where F: FnMut(&LogicalPlan, &dyn OptimizerRule), { - let state = &mut self.state.lock().unwrap(); + let state = &mut self.state.lock(); let execution_props = &mut state.execution_props.clone(); let optimizers = &state.config.optimizers; @@ -902,15 +889,15 @@ impl From>> for ExecutionContext { impl FunctionRegistry for ExecutionContext { fn udfs(&self) -> HashSet { - self.state.lock().unwrap().udfs() + self.state.lock().udfs() } fn udf(&self, name: &str) -> Result> { - self.state.lock().unwrap().udf(name) + self.state.lock().udf(name) } fn udaf(&self, name: &str) -> Result> { - self.state.lock().unwrap().udaf(name) + self.state.lock().udaf(name) } } @@ -1179,9 +1166,14 @@ impl ExecutionConfig { /// An instance of this struct is created each time a [`LogicalPlan`] is prepared for /// execution (optimized). If the same plan is optimized multiple times, a new /// `ExecutionProps` is created each time. +/// +/// It is important that this structure be cheap to create as it is +/// done so during predicate pruning and expression simplification #[derive(Clone)] pub struct ExecutionProps { pub(crate) query_execution_start_time: DateTime, + /// providers for scalar variables + pub var_providers: Option>>, } impl Default for ExecutionProps { @@ -1195,6 +1187,7 @@ impl ExecutionProps { pub fn new() -> Self { ExecutionProps { query_execution_start_time: chrono::Utc::now(), + var_providers: None, } } @@ -1203,6 +1196,32 @@ impl ExecutionProps { self.query_execution_start_time = chrono::Utc::now(); &*self } + + /// Registers a variable provider, returning the existing + /// provider, if any + pub fn add_var_provider( + &mut self, + var_type: VarType, + provider: Arc, + ) -> Option> { + let mut var_providers = self.var_providers.take().unwrap_or_default(); + + let old_provider = var_providers.insert(var_type, provider); + + self.var_providers = Some(var_providers); + + old_provider + } + + /// Returns the provider for the var_type, if any + pub fn get_var_provider( + &self, + var_type: VarType, + ) -> Option> { + self.var_providers + .as_ref() + .and_then(|var_providers| var_providers.get(&var_type).map(Arc::clone)) + } } /// Execution context for registering data sources and executing queries @@ -1212,8 +1231,6 @@ pub struct ExecutionContextState { pub catalog_list: Arc, /// Scalar functions that are registered with the context pub scalar_functions: HashMap>, - /// Variable provider that are registered with the context - pub var_provider: HashMap>, /// Aggregate functions registered in the context pub aggregate_functions: HashMap>, /// Context configuration @@ -1238,7 +1255,6 @@ impl ExecutionContextState { ExecutionContextState { catalog_list: Arc::new(MemoryCatalogList::new()), scalar_functions: HashMap::new(), - var_provider: HashMap::new(), aggregate_functions: HashMap::new(), config: ExecutionConfig::new(), execution_props: ExecutionProps::new(), @@ -1342,7 +1358,7 @@ mod tests { logical_plan::{col, create_udf, sum, Expr}, }; use crate::{ - datasource::{empty::EmptyTable, MemTable, TableType}, + datasource::{empty::EmptyTable, MemTable}, logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator, }; @@ -1547,7 +1563,7 @@ mod tests { let physical_plan = ctx.create_physical_plan(&logical_plan).await?; - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect_partitioned(physical_plan, runtime).await?; // note that the order of partitions is not deterministic @@ -1596,7 +1612,7 @@ mod tests { let tmp_dir = TempDir::new()?; let partition_count = 4; let ctx = create_ctx(&tmp_dir, partition_count).await?; - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let table = ctx.table("test")?; let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) @@ -1704,7 +1720,7 @@ mod tests { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("b", physical_plan.schema().field(0).name()); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let batches = collect(physical_plan, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -2333,121 +2349,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn aggregate_timestamps_sum() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t", - ) - .await - .unwrap_err(); - - assert_eq!(results.to_string(), "Error during planning: The function Sum does not support inputs of type Timestamp(Nanosecond, None)."); - - Ok(()) - } - - #[tokio::test] - async fn aggregate_timestamps_count() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT count(nanos), count(micros), count(millis), count(secs) FROM t", - ) - .await - .unwrap(); - - let expected = vec![ - "+----------------+-----------------+-----------------+---------------+", - "| COUNT(t.nanos) | COUNT(t.micros) | COUNT(t.millis) | COUNT(t.secs) |", - "+----------------+-----------------+-----------------+---------------+", - "| 3 | 3 | 3 | 3 |", - "+----------------+-----------------+-----------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - - #[tokio::test] - async fn aggregate_timestamps_min() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT min(nanos), min(micros), min(millis), min(secs) FROM t", - ) - .await - .unwrap(); - - let expected = vec![ - "+----------------------------+----------------------------+-------------------------+---------------------+", - "| MIN(t.nanos) | MIN(t.micros) | MIN(t.millis) | MIN(t.secs) |", - "+----------------------------+----------------------------+-------------------------+---------------------+", - "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 |", - "+----------------------------+----------------------------+-------------------------+---------------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - - #[tokio::test] - async fn aggregate_timestamps_max() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT max(nanos), max(micros), max(millis), max(secs) FROM t", - ) - .await - .unwrap(); - - let expected = vec![ - "+-------------------------+-------------------------+-------------------------+---------------------+", - "| MAX(t.nanos) | MAX(t.micros) | MAX(t.millis) | MAX(t.secs) |", - "+-------------------------+-------------------------+-------------------------+---------------------+", - "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 |", - "+-------------------------+-------------------------+-------------------------+---------------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - - #[tokio::test] - async fn aggregate_timestamps_avg() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t", - ) - .await - .unwrap_err(); - - assert_eq!(results.to_string(), "Error during planning: The function Avg does not support inputs of type Timestamp(Nanosecond, None)."); - Ok(()) - } - #[tokio::test] async fn aggregate_avg_add() -> Result<()> { let results = execute( @@ -2486,56 +2387,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn join_timestamp() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let expected = vec![ - "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", - "| nanos | micros | millis | secs | name | nanos | micros | millis | secs | name |", - "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", - "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 | Row 1 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 | Row 1 |", - "| 2018-11-13 17:11:10.011375885 | 2018-11-13 17:11:10.011375 | 2018-11-13 17:11:10.011 | 2018-11-13 17:11:10 | Row 0 | 2018-11-13 17:11:10.011375885 | 2018-11-13 17:11:10.011375 | 2018-11-13 17:11:10.011 | 2018-11-13 17:11:10 | Row 0 |", - "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 | Row 3 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 | Row 3 |", - "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", - ]; - - let results = plan_and_collect( - &mut ctx, - "SELECT * FROM t as t1 \ - JOIN (SELECT * FROM t) as t2 \ - ON t1.nanos = t2.nanos", - ) - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - let results = plan_and_collect( - &mut ctx, - "SELECT * FROM t as t1 \ - JOIN (SELECT * FROM t) as t2 \ - ON t1.micros = t2.micros", - ) - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - let results = plan_and_collect( - &mut ctx, - "SELECT * FROM t as t1 \ - JOIN (SELECT * FROM t) as t2 \ - ON t1.millis = t2.millis", - ) - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - #[tokio::test] async fn count_basic() -> Result<()> { let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1).await?; @@ -3480,7 +3331,7 @@ mod tests { let plan = ctx.optimize(&plan)?; let plan = ctx.create_physical_plan(&plan).await?; - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let result = collect(plan, runtime).await?; let expected = vec![ @@ -3617,476 +3468,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn information_schema_tables_not_exist_by_default() { - let mut ctx = ExecutionContext::new(); - - let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Table or CTE with name 'information_schema.tables' not found" - ); - } - - #[tokio::test] - async fn information_schema_tables_no_tables() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "+---------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_tables_tables_default_catalog() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - // Now, register an empty table - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+------------+", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | public | t | BASE TABLE |", - "+---------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - // Newly added tables should appear - ctx.register_table("t2", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | public | t | BASE TABLE |", - "| datafusion | public | t2 | BASE TABLE |", - "+---------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_tables_tables_with_multiple_catalogs() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - schema - .register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - schema - .register_table("t2".to_owned(), test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - catalog.register_schema("my_schema", Arc::new(schema)); - ctx.register_catalog("my_catalog", Arc::new(catalog)); - - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - schema - .register_table("t3".to_owned(), test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - catalog.register_schema("my_other_schema", Arc::new(schema)); - ctx.register_catalog("my_other_catalog", Arc::new(catalog)); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+------------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+------------------+--------------------+------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| my_catalog | information_schema | columns | VIEW |", - "| my_catalog | information_schema | tables | VIEW |", - "| my_catalog | my_schema | t1 | BASE TABLE |", - "| my_catalog | my_schema | t2 | BASE TABLE |", - "| my_other_catalog | information_schema | columns | VIEW |", - "| my_other_catalog | information_schema | tables | VIEW |", - "| my_other_catalog | my_other_schema | t3 | BASE TABLE |", - "+------------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_tables_table_types() { - struct TestTable(TableType); - - #[async_trait] - impl TableProvider for TestTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn table_type(&self) -> TableType { - self.0 - } - - fn schema(&self) -> SchemaRef { - unimplemented!() - } - - async fn scan( - &self, - _: &Option>, - _: &[Expr], - _: Option, - ) -> Result> { - unimplemented!() - } - } - - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("physical", Arc::new(TestTable(TableType::Base))) - .unwrap(); - ctx.register_table("query", Arc::new(TestTable(TableType::View))) - .unwrap(); - ctx.register_table("temp", Arc::new(TestTable(TableType::Temporary))) - .unwrap(); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+-----------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+-----------------+", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | public | physical | BASE TABLE |", - "| datafusion | public | query | VIEW |", - "| datafusion | public | temp | LOCAL TEMPORARY |", - "+---------------+--------------------+------------+-----------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_show_tables_no_information_schema() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - // use show tables alias - let err = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap_err(); - - assert_eq!(err.to_string(), "Error during planning: SHOW TABLES is not supported unless information_schema is enabled"); - } - - #[tokio::test] - async fn information_schema_show_tables() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - // use show tables alias - let result = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | public | t | BASE TABLE |", - "+---------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - let result = plan_and_collect(&mut ctx, "SHOW tables").await.unwrap(); - - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_show_columns_no_information_schema() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") - .await - .unwrap_err(); - - assert_eq!(err.to_string(), "Error during planning: SHOW COLUMNS is not supported unless information_schema is enabled"); - } - - #[tokio::test] - async fn information_schema_show_columns_like_where() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let expected = - "Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported"; - - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t LIKE 'f'") - .await - .unwrap_err(); - assert_eq!(err.to_string(), expected); - - let err = - plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t WHERE column_name = 'bar'") - .await - .unwrap_err(); - assert_eq!(err.to_string(), expected); - } - - #[tokio::test] - async fn information_schema_show_columns() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------+------------+-------------+-----------+-------------+", - "| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |", - "+---------------+--------------+------------+-------------+-----------+-------------+", - "| datafusion | public | t | i | Int32 | YES |", - "+---------------+--------------+------------+-------------+-----------+-------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - let result = plan_and_collect(&mut ctx, "SHOW columns from t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &result); - - // This isn't ideal but it is consistent behavior for `SELECT * from T` - let err = plan_and_collect(&mut ctx, "SHOW columns from T") - .await - .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Unknown relation for SHOW COLUMNS: T" - ); - } - - // test errors with WHERE and LIKE - #[tokio::test] - async fn information_schema_show_columns_full_extended() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = plan_and_collect(&mut ctx, "SHOW FULL COLUMNS FROM t") - .await - .unwrap(); - let expected = vec![ - "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", - "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| datafusion | public | t | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", - "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - let result = plan_and_collect(&mut ctx, "SHOW EXTENDED COLUMNS FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_show_table_table_names() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM public.t") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------+------------+-------------+-----------+-------------+", - "| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |", - "+---------------+--------------+------------+-------------+-----------+-------------+", - "| datafusion | public | t | i | Int32 | YES |", - "+---------------+--------------+------------+-------------+-----------+-------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - let result = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &result); - - let err = plan_and_collect(&mut ctx, "SHOW columns from t2") - .await - .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Unknown relation for SHOW COLUMNS: t2" - ); - - let err = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t2") - .await - .unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Unknown relation for SHOW COLUMNS: datafusion.public.t2"); - } - - #[tokio::test] - async fn show_unsupported() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - - let err = plan_and_collect(&mut ctx, "SHOW SOMETHING_UNKNOWN") - .await - .unwrap_err(); - - assert_eq!(err.to_string(), "This feature is not implemented: SHOW SOMETHING_UNKNOWN not implemented. Supported syntax: SHOW "); - } - - #[tokio::test] - async fn information_schema_columns_not_exist_by_default() { - let mut ctx = ExecutionContext::new(); - - let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") - .await - .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Table or CTE with name 'information_schema.columns' not found" - ); - } - - fn table_with_many_types() -> Arc { - let schema = Schema::new(vec![ - Field::new("int32_col", DataType::Int32, false), - Field::new("float64_col", DataType::Float64, true), - Field::new("utf8_col", DataType::Utf8, true), - Field::new("large_utf8_col", DataType::LargeUtf8, false), - Field::new("binary_col", DataType::Binary, false), - Field::new("large_binary_col", DataType::LargeBinary, false), - Field::new( - "timestamp_nanos", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from_slice(&[1])), - Arc::new(Float64Array::from_slice(&[1.0])), - Arc::new(Utf8Array::::from(&[Some("foo")])), - Arc::new(Utf8Array::::from(&[Some("bar")])), - Arc::new(BinaryArray::::from_slice(&[b"foo" as &[u8]])), - Arc::new(BinaryArray::::from_slice(&[b"foo" as &[u8]])), - Arc::new( - Int64Array::from(&[Some(123)]) - .to(DataType::Timestamp(TimeUnit::Nanosecond, None)), - ), - ], - ) - .unwrap(); - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); - Arc::new(provider) - } - - #[tokio::test] - async fn information_schema_columns() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - - schema - .register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - schema - .register_table("t2".to_owned(), table_with_many_types()) - .unwrap(); - catalog.register_schema("my_schema", Arc::new(schema)); - ctx.register_catalog("my_catalog", Arc::new(catalog)); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| my_catalog | my_schema | t1 | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", - "| my_catalog | my_schema | t2 | binary_col | 4 | | NO | Binary | | 2147483647 | | | | | |", - "| my_catalog | my_schema | t2 | float64_col | 1 | | YES | Float64 | | | 24 | 2 | | | |", - "| my_catalog | my_schema | t2 | int32_col | 0 | | NO | Int32 | | | 32 | 2 | | | |", - "| my_catalog | my_schema | t2 | large_binary_col | 5 | | NO | LargeBinary | | 9223372036854775807 | | | | | |", - "| my_catalog | my_schema | t2 | large_utf8_col | 3 | | NO | LargeUtf8 | | 9223372036854775807 | | | | | |", - "| my_catalog | my_schema | t2 | timestamp_nanos | 6 | | NO | Timestamp(Nanosecond, None) | | | | | | | |", - "| my_catalog | my_schema | t2 | utf8_col | 2 | | YES | Utf8 | | 2147483647 | | | | | |", - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - #[tokio::test] async fn disabled_default_catalog_and_schema() -> Result<()> { let mut ctx = ExecutionContext::with_config( @@ -4256,7 +3637,7 @@ mod tests { ctx.register_catalog("my_catalog", catalog); let catalog_list_weak = { - let state = ctx.state.lock().unwrap(); + let state = ctx.state.lock(); Arc::downgrade(&state.catalog_list) }; @@ -4300,8 +3681,8 @@ mod tests { }; // create mock record batch - let ids = Arc::new(Int32Array::from_slice(vec![i as i32])); - let names = Arc::new(Utf8Array::::from_slice(vec!["test"])); + let ids = Arc::new(Int32Array::from_slice(&[i as i32])); + let names = Arc::new(Utf8Array::::from_slice(&["test"])); let schema_ref = schema.as_ref(); let parquet_schema = to_parquet_schema(schema_ref).unwrap(); let iter = vec![Ok(Chunk::new(vec![ids as ArrayRef, names as ArrayRef]))]; diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index f097ca9bf3a3..1ad95950cdd7 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -17,8 +17,12 @@ //! Implementation of DataFrame API. -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::any::Any; +use std::sync::Arc; +use crate::arrow::datatypes::Schema; +use crate::arrow::datatypes::SchemaRef; use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; use crate::logical_plan::{ @@ -26,11 +30,16 @@ use crate::logical_plan::{ Partitioning, }; use crate::record_batch::RecordBatch; + +use crate::scalar::ScalarValue; use crate::{ dataframe::*, physical_plan::{collect, collect_partitioned}, }; +use crate::datasource::TableProvider; +use crate::datasource::TableType; +use crate::field_util::{FieldExt, SchemaExt}; use crate::physical_plan::{ execute_stream, execute_stream_partitioned, ExecutionPlan, SendableRecordBatchStream, }; @@ -54,13 +63,66 @@ impl DataFrameImpl { /// Create a physical plan async fn create_physical_plan(&self) -> Result> { - let state = self.ctx_state.lock().unwrap().clone(); + let state = self.ctx_state.lock().clone(); let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); let plan = ctx.optimize(&self.plan)?; ctx.create_physical_plan(&plan).await } } +#[async_trait] +impl TableProvider for DataFrameImpl { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + let schema: Schema = self.plan.schema().as_ref().into(); + Arc::new(schema) + } + + fn table_type(&self) -> TableType { + TableType::View + } + + async fn scan( + &self, + projection: &Option>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let expr = projection + .as_ref() + // construct projections + .map_or_else( + || Ok(Arc::new(Self::new(self.ctx_state.clone(), &self.plan)) as Arc<_>), + |projection| { + let schema = TableProvider::schema(self).project(projection)?; + let names = schema + .fields() + .iter() + .map(|field| field.name()) + .collect::>(); + self.select_columns(names.as_slice()) + }, + )? + // add predicates, otherwise use `true` as the predicate + .filter(filters.iter().cloned().fold( + Expr::Literal(ScalarValue::Boolean(Some(true))), + |acc, new| acc.and(new), + ))?; + // add a limit if given + Self::new( + self.ctx_state.clone(), + &limit + .map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))? + .to_logical_plan(), + ) + .create_physical_plan() + .await + } +} + #[async_trait] impl DataFrame for DataFrameImpl { /// Apply a projection based on a list of column names @@ -161,7 +223,7 @@ impl DataFrame for DataFrameImpl { /// execute it, collecting all resulting batches into memory async fn collect(&self) -> Result> { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx_state.lock().runtime_env.clone(); Ok(collect(plan, runtime).await?) } @@ -183,7 +245,7 @@ impl DataFrame for DataFrameImpl { /// execute it, returning a stream over a single partition async fn execute_stream(&self) -> Result { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx_state.lock().runtime_env.clone(); execute_stream(plan, runtime).await } @@ -192,7 +254,7 @@ impl DataFrame for DataFrameImpl { /// partitioning async fn collect_partitioned(&self) -> Result>> { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx_state.lock().runtime_env.clone(); Ok(collect_partitioned(plan, runtime).await?) } @@ -200,7 +262,7 @@ impl DataFrame for DataFrameImpl { /// execute it, returning a stream for each partition async fn execute_stream_partitioned(&self) -> Result> { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx_state.lock().runtime_env.clone(); Ok(execute_stream_partitioned(plan, runtime).await?) } @@ -217,7 +279,7 @@ impl DataFrame for DataFrameImpl { } fn registry(&self) -> Arc { - let registry = self.ctx_state.lock().unwrap().clone(); + let registry = self.ctx_state.lock().clone(); Arc::new(registry) } @@ -489,6 +551,61 @@ mod tests { Ok(()) } + #[tokio::test] + async fn register_table() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c12"])?; + let mut ctx = ExecutionContext::new(); + let df_impl = + Arc::new(DataFrameImpl::new(ctx.state.clone(), &df.to_logical_plan())); + + // register a dataframe as a table + ctx.register_table("test_table", df_impl.clone())?; + + // pull the table out + let table = ctx.table("test_table")?; + + let group_expr = vec![col("c1")]; + let aggr_expr = vec![sum(col("c12"))]; + + // check that we correctly read from the table + let df_results = &df_impl + .aggregate(group_expr.clone(), aggr_expr.clone())? + .collect() + .await?; + let table_results = &table.aggregate(group_expr, aggr_expr)?.collect().await?; + + assert_batches_sorted_eq!( + vec![ + "+----+-----------------------------+", + "| c1 | SUM(aggregate_test_100.c12) |", + "+----+-----------------------------+", + "| a | 10.238448667882977 |", + "| b | 7.797734760124923 |", + "| c | 13.860958726523545 |", + "| d | 8.793968289758968 |", + "| e | 10.206140546981722 |", + "+----+-----------------------------+", + ], + df_results + ); + + // the results are the same as the results from the view, modulo the leaf table name + assert_batches_sorted_eq!( + vec![ + "+----+---------------------+", + "| c1 | SUM(test_table.c12) |", + "+----+---------------------+", + "| a | 10.238448667882977 |", + "| b | 7.797734760124923 |", + "| c | 13.860958726523545 |", + "| d | 8.793968289758968 |", + "| e | 10.206140546981722 |", + "+----+---------------------+", + ], + table_results + ); + Ok(()) + } /// Compare the formatted string representation of two plans for equality fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) { assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2)); diff --git a/datafusion/src/execution/disk_manager.rs b/datafusion/src/execution/disk_manager.rs index 79b70f1f8b9a..c4fe6b4160fa 100644 --- a/datafusion/src/execution/disk_manager.rs +++ b/datafusion/src/execution/disk_manager.rs @@ -19,7 +19,8 @@ //! hashed among the directories listed in RuntimeConfig::local_dirs. use crate::error::{DataFusionError, Result}; -use log::{debug, info}; +use log::debug; +use parking_lot::Mutex; use rand::{thread_rng, Rng}; use std::path::PathBuf; use std::sync::Arc; @@ -67,7 +68,9 @@ impl DiskManagerConfig { /// while processing dataset larger than available memory. #[derive(Debug)] pub struct DiskManager { - local_dirs: Vec, + /// TempDirs to put temporary files in. A new OS specified + /// temporary directory will be created if this list is empty. + local_dirs: Mutex>, } impl DiskManager { @@ -75,31 +78,39 @@ impl DiskManager { pub fn try_new(config: DiskManagerConfig) -> Result> { match config { DiskManagerConfig::Existing(manager) => Ok(manager), - DiskManagerConfig::NewOs => { - let tempdir = tempfile::tempdir().map_err(DataFusionError::IoError)?; - - debug!( - "Created directory {:?} as DataFusion working directory", - tempdir - ); - Ok(Arc::new(Self { - local_dirs: vec![tempdir], - })) - } + DiskManagerConfig::NewOs => Ok(Arc::new(Self { + local_dirs: Mutex::new(vec![]), + })), DiskManagerConfig::NewSpecified(conf_dirs) => { let local_dirs = create_local_dirs(conf_dirs)?; - info!( + debug!( "Created local dirs {:?} as DataFusion working directory", local_dirs ); - Ok(Arc::new(Self { local_dirs })) + Ok(Arc::new(Self { + local_dirs: Mutex::new(local_dirs), + })) } } } /// Return a temporary file from a randomized choice in the configured locations pub fn create_tmp_file(&self) -> Result { - create_tmp_file(&self.local_dirs) + let mut local_dirs = self.local_dirs.lock(); + + // Create a temporary directory if needed + if local_dirs.is_empty() { + let tempdir = tempfile::tempdir().map_err(DataFusionError::IoError)?; + + debug!( + "Created directory '{:?}' as DataFusion tempfile directory", + tempdir.path().to_string_lossy() + ); + + local_dirs.push(tempdir); + } + + create_tmp_file(&local_dirs) } } @@ -129,10 +140,41 @@ fn create_tmp_file(local_dirs: &[TempDir]) -> Result { #[cfg(test)] mod tests { + use std::path::Path; + use super::*; use crate::error::Result; use tempfile::TempDir; + #[test] + fn lazy_temp_dir_creation() -> Result<()> { + // A default configuration should not create temp files until requested + let config = DiskManagerConfig::new(); + let dm = DiskManager::try_new(config)?; + + assert_eq!(0, local_dir_snapshot(&dm).len()); + + // can still create a tempfile however: + let actual = dm.create_tmp_file()?; + + // Now the tempdir has been created on demand + assert_eq!(1, local_dir_snapshot(&dm).len()); + + // the returned tempfile file should be in the temp directory + let local_dirs = local_dir_snapshot(&dm); + assert_path_in_dirs(actual.path(), local_dirs.iter().map(|p| p.as_path())); + + Ok(()) + } + + fn local_dir_snapshot(dm: &DiskManager) -> Vec { + dm.local_dirs + .lock() + .iter() + .map(|p| p.path().into()) + .collect() + } + #[test] fn file_in_right_dir() -> Result<()> { let local_dir1 = TempDir::new()?; @@ -147,19 +189,24 @@ mod tests { let actual = dm.create_tmp_file()?; // the file should be in one of the specified local directories - let found = local_dirs.iter().any(|p| { - actual - .path() + assert_path_in_dirs(actual.path(), local_dirs.into_iter()); + + Ok(()) + } + + /// Asserts that `file_path` is found anywhere in any of `dir` directories + fn assert_path_in_dirs<'a>( + file_path: &'a Path, + dirs: impl Iterator, + ) { + let dirs: Vec<&Path> = dirs.collect(); + + let found = dirs.iter().any(|file_path| { + file_path .ancestors() - .any(|candidate_path| *p == candidate_path) + .any(|candidate_path| *file_path == candidate_path) }); - assert!( - found, - "Can't find {:?} in specified local dirs: {:?}", - actual, local_dirs - ); - - Ok(()) + assert!(found, "Can't find {:?} in dirs: {:?}", file_path, dirs); } } diff --git a/datafusion/src/execution/memory_manager.rs b/datafusion/src/execution/memory_manager.rs index 32f79750a70d..d39eaab3c215 100644 --- a/datafusion/src/execution/memory_manager.rs +++ b/datafusion/src/execution/memory_manager.rs @@ -19,12 +19,13 @@ use crate::error::{DataFusionError, Result}; use async_trait::async_trait; -use hashbrown::HashMap; -use log::info; +use hashbrown::HashSet; +use log::debug; +use parking_lot::{Condvar, Mutex}; use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Condvar, Mutex, Weak}; +use std::sync::Arc; static CONSUMER_ID: AtomicUsize = AtomicUsize::new(0); @@ -169,7 +170,7 @@ pub trait MemoryConsumer: Send + Sync { /// reached for this consumer. async fn try_grow(&self, required: usize) -> Result<()> { let current = self.mem_used(); - info!( + debug!( "trying to acquire {} whiling holding {} from consumer {}", human_readable_size(required), human_readable_size(current), @@ -181,7 +182,7 @@ pub trait MemoryConsumer: Send + Sync { .can_grow_directly(required, current) .await; if !can_grow_directly { - info!( + debug!( "Failed to grow memory of {} directly from consumer {}, spilling first ...", human_readable_size(required), self.id() @@ -245,10 +246,10 @@ The memory management architecture is the following: /// Manage memory usage during physical plan execution #[derive(Debug)] pub struct MemoryManager { - requesters: Arc>>>, - trackers: Arc>>>, + requesters: Arc>>, pool_size: usize, requesters_total: Arc>, + trackers_total: AtomicUsize, cv: Condvar, } @@ -261,16 +262,16 @@ impl MemoryManager { match config { MemoryManagerConfig::Existing(manager) => manager, MemoryManagerConfig::New { .. } => { - info!( + debug!( "Creating memory manager with initial size {}", human_readable_size(pool_size) ); Arc::new(Self { - requesters: Arc::new(Mutex::new(HashMap::new())), - trackers: Arc::new(Mutex::new(HashMap::new())), + requesters: Arc::new(Mutex::new(HashSet::new())), pool_size, requesters_total: Arc::new(Mutex::new(0)), + trackers_total: AtomicUsize::new(0), cv: Condvar::new(), }) } @@ -278,30 +279,36 @@ impl MemoryManager { } fn get_tracker_total(&self) -> usize { - let trackers = self.trackers.lock().unwrap(); - if trackers.len() > 0 { - trackers.values().fold(0usize, |acc, y| match y.upgrade() { - None => acc, - Some(t) => acc + t.mem_used(), - }) - } else { - 0 - } + self.trackers_total.load(Ordering::SeqCst) } - /// Register a new memory consumer for memory usage tracking - pub(crate) fn register_consumer(&self, consumer: &Arc) { - let id = consumer.id().clone(); - match consumer.type_() { - ConsumerType::Requesting => { - let mut requesters = self.requesters.lock().unwrap(); - requesters.insert(id, Arc::downgrade(consumer)); - } - ConsumerType::Tracking => { - let mut trackers = self.trackers.lock().unwrap(); - trackers.insert(id, Arc::downgrade(consumer)); - } - } + pub(crate) fn grow_tracker_usage(&self, delta: usize) { + self.trackers_total.fetch_add(delta, Ordering::SeqCst); + } + + pub(crate) fn shrink_tracker_usage(&self, delta: usize) { + let update = + self.trackers_total + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| { + if x >= delta { + Some(x - delta) + } else { + None + } + }); + update.expect(&*format!( + "Tracker total memory shrink by {} underflow, current value is ", + delta + )); + } + + fn get_requester_total(&self) -> usize { + *self.requesters_total.lock() + } + + /// Register a new memory requester + pub(crate) fn register_requester(&self, requester_id: &MemoryConsumerId) { + self.requesters.lock().insert(requester_id.clone()); } fn max_mem_for_requesters(&self) -> usize { @@ -311,13 +318,12 @@ impl MemoryManager { /// Grow memory attempt from a consumer, return if we could grant that much to it async fn can_grow_directly(&self, required: usize, current: usize) -> bool { - let num_rqt = self.requesters.lock().unwrap().len(); - let mut rqt_current_used = self.requesters_total.lock().unwrap(); + let num_rqt = self.requesters.lock().len(); + let mut rqt_current_used = self.requesters_total.lock(); let mut rqt_max = self.max_mem_for_requesters(); let granted; loop { - let remaining = rqt_max - *rqt_current_used; let max_per_rqt = rqt_max / num_rqt; let min_per_rqt = max_per_rqt / 2; @@ -326,6 +332,7 @@ impl MemoryManager { break; } + let remaining = rqt_max.checked_sub(*rqt_current_used).unwrap_or_default(); if remaining >= required { granted = true; *rqt_current_used += required; @@ -333,7 +340,7 @@ impl MemoryManager { } else if current < min_per_rqt { // if we cannot acquire at lease 1/2n memory, just wait for others // to spill instead spill self frequently with limited total mem - rqt_current_used = self.cv.wait(rqt_current_used).unwrap(); + self.cv.wait(&mut rqt_current_used); } else { granted = false; break; @@ -345,48 +352,39 @@ impl MemoryManager { granted } - fn record_free_then_acquire(&self, freed: usize, acquired: usize) { - let mut requesters_total = self.requesters_total.lock().unwrap(); + fn record_free_then_acquire(&self, freed: usize, acquired: usize) -> usize { + let mut requesters_total = self.requesters_total.lock(); + assert!(*requesters_total >= freed); *requesters_total -= freed; *requesters_total += acquired; self.cv.notify_all() } - /// Drop a memory consumer from memory usage tracking - pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId) { + /// Drop a memory consumer and reclaim the memory + pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId, mem_used: usize) { // find in requesters first { - let mut requesters = self.requesters.lock().unwrap(); - if requesters.remove(id).is_some() { - return; + let mut requesters = self.requesters.lock(); + if requesters.remove(id) { + let mut total = self.requesters_total.lock(); + assert!(*total >= mem_used); + *total -= mem_used; } } - let mut trackers = self.trackers.lock().unwrap(); - trackers.remove(id); + self.shrink_tracker_usage(mem_used); + self.cv.notify_all(); } } impl Display for MemoryManager { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let requesters = - self.requesters - .lock() - .unwrap() - .values() - .fold(vec![], |mut acc, consumer| match consumer.upgrade() { - None => acc, - Some(c) => { - acc.push(format!("{}", c)); - acc - } - }); - let tracker_mem = self.get_tracker_total(); write!(f, - "MemoryManager usage statistics: total {}, tracker used {}, total {} requesters detail: \n {},", - human_readable_size(self.pool_size), - human_readable_size(tracker_mem), - &requesters.len(), - requesters.join("\n")) + "MemoryManager usage statistics: total {}, trackers used {}, total {} requesters used: {}", + human_readable_size(self.pool_size), + human_readable_size(self.get_tracker_total()), + self.requesters.lock().len(), + human_readable_size(self.get_requester_total()), + ) } } @@ -395,7 +393,8 @@ const GB: u64 = 1 << 30; const MB: u64 = 1 << 20; const KB: u64 = 1 << 10; -fn human_readable_size(size: usize) -> String { +/// Present size in human readable form +pub fn human_readable_size(size: usize) -> String { let size = size as u64; let (value, unit) = { if size >= 2 * TB { @@ -418,6 +417,8 @@ mod tests { use super::*; use crate::error::Result; use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::execution::MemoryConsumer; + use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics}; use async_trait::async_trait; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -487,6 +488,7 @@ mod tests { impl DummyTracker { fn new(partition: usize, runtime: Arc, mem_used: usize) -> Self { + runtime.grow_tracker_usage(mem_used); Self { id: MemoryConsumerId::new(partition), runtime, @@ -528,33 +530,39 @@ mod tests { .with_memory_manager(MemoryManagerConfig::try_new_limit(100, 1.0).unwrap()); let runtime = Arc::new(RuntimeEnv::new(config).unwrap()); - let tracker1 = Arc::new(DummyTracker::new(0, runtime.clone(), 5)); - runtime.register_consumer(&(tracker1.clone() as Arc)); + DummyTracker::new(0, runtime.clone(), 5); assert_eq!(runtime.memory_manager.get_tracker_total(), 5); - let tracker2 = Arc::new(DummyTracker::new(0, runtime.clone(), 10)); - runtime.register_consumer(&(tracker2.clone() as Arc)); + let tracker1 = DummyTracker::new(0, runtime.clone(), 10); assert_eq!(runtime.memory_manager.get_tracker_total(), 15); - let tracker3 = Arc::new(DummyTracker::new(0, runtime.clone(), 15)); - runtime.register_consumer(&(tracker3.clone() as Arc)); + DummyTracker::new(0, runtime.clone(), 15); assert_eq!(runtime.memory_manager.get_tracker_total(), 30); - runtime.drop_consumer(tracker2.id()); + runtime.drop_consumer(tracker1.id(), tracker1.mem_used); + assert_eq!(runtime.memory_manager.get_tracker_total(), 20); + + // MemTrackingMetrics as an easy way to track memory + let ms = ExecutionPlanMetricsSet::new(); + let tracking_metric = MemTrackingMetrics::new_with_rt(&ms, 0, runtime.clone()); + tracking_metric.init_mem_used(15); + assert_eq!(runtime.memory_manager.get_tracker_total(), 35); + + drop(tracking_metric); assert_eq!(runtime.memory_manager.get_tracker_total(), 20); - let requester1 = Arc::new(DummyRequester::new(0, runtime.clone())); - runtime.register_consumer(&(requester1.clone() as Arc)); + let requester1 = DummyRequester::new(0, runtime.clone()); + runtime.register_requester(requester1.id()); // first requester entered, should be able to use any of the remaining 80 requester1.do_with_mem(40).await.unwrap(); requester1.do_with_mem(10).await.unwrap(); assert_eq!(requester1.get_spills(), 0); assert_eq!(requester1.mem_used(), 50); - assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 50); + assert_eq!(*runtime.memory_manager.requesters_total.lock(), 50); - let requester2 = Arc::new(DummyRequester::new(0, runtime.clone())); - runtime.register_consumer(&(requester2.clone() as Arc)); + let requester2 = DummyRequester::new(0, runtime.clone()); + runtime.register_requester(requester2.id()); requester2.do_with_mem(20).await.unwrap(); requester2.do_with_mem(30).await.unwrap(); @@ -565,7 +573,7 @@ mod tests { assert_eq!(requester1.get_spills(), 1); assert_eq!(requester1.mem_used(), 10); - assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 40); + assert_eq!(*runtime.memory_manager.requesters_total.lock(), 40); } #[tokio::test] diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs index e3b42ae254a9..427c539cc75b 100644 --- a/datafusion/src/execution/mod.rs +++ b/datafusion/src/execution/mod.rs @@ -25,4 +25,6 @@ pub mod options; pub mod runtime_env; pub use disk_manager::DiskManager; -pub use memory_manager::{MemoryConsumer, MemoryConsumerId, MemoryManager}; +pub use memory_manager::{ + human_readable_size, MemoryConsumer, MemoryConsumerId, MemoryManager, +}; diff --git a/datafusion/src/execution/options.rs b/datafusion/src/execution/options.rs index 219e2fd89700..79b07536acb3 100644 --- a/datafusion/src/execution/options.rs +++ b/datafusion/src/execution/options.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::datatypes::{Schema, SchemaRef}; +use crate::datasource::file_format::json::DEFAULT_JSON_EXTENSION; use crate::datasource::{ file_format::{avro::AvroFormat, csv::CsvFormat}, listing::ListingOptions, @@ -173,7 +174,7 @@ impl<'a> Default for NdJsonReadOptions<'a> { Self { schema: None, schema_infer_max_records: 1000, - file_extension: ".json", + file_extension: DEFAULT_JSON_EXTENSION, } } } diff --git a/datafusion/src/execution/runtime_env.rs b/datafusion/src/execution/runtime_env.rs index cdcd1f71b4f5..e993b385ecd4 100644 --- a/datafusion/src/execution/runtime_env.rs +++ b/datafusion/src/execution/runtime_env.rs @@ -22,9 +22,7 @@ use crate::{ error::Result, execution::{ disk_manager::{DiskManager, DiskManagerConfig}, - memory_manager::{ - MemoryConsumer, MemoryConsumerId, MemoryManager, MemoryManagerConfig, - }, + memory_manager::{MemoryConsumerId, MemoryManager, MemoryManagerConfig}, }, }; @@ -71,13 +69,23 @@ impl RuntimeEnv { } /// Register the consumer to get it tracked - pub fn register_consumer(&self, memory_consumer: &Arc) { - self.memory_manager.register_consumer(memory_consumer); + pub fn register_requester(&self, id: &MemoryConsumerId) { + self.memory_manager.register_requester(id); } - /// Drop the consumer from get tracked - pub fn drop_consumer(&self, id: &MemoryConsumerId) { - self.memory_manager.drop_consumer(id) + /// Drop the consumer from get tracked, reclaim memory + pub fn drop_consumer(&self, id: &MemoryConsumerId, mem_used: usize) { + self.memory_manager.drop_consumer(id, mem_used) + } + + /// Grow tracker memory of `delta` + pub fn grow_tracker_usage(&self, delta: usize) { + self.memory_manager.grow_tracker_usage(delta) + } + + /// Shrink tracker memory of `delta` + pub fn shrink_tracker_usage(&self, delta: usize) { + self.memory_manager.shrink_tracker_usage(delta) } } diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 549db89035eb..3a64f7630a84 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -1148,7 +1148,7 @@ mod tests { // id column should only show up once in projection let expected = "Projection: #t1.id, #t1.first_name, #t1.last_name, #t1.state, #t1.salary, #t2.first_name, #t2.last_name, #t2.state, #t2.salary\ - \n Join: Using #t1.id = #t2.id\ + \n Inner Join: Using #t1.id = #t2.id\ \n TableScan: t1 projection=None\ \n TableScan: t2 projection=None"; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 4d81472da9dc..2dd9f9eb3c41 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -23,10 +23,12 @@ pub use super::Operator; use arrow::{compute::cast::can_cast_types, datatypes::DataType}; use crate::error::{DataFusionError, Result}; +use crate::execution::context::ExecutionProps; use crate::field_util::{get_indexed_field, FieldExt}; use crate::logical_plan::{ plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan, }; +use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier}; use crate::physical_plan::functions::Volatility; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, @@ -973,6 +975,58 @@ impl Expr { Ok(expr) } } + + /// Simplifies this [`Expr`]`s as much as possible, evaluating + /// constants and applying algebraic simplifications + /// + /// # Example: + /// `b > 2 AND b > 2` + /// can be written to + /// `b > 2` + /// + /// ``` + /// use datafusion::logical_plan::*; + /// use datafusion::error::Result; + /// use datafusion::execution::context::ExecutionProps; + /// + /// /// Simple implementation that provides `Simplifier` the information it needs + /// #[derive(Default)] + /// struct Info { + /// execution_props: ExecutionProps, + /// }; + /// + /// impl SimplifyInfo for Info { + /// fn is_boolean_type(&self, expr: &Expr) -> Result { + /// Ok(false) + /// } + /// fn nullable(&self, expr: &Expr) -> Result { + /// Ok(true) + /// } + /// fn execution_props(&self) -> &ExecutionProps { + /// &self.execution_props + /// } + /// } + /// + /// // b < 2 + /// let b_lt_2 = col("b").gt(lit(2)); + /// + /// // (b < 2) OR (b < 2) + /// let expr = b_lt_2.clone().or(b_lt_2.clone()); + /// + /// // (b < 2) OR (b < 2) --> (b < 2) + /// let expr = expr.simplify(&Info::default()).unwrap(); + /// assert_eq!(expr, b_lt_2); + /// ``` + pub fn simplify(self, info: &S) -> Result { + let mut rewriter = Simplifier::new(info); + let mut const_evaluator = ConstEvaluator::new(info.execution_props()); + + // TODO iterate until no changes are made during rewrite + // (evaluating constants can enable new simplifications and + // simplifications can enable new constant evaluation) + // https://github.com/apache/arrow-datafusion/issues/1160 + self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter) + } } impl Not for Expr { @@ -1094,6 +1148,20 @@ pub trait ExprRewriter: Sized { fn mutate(&mut self, expr: Expr) -> Result; } +/// The information necessary to apply algebraic simplification to an +/// [Expr]. See [SimplifyContext] for one implementation +pub trait SimplifyInfo { + /// returns true if this Expr has boolean type + fn is_boolean_type(&self, expr: &Expr) -> Result; + + /// returns true of this expr is nullable (could possibly be NULL) + fn nullable(&self, expr: &Expr) -> Result; + + /// Returns details needed for partial expression evaluation + fn execution_props(&self) -> &ExecutionProps; +} + +/// Helper struct for building [Expr::Case] pub struct CaseBuilder { expr: Option>, when_expr: Vec, @@ -1649,6 +1717,15 @@ pub fn approx_distinct(expr: Expr) -> Expr { } } +/// Calculate an approximation of the specified `percentile` for `expr`. +pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxPercentileCont, + distinct: false, + args: vec![expr, percentile], + } +} + // TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many // varying arity functions /// Create an convenience function representing a unary scalar function diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 56fec3cf1a0c..25714514d78a 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -36,17 +36,18 @@ pub use builder::{ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, binary_expr, - bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr, - combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, - create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, - initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, - max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, - regexp_match, regexp_replace, repeat, replace, replace_col, reverse, + abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, + avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, + columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, + create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, + floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, + lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, + or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion, + SimplifyInfo, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/logical_plan/operators.rs b/datafusion/src/logical_plan/operators.rs index fdfd3f3ca267..14ccab0537bd 100644 --- a/datafusion/src/logical_plan/operators.rs +++ b/datafusion/src/logical_plan/operators.rs @@ -64,6 +64,8 @@ pub enum Operator { RegexNotMatch, /// Case insensitive regex not match RegexNotIMatch, + /// Bitwise and, like `&` + BitwiseAnd, } impl fmt::Display for Operator { @@ -90,6 +92,7 @@ impl fmt::Display for Operator { Operator::RegexNotIMatch => "!~*", Operator::IsDistinctFrom => "IS DISTINCT FROM", Operator::IsNotDistinctFrom => "IS NOT DISTINCT FROM", + Operator::BitwiseAnd => "&", }; write!(f, "{}", display) } diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 2a001c148ec8..5729c62ed4e2 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -26,6 +26,7 @@ use crate::field_util::SchemaExt; use crate::logical_plan::dfschema::DFSchemaRef; use crate::sql::parser::FileType; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use std::fmt::Formatter; use std::{ collections::HashSet, fmt::{self, Display}, @@ -49,6 +50,20 @@ pub enum JoinType { Anti, } +impl Display for JoinType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let join_type = match self { + JoinType::Inner => "Inner", + JoinType::Left => "Left", + JoinType::Right => "Right", + JoinType::Full => "Full", + JoinType::Semi => "Semi", + JoinType::Anti => "Anti", + }; + write!(f, "{}", join_type) + } +} + /// Join constraint #[derive(Debug, Clone, Copy)] pub enum JoinConstraint { @@ -934,16 +949,22 @@ impl LogicalPlan { LogicalPlan::Join(Join { on: ref keys, join_constraint, + join_type, .. }) => { let join_expr: Vec = keys.iter().map(|(l, r)| format!("{} = {}", l, r)).collect(); match join_constraint { JoinConstraint::On => { - write!(f, "Join: {}", join_expr.join(", ")) + write!(f, "{} Join: {}", join_type, join_expr.join(", ")) } JoinConstraint::Using => { - write!(f, "Join: Using {}", join_expr.join(", ")) + write!( + f, + "{} Join: Using {}", + join_type, + join_expr.join(", ") + ) } } } diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index d104e4435f53..7f631d37018f 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -1014,7 +1014,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a <= Int64(1)\ - \n Join: #test.a = #test2.a\ + \n Inner Join: #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1022,7 +1022,7 @@ mod tests { // filter sent to side before the join let expected = "\ - Join: #test.a = #test2.a\ + Inner Join: #test.a = #test2.a\ \n Filter: #test.a <= Int64(1)\ \n TableScan: test projection=None\ \n Projection: #test2.a\ @@ -1055,7 +1055,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Inner Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1063,7 +1063,7 @@ mod tests { // filter sent to side before the join let expected = "\ - Join: Using #test.a = #test2.a\ + Inner Join: Using #test.a = #test2.a\ \n Filter: #test.a <= Int64(1)\ \n TableScan: test projection=None\ \n Projection: #test2.a\ @@ -1099,7 +1099,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.c <= #test2.b\ - \n Join: #test.a = #test2.a\ + \n Inner Join: #test.a = #test2.a\ \n Projection: #test.a, #test.c\ \n TableScan: test projection=None\ \n Projection: #test2.a, #test2.b\ @@ -1138,7 +1138,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.b <= Int64(1)\ - \n Join: #test.a = #test2.a\ + \n Inner Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n TableScan: test projection=None\ \n Projection: #test2.a, #test2.c\ @@ -1146,7 +1146,7 @@ mod tests { ); let expected = "\ - Join: #test.a = #test2.a\ + Inner Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n Filter: #test.b <= Int64(1)\ \n TableScan: test projection=None\ @@ -1180,7 +1180,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test2.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Left Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1189,7 +1189,7 @@ mod tests { // filter not duplicated nor pushed down - i.e. noop let expected = "\ Filter: #test2.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Left Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None"; @@ -1221,7 +1221,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Right Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1230,7 +1230,7 @@ mod tests { // filter not duplicated nor pushed down - i.e. noop let expected = "\ Filter: #test.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Right Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None"; @@ -1262,7 +1262,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Left Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1270,7 +1270,7 @@ mod tests { // filter sent to left side of the join, not the right let expected = "\ - Join: Using #test.a = #test2.a\ + Left Join: Using #test.a = #test2.a\ \n Filter: #test.a <= Int64(1)\ \n TableScan: test projection=None\ \n Projection: #test2.a\ @@ -1303,7 +1303,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test2.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Right Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1311,7 +1311,7 @@ mod tests { // filter sent to right side of join, not duplicated to the left let expected = "\ - Join: Using #test.a = #test2.a\ + Right Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n Filter: #test2.a <= Int64(1)\ diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index d2f482f6caf6..1d2a3028b6ca 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -593,7 +593,7 @@ mod tests { // make sure projections are pushed down to both table scans let expected = "Projection: #test.a, #test.b, #test2.c1\ - \n Join: #test.a = #test2.c1\ + \n Left Join: #test.a = #test2.c1\ \n TableScan: test projection=Some([0, 1])\ \n TableScan: test2 projection=Some([0])"; @@ -634,7 +634,7 @@ mod tests { // make sure projections are pushed down to both table scans let expected = "Projection: #test.a, #test.b\ - \n Join: #test.a = #test2.c1\ + \n Left Join: #test.a = #test2.c1\ \n TableScan: test projection=Some([0, 1])\ \n TableScan: test2 projection=Some([0])"; @@ -673,7 +673,7 @@ mod tests { // make sure projections are pushed down to table scan let expected = "Projection: #test.a, #test.b\ - \n Join: Using #test.a = #test2.a\ + \n Left Join: Using #test.a = #test2.a\ \n TableScan: test projection=Some([0, 1])\ \n TableScan: test2 projection=Some([0])"; diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 4583a6730536..e03babef49ef 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -22,19 +22,70 @@ use arrow::array::new_null_array; use arrow::datatypes::{DataType, Field, Schema}; use crate::error::DataFusionError; -use crate::execution::context::{ExecutionContextState, ExecutionProps}; +use crate::execution::context::ExecutionProps; use crate::field_util::SchemaExt; -use crate::logical_plan::{lit, DFSchemaRef, Expr}; -use crate::logical_plan::{DFSchema, ExprRewriter, LogicalPlan, RewriteRecursion}; +use crate::logical_plan::{ + lit, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan, RewriteRecursion, + SimplifyInfo, +}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::physical_plan::functions::Volatility; -use crate::physical_plan::planner::DefaultPhysicalPlanner; +use crate::physical_plan::planner::create_physical_expr; use crate::scalar::ScalarValue; use crate::{error::Result, logical_plan::Operator}; -/// Simplifies plans by rewriting [`Expr`]`s evaluating constants -/// and applying algebraic simplifications +/// Provides simplification information based on schema and properties +struct SimplifyContext<'a, 'b> { + schemas: Vec<&'a DFSchemaRef>, + props: &'b ExecutionProps, +} + +impl<'a, 'b> SimplifyContext<'a, 'b> { + /// Create a new SimplifyContext + pub fn new(schemas: Vec<&'a DFSchemaRef>, props: &'b ExecutionProps) -> Self { + Self { schemas, props } + } +} + +impl<'a, 'b> SimplifyInfo for SimplifyContext<'a, 'b> { + /// returns true if this Expr has boolean type + fn is_boolean_type(&self, expr: &Expr) -> Result { + for schema in &self.schemas { + if let Ok(DataType::Boolean) = expr.get_type(schema) { + return Ok(true); + } + } + + Ok(false) + } + /// Returns true if expr is nullable + fn nullable(&self, expr: &Expr) -> Result { + self.schemas + .iter() + .find_map(|schema| { + // expr may be from another input, so ignore errors + // by converting to None to keep trying + expr.nullable(schema.as_ref()).ok() + }) + .ok_or_else(|| { + // This means we weren't able to compute `Expr::nullable` with + // *any* input schemas, signalling a problem + DataFusionError::Internal(format!( + "Could not find find columns in '{}' during simplify", + expr + )) + }) + } + + fn execution_props(&self) -> &ExecutionProps { + self.props + } +} + +/// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting +/// [`Expr`]`s evaluating constants and applying algebraic +/// simplifications /// /// # Introduction /// It uses boolean algebra laws to simplify or reduce the number of terms in expressions. @@ -45,7 +96,7 @@ use crate::{error::Result, logical_plan::Operator}; /// `Filter: b > 2` /// #[derive(Default)] -pub struct SimplifyExpressions {} +pub(crate) struct SimplifyExpressions {} /// returns true if `needle` is found in a chain of search_op /// expressions. Such as: (A AND B) AND C @@ -151,9 +202,7 @@ impl OptimizerRule for SimplifyExpressions { // projected columns. With just the projected schema, it's not possible to infer types for // expressions that references non-projected columns within the same project plan or its // children plans. - let mut simplifier = Simplifier::new(plan.all_schemas()); - - let mut const_evaluator = ConstEvaluator::new(execution_props); + let info = SimplifyContext::new(plan.all_schemas(), execution_props); let new_inputs = plan .inputs() @@ -169,15 +218,8 @@ impl OptimizerRule for SimplifyExpressions { // Constant folding should not change expression name. let name = &e.name(plan.schema()); - // TODO iterate until no changes are made - // during rewrite (evaluating constants can - // enable new simplifications and - // simplifications can enable new constant - // evaluation) - let new_e = e - // fold constants and then simplify - .rewrite(&mut const_evaluator)? - .rewrite(&mut simplifier)?; + // Apply the actual simplification logic + let new_e = e.simplify(&info)?; let new_name = &new_e.name(plan.schema()); @@ -224,7 +266,7 @@ impl SimplifyExpressions { /// let rewritten = expr.rewrite(&mut const_evaluator).unwrap(); /// assert_eq!(rewritten, lit(3) + col("a")); /// ``` -pub struct ConstEvaluator { +pub struct ConstEvaluator<'a> { /// can_evaluate is used during the depth-first-search of the /// Expr tree to track if any siblings (or their descendants) were /// non evaluatable (e.g. had a column reference or volatile @@ -239,13 +281,12 @@ pub struct ConstEvaluator { /// descendants) so this Expr can be evaluated can_evaluate: Vec, - ctx_state: ExecutionContextState, - planner: DefaultPhysicalPlanner, + execution_props: &'a ExecutionProps, input_schema: DFSchema, input_batch: RecordBatch, } -impl ExprRewriter for ConstEvaluator { +impl<'a> ExprRewriter for ConstEvaluator<'a> { fn pre_visit(&mut self, expr: &Expr) -> Result { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -283,16 +324,11 @@ impl ExprRewriter for ConstEvaluator { } } -impl ConstEvaluator { +impl<'a> ConstEvaluator<'a> { /// Create a new `ConstantEvaluator`. Session constants (such as /// the time for `now()` are taken from the passed /// `execution_props`. - pub fn new(execution_props: &ExecutionProps) -> Self { - let planner = DefaultPhysicalPlanner::default(); - let ctx_state = ExecutionContextState { - execution_props: execution_props.clone(), - ..ExecutionContextState::new() - }; + pub fn new(execution_props: &'a ExecutionProps) -> Self { let input_schema = DFSchema::empty(); // The dummy column name is unused and doesn't matter as only @@ -307,8 +343,7 @@ impl ConstEvaluator { Self { can_evaluate: vec![], - ctx_state, - planner, + execution_props, input_schema, input_batch, } @@ -365,11 +400,11 @@ impl ConstEvaluator { return Ok(s); } - let phys_expr = self.planner.create_physical_expr( + let phys_expr = create_physical_expr( &expr, &self.input_schema, self.input_batch.schema(), - &self.ctx_state, + self.execution_props, )?; let col_val = phys_expr.evaluate(&self.input_batch)?; match col_val { @@ -397,52 +432,23 @@ impl ConstEvaluator { /// * `false = true` and `true = false` to `false` /// * `!!expr` to `expr` /// * `expr = null` and `expr != null` to `null` -pub(crate) struct Simplifier<'a> { - /// input schemas - schemas: Vec<&'a DFSchemaRef>, +pub(crate) struct Simplifier<'a, S> { + info: &'a S, } -impl<'a> Simplifier<'a> { - pub fn new(schemas: Vec<&'a DFSchemaRef>) -> Self { - Self { schemas } - } - - fn is_boolean_type(&self, expr: &Expr) -> bool { - for schema in &self.schemas { - if let Ok(DataType::Boolean) = expr.get_type(schema) { - return true; - } - } - - false - } - - /// Returns true if expr is nullable - fn nullable(&self, expr: &Expr) -> Result { - self.schemas - .iter() - .find_map(|schema| { - // expr may be from another input, so ignore errors - // by converting to None to keep trying - expr.nullable(schema.as_ref()).ok() - }) - .ok_or_else(|| { - // This means we weren't able to compute `Expr::nullable` with - // *any* input schemas, signalling a problem - DataFusionError::Internal(format!( - "Could not find find columns in '{}' during simplify", - expr - )) - }) +impl<'a, S> Simplifier<'a, S> { + pub fn new(info: &'a S) -> Self { + Self { info } } } -impl<'a> ExprRewriter for Simplifier<'a> { +impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { /// rewrite the expression simplifying any constant expressions fn mutate(&mut self, expr: Expr) -> Result { use Expr::*; use Operator::{And, Divide, Eq, Multiply, NotEq, Or}; + let info = self.info; let new_expr = match expr { // // Rules for Eq @@ -455,7 +461,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: Eq, right, - } if is_bool_lit(&left) && self.is_boolean_type(&right) => { + } if is_bool_lit(&left) && info.is_boolean_type(&right)? => { match as_bool_lit(*left) { Some(true) => *right, Some(false) => Not(right), @@ -469,7 +475,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: Eq, right, - } if is_bool_lit(&right) && self.is_boolean_type(&left) => { + } if is_bool_lit(&right) && info.is_boolean_type(&left)? => { match as_bool_lit(*right) { Some(true) => *left, Some(false) => Not(left), @@ -488,7 +494,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: NotEq, right, - } if is_bool_lit(&left) && self.is_boolean_type(&right) => { + } if is_bool_lit(&left) && info.is_boolean_type(&right)? => { match as_bool_lit(*left) { Some(true) => Not(right), Some(false) => *right, @@ -502,7 +508,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: NotEq, right, - } if is_bool_lit(&right) && self.is_boolean_type(&left) => { + } if is_bool_lit(&right) && info.is_boolean_type(&left)? => { match as_bool_lit(*right) { Some(true) => Not(left), Some(false) => *left, @@ -555,13 +561,13 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: Or, right, - } if !self.nullable(&right)? && is_op_with(And, &right, &left) => *left, + } if !info.nullable(&right)? && is_op_with(And, &right, &left) => *left, // (A AND B) OR A --> A (if B not null) BinaryExpr { left, op: Or, right, - } if !self.nullable(&left)? && is_op_with(And, &left, &right) => *right, + } if !info.nullable(&left)? && is_op_with(And, &left, &right) => *right, // // Rules for AND @@ -608,13 +614,13 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: And, right, - } if !self.nullable(&right)? && is_op_with(Or, &right, &left) => *left, + } if !info.nullable(&right)? && is_op_with(Or, &right, &left) => *left, // (A OR B) AND A --> A (if B not null) BinaryExpr { left, op: And, right, - } if !self.nullable(&left)? && is_op_with(Or, &left, &right) => *right, + } if !info.nullable(&left)? && is_op_with(Or, &left, &right) => *right, // // Rules for Multiply @@ -651,7 +657,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: Divide, right, - } if !self.nullable(&left)? && left == right => lit(1), + } if !info.nullable(&left)? && left == right => lit(1), // // Rules for Not @@ -663,6 +669,54 @@ impl<'a> ExprRewriter for Simplifier<'a> { _ => unreachable!(), }, + // + // Rules for Case + // + + // CASE + // WHEN X THEN A + // WHEN Y THEN B + // ... + // ELSE Q + // END + // + // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q) + // + // Note: the rationale for this rewrite is that the expr can then be further + // simplified using the existing rules for AND/OR + Case { + expr: None, + when_then_expr, + else_expr, + } if !when_then_expr.is_empty() + && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number + && info.is_boolean_type(&when_then_expr[0].1)? => + { + // The disjunction of all the when predicates encountered so far + let mut filter_expr = lit(false); + // The disjunction of all the cases + let mut out_expr = lit(false); + + for (when, then) in when_then_expr { + let case_expr = when + .as_ref() + .clone() + .and(filter_expr.clone().not()) + .and(*then); + + out_expr = out_expr.or(case_expr); + filter_expr = filter_expr.or(*when); + } + + if let Some(else_expr) = else_expr { + let case_expr = filter_expr.not().and(*else_expr); + out_expr = out_expr.or(case_expr); + } + + // Do a first pass at simplification + out_expr.rewrite(self)? + } + expr => { // no additional rewrites possible expr @@ -1142,6 +1196,7 @@ mod tests { ) { let execution_props = ExecutionProps { query_execution_start_time: *date_time, + var_providers: None, }; let mut const_evaluator = ConstEvaluator::new(&execution_props); @@ -1167,15 +1222,9 @@ mod tests { fn simplify(expr: Expr) -> Expr { let schema = expr_test_schema(); - let mut rewriter = Simplifier::new(vec![&schema]); - let execution_props = ExecutionProps::new(); - let mut const_evaluator = ConstEvaluator::new(&execution_props); - - expr.rewrite(&mut rewriter) - .expect("expected to simplify") - .rewrite(&mut const_evaluator) - .expect("expected to const evaluate") + let info = SimplifyContext::new(vec![&schema], &execution_props); + expr.simplify(&info).unwrap() } fn expr_test_schema() -> DFSchemaRef { @@ -1292,6 +1341,11 @@ mod tests { #[test] fn simplify_expr_case_when_then_else() { + // CASE WHERE c2 != false THEN "ok" == "not_ok" ELSE c2 == true + // --> + // CASE WHERE c2 THEN false ELSE c2 + // --> + // false assert_eq!( simplify(Expr::Case { expr: None, @@ -1301,11 +1355,85 @@ mod tests { )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), }), - Expr::Case { + col("c2").not().and(col("c2")) // #1716 + ); + + // CASE WHERE c2 != false THEN "ok" == "ok" ELSE c2 + // --> + // CASE WHERE c2 THEN true ELSE c2 + // --> + // c2 + // + // Need to call simplify 2x due to + // https://github.com/apache/arrow-datafusion/issues/1160 + assert_eq!( + simplify(simplify(Expr::Case { expr: None, - when_then_expr: vec![(Box::new(col("c2")), Box::new(lit(false)))], + when_then_expr: vec![( + Box::new(col("c2").not_eq(lit(false))), + Box::new(lit("ok").eq(lit("ok"))), + )], + else_expr: Some(Box::new(col("c2").eq(lit(true)))), + })), + col("c2").or(col("c2").not().and(col("c2"))) // #1716 + ); + + // CASE WHERE ISNULL(c2) THEN true ELSE c2 + // --> + // ISNULL(c2) OR c2 + // + // Need to call simplify 2x due to + // https://github.com/apache/arrow-datafusion/issues/1160 + assert_eq!( + simplify(simplify(Expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(col("c2").is_null()), + Box::new(lit(true)), + )], else_expr: Some(Box::new(col("c2"))), - } + })), + col("c2") + .is_null() + .or(col("c2").is_null().not().and(col("c2"))) + ); + + // CASE WHERE c1 then true WHERE c2 then false ELSE true + // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE) + // --> c1 OR (NOT(c1 OR c2)) + // --> NOT(c1) AND c2 + // + // Need to call simplify 2x due to + // https://github.com/apache/arrow-datafusion/issues/1160 + assert_eq!( + simplify(simplify(Expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(col("c1")), Box::new(lit(true)),), + (Box::new(col("c2")), Box::new(lit(false)),) + ], + else_expr: Some(Box::new(lit(true))), + })), + col("c1").or(col("c1").or(col("c2")).not()) + ); + + // CASE WHERE c1 then true WHERE c2 then true ELSE false + // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE) + // --> c1 OR (NOT(c1) AND c2) + // --> c1 OR c2 + // + // Need to call simplify 2x due to + // https://github.com/apache/arrow-datafusion/issues/1160 + assert_eq!( + simplify(simplify(Expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(col("c1")), Box::new(lit(true)),), + (Box::new(col("c2")), Box::new(lit(false)),) + ], + else_expr: Some(Box::new(lit(true))), + })), + col("c1").or(col("c1").or(col("c2")).not()) ); } @@ -1623,6 +1751,7 @@ mod tests { let rule = SimplifyExpressions::new(); let execution_props = ExecutionProps { query_execution_start_time: *date_time, + var_providers: None, }; let err = rule @@ -1639,6 +1768,7 @@ mod tests { let rule = SimplifyExpressions::new(); let execution_props = ExecutionProps { query_execution_start_time: *date_time, + var_providers: None, }; let optimized_plan = rule diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index fe577a644905..7590d37c745d 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -38,15 +38,16 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, }; +use crate::execution::context::ExecutionProps; use crate::field_util::{FieldExt, SchemaExt}; -use crate::physical_plan::expressions::cast::cast_with_error; +use crate::physical_plan::expressions::cast::DEFAULT_DATAFUSION_CAST_OPTIONS; +use crate::physical_plan::planner::create_physical_expr; use crate::prelude::lit; use crate::{ error::{DataFusionError, Result}, - execution::context::ExecutionContextState, logical_plan::{Column, DFSchema, Expr, Operator}, optimizer::utils, - physical_plan::{planner::DefaultPhysicalPlanner, ColumnarValue, PhysicalExpr}, + physical_plan::{ColumnarValue, PhysicalExpr}, }; /// Interface to pass statistics information to [`PruningPredicates`] @@ -132,12 +133,14 @@ impl PruningPredicate { .collect::>(); let stat_schema = Schema::new(stat_fields); let stat_dfschema = DFSchema::try_from(stat_schema.clone())?; - let execution_context_state = ExecutionContextState::new(); - let predicate_expr = DefaultPhysicalPlanner::default().create_physical_expr( + + // TODO allow these properties to be passed in + let execution_props = ExecutionProps::new(); + let predicate_expr = create_physical_expr( &logical_predicate_expr, &stat_dfschema, &stat_schema, - &execution_context_state, + &execution_props, )?; Ok(Self { schema, @@ -371,7 +374,7 @@ fn build_statistics_record_batch( // cast statistics array to required data type (e.g. parquet // provides timestamp statistics as "Int64") let array = - cast_with_error(array.as_ref(), data_type, cast::CastOptions::default())? + cast::cast(array.as_ref(), data_type, DEFAULT_DATAFUSION_CAST_OPTIONS)? .into(); fields.push(stat_field.clone()); diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index ac87c25b6101..59231fccfc65 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -27,7 +27,7 @@ //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. use super::{ - functions::{Signature, Volatility}, + functions::{Signature, TypeSignature, Volatility}, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; @@ -80,6 +80,8 @@ pub enum AggregateFunction { CovariancePop, /// Correlation Correlation, + /// Approximate continuous percentile function + ApproxPercentileCont, } impl fmt::Display for AggregateFunction { @@ -110,6 +112,7 @@ impl FromStr for AggregateFunction { "covar_samp" => AggregateFunction::Covariance, "covar_pop" => AggregateFunction::CovariancePop, "corr" => AggregateFunction::Correlation, + "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -157,6 +160,7 @@ pub fn return_type( coerced_data_types[0].clone(), true, )))), + AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), } } @@ -331,6 +335,20 @@ pub fn create_aggregate_expr( "CORR(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::ApproxPercentileCont, false) => { + Arc::new(expressions::ApproxPercentileCont::new( + // Pass in the desired percentile expr + coerced_phy_exprs, + name, + return_type, + )?) + } + (AggregateFunction::ApproxPercentileCont, true) => { + return Err(DataFusionError::NotImplemented( + "approx_percentile_cont(DISTINCT) aggregations are not available" + .to_string(), + )); + } }) } @@ -359,7 +377,7 @@ static TIMESTAMPS: &[DataType] = &[ static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; /// the signatures supported by the function `fun`. -pub fn signature(fun: &AggregateFunction) -> Signature { +pub(super) fn signature(fun: &AggregateFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match fun { AggregateFunction::Count @@ -389,18 +407,26 @@ pub fn signature(fun: &AggregateFunction) -> Signature { AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } + AggregateFunction::ApproxPercentileCont => Signature::one_of( + // Accept any numeric value paired with a float64 percentile + NUMERICS + .iter() + .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) + .collect(), + Volatility::Immutable, + ), } } #[cfg(test)] mod tests { use super::*; - use crate::error::Result; use crate::field_util::SchemaExt; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, DistinctArrayAgg, - DistinctCount, Max, Min, Stddev, Sum, Variance, + ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count, + Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; + use crate::{error::Result, scalar::ScalarValue}; #[test] fn test_count_arragg_approx_expr() -> Result<()> { @@ -514,6 +540,59 @@ mod tests { Ok(()) } + #[test] + fn test_agg_approx_percentile_phy_expr() { + for data_type in NUMERICS { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + ), + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), + ]; + let result_agg_phy_exprs = create_aggregate_expr( + &AggregateFunction::ApproxPercentileCont, + false, + &input_phy_exprs[..], + &input_schema, + "c1", + ) + .expect("failed to create aggregate expr"); + + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), false), + result_agg_phy_exprs.field().unwrap() + ); + } + } + + #[test] + fn test_agg_approx_percentile_invalid_phy_expr() { + for data_type in NUMERICS { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + ), + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), + ]; + let err = create_aggregate_expr( + &AggregateFunction::ApproxPercentileCont, + false, + &input_phy_exprs[..], + &input_schema, + "c1", + ) + .expect_err("should fail due to invalid percentile"); + + assert!(matches!(err, DataFusionError::Plan(_))); + } + } + #[test] fn test_min_max_expr() -> Result<()> { let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; diff --git a/datafusion/src/physical_plan/coalesce_batches.rs b/datafusion/src/physical_plan/coalesce_batches.rs index 3f17d4d50d92..48956efee285 100644 --- a/datafusion/src/physical_plan/coalesce_batches.rs +++ b/datafusion/src/physical_plan/coalesce_batches.rs @@ -327,9 +327,8 @@ pub fn concat_chunks( #[cfg(test)] mod tests { use super::*; - use crate::physical_plan::{memory::MemoryExec, repartition::RepartitionExec}; - use arrow::array::UInt32Array; + use crate::test::create_vec_batches; use arrow::datatypes::{DataType, Field, Schema}; #[tokio::test(flavor = "multi_thread")] @@ -357,23 +356,6 @@ mod tests { Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) } - fn create_vec_batches(schema: &Arc, num_batches: usize) -> Vec { - let batch = create_batch(schema); - let mut vec = Vec::with_capacity(num_batches); - for _ in 0..num_batches { - vec.push(batch.clone()); - } - vec - } - - fn create_batch(schema: &Arc) -> RecordBatch { - RecordBatch::try_new( - schema.clone(), - vec![Arc::new(UInt32Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]))], - ) - .unwrap() - } - async fn coalesce_batches( schema: &SchemaRef, input_partitions: Vec>, diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index 92168b9dff8f..144be87722ef 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -17,7 +17,6 @@ //! Support the coercion rule for aggregate function. -use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ @@ -27,6 +26,10 @@ use crate::physical_plan::expressions::{ }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; +use crate::{ + arrow::datatypes::Schema, + physical_plan::expressions::is_approx_percentile_cont_supported_arg_type, +}; use arrow::datatypes::DataType; use std::ops::Deref; use std::sync::Arc; @@ -38,24 +41,9 @@ pub(crate) fn coerce_types( input_types: &[DataType], signature: &Signature, ) -> Result> { - match signature.type_signature { - TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != agg_count { - return Err(DataFusionError::Plan(format!( - "The function {:?} expects {:?} arguments, but {:?} were provided", - agg_fun, - agg_count, - input_types.len() - ))); - } - } - _ => { - return Err(DataFusionError::Internal(format!( - "Aggregate functions do not support this {:?}", - signature - ))); - } - }; + // Validate input_types matches (at least one of) the func signature. + check_arg_count(agg_fun, input_types, &signature.type_signature)?; + match agg_fun { AggregateFunction::Count | AggregateFunction::ApproxDistinct => { Ok(input_types.to_vec()) @@ -151,7 +139,75 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::ApproxPercentileCont => { + if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + if !matches!(input_types[1], DataType::Float64) { + return Err(DataFusionError::Plan(format!( + "The percentile argument for {:?} must be Float64, not {:?}.", + agg_fun, input_types[1] + ))); + } + Ok(input_types.to_vec()) + } + } +} + +/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// +/// This method DOES NOT validate the argument types - only that (at least one, +/// in the case of [`TypeSignature::OneOf`]) signature matches the desired +/// number of input types. +fn check_arg_count( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &TypeSignature, +) -> Result<()> { + match signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != *agg_count { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + agg_count, + input_types.len() + ))); + } + } + TypeSignature::Exact(types) => { + if types.len() != input_types.len() { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + types.len(), + input_types.len() + ))); + } + } + TypeSignature::OneOf(variants) => { + let ok = variants + .iter() + .any(|v| check_arg_count(agg_fun, input_types, v).is_ok()); + if !ok { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not accept {:?} function arguments.", + agg_fun, + input_types.len() + ))); + } + } + _ => { + return Err(DataFusionError::Internal(format!( + "Aggregate functions do not support this {:?}", + signature + ))); + } } + Ok(()) } fn get_min_max_result_type(input_types: &[DataType]) -> Result> { @@ -267,5 +323,29 @@ mod tests { assert_eq!(*input_type, result.unwrap()); } } + + // ApproxPercentileCont input types + let input_types = vec![ + vec![DataType::Int8, DataType::Float64], + vec![DataType::Int16, DataType::Float64], + vec![DataType::Int32, DataType::Float64], + vec![DataType::Int64, DataType::Float64], + vec![DataType::UInt8, DataType::Float64], + vec![DataType::UInt16, DataType::Float64], + vec![DataType::UInt32, DataType::Float64], + vec![DataType::UInt64, DataType::Float64], + vec![DataType::Float32, DataType::Float64], + vec![DataType::Float64, DataType::Float64], + ]; + for input_type in &input_types { + let signature = + aggregates::signature(&AggregateFunction::ApproxPercentileCont); + let result = coerce_types( + &AggregateFunction::ApproxPercentileCont, + input_type, + &signature, + ); + assert_eq!(*input_type, result.unwrap()); + } } } diff --git a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs index cfb9828d710b..3afbc863bb1d 100644 --- a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs @@ -31,6 +31,7 @@ pub(crate) fn coerce_types( ) -> Result { // This result MUST be compatible with `binary_coerce` let result = match op { + Operator::BitwiseAnd => bitwise_coercion(lhs_type, rhs_type), Operator::And | Operator::Or => match (lhs_type, rhs_type) { // logical binary boolean operators can only be evaluated in bools (DataType::Boolean, DataType::Boolean) => Some(DataType::Boolean), @@ -72,6 +73,25 @@ pub(crate) fn coerce_types( } } +fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + + if !is_numeric(left_type) || !is_numeric(right_type) { + return None; + } + if left_type == right_type && !is_dictionary(left_type) { + return Some(left_type.clone()); + } + // TODO support other data type + match (left_type, right_type) { + (Int64, _) | (_, Int64) => Some(Int64), + (Int32, _) | (_, Int32) => Some(Int32), + (Int16, _) | (_, Int16) => Some(Int16), + (Int8, _) | (_, Int8) => Some(Int8), + _ => None, + } +} + fn comparison_eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { // can't compare dictionaries directly due to // https://github.com/apache/arrow-rs/issues/1201 diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index 733b1ee92fbf..31cb2f56ab28 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -21,7 +21,7 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; use crate::execution::runtime_env::RuntimeEnv; use crate::field_util::SchemaExt; -use crate::physical_plan::metrics::BaselineMetrics; +use crate::physical_plan::metrics::MemTrackingMetrics; use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use crate::record_batch::RecordBatch; use arrow::compute::aggregate::estimated_bytes_size; @@ -45,7 +45,7 @@ pub struct SizedRecordBatchStream { schema: SchemaRef, batches: Vec>, index: usize, - baseline_metrics: BaselineMetrics, + metrics: MemTrackingMetrics, } impl SizedRecordBatchStream { @@ -53,13 +53,15 @@ impl SizedRecordBatchStream { pub fn new( schema: SchemaRef, batches: Vec>, - baseline_metrics: BaselineMetrics, + metrics: MemTrackingMetrics, ) -> Self { + let size = batches.iter().map(|b| batch_byte_size(b)).sum::(); + metrics.init_mem_used(size); SizedRecordBatchStream { schema, index: 0, batches, - baseline_metrics, + metrics, } } } @@ -77,7 +79,7 @@ impl Stream for SizedRecordBatchStream { } else { None }); - self.baseline_metrics.record_poll(poll) + self.metrics.record_poll(poll) } } diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs index ed750eafeb62..ff799756de4f 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/cross_join.rs @@ -194,7 +194,7 @@ impl ExecutionPlan for CrossJoinExec { schema: self.schema.clone(), left_data, right: stream, - right_batch: Arc::new(std::sync::Mutex::new(None)), + right_batch: Arc::new(parking_lot::Mutex::new(None)), left_index: 0, num_input_batches: 0, num_input_rows: 0, @@ -301,7 +301,7 @@ struct CrossJoinStream { /// Current value on the left left_index: usize, /// Current batch being processed from the right side - right_batch: Arc>>, + right_batch: Arc>>, /// number of input batches num_input_batches: usize, /// number of input rows @@ -356,7 +356,7 @@ impl Stream for CrossJoinStream { if self.left_index > 0 && self.left_index < self.left_data.num_rows() { let start = Instant::now(); let right_batch = { - let right_batch = self.right_batch.lock().unwrap(); + let right_batch = self.right_batch.lock(); right_batch.clone().unwrap() }; let result = @@ -391,7 +391,7 @@ impl Stream for CrossJoinStream { } self.left_index = 1; - let mut right_batch = self.right_batch.lock().unwrap(); + let mut right_batch = self.right_batch.lock(); *right_batch = Some(batch); Some(result) diff --git a/datafusion/src/physical_plan/explain.rs b/datafusion/src/physical_plan/explain.rs index eb1a3e09e5eb..66a441888c51 100644 --- a/datafusion/src/physical_plan/explain.rs +++ b/datafusion/src/physical_plan/explain.rs @@ -20,8 +20,12 @@ use std::any::Any; use std::sync::Arc; +use arrow::{array::*, datatypes::SchemaRef}; +use async_trait::async_trait; + use super::SendableRecordBatchStream; use crate::execution::runtime_env::RuntimeEnv; +use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics}; use crate::record_batch::RecordBatch; use crate::{ error::{DataFusionError, Result}, @@ -31,10 +35,6 @@ use crate::{ Statistics, }, }; -use arrow::{array::*, datatypes::SchemaRef}; - -use crate::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; -use async_trait::async_trait; /// Explain execution plan operator. This operator contains the string /// values of the various plans it has when it is created, and passes @@ -148,12 +148,12 @@ impl ExecutionPlan for ExplainExec { )?; let metrics = ExecutionPlanMetricsSet::new(); - let baseline_metrics = BaselineMetrics::new(&metrics, partition); + let tracking_metrics = MemTrackingMetrics::new(&metrics, partition); Ok(Box::pin(SizedRecordBatchStream::new( self.schema.clone(), vec![Arc::new(record_batch)], - baseline_metrics, + tracking_metrics, ))) } diff --git a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs new file mode 100644 index 000000000000..b20f6354f144 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, iter, sync::Arc}; + +use arrow::{ + array::{ + ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::{DataType, Field}, +}; + +use crate::{ + error::DataFusionError, + physical_plan::{tdigest::TDigest, Accumulator, AggregateExpr, PhysicalExpr}, + scalar::ScalarValue, +}; + +use crate::error::Result; + +use super::{format_state_name, Literal}; + +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`ApproxPercentileCont`] aggregation can operate on. +pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +/// APPROX_PERCENTILE_CONT aggregate expression +#[derive(Debug)] +pub struct ApproxPercentileCont { + name: String, + input_data_type: DataType, + expr: Arc, + percentile: f64, +} + +impl ApproxPercentileCont { + /// Create a new [`ApproxPercentileCont`] aggregate function. + pub fn new( + expr: Vec>, + name: impl Into, + input_data_type: DataType, + ) -> Result { + // Arguments should be [ColumnExpr, DesiredPercentileLiteral] + debug_assert_eq!(expr.len(), 2); + + // Extract the desired percentile literal + let lit = expr[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "desired percentile argument must be float literal".to_string(), + ) + })? + .value(); + let percentile = match lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q as f64, + got => return Err(DataFusionError::NotImplemented(format!( + "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + got + ))) + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return Err(DataFusionError::Plan(format!( + "Percentile value must be between 0.0 and 1.0 inclusive, {} is invalid", + percentile + ))); + } + + Ok(Self { + name: name.into(), + input_data_type, + // The physical expr to evaluate during accumulation + expr: expr[0].clone(), + percentile, + }) + } +} + +impl AggregateExpr for ApproxPercentileCont { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), false)) + } + + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + &format_state_name(&self.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "count"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "max"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "min"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "centroids"), + DataType::List(Box::new(Field::new("item", DataType::Float64, true))), + false, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn create_accumulator(&self) -> Result> { + let accumulator: Box = match &self.input_data_type { + t @ (DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64) => { + Box::new(ApproxPercentileAccumulator::new(self.percentile, t.clone())) + } + other => { + return Err(DataFusionError::NotImplemented(format!( + "Support for 'APPROX_PERCENTILE_CONT' for data type {:?} is not implemented", + other + ))) + } + }; + Ok(accumulator) + } + + fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +pub struct ApproxPercentileAccumulator { + digest: TDigest, + percentile: f64, + return_type: DataType, +} + +impl ApproxPercentileAccumulator { + pub fn new(percentile: f64, return_type: DataType) -> Self { + Self { + digest: TDigest::new(100), + percentile, + return_type, + } + } +} + +impl Accumulator for ApproxPercentileAccumulator { + fn state(&self) -> Result> { + Ok(self.digest.to_scalar_state()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + debug_assert_eq!( + values.len(), + 1, + "invalid number of values in batch percentile update" + ); + let values = &values[0]; + + self.digest = match values.data_type() { + DataType::Float64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Float32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int16 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int8 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt16 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt8 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + e => { + return Err(DataFusionError::Internal(format!( + "APPROX_PERCENTILE_CONT is not expected to receive the type {:?}", + e + ))); + } + }; + + Ok(()) + } + + fn evaluate(&self) -> Result { + let q = self.digest.estimate_quantile(self.percentile); + + // These acceptable return types MUST match the validation in + // ApproxPercentile::create_accumulator. + Ok(match &self.return_type { + DataType::Int8 => ScalarValue::Int8(Some(q as i8)), + DataType::Int16 => ScalarValue::Int16(Some(q as i16)), + DataType::Int32 => ScalarValue::Int32(Some(q as i32)), + DataType::Int64 => ScalarValue::Int64(Some(q as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), + DataType::Float32 => ScalarValue::Float32(Some(q as f32)), + DataType::Float64 => ScalarValue::Float64(Some(q as f64)), + v => unreachable!("unexpected return type {:?}", v), + }) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + + let states = (0..states[0].len()) + .map(|index| { + states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>() + .map(|state| TDigest::from_scalar_state(&state)) + }) + .chain(iter::once(Ok(self.digest.clone()))) + .collect::>>()?; + + self.digest = TDigest::merge_digests(&states); + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/expressions/array_agg.rs b/datafusion/src/physical_plan/expressions/array_agg.rs index be49408bdf16..032ac6f4bd5e 100644 --- a/datafusion/src/physical_plan/expressions/array_agg.rs +++ b/datafusion/src/physical_plan/expressions/array_agg.rs @@ -168,7 +168,7 @@ mod tests { #[test] fn array_agg_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5])); + let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); let list = ScalarValue::List( Some(Box::new(vec![ diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index d902544a96df..732fa1d8c676 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -20,6 +20,7 @@ use std::{any::Any, convert::TryInto, sync::Arc}; use crate::record_batch::RecordBatch; use arrow::array::*; use arrow::compute; +use arrow::datatypes::DataType::Decimal; use arrow::datatypes::{DataType, Schema}; use crate::error::{DataFusionError, Result}; @@ -66,6 +67,103 @@ fn is_not_distinct_from_bool(left: &dyn Array, right: &dyn Array) -> BooleanArra .collect() } +/// The binary_bitwise_array_op macro only evaluates for integer types +/// like int64, int32. +/// It is used to do bitwise operation. +macro_rules! binary_bitwise_array_op { + ($LEFT:expr, $RIGHT:expr, $OP:tt, $ARRAY_TYPE:ident, $TYPE:ty) => {{ + let len = $LEFT.len(); + let left = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let right = $RIGHT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let result = (0..len) + .into_iter() + .map(|i| { + if left.is_null(i) || right.is_null(i) { + None + } else { + Some(left.value(i) $OP right.value(i)) + } + }) + .collect::<$ARRAY_TYPE>(); + Ok(Arc::new(result)) + }}; +} + +/// The binary_bitwise_array_op macro only evaluates for integer types +/// like int64, int32. +/// It is used to do bitwise operation on an array with a scalar. +macro_rules! binary_bitwise_array_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:tt, $ARRAY_TYPE:ident, $TYPE:ty) => {{ + let len = $LEFT.len(); + let array = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let scalar = $RIGHT; + if scalar.is_null() { + Ok(new_null_array(array.data_type().clone(), len).into()) + } else { + let right: $TYPE = scalar.try_into().unwrap(); + let result = (0..len) + .into_iter() + .map(|i| { + if array.is_null(i) { + None + } else { + Some(array.value(i) $OP right) + } + }) + .collect::<$ARRAY_TYPE>(); + Ok(Arc::new(result) as ArrayRef) + } + }}; +} + +fn bitwise_and(left: &dyn Array, right: &dyn Array) -> Result { + match &left.data_type() { + DataType::Int8 => { + binary_bitwise_array_op!(left, right, &, Int8Array, i8) + } + DataType::Int16 => { + binary_bitwise_array_op!(left, right, &, Int16Array, i16) + } + DataType::Int32 => { + binary_bitwise_array_op!(left, right, &, Int32Array, i32) + } + DataType::Int64 => { + binary_bitwise_array_op!(left, right, &, Int64Array, i64) + } + other => Err(DataFusionError::Internal(format!( + "Data type {:?} not supported for binary operation '{}' on dyn arrays", + other, + Operator::BitwiseAnd + ))), + } +} + +fn bitwise_and_scalar( + array: &dyn Array, + scalar: ScalarValue, +) -> Option> { + let result = match array.data_type() { + DataType::Int8 => { + binary_bitwise_array_scalar!(array, scalar, &, Int8Array, i8) + } + DataType::Int16 => { + binary_bitwise_array_scalar!(array, scalar, &, Int16Array, i16) + } + DataType::Int32 => { + binary_bitwise_array_scalar!(array, scalar, &, Int32Array, i32) + } + DataType::Int64 => { + binary_bitwise_array_scalar!(array, scalar, &, Int64Array, i64) + } + other => Err(DataFusionError::Internal(format!( + "Data type {:?} not supported for binary operation '{}' on dyn arrays", + other, + Operator::BitwiseAnd + ))), + }; + Some(result) +} + /// Binary expression #[derive(Debug)] pub struct BinaryExpr { @@ -150,9 +248,24 @@ fn evaluate_regex_case_insensitive( fn evaluate(lhs: &dyn Array, op: &Operator, rhs: &dyn Array) -> Result> { use Operator::*; - if matches!(op, Plus | Minus | Divide | Multiply | Modulo) { + if matches!(op, Plus) { + let arr: ArrayRef = match (lhs.data_type(), rhs.data_type()) { + (Decimal(p1, s1), Decimal(p2, s2)) => { + let left_array = + lhs.as_any().downcast_ref::>().unwrap(); + let right_array = + rhs.as_any().downcast_ref::>().unwrap(); + Arc::new(if *p1 == *p2 && *s1 == *s2 { + compute::arithmetics::decimal::add(left_array, right_array) + } else { + compute::arithmetics::decimal::adaptive_add(left_array, right_array)? + }) + } + _ => compute::arithmetics::add(lhs, rhs).into(), + }; + Ok(arr) + } else if matches!(op, Minus | Divide | Multiply | Modulo) { let arr = match op { - Operator::Plus => compute::arithmetics::add(lhs, rhs), Operator::Minus => compute::arithmetics::sub(lhs, rhs), Operator::Divide => compute::arithmetics::div(lhs, rhs), Operator::Multiply => compute::arithmetics::mul(lhs, rhs), @@ -181,6 +294,8 @@ fn evaluate(lhs: &dyn Array, op: &Operator, rhs: &dyn Array) -> Result { @@ -374,6 +489,8 @@ fn evaluate_scalar( } else if matches!(op, Or | And) { // TODO: optimize scalar Or | And Ok(None) + } else if matches!(op, BitwiseAnd) { + bitwise_and_scalar(lhs, rhs.clone()).transpose() } else { match (lhs.data_type(), op) { (DataType::Utf8, RegexMatch) => { @@ -459,6 +576,8 @@ pub fn binary_operator_data_type( | Operator::RegexNotIMatch | Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => Ok(DataType::Boolean), + // bitwise operations return the common coerced type + Operator::BitwiseAnd => Ok(result_type), // math operations return the same value as the common coerced type Operator::Plus | Operator::Minus @@ -725,6 +844,7 @@ mod tests { use crate::error::Result; use crate::field_util::SchemaExt; use crate::physical_plan::expressions::{col, lit}; + use crate::test_util::create_decimal_array; use arrow::datatypes::{Field, SchemaRef}; use arrow::error::ArrowError; @@ -912,7 +1032,11 @@ mod tests { } fn add_decimal(left: &Int128Array, right: &Int128Array) -> Result { - let mut decimal_builder = Int128Vec::with_capacity(left.len()); + let mut decimal_builder = Int128Vec::from_data( + left.data_type().clone(), + Vec::::with_capacity(left.len()), + None, + ); for i in 0..left.len() { if left.is_null(i) || right.is_null(i) { decimal_builder.push(None); @@ -924,7 +1048,11 @@ mod tests { } fn subtract_decimal(left: &Int128Array, right: &Int128Array) -> Result { - let mut decimal_builder = Int128Vec::with_capacity(left.len()); + let mut decimal_builder = Int128Vec::from_data( + left.data_type().clone(), + Vec::::with_capacity(left.len()), + None, + ); for i in 0..left.len() { if left.is_null(i) || right.is_null(i) { decimal_builder.push(None); @@ -940,7 +1068,11 @@ mod tests { right: &Int128Array, scale: u32, ) -> Result { - let mut decimal_builder = Int128Vec::with_capacity(left.len()); + let mut decimal_builder = Int128Vec::from_data( + left.data_type().clone(), + Vec::::with_capacity(left.len()), + None, + ); let divide = 10_i128.pow(scale); for i in 0..left.len() { if left.is_null(i) || right.is_null(i) { @@ -958,7 +1090,11 @@ mod tests { right: &Int128Array, scale: i32, ) -> Result { - let mut decimal_builder = Int128Vec::with_capacity(left.len()); + let mut decimal_builder = Int128Vec::from_data( + left.data_type().clone(), + Vec::::with_capacity(left.len()), + None, + ); let mul = 10_f64.powi(scale); for i in 0..left.len() { if left.is_null(i) || right.is_null(i) { @@ -978,7 +1114,11 @@ mod tests { } fn modulus_decimal(left: &Int128Array, right: &Int128Array) -> Result { - let mut decimal_builder = Int128Vec::with_capacity(left.len()); + let mut decimal_builder = Int128Vec::from_data( + left.data_type().clone(), + Vec::::with_capacity(left.len()), + None, + ); for i in 0..left.len() { if left.is_null(i) || right.is_null(i) { decimal_builder.push(None); @@ -1211,6 +1351,11 @@ mod tests { let b = Utf8Array::::from_slice(["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"]); let c = BooleanArray::from_slice(&[false, false, false, false, true]); test_coercion!(a, b, Operator::RegexNotIMatch, c); + + let a = Int16Array::from_slice(&[1i16, 2i16, 3i16]); + let b = Int64Array::from_slice(&[10i64, 4i64, 5i64]); + let c = Int64Array::from_slice(&[0i64, 0i64, 1i64]); + test_coercion!(a, b, Operator::BitwiseAnd, c); Ok(()) } @@ -2027,25 +2172,6 @@ mod tests { assert_eq!(result.as_ref(), &expected as &dyn Array); } - fn create_decimal_array( - array: &[Option], - _precision: usize, - _scale: usize, - ) -> Result { - let mut decimal_builder = Int128Vec::with_capacity(array.len()); - for value in array { - match value { - None => { - decimal_builder.push(None); - } - Some(v) => { - decimal_builder.try_push(Some(*v))?; - } - } - } - Ok(decimal_builder.into()) - } - #[test] fn comparison_decimal_op_test() -> Result<()> { let value_i128: i128 = 123; @@ -2604,4 +2730,25 @@ mod tests { Ok(()) } + + #[test] + fn bitwise_array_test() -> Result<()> { + let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; + let right = + Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef; + let result = bitwise_and(left.as_ref(), right.as_ref())?; + let expected = Int32Vec::from(vec![Some(0), None, Some(3)]).as_arc(); + assert_eq!(result.as_ref(), expected.as_ref()); + Ok(()) + } + + #[test] + fn bitwise_scalar_test() -> Result<()> { + let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; + let right = ScalarValue::from(3i32); + let result = bitwise_and_scalar(left.as_ref(), right).unwrap()?; + let expected = Int32Vec::from(vec![Some(0), None, Some(3)]).as_arc(); + assert_eq!(result.as_ref(), expected.as_ref()); + Ok(()) + } } diff --git a/datafusion/src/physical_plan/expressions/cast.rs b/datafusion/src/physical_plan/expressions/cast.rs index 2e4a9158eeca..bef1b741825a 100644 --- a/datafusion/src/physical_plan/expressions/cast.rs +++ b/datafusion/src/physical_plan/expressions/cast.rs @@ -191,6 +191,7 @@ mod tests { use crate::error::Result; use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; + use crate::test_util::create_decimal_array_from_slice; use arrow::{array::*, datatypes::*}; type StringArray = Utf8Array; @@ -297,7 +298,7 @@ mod tests { #[test] fn test_cast_decimal_to_decimal() -> Result<()> { let array: Vec = vec![1234, 2222, 3, 4000, 5000]; - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 3), @@ -314,7 +315,7 @@ mod tests { DEFAULT_DATAFUSION_CAST_OPTIONS ); - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 3), @@ -338,7 +339,7 @@ mod tests { fn test_cast_decimal_to_numeric() -> Result<()> { let array: Vec = vec![1, 2, 3, 4, 5]; // decimal to i8 - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 0), @@ -355,7 +356,7 @@ mod tests { DEFAULT_DATAFUSION_CAST_OPTIONS ); // decimal to i16 - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 0), @@ -372,7 +373,7 @@ mod tests { DEFAULT_DATAFUSION_CAST_OPTIONS ); // decimal to i32 - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 0), @@ -389,7 +390,7 @@ mod tests { DEFAULT_DATAFUSION_CAST_OPTIONS ); // decimal to i64 - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 0), @@ -407,7 +408,7 @@ mod tests { ); // decimal to float32 let array: Vec = vec![1234, 2222, 3, 4000, 5000]; - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 3), @@ -424,7 +425,7 @@ mod tests { DEFAULT_DATAFUSION_CAST_OPTIONS ); // decimal to float64 - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 20, 6)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(20, 6), diff --git a/datafusion/src/physical_plan/expressions/correlation.rs b/datafusion/src/physical_plan/expressions/correlation.rs index 9e973b193974..0ce1c79c22ef 100644 --- a/datafusion/src/physical_plan/expressions/correlation.rs +++ b/datafusion/src/physical_plan/expressions/correlation.rs @@ -284,10 +284,10 @@ mod tests { #[test] fn correlation_f64_6() -> Result<()> { - let a = Arc::new(Float64Array::from_slice(vec![ + let a = Arc::new(Float64Array::from_slice(&[ 1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64, ])); - let b = Arc::new(Float64Array::from_slice(vec![ + let b = Arc::new(Float64Array::from_slice(&[ 4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64, ])); diff --git a/datafusion/src/physical_plan/expressions/covariance.rs b/datafusion/src/physical_plan/expressions/covariance.rs index d89d5736129b..c7fd95e6ee4f 100644 --- a/datafusion/src/physical_plan/expressions/covariance.rs +++ b/datafusion/src/physical_plan/expressions/covariance.rs @@ -468,10 +468,10 @@ mod tests { #[test] fn covariance_f64_6() -> Result<()> { - let a = Arc::new(Float64Array::from_slice(vec![ + let a = Arc::new(Float64Array::from_slice(&[ 1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64, ])); - let b = Arc::new(Float64Array::from_slice(vec![ + let b = Arc::new(Float64Array::from_slice(&[ 4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64, ])); diff --git a/datafusion/src/physical_plan/expressions/distinct_expressions.rs b/datafusion/src/physical_plan/expressions/distinct_expressions.rs index f0a741bb2f0b..dba5ea87e78b 100644 --- a/datafusion/src/physical_plan/expressions/distinct_expressions.rs +++ b/datafusion/src/physical_plan/expressions/distinct_expressions.rs @@ -878,7 +878,7 @@ mod tests { #[test] fn distinct_array_agg_i32() -> Result<()> { - let col: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 7, 4, 5, 2])); + let col: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 4, 5, 2])); let out = ScalarValue::List( Some(Box::new(vec![ diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index 344833e962cf..279f12710775 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -227,7 +227,7 @@ mod tests { fn get_indexed_field_invalid_list_index() -> Result<()> { let schema = list_schema("l"); let expr = col("l", &schema).unwrap(); - get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, metadata: {} }) with 0 index") + get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, is_nullable: true, metadata: {} }) with 0 index") } fn build_struct( diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index c83fd492932e..e63a6fc9f871 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -27,6 +27,7 @@ use crate::record_batch::RecordBatch; use arrow::compute::sort::SortOptions; mod approx_distinct; +mod approx_percentile_cont; mod array_agg; mod average; #[macro_use] @@ -65,6 +66,9 @@ pub mod helpers { } pub use approx_distinct::ApproxDistinct; +pub use approx_percentile_cont::{ + is_approx_percentile_cont_supported_arg_type, ApproxPercentileCont, +}; pub use array_agg::ArrayAgg; pub(crate) use average::is_avg_support_arg_type; pub use average::{avg_return_type, Avg, AvgAccumulator}; diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index 72106f4cdded..d9f209a3bd74 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -260,7 +260,7 @@ mod tests { #[test] fn stddev_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64])); + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64])); generic_test_op!( a, DataType::Float64, @@ -272,7 +272,7 @@ mod tests { #[test] fn stddev_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64])); + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); generic_test_op!( a, DataType::Float64, @@ -284,7 +284,7 @@ mod tests { #[test] fn stddev_f64_3() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![ + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, ])); generic_test_op!( @@ -298,7 +298,7 @@ mod tests { #[test] fn stddev_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64])); + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); generic_test_op!( a, DataType::Float64, @@ -310,7 +310,7 @@ mod tests { #[test] fn stddev_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5])); + let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); generic_test_op!( a, DataType::Int32, @@ -322,7 +322,7 @@ mod tests { #[test] fn stddev_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![ + let a: ArrayRef = Arc::new(UInt32Array::from_slice(&[ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, ])); generic_test_op!( @@ -336,7 +336,7 @@ mod tests { #[test] fn stddev_f32() -> Result<()> { - let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![ + let a: ArrayRef = Arc::new(Float32Array::from_slice(&[ 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, ])); generic_test_op!( @@ -361,7 +361,7 @@ mod tests { #[test] fn test_stddev_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64])); + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64])); let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; diff --git a/datafusion/src/physical_plan/expressions/try_cast.rs b/datafusion/src/physical_plan/expressions/try_cast.rs index d47270c8a3a9..a2e74bbac798 100644 --- a/datafusion/src/physical_plan/expressions/try_cast.rs +++ b/datafusion/src/physical_plan/expressions/try_cast.rs @@ -129,6 +129,7 @@ mod tests { use crate::error::Result; use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; + use crate::test_util::create_decimal_array_from_slice; use arrow::{array::*, datatypes::*}; type StringArray = Utf8Array; @@ -234,7 +235,7 @@ mod tests { fn test_try_cast_decimal_to_decimal() -> Result<()> { // try cast one decimal data type to another decimal data type let array: Vec = vec![1234, 2222, 3, 4000, 5000]; - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 3), @@ -250,7 +251,7 @@ mod tests { ] ); - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 3), @@ -274,7 +275,7 @@ mod tests { // TODO we should add function to create Int128Array with value and metadata // https://github.com/apache/arrow-rs/issues/1009 let array: Vec = vec![1, 2, 3, 4, 5]; - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; // decimal to i8 generic_decimal_to_other_test_cast!( decimal_array, @@ -292,7 +293,7 @@ mod tests { ); // decimal to i16 - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 0), @@ -309,7 +310,7 @@ mod tests { ); // decimal to i32 - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 0), @@ -326,7 +327,7 @@ mod tests { ); // decimal to i64 - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 0), @@ -344,7 +345,7 @@ mod tests { // decimal to float32 let array: Vec = vec![1234, 2222, 3, 4000, 5000]; - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(10, 3), @@ -360,7 +361,7 @@ mod tests { ] ); // decimal to float64 - let decimal_array = Int128Array::from_slice(&array); + let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, DataType::Decimal(20, 6), diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 0ab9aa3482b4..9d851611df00 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -345,7 +345,7 @@ mod tests { #[test] fn variance_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64])); + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64])); generic_test_op!( a, DataType::Float64, @@ -357,7 +357,7 @@ mod tests { #[test] fn variance_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![ + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, ])); generic_test_op!( @@ -371,7 +371,7 @@ mod tests { #[test] fn variance_f64_3() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![ + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, ])); generic_test_op!( @@ -385,7 +385,7 @@ mod tests { #[test] fn variance_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64])); + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); generic_test_op!( a, DataType::Float64, @@ -397,7 +397,7 @@ mod tests { #[test] fn variance_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5])); + let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); generic_test_op!( a, DataType::Int32, @@ -409,7 +409,7 @@ mod tests { #[test] fn variance_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![ + let a: ArrayRef = Arc::new(UInt32Array::from_slice(&[ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, ])); generic_test_op!( @@ -424,7 +424,7 @@ mod tests { #[test] fn variance_f32() -> Result<()> { let a: ArrayRef = - Float32Vec::from_slice(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]).as_arc(); + Float32Vec::from_slice(&[1_f32, 2_f32, 3_f32, 4_f32, 5_f32]).as_arc(); generic_test_op!( a, DataType::Float32, @@ -447,7 +447,7 @@ mod tests { #[test] fn test_variance_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64])); + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64])); let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; diff --git a/datafusion/src/physical_plan/file_format/avro.rs b/datafusion/src/physical_plan/file_format/avro.rs index b5b5a829034f..894382863d95 100644 --- a/datafusion/src/physical_plan/file_format/avro.rs +++ b/datafusion/src/physical_plan/file_format/avro.rs @@ -162,13 +162,13 @@ impl ExecutionPlan for AvroExec { #[cfg(test)] #[cfg(feature = "avro")] mod tests { - use crate::datasource::file_format::{avro::AvroFormat, FileFormat}; use crate::datasource::object_store::local::{ local_object_reader_stream, local_unpartitioned_file, LocalFileSystem, }; use crate::field_util::SchemaExt; use crate::scalar::ScalarValue; + use arrow::datatypes::{DataType, Field, Schema}; use futures::StreamExt; use super::*; @@ -230,6 +230,71 @@ mod tests { Ok(()) } + #[tokio::test] + async fn avro_exec_missing_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let testdata = crate::test_util::arrow_test_data(); + let filename = format!("{}/avro/alltypes_plain.avro", testdata); + let actual_schema = AvroFormat {} + .infer_schema(local_object_reader_stream(vec![filename.clone()])) + .await?; + + let mut fields = actual_schema.fields().to_vec(); + fields.push(Field::new("missing_col", DataType::Int32, true)); + + let file_schema = Arc::new(Schema::new(fields)); + + let avro_exec = AvroExec::new(FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_groups: vec![vec![local_unpartitioned_file(filename.clone())]], + file_schema: file_schema.clone(), + statistics: Statistics::default(), + // Include the missing column in the projection + projection: Some(vec![0, 1, 2, file_schema.fields().len()]), + limit: None, + table_partition_cols: vec![], + }); + assert_eq!(avro_exec.output_partitioning().partition_count(), 1); + + let mut results = avro_exec + .execute(0, runtime) + .await + .expect("plan execution failed"); + let batch = results + .next() + .await + .expect("plan iterator empty") + .expect("plan iterator returned an error"); + + let expected = vec![ + "+----+----------+-------------+-------------+", + "| id | bool_col | tinyint_col | missing_col |", + "+----+----------+-------------+-------------+", + "| 4 | true | 0 | |", + "| 5 | false | 1 | |", + "| 6 | true | 0 | |", + "| 7 | false | 1 | |", + "| 2 | true | 0 | |", + "| 3 | false | 1 | |", + "| 0 | true | 0 | |", + "| 1 | false | 1 | |", + "+----+----------+-------------+-------------+", + ]; + + crate::assert_batches_eq!(expected, &[batch]); + + let batch = results.next().await; + assert!(batch.is_none()); + + let batch = results.next().await; + assert!(batch.is_none()); + + let batch = results.next().await; + assert!(batch.is_none()); + + Ok(()) + } + #[tokio::test] async fn avro_exec_with_partition() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); diff --git a/datafusion/src/physical_plan/file_format/csv.rs b/datafusion/src/physical_plan/file_format/csv.rs index bf7c21fa567a..8fa5c85bb472 100644 --- a/datafusion/src/physical_plan/file_format/csv.rs +++ b/datafusion/src/physical_plan/file_format/csv.rs @@ -258,6 +258,7 @@ impl ExecutionPlan for CsvExec { mod tests { use super::*; use crate::field_util::SchemaExt; + use crate::test_util::aggr_test_schema_with_missing_col; use crate::{ assert_batches_eq, datasource::object_store::local::{local_unpartitioned_file, LocalFileSystem}, @@ -353,7 +354,53 @@ mod tests { "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+", ]; - assert_batches_eq!(expected, &[batch]); + crate::assert_batches_eq!(expected, &[batch]); + + Ok(()) + } + + #[tokio::test] + async fn csv_exec_with_missing_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let file_schema = aggr_test_schema_with_missing_col(); + let testdata = crate::test_util::arrow_test_data(); + let filename = "aggregate_test_100.csv"; + let path = format!("{}/csv/{}", testdata, filename); + let csv = CsvExec::new( + FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_schema, + file_groups: vec![vec![local_unpartitioned_file(path)]], + statistics: Statistics::default(), + projection: None, + limit: Some(5), + table_partition_cols: vec![], + }, + true, + b',', + ); + assert_eq!(14, csv.base_config.file_schema.fields().len()); + assert_eq!(14, csv.projected_schema.fields().len()); + assert_eq!(14, csv.schema().fields().len()); + + let mut it = csv.execute(0, runtime).await?; + let batch = it.next().await.unwrap()?; + assert_eq!(14, batch.num_columns()); + assert_eq!(5, batch.num_rows()); + + let expected = vec![ + "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+-------------+", + "| c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 | missing_col |", + "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+-------------+", + "| c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | |", + "| d | 5 | -40 | 22614 | 706441268 | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107 | 0.3114712539863804 | C2GT5KVyOPZpgKVl110TyZO0NcJ434 | |", + "| b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | |", + "| a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | |", + "| b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd | |", + "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+-------------+", + ]; + + crate::assert_batches_eq!(expected, &[batch]); Ok(()) } diff --git a/datafusion/src/physical_plan/file_format/json.rs b/datafusion/src/physical_plan/file_format/json.rs index fca810bc198a..f1daef9035b6 100644 --- a/datafusion/src/physical_plan/file_format/json.rs +++ b/datafusion/src/physical_plan/file_format/json.rs @@ -194,6 +194,8 @@ impl ExecutionPlan for NdJsonExec { #[cfg(test)] mod tests { + use arrow::array::Array; + use arrow::datatypes::{Field, Schema}; use futures::StreamExt; use crate::datasource::{ @@ -269,6 +271,47 @@ mod tests { Ok(()) } + #[tokio::test] + async fn nd_json_exec_file_with_missing_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + use arrow::datatypes::DataType; + let path = format!("{}/1.json", TEST_DATA_BASE); + + let actual_schema = infer_schema(path.clone()).await?; + + let mut fields = actual_schema.fields().to_vec(); + fields.push(Field::new("missing_col", DataType::Int32, true)); + let missing_field_idx = fields.len() - 1; + + let file_schema = Arc::new(Schema::new(fields)); + + let exec = NdJsonExec::new(FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_groups: vec![vec![local_unpartitioned_file(path.clone())]], + file_schema, + statistics: Statistics::default(), + projection: None, + limit: Some(3), + table_partition_cols: vec![], + }); + + let mut it = exec.execute(0, runtime).await?; + let batch = it.next().await.unwrap()?; + + assert_eq!(batch.num_rows(), 3); + let values = batch + .column(missing_field_idx) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.len(), 3); + assert!(values.is_null(0)); + assert!(values.is_null(1)); + assert!(values.is_null(2)); + + Ok(()) + } + #[tokio::test] async fn nd_json_exec_file_projection() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); diff --git a/datafusion/src/physical_plan/file_format/mod.rs b/datafusion/src/physical_plan/file_format/mod.rs index 9b34e9df723c..a95a50d32e0f 100644 --- a/datafusion/src/physical_plan/file_format/mod.rs +++ b/datafusion/src/physical_plan/file_format/mod.rs @@ -35,14 +35,18 @@ pub use csv::CsvExec; pub use json::NdJsonExec; use std::iter; +use crate::error::DataFusionError; use crate::field_util::{FieldExt, SchemaExt}; use crate::{ datasource::{object_store::ObjectStore, PartitionedFile}, + error::Result, scalar::ScalarValue, }; +use arrow::array::new_null_array; use arrow::array::UInt8Array; use arrow::datatypes::IntegerType; use lazy_static::lazy_static; +use log::info; use std::{ collections::HashMap, fmt::{Display, Formatter, Result as FmtResult}, @@ -168,6 +172,89 @@ impl<'a> Display for FileGroupsDisplay<'a> { } } +/// A utility which can adapt file-level record batches to a table schema which may have a schema +/// obtained from merging multiple file-level schemas. +/// +/// This is useful for enabling schema evolution in partitioned datasets. +/// +/// This has to be done in two stages. +/// +/// 1. Before reading the file, we have to map projected column indexes from the table schema to +/// the file schema. +/// +/// 2. After reading a record batch we need to map the read columns back to the expected columns +/// indexes and insert null-valued columns wherever the file schema was missing a colum present +/// in the table schema. +#[derive(Clone, Debug)] +pub(crate) struct SchemaAdapter { + /// Schema for the table + table_schema: SchemaRef, +} + +impl SchemaAdapter { + pub(crate) fn new(table_schema: SchemaRef) -> SchemaAdapter { + Self { table_schema } + } + + /// Map projected column indexes to the file schema. This will fail if the table schema + /// and the file schema contain a field with the same name and different types. + pub fn map_projections( + &self, + file_schema: &Schema, + projections: &[usize], + ) -> Result> { + let mut mapped: Vec = vec![]; + for idx in projections { + let field = self.table_schema.field(*idx); + if let Ok(mapped_idx) = file_schema.index_of(field.name()) { + if file_schema.field(mapped_idx).data_type() == field.data_type() { + mapped.push(mapped_idx) + } else { + let msg = format!("Failed to map column projection for field {}. Incompatible data types {:?} and {:?}", field.name(), file_schema.field(mapped_idx).data_type(), field.data_type()); + info!("{}", msg); + return Err(DataFusionError::Execution(msg)); + } + } + } + Ok(mapped) + } + + /// Re-order projected columns by index in record batch to match table schema column ordering. If the record + /// batch does not contain a column for an expected field, insert a null-valued column at the + /// required column index. + pub fn adapt_batch( + &self, + batch: RecordBatch, + projections: &[usize], + ) -> Result { + let batch_rows = batch.num_rows(); + + let batch_schema = batch.schema(); + + let mut cols: Vec = Vec::with_capacity(batch.columns().len()); + let batch_cols = batch.columns().to_vec(); + + for field_idx in projections { + let table_field = &self.table_schema.fields()[*field_idx]; + if let Some((batch_idx, _name)) = + batch_schema.column_with_name(table_field.name()) + { + cols.push(batch_cols[batch_idx].clone()); + } else { + cols.push( + new_null_array(table_field.data_type().clone(), batch_rows).into(), + ) + } + } + + let projected_schema = Arc::new(self.table_schema.clone().project(projections)?); + + let merged_batch = RecordBatch::try_new(projected_schema, cols.clone())?; + + Ok(merged_batch) + } +} + /// A helper that projects partition columns into the file record batches. /// /// One interesting trick is the usage of a cache for the key buffers of the partition column @@ -459,6 +546,61 @@ mod tests { assert_batches_eq!(expected, &[projected_batch]); } + #[test] + fn schema_adapter_adapt_projections() { + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Int64, true), + Field::new("c3", DataType::Int8, true), + ])); + + let file_schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Int64, true), + ]); + + let file_schema_2 = Arc::new(Schema::new(vec![ + Field::new("c3", DataType::Int8, true), + Field::new("c2", DataType::Int64, true), + ])); + + let file_schema_3 = + Arc::new(Schema::new(vec![Field::new("c3", DataType::Float32, true)])); + + let adapter = SchemaAdapter::new(table_schema); + + let projections1: Vec = vec![0, 1, 2]; + let projections2: Vec = vec![2]; + + let mapped = adapter + .map_projections(&file_schema, projections1.as_slice()) + .expect("mapping projections"); + + assert_eq!(mapped, vec![0, 1]); + + let mapped = adapter + .map_projections(&file_schema, projections2.as_slice()) + .expect("mapping projections"); + + assert!(mapped.is_empty()); + + let mapped = adapter + .map_projections(&file_schema_2, projections1.as_slice()) + .expect("mapping projections"); + + assert_eq!(mapped, vec![1, 0]); + + let mapped = adapter + .map_projections(&file_schema_2, projections2.as_slice()) + .expect("mapping projections"); + + assert_eq!(mapped, vec![0]); + + let mapped = adapter.map_projections(&file_schema_3, projections1.as_slice()); + + assert!(mapped.is_err()); + } + // sets default for configs that play no role in projections fn config_for_projection( file_schema: SchemaRef, diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index 1903a8b7425d..063023b5488d 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -25,7 +25,7 @@ use std::{any::Any, convert::TryInto}; use crate::datasource::object_store::ObjectStore; use crate::datasource::PartitionedFile; -use crate::field_util::{FieldExt, SchemaExt}; +use crate::field_util::SchemaExt; use crate::record_batch::RecordBatch; use crate::{ error::{DataFusionError, Result}, @@ -40,13 +40,14 @@ use crate::{ }, scalar::ScalarValue, }; +use arrow::error::ArrowError; use arrow::{ array::ArrayRef, datatypes::*, error::Result as ArrowResult, io::parquet::read::{self, RowGroupMetaData}, }; -use log::{debug, info}; +use log::debug; use parquet::statistics::{ BinaryStatistics as ParquetBinaryStatistics, @@ -59,7 +60,9 @@ use tokio::{ task, }; +use crate::datasource::file_format::parquet::fetch_schema; use crate::execution::runtime_env::RuntimeEnv; +use crate::physical_plan::file_format::SchemaAdapter; use async_trait::async_trait; use super::PartitionColumnProjector; @@ -213,13 +216,14 @@ impl ExecutionPlan for ParquetExec { &self.base_config.table_partition_cols, ); - let file_schema_ref = self.base_config().file_schema.clone(); + let adapter = SchemaAdapter::new(self.base_config.file_schema.clone()); + let join_handle = task::spawn_blocking(move || { if let Err(e) = read_partition( object_store.as_ref(), - file_schema_ref, + adapter, partition_index, - partition, + &partition, metrics, &projection, &pruning_predicate, @@ -228,7 +232,10 @@ impl ExecutionPlan for ParquetExec { limit, partition_col_proj, ) { - println!("Parquet reader thread terminated due to error: {:?}", e); + println!( + "Parquet reader thread terminated due to error: {:?} for files: {:?}", + e, partition + ); } }); @@ -441,35 +448,12 @@ fn build_row_group_predicate( } } -// Map projections from the schema which merges all file schemas to projections on a particular -// file -fn map_projections( - merged_schema: &Schema, - file_schema: &Schema, - projections: &[usize], -) -> Result> { - let mut mapped: Vec = vec![]; - for idx in projections { - let field = merged_schema.field(*idx); - if let Ok(mapped_idx) = file_schema.index_of(field.name()) { - if file_schema.field(mapped_idx).data_type() == field.data_type() { - mapped.push(mapped_idx) - } else { - let msg = format!("Failed to map column projection for field {}. Incompatible data types {:?} and {:?}", field.name(), file_schema.field(mapped_idx).data_type(), field.data_type()); - info!("{}", msg); - return Err(DataFusionError::Execution(msg)); - } - } - } - Ok(mapped) -} - #[allow(clippy::too_many_arguments)] fn read_partition( object_store: &dyn ObjectStore, - file_schema: SchemaRef, + schema_adapter: SchemaAdapter, partition_index: usize, - partition: Vec, + partition: &[PartitionedFile], metrics: ExecutionPlanMetricsSet, projection: &[usize], pruning_predicate: &Option, @@ -478,7 +462,8 @@ fn read_partition( limit: Option, mut partition_column_projector: PartitionColumnProjector, ) -> Result<()> { - for partitioned_file in partition { + let mut total_rows = 0; + 'outer: for partitioned_file in partition { debug!("Reading file {}", &partitioned_file.file_meta.path()); let file_metrics = ParquetFileMetrics::new( @@ -490,16 +475,17 @@ fn read_partition( object_store.file_reader(partitioned_file.file_meta.sized_file.clone())?; let reader = object_reader.sync_reader()?; + let file_schema = fetch_schema(object_reader)?; + let adapted_projections = + schema_adapter.map_projections(&file_schema.clone(), projection)?; let mut record_reader = read::RecordReader::try_new( reader, - Some(projection.to_vec()), + Some(adapted_projections.clone()), limit, None, None, )?; - // TODO : ??? - let _mapped_projections = - map_projections(&file_schema, record_reader.schema(), projection)?; + if let Some(pruning_predicate) = pruning_predicate { record_reader.set_groups_filter(Arc::new(build_row_group_predicate( pruning_predicate, @@ -508,14 +494,40 @@ fn read_partition( ))); } - let schema = record_reader.schema().clone(); - for chunk in record_reader { - let batch = RecordBatch::new_with_chunk(&schema, chunk?); - let proj_batch = partition_column_projector - .project(batch, &partitioned_file.partition_values); - response_tx - .blocking_send(proj_batch) - .map_err(|x| DataFusionError::Execution(format!("{}", x)))?; + let read_schema = record_reader.schema().clone(); + for chunk_r in record_reader { + match chunk_r { + Ok(chunk) => { + total_rows += chunk.len(); + + let batch = RecordBatch::try_new( + read_schema.clone(), + chunk.columns().to_vec(), + )?; + + let adapted_batch = schema_adapter.adapt_batch(batch, projection)?; + + let proj_batch = partition_column_projector + .project(adapted_batch, &partitioned_file.partition_values); + response_tx + .blocking_send(proj_batch) + .map_err(|x| DataFusionError::Execution(format!("{}", x)))?; + if limit.map(|l| total_rows >= l).unwrap_or(false) { + break 'outer; + } + } + Err(e) => { + let err_msg = + format!("Error reading batch from {}: {}", partitioned_file, e); + // send error to operator + send_result( + &response_tx, + Err(ArrowError::ExternalFormat(err_msg.clone())), + )?; + // terminate thread with error + return Err(DataFusionError::Execution(err_msg)); + } + } } } @@ -555,8 +567,6 @@ mod tests { projection: Option>, schema: Option, ) -> Vec { - let runtime = Arc::new(RuntimeEnv::default()); - // When vec is dropped, temp files are deleted let files: Vec<_> = batches .into_iter() @@ -579,7 +589,7 @@ mod tests { iter.into_iter(), schema_ref, options, - vec![Encoding::Plain, Encoding::Plain], + vec![Encoding::Plain].repeat(schema_ref.fields.len()), ) .unwrap(); @@ -630,6 +640,7 @@ mod tests { None, ); + let runtime = Arc::new(RuntimeEnv::default()); collect(Arc::new(parquet_exec), runtime) .await .expect("reading parquet data") diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index ba5dc87f99b1..3a4d88fc3cbe 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -33,7 +33,7 @@ use super::{ type_coercion::{coerce, data_types}, ColumnarValue, PhysicalExpr, }; -use crate::execution::context::ExecutionContextState; +use crate::execution::context::ExecutionProps; use crate::physical_plan::array_expressions; use crate::physical_plan::datetime_expressions; use crate::physical_plan::expressions::cast::DEFAULT_DATAFUSION_CAST_OPTIONS; @@ -766,7 +766,7 @@ fn bit_length(array: &dyn Array) -> ArrowResult> { /// Create a physical scalar function. pub fn create_physical_fun( fun: &BuiltinScalarFunction, - ctx_state: &ExecutionContextState, + execution_props: &ExecutionProps, ) -> Result { Ok(match fun { // math functions @@ -865,7 +865,7 @@ pub fn create_physical_fun( BuiltinScalarFunction::Now => { // bind value for now at plan time Arc::new(datetime_expressions::make_now( - ctx_state.execution_props.query_execution_start_time, + execution_props.query_execution_start_time, )) } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { @@ -1202,7 +1202,7 @@ pub fn create_physical_expr( fun: &BuiltinScalarFunction, input_phy_exprs: &[Arc], input_schema: &Schema, - ctx_state: &ExecutionContextState, + execution_props: &ExecutionProps, ) -> Result> { let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?; @@ -1299,7 +1299,7 @@ pub fn create_physical_expr( } }), // These don't need args and input schema - _ => create_physical_fun(fun, ctx_state)?, + _ => create_physical_fun(fun, execution_props)?, }; Ok(Arc::new(ScalarFunctionExpr::new( @@ -1761,14 +1761,14 @@ mod tests { ($FUNC:ident, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $DATA_TYPE: ident, $ARRAY_TYPE:ident) => { // used to provide type annotation let expected: Result> = $EXPECTED; - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); // any type works here: we evaluate against a literal of `value` let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let columns: Vec = vec![Arc::new(Int32Array::from_slice(&[1]))]; let expr = - create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &ctx_state)?; + create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &execution_props)?; // type is correct assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE); @@ -3624,6 +3624,18 @@ mod tests { StringArray ); #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(-5))), + ], + Ok(Some("joséésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -3722,6 +3734,61 @@ mod tests { StringArray ); #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from 5 (10 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-5))), + lit(ScalarValue::Int64(Some(10))), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from -1 (4 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-5))), + lit(ScalarValue::Int64(Some(4))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + // starting from 0 (5 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-5))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -3930,7 +3997,7 @@ mod tests { #[test] fn test_empty_arguments_error() -> Result<()> { - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); // pick some arbitrary functions to test @@ -3942,7 +4009,7 @@ mod tests { ]; for fun in funs.iter() { - let expr = create_physical_expr(fun, &[], &schema, &ctx_state); + let expr = create_physical_expr(fun, &[], &schema, &execution_props); match expr { Ok(..) => { @@ -3973,13 +4040,13 @@ mod tests { #[test] fn test_empty_arguments() -> Result<()> { - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let funs = [BuiltinScalarFunction::Now, BuiltinScalarFunction::Random]; for fun in funs.iter() { - create_physical_expr(fun, &[], &schema, &ctx_state)?; + create_physical_expr(fun, &[], &schema, &execution_props)?; } Ok(()) } @@ -3995,13 +4062,13 @@ mod tests { Field::new("b", value2.data_type().clone(), false), ]); let columns: Vec = vec![value1, value2]; - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); let expr = create_physical_expr( &BuiltinScalarFunction::Array, &[col("a", &schema)?, col("b", &schema)?], &schema, - &ctx_state, + &execution_props, )?; // evaluate works @@ -4048,7 +4115,7 @@ mod tests { fn test_regexp_match() -> Result<()> { use arrow::array::ListArray; let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); // concat(value, value) let col_value: ArrayRef = Arc::new(StringArray::from_slice(&["aaa-555"])); @@ -4058,7 +4125,7 @@ mod tests { &BuiltinScalarFunction::RegexpMatch, &[col("a", &schema)?, pattern], &schema, - &ctx_state, + &execution_props, )?; // type is correct @@ -4088,7 +4155,7 @@ mod tests { fn test_regexp_match_all_literals() -> Result<()> { use arrow::array::ListArray; let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); @@ -4097,7 +4164,7 @@ mod tests { &BuiltinScalarFunction::RegexpMatch, &[col_value, pattern], &schema, - &ctx_state, + &execution_props, )?; // type is correct diff --git a/datafusion/src/physical_plan/metrics/aggregated.rs b/datafusion/src/physical_plan/metrics/aggregated.rs deleted file mode 100644 index c55cc1601768..000000000000 --- a/datafusion/src/physical_plan/metrics/aggregated.rs +++ /dev/null @@ -1,155 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Metrics common for complex operators with multiple steps. - -use crate::physical_plan::metrics::{ - BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricsSet, Time, -}; -use std::sync::Arc; -use std::time::Duration; - -#[derive(Debug, Clone)] -/// Aggregates all metrics during a complex operation, which is composed of multiple steps and -/// each stage reports its statistics separately. -/// Give sort as an example, when the dataset is more significant than available memory, it will report -/// multiple in-mem sort metrics and final merge-sort metrics from `SortPreservingMergeStream`. -/// Therefore, We need a separation of metrics for which are final metrics (for output_rows accumulation), -/// and which are intermediate metrics that we only account for elapsed_compute time. -pub struct AggregatedMetricsSet { - intermediate: Arc>>, - final_: Arc>>, -} - -impl Default for AggregatedMetricsSet { - fn default() -> Self { - Self::new() - } -} - -impl AggregatedMetricsSet { - /// Create a new aggregated set - pub fn new() -> Self { - Self { - intermediate: Arc::new(std::sync::Mutex::new(vec![])), - final_: Arc::new(std::sync::Mutex::new(vec![])), - } - } - - /// create a new intermediate baseline - pub fn new_intermediate_baseline(&self, partition: usize) -> BaselineMetrics { - let ms = ExecutionPlanMetricsSet::new(); - let result = BaselineMetrics::new(&ms, partition); - self.intermediate.lock().unwrap().push(ms); - result - } - - /// create a new final baseline - pub fn new_final_baseline(&self, partition: usize) -> BaselineMetrics { - let ms = ExecutionPlanMetricsSet::new(); - let result = BaselineMetrics::new(&ms, partition); - self.final_.lock().unwrap().push(ms); - result - } - - fn merge_compute_time(&self, dest: &Time) { - let time1 = self - .intermediate - .lock() - .unwrap() - .iter() - .map(|es| { - es.clone_inner() - .elapsed_compute() - .map_or(0u64, |v| v as u64) - }) - .sum(); - let time2 = self - .final_ - .lock() - .unwrap() - .iter() - .map(|es| { - es.clone_inner() - .elapsed_compute() - .map_or(0u64, |v| v as u64) - }) - .sum(); - dest.add_duration(Duration::from_nanos(time1)); - dest.add_duration(Duration::from_nanos(time2)); - } - - fn merge_spill_count(&self, dest: &Count) { - let count1 = self - .intermediate - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().spill_count().map_or(0, |v| v)) - .sum(); - let count2 = self - .final_ - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().spill_count().map_or(0, |v| v)) - .sum(); - dest.add(count1); - dest.add(count2); - } - - fn merge_spilled_bytes(&self, dest: &Count) { - let count1 = self - .intermediate - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().spilled_bytes().map_or(0, |v| v)) - .sum(); - let count2 = self - .final_ - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().spilled_bytes().map_or(0, |v| v)) - .sum(); - dest.add(count1); - dest.add(count2); - } - - fn merge_output_count(&self, dest: &Count) { - let count = self - .final_ - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().output_rows().map_or(0, |v| v)) - .sum(); - dest.add(count); - } - - /// Aggregate all metrics into a one - pub fn aggregate_all(&self) -> MetricsSet { - let metrics = ExecutionPlanMetricsSet::new(); - let baseline = BaselineMetrics::new(&metrics, 0); - self.merge_compute_time(baseline.elapsed_compute()); - self.merge_spill_count(baseline.spill_count()); - self.merge_spilled_bytes(baseline.spilled_bytes()); - self.merge_output_count(baseline.output_rows()); - metrics.clone_inner() - } -} diff --git a/datafusion/src/physical_plan/metrics/baseline.rs b/datafusion/src/physical_plan/metrics/baseline.rs index a095360ef54c..6810cf18e56c 100644 --- a/datafusion/src/physical_plan/metrics/baseline.rs +++ b/datafusion/src/physical_plan/metrics/baseline.rs @@ -113,7 +113,7 @@ impl BaselineMetrics { /// Records the fact that this operator's execution is complete /// (recording the `end_time` metric). /// - /// Note care should be taken to call `done()` maually if + /// Note care should be taken to call `done()` manually if /// `BaselineMetrics` is not `drop`ped immediately upon operator /// completion, as async streams may not be dropped immediately /// depending on the consumer. @@ -129,6 +129,13 @@ impl BaselineMetrics { self.output_rows.add(num_rows); } + /// If not previously recorded `done()`, record + pub fn try_done(&self) { + if self.end_time.value().is_none() { + self.end_time.record() + } + } + /// Process a poll result of a stream producing output for an /// operator, recording the output rows and stream done time and /// returning the same poll result @@ -151,10 +158,7 @@ impl BaselineMetrics { impl Drop for BaselineMetrics { fn drop(&mut self) { - // if not previously recorded, record - if self.end_time.value().is_none() { - self.end_time.record() - } + self.try_done() } } diff --git a/datafusion/src/physical_plan/metrics/composite.rs b/datafusion/src/physical_plan/metrics/composite.rs new file mode 100644 index 000000000000..cd4d5c38a9ec --- /dev/null +++ b/datafusion/src/physical_plan/metrics/composite.rs @@ -0,0 +1,205 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Metrics common for complex operators with multiple steps. + +use crate::execution::runtime_env::RuntimeEnv; +use crate::physical_plan::metrics::tracker::MemTrackingMetrics; +use crate::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricValue, MetricsSet, Time, + Timestamp, +}; +use crate::physical_plan::Metric; +use chrono::{TimeZone, Utc}; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Debug, Clone)] +/// Collects all metrics during a complex operation, which is composed of multiple steps and +/// each stage reports its statistics separately. +/// Give sort as an example, when the dataset is more significant than available memory, it will report +/// multiple in-mem sort metrics and final merge-sort metrics from `SortPreservingMergeStream`. +/// Therefore, We need a separation of metrics for which are final metrics (for output_rows accumulation), +/// and which are intermediate metrics that we only account for elapsed_compute time. +pub struct CompositeMetricsSet { + mid: ExecutionPlanMetricsSet, + final_: ExecutionPlanMetricsSet, +} + +impl Default for CompositeMetricsSet { + fn default() -> Self { + Self::new() + } +} + +impl CompositeMetricsSet { + /// Create a new aggregated set + pub fn new() -> Self { + Self { + mid: ExecutionPlanMetricsSet::new(), + final_: ExecutionPlanMetricsSet::new(), + } + } + + /// create a new intermediate baseline + pub fn new_intermediate_baseline(&self, partition: usize) -> BaselineMetrics { + BaselineMetrics::new(&self.mid, partition) + } + + /// create a new final baseline + pub fn new_final_baseline(&self, partition: usize) -> BaselineMetrics { + BaselineMetrics::new(&self.final_, partition) + } + + /// create a new intermediate memory tracking metrics + pub fn new_intermediate_tracking( + &self, + partition: usize, + runtime: Arc, + ) -> MemTrackingMetrics { + MemTrackingMetrics::new_with_rt(&self.mid, partition, runtime) + } + + /// create a new final memory tracking metrics + pub fn new_final_tracking( + &self, + partition: usize, + runtime: Arc, + ) -> MemTrackingMetrics { + MemTrackingMetrics::new_with_rt(&self.final_, partition, runtime) + } + + fn merge_compute_time(&self, dest: &Time) { + let time1 = self + .mid + .clone_inner() + .elapsed_compute() + .map_or(0u64, |v| v as u64); + let time2 = self + .final_ + .clone_inner() + .elapsed_compute() + .map_or(0u64, |v| v as u64); + dest.add_duration(Duration::from_nanos(time1)); + dest.add_duration(Duration::from_nanos(time2)); + } + + fn merge_spill_count(&self, dest: &Count) { + let count1 = self.mid.clone_inner().spill_count().map_or(0, |v| v); + let count2 = self.final_.clone_inner().spill_count().map_or(0, |v| v); + dest.add(count1); + dest.add(count2); + } + + fn merge_spilled_bytes(&self, dest: &Count) { + let count1 = self.mid.clone_inner().spilled_bytes().map_or(0, |v| v); + let count2 = self.final_.clone_inner().spill_count().map_or(0, |v| v); + dest.add(count1); + dest.add(count2); + } + + fn merge_output_count(&self, dest: &Count) { + let count = self.final_.clone_inner().output_rows().map_or(0, |v| v); + dest.add(count); + } + + fn merge_start_time(&self, dest: &Timestamp) { + let start1 = self + .mid + .clone_inner() + .sum(|metric| matches!(metric.value(), MetricValue::StartTimestamp(_))) + .map(|v| v.as_usize()); + let start2 = self + .final_ + .clone_inner() + .sum(|metric| matches!(metric.value(), MetricValue::StartTimestamp(_))) + .map(|v| v.as_usize()); + match (start1, start2) { + (Some(start1), Some(start2)) => { + dest.set(Utc.timestamp_nanos(start1.min(start2) as i64)) + } + (Some(start1), None) => dest.set(Utc.timestamp_nanos(start1 as i64)), + (None, Some(start2)) => dest.set(Utc.timestamp_nanos(start2 as i64)), + (None, None) => {} + } + } + + fn merge_end_time(&self, dest: &Timestamp) { + let start1 = self + .mid + .clone_inner() + .sum(|metric| matches!(metric.value(), MetricValue::EndTimestamp(_))) + .map(|v| v.as_usize()); + let start2 = self + .final_ + .clone_inner() + .sum(|metric| matches!(metric.value(), MetricValue::EndTimestamp(_))) + .map(|v| v.as_usize()); + match (start1, start2) { + (Some(start1), Some(start2)) => { + dest.set(Utc.timestamp_nanos(start1.max(start2) as i64)) + } + (Some(start1), None) => dest.set(Utc.timestamp_nanos(start1 as i64)), + (None, Some(start2)) => dest.set(Utc.timestamp_nanos(start2 as i64)), + (None, None) => {} + } + } + + /// Aggregate all metrics into a one + pub fn aggregate_all(&self) -> MetricsSet { + let mut metrics = MetricsSet::new(); + let elapsed_time = Time::new(); + let spill_count = Count::new(); + let spilled_bytes = Count::new(); + let output_count = Count::new(); + let start_time = Timestamp::new(); + let end_time = Timestamp::new(); + + metrics.push(Arc::new(Metric::new( + MetricValue::ElapsedCompute(elapsed_time.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::SpillCount(spill_count.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::SpilledBytes(spilled_bytes.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::OutputRows(output_count.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::StartTimestamp(start_time.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::EndTimestamp(end_time.clone()), + None, + ))); + + self.merge_compute_time(&elapsed_time); + self.merge_spill_count(&spill_count); + self.merge_spilled_bytes(&spilled_bytes); + self.merge_output_count(&output_count); + self.merge_start_time(&start_time); + self.merge_end_time(&end_time); + metrics + } +} diff --git a/datafusion/src/physical_plan/metrics/mod.rs b/datafusion/src/physical_plan/metrics/mod.rs index d48959974e8d..021f2df823ae 100644 --- a/datafusion/src/physical_plan/metrics/mod.rs +++ b/datafusion/src/physical_plan/metrics/mod.rs @@ -17,23 +17,26 @@ //! Metrics for recording information about execution -mod aggregated; mod baseline; mod builder; +mod composite; +mod tracker; mod value; +use parking_lot::Mutex; use std::{ borrow::Cow, fmt::{Debug, Display}, - sync::{Arc, Mutex}, + sync::Arc, }; use hashbrown::HashMap; // public exports -pub use aggregated::AggregatedMetricsSet; pub use baseline::{BaselineMetrics, RecordOutput}; pub use builder::MetricBuilder; +pub use composite::CompositeMetricsSet; +pub use tracker::MemTrackingMetrics; pub use value::{Count, Gauge, MetricValue, ScopedTimerGuard, Time, Timestamp}; /// Something that tracks a value of interest (metric) of a DataFusion @@ -337,12 +340,12 @@ impl ExecutionPlanMetricsSet { /// Add the specified metric to the underlying metric set pub fn register(&self, metric: Arc) { - self.inner.lock().expect("not poisoned").push(metric) + self.inner.lock().push(metric) } /// Return a clone of the inner MetricsSet pub fn clone_inner(&self) -> MetricsSet { - let guard = self.inner.lock().expect("not poisoned"); + let guard = self.inner.lock(); (*guard).clone() } } diff --git a/datafusion/src/physical_plan/metrics/tracker.rs b/datafusion/src/physical_plan/metrics/tracker.rs new file mode 100644 index 000000000000..b14e9a6f72c5 --- /dev/null +++ b/datafusion/src/physical_plan/metrics/tracker.rs @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Metrics with memory usage tracking capability + +use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::MemoryConsumerId; +use crate::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, Time, +}; +use std::sync::Arc; +use std::task::Poll; + +use crate::record_batch::RecordBatch; +use arrow::error::ArrowError; + +/// Simplified version of tracking memory consumer, +/// see also: [`Tracking`](crate::execution::memory_manager::ConsumerType::Tracking) +/// +/// You could use this to replace [BaselineMetrics], report the memory, +/// and get the memory usage bookkeeping in the memory manager easily. +#[derive(Debug)] +pub struct MemTrackingMetrics { + id: MemoryConsumerId, + runtime: Option>, + metrics: BaselineMetrics, +} + +/// Delegates most of the metrics functionalities to the inner BaselineMetrics, +/// intercept memory metrics functionalities and do memory manager bookkeeping. +impl MemTrackingMetrics { + /// Create metrics similar to [BaselineMetrics] + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + let id = MemoryConsumerId::new(partition); + Self { + id, + runtime: None, + metrics: BaselineMetrics::new(metrics, partition), + } + } + + /// Create memory tracking metrics with reference to runtime + pub fn new_with_rt( + metrics: &ExecutionPlanMetricsSet, + partition: usize, + runtime: Arc, + ) -> Self { + let id = MemoryConsumerId::new(partition); + Self { + id, + runtime: Some(runtime), + metrics: BaselineMetrics::new(metrics, partition), + } + } + + /// return the metric for cpu time spend in this operator + pub fn elapsed_compute(&self) -> &Time { + self.metrics.elapsed_compute() + } + + /// return the size for current memory usage + pub fn mem_used(&self) -> usize { + self.metrics.mem_used().value() + } + + /// setup initial memory usage and register it with memory manager + pub fn init_mem_used(&self, size: usize) { + self.metrics.mem_used().set(size); + if let Some(rt) = self.runtime.as_ref() { + rt.memory_manager.grow_tracker_usage(size); + } + } + + /// return the metric for the total number of output rows produced + pub fn output_rows(&self) -> &Count { + self.metrics.output_rows() + } + + /// Records the fact that this operator's execution is complete + /// (recording the `end_time` metric). + /// + /// Note care should be taken to call `done()` manually if + /// `MemTrackingMetrics` is not `drop`ped immediately upon operator + /// completion, as async streams may not be dropped immediately + /// depending on the consumer. + pub fn done(&self) { + self.metrics.done() + } + + /// Record that some number of rows have been produced as output + /// + /// See the [`RecordOutput`] for conveniently recording record + /// batch output for other thing + pub fn record_output(&self, num_rows: usize) { + self.metrics.record_output(num_rows) + } + + /// Process a poll result of a stream producing output for an + /// operator, recording the output rows and stream done time and + /// returning the same poll result + pub fn record_poll( + &self, + poll: Poll>>, + ) -> Poll>> { + self.metrics.record_poll(poll) + } +} + +impl Drop for MemTrackingMetrics { + fn drop(&mut self) { + self.metrics.try_done(); + if self.mem_used() != 0 { + if let Some(rt) = self.runtime.as_ref() { + rt.drop_consumer(&self.id, self.mem_used()); + } + } + } +} diff --git a/datafusion/src/physical_plan/metrics/value.rs b/datafusion/src/physical_plan/metrics/value.rs index 6ac282a496ee..43a0ad236500 100644 --- a/datafusion/src/physical_plan/metrics/value.rs +++ b/datafusion/src/physical_plan/metrics/value.rs @@ -22,11 +22,13 @@ use std::{ fmt::Display, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex, + Arc, }, time::{Duration, Instant}, }; +use parking_lot::Mutex; + use chrono::{DateTime, Utc}; /// A counter to record things such as number of input or output rows @@ -229,7 +231,7 @@ impl Timestamp { /// Sets the timestamps value to a specified time pub fn set(&self, now: DateTime) { - *self.timestamp.lock().unwrap() = Some(now); + *self.timestamp.lock() = Some(now); } /// return the timestamps value at the last time `record()` was @@ -237,7 +239,7 @@ impl Timestamp { /// /// Returns `None` if `record()` has not been called pub fn value(&self) -> Option> { - *self.timestamp.lock().unwrap() + *self.timestamp.lock() } /// sets the value of this timestamp to the minimum of this and other @@ -249,7 +251,7 @@ impl Timestamp { (Some(v1), Some(v2)) => Some(if v1 < v2 { v1 } else { v2 }), }; - *self.timestamp.lock().unwrap() = min; + *self.timestamp.lock() = min; } /// sets the value of this timestamp to the maximum of this and other @@ -261,7 +263,7 @@ impl Timestamp { (Some(v1), Some(v2)) => Some(if v1 < v2 { v2 } else { v1 }), }; - *self.timestamp.lock().unwrap() = max; + *self.timestamp.lock() = max; } } diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 79f0aa499c33..5d605a02abe9 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -62,7 +62,7 @@ pub type SendableRecordBatchStream = Pin usize { use Partitioning::*; match self { - RoundRobinBatch(n) => *n, - Hash(_, n) => *n, - UnknownPartitioning(n) => *n, + RoundRobinBatch(n) | Hash(_, n) | UnknownPartitioning(n) => *n, } } } @@ -669,6 +667,7 @@ pub mod repartition; pub mod sorts; pub mod stream; pub mod string_expressions; +pub(crate) mod tdigest; pub mod type_coercion; pub mod udaf; pub mod udf; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 84821a067179..38658e14e2e9 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -22,7 +22,7 @@ use super::{ aggregates, empty::EmptyExec, expressions::binary, functions, hash_join::PartitionMode, udaf, union::UnionExec, values::ValuesExec, windows, }; -use crate::execution::context::ExecutionContextState; +use crate::execution::context::{ExecutionContextState, ExecutionProps}; use crate::field_util::{FieldExt, SchemaExt}; use crate::logical_plan::plan::{ Aggregate, EmptyRelation, Filter, Join, Projection, Sort, TableScan, Window, @@ -227,9 +227,9 @@ pub trait PhysicalPlanner { /// /// `expr`: the expression to convert /// - /// `input_dfschema`: the logical plan schema for evaluating `e` + /// `input_dfschema`: the logical plan schema for evaluating `expr` /// - /// `input_schema`: the physical schema for evaluating `e` + /// `input_schema`: the physical schema for evaluating `expr` fn create_physical_expr( &self, expr: &Expr, @@ -300,12 +300,11 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { input_schema: &Schema, ctx_state: &ExecutionContextState, ) -> Result> { - DefaultPhysicalPlanner::create_physical_expr( - self, + create_physical_expr( expr, input_dfschema, input_schema, - ctx_state, + &ctx_state.execution_props, ) } } @@ -441,7 +440,7 @@ impl DefaultPhysicalPlanner { expr, asc, nulls_first, - } => self.create_physical_sort_expr( + } => create_physical_sort_expr( expr, logical_input_schema, &physical_input_schema, @@ -449,7 +448,7 @@ impl DefaultPhysicalPlanner { descending: !*asc, nulls_first: *nulls_first, }, - ctx_state, + &ctx_state.execution_props, ), _ => unreachable!(), }) @@ -465,11 +464,11 @@ impl DefaultPhysicalPlanner { let window_expr = window_expr .iter() .map(|e| { - self.create_window_expr( + create_window_expr( e, logical_input_schema, &physical_input_schema, - ctx_state, + &ctx_state.execution_props, ) }) .collect::>>()?; @@ -508,11 +507,11 @@ impl DefaultPhysicalPlanner { let aggregates = aggr_expr .iter() .map(|e| { - self.create_aggregate_expr( + create_aggregate_expr( e, logical_input_schema, &physical_input_schema, - ctx_state, + &ctx_state.execution_props, ) }) .collect::>>()?; @@ -689,7 +688,7 @@ impl DefaultPhysicalPlanner { expr, asc, nulls_first, - } => self.create_physical_sort_expr( + } => create_physical_sort_expr( expr, input_dfschema, &input_schema, @@ -697,7 +696,7 @@ impl DefaultPhysicalPlanner { descending: !*asc, nulls_first: *nulls_first, }, - ctx_state, + &ctx_state.execution_props, ), _ => Err(DataFusionError::Plan( "Sort only accepts sort expressions".to_string(), @@ -867,517 +866,487 @@ impl DefaultPhysicalPlanner { exec_plan }.boxed() } +} - /// Create a physical expression from a logical expression - pub fn create_physical_expr( - &self, - e: &Expr, - input_dfschema: &DFSchema, - input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - match e { - Expr::Alias(expr, ..) => Ok(self.create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - )?), - Expr::Column(c) => { - let idx = input_dfschema.index_of_column(c)?; - Ok(Arc::new(Column::new(&c.name, idx))) - } - Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), - Expr::ScalarVariable(variable_names) => { - if &variable_names[0][0..2] == "@@" { - match ctx_state.var_provider.get(&VarType::System) { - Some(provider) => { - let scalar_value = - provider.get_value(variable_names.clone())?; - Ok(Arc::new(Literal::new(scalar_value))) - } - _ => Err(DataFusionError::Plan( - "No system variable provider found".to_string(), - )), +/// Create a physical expression from a logical expression ([Expr]) +pub fn create_physical_expr( + e: &Expr, + input_dfschema: &DFSchema, + input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + match e { + Expr::Alias(expr, ..) => Ok(create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?), + Expr::Column(c) => { + let idx = input_dfschema.index_of_column(c)?; + Ok(Arc::new(Column::new(&c.name, idx))) + } + Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), + Expr::ScalarVariable(variable_names) => { + if &variable_names[0][0..2] == "@@" { + match execution_props.get_var_provider(VarType::System) { + Some(provider) => { + let scalar_value = provider.get_value(variable_names.clone())?; + Ok(Arc::new(Literal::new(scalar_value))) } - } else { - match ctx_state.var_provider.get(&VarType::UserDefined) { - Some(provider) => { - let scalar_value = - provider.get_value(variable_names.clone())?; - Ok(Arc::new(Literal::new(scalar_value))) - } - _ => Err(DataFusionError::Plan( - "No user defined variable provider found".to_string(), - )), + _ => Err(DataFusionError::Plan( + "No system variable provider found".to_string(), + )), + } + } else { + match execution_props.get_var_provider(VarType::UserDefined) { + Some(provider) => { + let scalar_value = provider.get_value(variable_names.clone())?; + Ok(Arc::new(Literal::new(scalar_value))) } + _ => Err(DataFusionError::Plan( + "No user defined variable provider found".to_string(), + )), } } - Expr::BinaryExpr { left, op, right } => { - let lhs = self.create_physical_expr( - left, - input_dfschema, - input_schema, - ctx_state, - )?; - let rhs = self.create_physical_expr( - right, + } + Expr::BinaryExpr { left, op, right } => { + let lhs = create_physical_expr( + left, + input_dfschema, + input_schema, + execution_props, + )?; + let rhs = create_physical_expr( + right, + input_dfschema, + input_schema, + execution_props, + )?; + binary(lhs, *op, rhs, input_schema) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + .. + } => { + let expr: Option> = if let Some(e) = expr { + Some(create_physical_expr( + e.as_ref(), input_dfschema, input_schema, - ctx_state, - )?; - binary(lhs, *op, rhs, input_schema) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - .. - } => { - let expr: Option> = if let Some(e) = expr { - Some(self.create_physical_expr( - e.as_ref(), + execution_props, + )?) + } else { + None + }; + let when_expr = when_then_expr + .iter() + .map(|(w, _)| { + create_physical_expr( + w.as_ref(), input_dfschema, input_schema, - ctx_state, - )?) - } else { - None - }; - let when_expr = when_then_expr - .iter() - .map(|(w, _)| { - self.create_physical_expr( - w.as_ref(), - input_dfschema, - input_schema, - ctx_state, - ) - }) - .collect::>>()?; - let then_expr = when_then_expr - .iter() - .map(|(_, t)| { - self.create_physical_expr( - t.as_ref(), - input_dfschema, - input_schema, - ctx_state, - ) - }) - .collect::>>()?; - let when_then_expr: Vec<(Arc, Arc)> = - when_expr - .iter() - .zip(then_expr.iter()) - .map(|(w, t)| (w.clone(), t.clone())) - .collect(); - let else_expr: Option> = if let Some(e) = else_expr - { - Some(self.create_physical_expr( - e.as_ref(), + execution_props, + ) + }) + .collect::>>()?; + let then_expr = when_then_expr + .iter() + .map(|(_, t)| { + create_physical_expr( + t.as_ref(), input_dfschema, input_schema, - ctx_state, - )?) - } else { - None - }; - Ok(Arc::new(CaseExpr::try_new( - expr, - &when_then_expr, - else_expr, - )?)) - } - Expr::Cast { expr, data_type } => expressions::cast( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - input_schema, - data_type.clone(), - ), - Expr::TryCast { expr, data_type } => expressions::try_cast( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - input_schema, - data_type.clone(), - ), - Expr::Not(expr) => expressions::not( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - input_schema, - ), - Expr::Negative(expr) => expressions::negative( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - input_schema, - ), - Expr::IsNull(expr) => expressions::is_null(self.create_physical_expr( + execution_props, + ) + }) + .collect::>>()?; + let when_then_expr: Vec<(Arc, Arc)> = + when_expr + .iter() + .zip(then_expr.iter()) + .map(|(w, t)| (w.clone(), t.clone())) + .collect(); + let else_expr: Option> = if let Some(e) = else_expr { + Some(create_physical_expr( + e.as_ref(), + input_dfschema, + input_schema, + execution_props, + )?) + } else { + None + }; + Ok(Arc::new(CaseExpr::try_new( expr, - input_dfschema, + &when_then_expr, + else_expr, + )?)) + } + Expr::Cast { expr, data_type } => expressions::cast( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + input_schema, + data_type.clone(), + ), + Expr::TryCast { expr, data_type } => expressions::try_cast( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + input_schema, + data_type.clone(), + ), + Expr::Not(expr) => expressions::not( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + input_schema, + ), + Expr::Negative(expr) => expressions::negative( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + input_schema, + ), + Expr::IsNull(expr) => expressions::is_null(create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?), + Expr::IsNotNull(expr) => expressions::is_not_null(create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?), + Expr::GetIndexedField { expr, key } => Ok(Arc::new(GetIndexedFieldExpr::new( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + key.clone(), + ))), + + Expr::ScalarFunction { fun, args } => { + let physical_args = args + .iter() + .map(|e| { + create_physical_expr(e, input_dfschema, input_schema, execution_props) + }) + .collect::>>()?; + functions::create_physical_expr( + fun, + &physical_args, input_schema, - ctx_state, - )?), - Expr::IsNotNull(expr) => expressions::is_not_null( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - ), - Expr::GetIndexedField { expr, key } => { - Ok(Arc::new(GetIndexedFieldExpr::new( - self.create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - )?, - key.clone(), - ))) - } - - Expr::ScalarFunction { fun, args } => { - let physical_args = args - .iter() - .map(|e| { - self.create_physical_expr( - e, - input_dfschema, - input_schema, - ctx_state, - ) - }) - .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, + execution_props, + ) + } + Expr::ScalarUDF { fun, args } => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(create_physical_expr( + e, + input_dfschema, input_schema, - ctx_state, - ) + execution_props, + )?); } - Expr::ScalarUDF { fun, args } => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(self.create_physical_expr( - e, - input_dfschema, - input_schema, - ctx_state, - )?); - } - udf::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - input_schema, - ) - } - Expr::Between { + udf::create_physical_expr(fun.clone().as_ref(), &physical_args, input_schema) + } + Expr::Between { + expr, + negated, + low, + high, + } => { + let value_expr = create_physical_expr( expr, - negated, - low, + input_dfschema, + input_schema, + execution_props, + )?; + let low_expr = + create_physical_expr(low, input_dfschema, input_schema, execution_props)?; + let high_expr = create_physical_expr( high, - } => { - let value_expr = self.create_physical_expr( + input_dfschema, + input_schema, + execution_props, + )?; + + // rewrite the between into the two binary operators + let binary_expr = binary( + binary(value_expr.clone(), Operator::GtEq, low_expr, input_schema)?, + Operator::And, + binary(value_expr.clone(), Operator::LtEq, high_expr, input_schema)?, + input_schema, + ); + + if *negated { + expressions::not(binary_expr?, input_schema) + } else { + binary_expr + } + } + Expr::InList { + expr, + list, + negated, + } => match expr.as_ref() { + Expr::Literal(ScalarValue::Utf8(None)) => { + Ok(expressions::lit(ScalarValue::Boolean(None))) + } + _ => { + let value_expr = create_physical_expr( expr, input_dfschema, input_schema, - ctx_state, + execution_props, )?; - let low_expr = self.create_physical_expr( - low, - input_dfschema, - input_schema, - ctx_state, - )?; - let high_expr = self.create_physical_expr( - high, - input_dfschema, - input_schema, - ctx_state, - )?; - - // rewrite the between into the two binary operators - let binary_expr = binary( - binary(value_expr.clone(), Operator::GtEq, low_expr, input_schema)?, - Operator::And, - binary(value_expr.clone(), Operator::LtEq, high_expr, input_schema)?, - input_schema, - ); - - if *negated { - expressions::not(binary_expr?, input_schema) - } else { - binary_expr - } - } - Expr::InList { - expr, - list, - negated, - } => match expr.as_ref() { - Expr::Literal(ScalarValue::Utf8(None)) => { - Ok(expressions::lit(ScalarValue::Boolean(None))) - } - _ => { - let value_expr = self.create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - )?; - let value_expr_data_type = value_expr.data_type(input_schema)?; - - let list_exprs = list - .iter() - .map(|expr| match expr { - Expr::Literal(ScalarValue::Utf8(None)) => self - .create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - ), - _ => { - let list_expr = self.create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - )?; - let list_expr_data_type = - list_expr.data_type(input_schema)?; - - if list_expr_data_type == value_expr_data_type { - Ok(list_expr) - } else if can_cast_types( - &list_expr_data_type, - &value_expr_data_type, - ) { - expressions::cast( - list_expr, - input_schema, - value_expr.data_type(input_schema)?, - ) - } else { - Err(DataFusionError::Plan(format!( - "Unsupported CAST from {:?} to {:?}", - list_expr_data_type, value_expr_data_type - ))) - } - } - }) - .collect::>>()?; + let value_expr_data_type = value_expr.data_type(input_schema)?; - expressions::in_list(value_expr, list_exprs, negated) - } - }, - other => Err(DataFusionError::NotImplemented(format!( - "Physical plan does not support logical expression {:?}", - other - ))), - } - } - - /// Create a window expression with a name from a logical expression - pub fn create_window_expr_with_name( - &self, - e: &Expr, - name: impl Into, - logical_input_schema: &DFSchema, - physical_input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - let name = name.into(); - match e { - Expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - } => { - let args = args - .iter() - .map(|e| { - self.create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - }) - .collect::>>()?; - let partition_by = partition_by + let list_exprs = list .iter() - .map(|e| { - self.create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - }) - .collect::>>()?; - let order_by = order_by - .iter() - .map(|e| match e { - Expr::Sort { - expr, - asc, - nulls_first, - } => self.create_physical_sort_expr( + .map(|expr| match expr { + Expr::Literal(ScalarValue::Utf8(None)) => create_physical_expr( expr, - logical_input_schema, - physical_input_schema, - SortOptions { - descending: !*asc, - nulls_first: *nulls_first, - }, - ctx_state, + input_dfschema, + input_schema, + execution_props, ), - _ => Err(DataFusionError::Plan( - "Sort only accepts sort expressions".to_string(), - )), + _ => { + let list_expr = create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?; + let list_expr_data_type = + list_expr.data_type(input_schema)?; + + if list_expr_data_type == value_expr_data_type { + Ok(list_expr) + } else if can_cast_types( + &list_expr_data_type, + &value_expr_data_type, + ) { + expressions::cast( + list_expr, + input_schema, + value_expr.data_type(input_schema)?, + ) + } else { + Err(DataFusionError::Plan(format!( + "Unsupported CAST from {:?} to {:?}", + list_expr_data_type, value_expr_data_type + ))) + } + } }) .collect::>>()?; - if window_frame.is_some() { - return Err(DataFusionError::NotImplemented( - "window expression with window frame definition is not yet supported" - .to_owned(), - )); - } - windows::create_window_expr( - fun, - name, - &args, - &partition_by, - &order_by, - *window_frame, - physical_input_schema, - ) + + expressions::in_list(value_expr, list_exprs, negated) } - other => Err(DataFusionError::Internal(format!( - "Invalid window expression '{:?}'", - other - ))), - } + }, + other => Err(DataFusionError::NotImplemented(format!( + "Physical plan does not support logical expression {:?}", + other + ))), } +} - /// Create a window expression from a logical expression or an alias - pub fn create_window_expr( - &self, - e: &Expr, - logical_input_schema: &DFSchema, - physical_input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - // unpack aliased logical expressions, e.g. "sum(col) over () as total" - let (name, e) = match e { - Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (physical_name(e)?, e), - }; - self.create_window_expr_with_name( - e, - name, - logical_input_schema, - physical_input_schema, - ctx_state, - ) +/// Create a window expression with a name from a logical expression +pub fn create_window_expr_with_name( + e: &Expr, + name: impl Into, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + let name = name.into(); + match e { + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + let args = args + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; + let partition_by = partition_by + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; + let order_by = order_by + .iter() + .map(|e| match e { + Expr::Sort { + expr, + asc, + nulls_first, + } => create_physical_sort_expr( + expr, + logical_input_schema, + physical_input_schema, + SortOptions { + descending: !*asc, + nulls_first: *nulls_first, + }, + execution_props, + ), + _ => Err(DataFusionError::Plan( + "Sort only accepts sort expressions".to_string(), + )), + }) + .collect::>>()?; + if window_frame.is_some() { + return Err(DataFusionError::NotImplemented( + "window expression with window frame definition is not yet supported" + .to_owned(), + )); + } + windows::create_window_expr( + fun, + name, + &args, + &partition_by, + &order_by, + *window_frame, + physical_input_schema, + ) + } + other => Err(DataFusionError::Internal(format!( + "Invalid window expression '{:?}'", + other + ))), } +} - /// Create an aggregate expression with a name from a logical expression - pub fn create_aggregate_expr_with_name( - &self, - e: &Expr, - name: impl Into, - logical_input_schema: &DFSchema, - physical_input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - match e { - Expr::AggregateFunction { +/// Create a window expression from a logical expression or an alias +pub fn create_window_expr( + e: &Expr, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + // unpack aliased logical expressions, e.g. "sum(col) over () as total" + let (name, e) = match e { + Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), + _ => (physical_name(e)?, e), + }; + create_window_expr_with_name( + e, + name, + logical_input_schema, + physical_input_schema, + execution_props, + ) +} + +/// Create an aggregate expression with a name from a logical expression +pub fn create_aggregate_expr_with_name( + e: &Expr, + name: impl Into, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + match e { + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => { + let args = args + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; + aggregates::create_aggregate_expr( fun, - distinct, - args, - .. - } => { - let args = args - .iter() - .map(|e| { - self.create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - }) - .collect::>>()?; - aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - physical_input_schema, - name, - ) - } - Expr::AggregateUDF { fun, args, .. } => { - let args = args - .iter() - .map(|e| { - self.create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - }) - .collect::>>()?; + *distinct, + &args, + physical_input_schema, + name, + ) + } + Expr::AggregateUDF { fun, args, .. } => { + let args = args + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; - udaf::create_aggregate_expr(fun, &args, physical_input_schema, name) - } - other => Err(DataFusionError::Internal(format!( - "Invalid aggregate expression '{:?}'", - other - ))), + udaf::create_aggregate_expr(fun, &args, physical_input_schema, name) } + other => Err(DataFusionError::Internal(format!( + "Invalid aggregate expression '{:?}'", + other + ))), } +} - /// Create an aggregate expression from a logical expression or an alias - pub fn create_aggregate_expr( - &self, - e: &Expr, - logical_input_schema: &DFSchema, - physical_input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" - let (name, e) = match e { - Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (physical_name(e)?, e), - }; - - self.create_aggregate_expr_with_name( - e, - name, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - } +/// Create an aggregate expression from a logical expression or an alias +pub fn create_aggregate_expr( + e: &Expr, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" + let (name, e) = match e { + Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), + _ => (physical_name(e)?, e), + }; - /// Create a physical sort expression from a logical expression - pub fn create_physical_sort_expr( - &self, - e: &Expr, - input_dfschema: &DFSchema, - input_schema: &Schema, - options: SortOptions, - ctx_state: &ExecutionContextState, - ) -> Result { - Ok(PhysicalSortExpr { - expr: self.create_physical_expr( - e, - input_dfschema, - input_schema, - ctx_state, - )?, - options, - }) - } + create_aggregate_expr_with_name( + e, + name, + logical_input_schema, + physical_input_schema, + execution_props, + ) +} + +/// Create a physical sort expression from a logical expression +pub fn create_physical_sort_expr( + e: &Expr, + input_dfschema: &DFSchema, + input_schema: &Schema, + options: SortOptions, + execution_props: &ExecutionProps, +) -> Result { + Ok(PhysicalSortExpr { + expr: create_physical_expr(e, input_dfschema, input_schema, execution_props)?, + options, + }) +} +impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// /// Returns @@ -1653,14 +1622,14 @@ mod tests { DFField { qualifier: None, field: Field { \ name: \"a\", \ data_type: Int32, \ - nullable: false, \ + is_nullable: false, \ metadata: {} } }\ ] }, \ ExecutionPlan schema: Schema { fields: [\ Field { \ name: \"b\", \ data_type: Int32, \ - nullable: false, \ + is_nullable: false, \ metadata: {} }\ ], metadata: {} }"; match plan { @@ -1699,7 +1668,7 @@ mod tests { .build()?; let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8 - let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8 }], negated: false }"; + let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { wrapped: false, partial: false } }], negated: false }"; assert!(format!("{:?}", execution_plan).contains(expected)); // expression: "a in (true, 'a')" diff --git a/datafusion/src/physical_plan/regex_expressions.rs b/datafusion/src/physical_plan/regex_expressions.rs index 71c0901a677e..470010d5491e 100644 --- a/datafusion/src/physical_plan/regex_expressions.rs +++ b/datafusion/src/physical_plan/regex_expressions.rs @@ -302,7 +302,7 @@ mod tests { ]); let patterns = Utf8Array::::from_slice(&vec![r"x.*-(\d*)-.*"; 4]); - let flags = Utf8Array::::from_slice(vec!["i"; 4]); + let flags = Utf8Array::::from_slice(&["i"; 4]); let result = regexp_matches(&array, &patterns, Some(&flags))?; @@ -317,9 +317,9 @@ mod tests { #[test] fn test_case_sensitive_regexp_match() { - let values = StringArray::from_slice(vec!["abc"; 5]); + let values = StringArray::from_slice(&["abc"; 5]); let patterns = - StringArray::from_slice(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + StringArray::from_slice(&["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); let expected = vec![ Some(vec![Some("a")]), None, @@ -337,10 +337,10 @@ mod tests { #[test] fn test_case_insensitive_regexp_match() { - let values = StringArray::from_slice(vec!["abc"; 5]); + let values = StringArray::from_slice(&["abc"; 5]); let patterns = - StringArray::from_slice(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); - let flags = StringArray::from_slice(vec!["i"; 5]); + StringArray::from_slice(&["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + let flags = StringArray::from_slice(&["i"; 5]); let expected = vec![ Some(vec![Some("a")]), diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 75b857a78361..9345b5b553a6 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -452,7 +452,7 @@ struct RepartitionStream { /// Number of input partitions that have finished sending batches to this output channel num_input_partitions_processed: usize, - /// Schema + /// Schema wrapped by Arc schema: SchemaRef, /// channel containing the repartitioned batches @@ -503,6 +503,7 @@ mod tests { use super::*; use crate::field_util::SchemaExt; use crate::record_batch::RecordBatch; + use crate::test::create_vec_batches; use crate::{ assert_batches_sorted_eq, physical_plan::{collect, expressions::col, memory::MemoryExec}, @@ -514,7 +515,7 @@ mod tests { }, }, }; - use arrow::array::{ArrayRef, UInt32Array, Utf8Array}; + use arrow::array::{ArrayRef, Utf8Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; use futures::FutureExt; @@ -606,23 +607,6 @@ mod tests { Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) } - fn create_vec_batches(schema: &Arc, n: usize) -> Vec { - let batch = create_batch(schema); - let mut vec = Vec::with_capacity(n); - for _ in 0..n { - vec.push(batch.clone()); - } - vec - } - - fn create_batch(schema: &Arc) -> RecordBatch { - RecordBatch::try_new( - schema.clone(), - vec![Arc::new(UInt32Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]))], - ) - .unwrap() - } - async fn repartition( schema: &SchemaRef, input_partitions: Vec>, @@ -983,7 +967,7 @@ mod tests { let runtime = Arc::new(RuntimeEnv::default()); let batch = RecordBatch::try_from_iter(vec![( "a", - Arc::new(StringArray::from_slice(vec!["foo"])) as ArrayRef, + Arc::new(StringArray::from_slice(&["foo"])) as ArrayRef, )]) .unwrap(); let partitioning = Partitioning::Hash( diff --git a/datafusion/src/physical_plan/sorts/mod.rs b/datafusion/src/physical_plan/sorts/mod.rs index fdde229f9ca9..c92546a1d1de 100644 --- a/datafusion/src/physical_plan/sorts/mod.rs +++ b/datafusion/src/physical_plan/sorts/mod.rs @@ -29,11 +29,12 @@ use futures::channel::mpsc; use futures::stream::FusedStream; use futures::Stream; use hashbrown::HashMap; +use parking_lot::RwLock; use std::borrow::BorrowMut; use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::pin::Pin; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::task::{Context, Poll}; pub mod sort; @@ -136,7 +137,7 @@ impl SortKeyCursor { .collect::>(); self.init_cmp_if_needed(other, &zipped)?; - let map = self.batch_comparators.read().unwrap(); + let map = self.batch_comparators.read(); let cmp = map.get(&other.batch_id).ok_or_else(|| { DataFusionError::Execution(format!( "Failed to find comparator for {} cmp {}", @@ -173,10 +174,10 @@ impl SortKeyCursor { other: &SortKeyCursor, zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)], ) -> Result<()> { - let hm = self.batch_comparators.read().unwrap(); + let hm = self.batch_comparators.read(); if !hm.contains_key(&other.batch_id) { drop(hm); - let mut map = self.batch_comparators.write().unwrap(); + let mut map = self.batch_comparators.write(); let cmp = map .borrow_mut() .entry(other.batch_id) @@ -249,15 +250,6 @@ enum StreamWrapper { Stream(Option), } -impl StreamWrapper { - fn mem_used(&self) -> usize { - match &self { - StreamWrapper::Stream(Some(s)) => s.mem_used, - _ => 0, - } - } -} - impl Stream for StreamWrapper { type Item = ArrowResult; diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index a39ddd3950ae..d46236126072 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -21,12 +21,14 @@ use crate::error::{DataFusionError, Result}; use crate::execution::memory_manager::{ - ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, + human_readable_size, ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, }; use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::common::{batch_byte_size, IPCWriter, SizedRecordBatchStream}; use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::physical_plan::metrics::{AggregatedMetricsSet, BaselineMetrics, MetricsSet}; +use crate::physical_plan::metrics::{ + BaselineMetrics, CompositeMetricsSet, MemTrackingMetrics, MetricsSet, +}; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream; use crate::physical_plan::sorts::{SortColumn, SortedStream}; use crate::physical_plan::stream::RecordBatchReceiverStream; @@ -45,7 +47,7 @@ use arrow::io::ipc::read::{read_file_metadata, FileReader}; use async_trait::async_trait; use futures::lock::Mutex; use futures::StreamExt; -use log::{error, info}; +use log::{debug, error}; use std::any::Any; use std::fmt; use std::fmt::{Debug, Formatter}; @@ -74,8 +76,8 @@ struct ExternalSorter { /// Sort expressions expr: Vec, runtime: Arc, - metrics: AggregatedMetricsSet, - inner_metrics: BaselineMetrics, + metrics_set: CompositeMetricsSet, + metrics: BaselineMetrics, } impl ExternalSorter { @@ -83,10 +85,10 @@ impl ExternalSorter { partition_id: usize, schema: SchemaRef, expr: Vec, - metrics: AggregatedMetricsSet, + metrics_set: CompositeMetricsSet, runtime: Arc, ) -> Self { - let inner_metrics = metrics.new_intermediate_baseline(partition_id); + let metrics = metrics_set.new_intermediate_baseline(partition_id); Self { id: MemoryConsumerId::new(partition_id), schema, @@ -94,8 +96,8 @@ impl ExternalSorter { spills: Mutex::new(vec![]), expr, runtime, + metrics_set, metrics, - inner_metrics, } } @@ -103,7 +105,7 @@ impl ExternalSorter { if input.num_rows() > 0 { let size = batch_byte_size(&input); self.try_grow(size).await?; - self.inner_metrics.mem_used().add(size); + self.metrics.mem_used().add(size); let mut in_mem_batches = self.in_mem_batches.lock().await; in_mem_batches.push(input); } @@ -121,16 +123,18 @@ impl ExternalSorter { let mut in_mem_batches = self.in_mem_batches.lock().await; if self.spilled_before().await { - let baseline_metrics = self.metrics.new_intermediate_baseline(partition); + let tracking_metrics = self + .metrics_set + .new_intermediate_tracking(partition, self.runtime.clone()); let mut streams: Vec = vec![]; if in_mem_batches.len() > 0 { let in_mem_stream = in_mem_partial_sort( &mut *in_mem_batches, self.schema.clone(), &self.expr, - baseline_metrics, + tracking_metrics, )?; - let prev_used = self.inner_metrics.mem_used().set(0); + let prev_used = self.metrics.mem_used().set(0); streams.push(SortedStream::new(in_mem_stream, prev_used)); } @@ -140,25 +144,28 @@ impl ExternalSorter { let stream = read_spill_as_stream(spill, self.schema.clone())?; streams.push(SortedStream::new(stream, 0)); } - let baseline_metrics = self.metrics.new_final_baseline(partition); + let tracking_metrics = self + .metrics_set + .new_final_tracking(partition, self.runtime.clone()); Ok(Box::pin(SortPreservingMergeStream::new_from_streams( streams, self.schema.clone(), &self.expr, - baseline_metrics, - partition, + tracking_metrics, self.runtime.clone(), ))) } else if in_mem_batches.len() > 0 { - let baseline_metrics = self.metrics.new_final_baseline(partition); + let tracking_metrics = self + .metrics_set + .new_final_tracking(partition, self.runtime.clone()); let result = in_mem_partial_sort( &mut *in_mem_batches, self.schema.clone(), &self.expr, - baseline_metrics, + tracking_metrics, ); - self.inner_metrics.mem_used().set(0); - // TODO: the result size is not tracked + // Report to the memory manager we are no longer using memory + self.metrics.mem_used().set(0); result } else { Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))) @@ -166,15 +173,15 @@ impl ExternalSorter { } fn used(&self) -> usize { - self.inner_metrics.mem_used().value() + self.metrics.mem_used().value() } fn spilled_bytes(&self) -> usize { - self.inner_metrics.spilled_bytes().value() + self.metrics.spilled_bytes().value() } fn spill_count(&self) -> usize { - self.inner_metrics.spill_count().value() + self.metrics.spill_count().value() } } @@ -189,6 +196,12 @@ impl Debug for ExternalSorter { } } +impl Drop for ExternalSorter { + fn drop(&mut self) { + self.runtime.drop_consumer(self.id(), self.used()); + } +} + #[async_trait] impl MemoryConsumer for ExternalSorter { fn name(&self) -> String { @@ -208,7 +221,7 @@ impl MemoryConsumer for ExternalSorter { } async fn spill(&self) -> Result { - info!( + debug!( "{}[{}] spilling sort data of {} to disk while inserting ({} time(s) so far)", self.name(), self.id(), @@ -223,27 +236,29 @@ impl MemoryConsumer for ExternalSorter { return Ok(0); } - let baseline_metrics = self.metrics.new_intermediate_baseline(partition); + let tracking_metrics = self + .metrics_set + .new_intermediate_tracking(partition, self.runtime.clone()); let spillfile = self.runtime.disk_manager.create_tmp_file()?; let stream = in_mem_partial_sort( &mut *in_mem_batches, self.schema.clone(), &*self.expr, - baseline_metrics, + tracking_metrics, ); spill_partial_sorted_stream(&mut stream?, spillfile.path(), self.schema.clone()) .await?; let mut spills = self.spills.lock().await; - let used = self.inner_metrics.mem_used().set(0); - self.inner_metrics.record_spill(used); + let used = self.metrics.mem_used().set(0); + self.metrics.record_spill(used); spills.push(spillfile); Ok(used) } fn mem_used(&self) -> usize { - self.inner_metrics.mem_used().value() + self.metrics.mem_used().value() } } @@ -252,14 +267,14 @@ fn in_mem_partial_sort( buffered_batches: &mut Vec, schema: SchemaRef, expressions: &[PhysicalSortExpr], - baseline_metrics: BaselineMetrics, + tracking_metrics: MemTrackingMetrics, ) -> Result { assert_ne!(buffered_batches.len(), 0); let result = { // NB timer records time taken on drop, so there are no // calls to `timer.done()` below. - let _timer = baseline_metrics.elapsed_compute().timer(); + let _timer = tracking_metrics.elapsed_compute().timer(); let pre_sort = if buffered_batches.len() == 1 { buffered_batches.pop() @@ -277,7 +292,7 @@ fn in_mem_partial_sort( Ok(Box::pin(SizedRecordBatchStream::new( schema, vec![Arc::new(result.unwrap())], - baseline_metrics, + tracking_metrics, ))) } @@ -310,9 +325,8 @@ fn read_spill_as_stream( Sender>, Receiver>, ) = tokio::sync::mpsc::channel(2); - let schema_ref = schema.clone(); let join_handle = task::spawn_blocking(move || { - if let Err(e) = read_spill(sender, path.path(), schema_ref) { + if let Err(e) = read_spill(sender, path.path()) { error!("Failure while reading spill file: {:?}. Error: {}", path, e); } }); @@ -333,23 +347,22 @@ fn write_sorted( writer.write(&batch?)?; } writer.finish()?; - info!( + debug!( "Spilled {} batches of total {} rows to disk, memory released {}", - writer.num_batches, writer.num_rows, writer.num_bytes + writer.num_batches, + writer.num_rows, + human_readable_size(writer.num_bytes as usize), ); Ok(()) } -fn read_spill( - sender: Sender>, - path: &Path, - schena: SchemaRef, -) -> Result<()> { +fn read_spill(sender: Sender>, path: &Path) -> Result<()> { let mut file = BufReader::new(File::open(&path)?); let metadata = read_file_metadata(&mut file)?; let reader = FileReader::new(file, metadata, None); + let reader_schema = Arc::new(reader.schema().clone()); for chunk in reader { - let rb = RecordBatch::try_new(schena.clone(), chunk?.into_arrays()); + let rb = RecordBatch::try_new(reader_schema.clone(), chunk?.into_arrays()); sender .blocking_send(rb) .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; @@ -365,7 +378,7 @@ pub struct SortExec { /// Sort expressions expr: Vec, /// Containing all metrics set created during sort - all_metrics: AggregatedMetricsSet, + metrics_set: CompositeMetricsSet, /// Preserve partitions of input plan preserve_partitioning: bool, } @@ -389,7 +402,7 @@ impl SortExec { Self { expr, input, - all_metrics: AggregatedMetricsSet::new(), + metrics_set: CompositeMetricsSet::new(), preserve_partitioning, } } @@ -478,14 +491,14 @@ impl ExecutionPlan for SortExec { input, partition, self.expr.clone(), - self.all_metrics.clone(), + self.metrics_set.clone(), runtime, ) .await } fn metrics(&self) -> Option { - Some(self.all_metrics.aggregate_all()) + Some(self.metrics_set.aggregate_all()) } fn fmt_as( @@ -539,27 +552,23 @@ async fn do_sort( mut input: SendableRecordBatchStream, partition_id: usize, expr: Vec, - metrics: AggregatedMetricsSet, + metrics_set: CompositeMetricsSet, runtime: Arc, ) -> Result { let schema = input.schema(); - let sorter = Arc::new(ExternalSorter::new( + let sorter = ExternalSorter::new( partition_id, schema.clone(), expr, - metrics, + metrics_set, runtime.clone(), - )); - runtime.register_consumer(&(sorter.clone() as Arc)); - + ); + runtime.register_requester(sorter.id()); while let Some(batch) = input.next().await { let batch = batch?; sorter.insert_batch(batch).await?; } - - let result = sorter.sort().await; - runtime.drop_consumer(sorter.id()); - result + sorter.sort().await } #[cfg(test)] diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs index f641701bad3a..df8e70d9747f 100644 --- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs @@ -19,13 +19,14 @@ use crate::physical_plan::common::AbortOnDropMany; use crate::physical_plan::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, + ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet, }; +use parking_lot::Mutex; use std::any::Any; use std::collections::{BinaryHeap, VecDeque}; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::task::{Context, Poll}; use crate::record_batch::RecordBatch; @@ -40,9 +41,7 @@ use futures::stream::FusedStream; use futures::{Stream, StreamExt}; use crate::error::{DataFusionError, Result}; -use crate::execution::memory_manager::ConsumerType; use crate::execution::runtime_env::RuntimeEnv; -use crate::execution::{MemoryConsumer, MemoryConsumerId, MemoryManager}; use crate::field_util::SchemaExt; use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream, StreamWrapper}; use crate::physical_plan::{ @@ -160,7 +159,7 @@ impl ExecutionPlan for SortPreservingMergeExec { ))); } - let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let tracking_metrics = MemTrackingMetrics::new(&self.metrics, partition); let input_partitions = self.input.output_partitioning().partition_count(); match input_partitions { @@ -192,8 +191,7 @@ impl ExecutionPlan for SortPreservingMergeExec { AbortOnDropMany(join_handles), self.schema(), &self.expr, - baseline_metrics, - partition, + tracking_metrics, runtime, ))) } @@ -222,36 +220,19 @@ impl ExecutionPlan for SortPreservingMergeExec { } } +#[derive(Debug)] struct MergingStreams { - /// ConsumerId - id: MemoryConsumerId, /// The sorted input streams to merge together streams: Mutex>, /// number of streams num_streams: usize, - /// Runtime - runtime: Arc, -} - -impl Debug for MergingStreams { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("MergingStreams") - .field("id", &self.id()) - .finish() - } } impl MergingStreams { - fn new( - partition: usize, - input_streams: Vec, - runtime: Arc, - ) -> Self { + fn new(input_streams: Vec) -> Self { Self { - id: MemoryConsumerId::new(partition), num_streams: input_streams.len(), streams: Mutex::new(input_streams), - runtime, } } @@ -260,45 +241,13 @@ impl MergingStreams { } } -#[async_trait] -impl MemoryConsumer for MergingStreams { - fn name(&self) -> String { - "MergingStreams".to_owned() - } - - fn id(&self) -> &MemoryConsumerId { - &self.id - } - - fn memory_manager(&self) -> Arc { - self.runtime.memory_manager.clone() - } - - fn type_(&self) -> &ConsumerType { - &ConsumerType::Tracking - } - - async fn spill(&self) -> Result { - return Err(DataFusionError::Internal(format!( - "Calling spill on a tracking only consumer {}, {}", - self.name(), - self.id, - ))); - } - - fn mem_used(&self) -> usize { - let streams = self.streams.lock().unwrap(); - streams.iter().map(StreamWrapper::mem_used).sum::() - } -} - #[derive(Debug)] pub(crate) struct SortPreservingMergeStream { /// The schema of the RecordBatches yielded by this stream schema: SchemaRef, /// The sorted input streams to merge together - streams: Arc, + streams: MergingStreams, /// Drop helper for tasks feeding the [`receivers`](Self::receivers) _drop_helper: AbortOnDropMany<()>, @@ -323,7 +272,7 @@ pub(crate) struct SortPreservingMergeStream { sort_options: Arc>, /// used to record execution metrics - baseline_metrics: BaselineMetrics, + tracking_metrics: MemTrackingMetrics, /// If the stream has encountered an error aborted: bool, @@ -334,25 +283,17 @@ pub(crate) struct SortPreservingMergeStream { /// min heap for record comparison min_heap: BinaryHeap, - /// runtime - runtime: Arc, -} - -impl Drop for SortPreservingMergeStream { - fn drop(&mut self) { - self.runtime.drop_consumer(self.streams.id()) - } + /// target batch size + batch_size: usize, } impl SortPreservingMergeStream { - #[allow(clippy::too_many_arguments)] pub(crate) fn new_from_receivers( receivers: Vec>>, _drop_helper: AbortOnDropMany<()>, schema: SchemaRef, expressions: &[PhysicalSortExpr], - baseline_metrics: BaselineMetrics, - partition: usize, + tracking_metrics: MemTrackingMetrics, runtime: Arc, ) -> Self { let stream_count = receivers.len(); @@ -361,23 +302,21 @@ impl SortPreservingMergeStream { .map(|_| VecDeque::new()) .collect(); let wrappers = receivers.into_iter().map(StreamWrapper::Receiver).collect(); - let streams = Arc::new(MergingStreams::new(partition, wrappers, runtime.clone())); - runtime.register_consumer(&(streams.clone() as Arc)); SortPreservingMergeStream { schema, batches, cursor_finished: vec![true; stream_count], - streams, + streams: MergingStreams::new(wrappers), _drop_helper, column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), - baseline_metrics, + tracking_metrics, aborted: false, in_progress: vec![], next_batch_id: 0, min_heap: BinaryHeap::with_capacity(stream_count), - runtime, + batch_size: runtime.batch_size(), } } @@ -385,8 +324,7 @@ impl SortPreservingMergeStream { streams: Vec, schema: SchemaRef, expressions: &[PhysicalSortExpr], - baseline_metrics: BaselineMetrics, - partition: usize, + tracking_metrics: MemTrackingMetrics, runtime: Arc, ) -> Self { let stream_count = streams.len(); @@ -394,27 +332,26 @@ impl SortPreservingMergeStream { .into_iter() .map(|_| VecDeque::new()) .collect(); + tracking_metrics.init_mem_used(streams.iter().map(|s| s.mem_used).sum()); let wrappers = streams .into_iter() .map(|s| StreamWrapper::Stream(Some(s))) .collect(); - let streams = Arc::new(MergingStreams::new(partition, wrappers, runtime.clone())); - runtime.register_consumer(&(streams.clone() as Arc)); Self { schema, batches, cursor_finished: vec![true; stream_count], - streams, + streams: MergingStreams::new(wrappers), _drop_helper: AbortOnDropMany(vec![]), column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), - baseline_metrics, + tracking_metrics, aborted: false, in_progress: vec![], next_batch_id: 0, min_heap: BinaryHeap::with_capacity(stream_count), - runtime, + batch_size: runtime.batch_size(), } } @@ -432,7 +369,7 @@ impl SortPreservingMergeStream { } let mut empty_batch = false; { - let mut streams = self.streams.streams.lock().unwrap(); + let mut streams = self.streams.streams.lock(); let stream = &mut streams[idx]; if stream.is_terminated() { @@ -578,7 +515,7 @@ impl Stream for SortPreservingMergeStream { cx: &mut Context<'_>, ) -> Poll> { let poll = self.poll_next_inner(cx); - self.baseline_metrics.record_poll(poll) + self.tracking_metrics.record_poll(poll) } } @@ -607,7 +544,7 @@ impl SortPreservingMergeStream { loop { // NB timer records time taken on drop, so there are no // calls to `timer.done()` below. - let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let elapsed_compute = self.tracking_metrics.elapsed_compute().clone(); let _timer = elapsed_compute.timer(); match self.min_heap.pop() { @@ -631,7 +568,7 @@ impl SortPreservingMergeStream { row_idx, }); - if self.in_progress.len() == self.runtime.batch_size() { + if self.in_progress.len() == self.batch_size { return Poll::Ready(Some(self.build_record_batch())); } @@ -1277,7 +1214,7 @@ mod tests { } let metrics = ExecutionPlanMetricsSet::new(); - let baseline_metrics = BaselineMetrics::new(&metrics, 0); + let tracking_metrics = MemTrackingMetrics::new(&metrics, 0); let merge_stream = SortPreservingMergeStream::new_from_receivers( receivers, @@ -1285,8 +1222,7 @@ mod tests { AbortOnDropMany(vec![]), batches.schema(), sort.as_slice(), - baseline_metrics, - 0, + tracking_metrics, runtime.clone(), ); diff --git a/datafusion/src/physical_plan/tdigest/mod.rs b/datafusion/src/physical_plan/tdigest/mod.rs new file mode 100644 index 000000000000..603cfd867c48 --- /dev/null +++ b/datafusion/src/physical_plan/tdigest/mod.rs @@ -0,0 +1,819 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with this +// work for additional information regarding copyright ownership. The ASF +// licenses this file to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//! An implementation of the [TDigest sketch algorithm] providing approximate +//! quantile calculations. +//! +//! The TDigest code in this module is modified from +//! https://github.com/MnO2/t-digest, itself a rust reimplementation of +//! [Facebook's Folly TDigest] implementation. +//! +//! Alterations include reduction of runtime heap allocations, broader type +//! support, (de-)serialisation support, reduced type conversions and null value +//! tolerance. +//! +//! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023 +//! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h + +use arrow::datatypes::DataType; +use ordered_float::OrderedFloat; +use std::cmp::Ordering; + +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +// Cast a non-null [`ScalarValue::Float64`] to an [`OrderedFloat`], or +// panic. +macro_rules! cast_scalar_f64 { + ($value:expr ) => { + match &$value { + ScalarValue::Float64(Some(v)) => OrderedFloat::from(*v), + v => panic!("invalid type {:?}", v), + } + }; +} + +/// This trait is implemented for each type a [`TDigest`] can operate on, +/// allowing it to support both numerical rust types (obtained from +/// `PrimitiveArray` instances), and [`ScalarValue`] instances. +pub(crate) trait TryIntoOrderedF64 { + /// A fallible conversion of a possibly null `self` into a [`OrderedFloat`]. + /// + /// If `self` is null, this method must return `Ok(None)`. + /// + /// If `self` cannot be coerced to the desired type, this method must return + /// an `Err` variant. + fn try_as_f64(&self) -> Result>>; +} + +/// Generate an infallible conversion from `type` to an [`OrderedFloat`]. +macro_rules! impl_try_ordered_f64 { + ($type:ty) => { + impl TryIntoOrderedF64 for $type { + fn try_as_f64(&self) -> Result>> { + Ok(Some(OrderedFloat::from(*self as f64))) + } + } + }; +} + +impl_try_ordered_f64!(f64); +impl_try_ordered_f64!(f32); +impl_try_ordered_f64!(i64); +impl_try_ordered_f64!(i32); +impl_try_ordered_f64!(i16); +impl_try_ordered_f64!(i8); +impl_try_ordered_f64!(u64); +impl_try_ordered_f64!(u32); +impl_try_ordered_f64!(u16); +impl_try_ordered_f64!(u8); + +impl TryIntoOrderedF64 for ScalarValue { + fn try_as_f64(&self) -> Result>> { + match self { + ScalarValue::Float32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Float64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + + got => { + return Err(DataFusionError::NotImplemented(format!( + "Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented", + got + ))) + } + } + } +} + +/// Centroid implementation to the cluster mentioned in the paper. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct Centroid { + mean: OrderedFloat, + weight: OrderedFloat, +} + +impl PartialOrd for Centroid { + fn partial_cmp(&self, other: &Centroid) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Centroid { + fn cmp(&self, other: &Centroid) -> Ordering { + self.mean.cmp(&other.mean) + } +} + +impl Centroid { + pub(crate) fn new( + mean: impl Into>, + weight: impl Into>, + ) -> Self { + Centroid { + mean: mean.into(), + weight: weight.into(), + } + } + + #[inline] + pub(crate) fn mean(&self) -> OrderedFloat { + self.mean + } + + #[inline] + pub(crate) fn weight(&self) -> OrderedFloat { + self.weight + } + + pub(crate) fn add( + &mut self, + sum: impl Into>, + weight: impl Into>, + ) -> f64 { + let new_sum = sum.into() + self.weight * self.mean; + let new_weight = self.weight + weight.into(); + self.weight = new_weight; + self.mean = new_sum / new_weight; + new_sum.into_inner() + } +} + +impl Default for Centroid { + fn default() -> Self { + Centroid { + mean: OrderedFloat::from(0.0), + weight: OrderedFloat::from(1.0), + } + } +} + +/// T-Digest to be operated on. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct TDigest { + centroids: Vec, + max_size: usize, + sum: OrderedFloat, + count: OrderedFloat, + max: OrderedFloat, + min: OrderedFloat, +} + +impl TDigest { + pub(crate) fn new(max_size: usize) -> Self { + TDigest { + centroids: Vec::new(), + max_size, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } + + #[inline] + pub(crate) fn count(&self) -> f64 { + self.count.into_inner() + } + + #[inline] + pub(crate) fn max(&self) -> f64 { + self.max.into_inner() + } + + #[inline] + pub(crate) fn min(&self) -> f64 { + self.min.into_inner() + } + + #[inline] + pub(crate) fn max_size(&self) -> usize { + self.max_size + } +} + +impl Default for TDigest { + fn default() -> Self { + TDigest { + centroids: Vec::new(), + max_size: 100, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } +} + +impl TDigest { + fn k_to_q(k: f64, d: f64) -> OrderedFloat { + let k_div_d = k / d; + if k_div_d >= 0.5 { + let base = 1.0 - k_div_d; + 1.0 - 2.0 * base * base + } else { + 2.0 * k_div_d * k_div_d + } + .into() + } + + fn clamp( + v: OrderedFloat, + lo: OrderedFloat, + hi: OrderedFloat, + ) -> OrderedFloat { + if v > hi { + hi + } else if v < lo { + lo + } else { + v + } + } + + pub(crate) fn merge_unsorted( + &self, + unsorted_values: impl IntoIterator, + ) -> Result { + let mut values = unsorted_values + .into_iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?; + + values.sort(); + + Ok(self.merge_sorted_f64(&values)) + } + + fn merge_sorted_f64(&self, sorted_values: &[OrderedFloat]) -> TDigest { + #[cfg(debug_assertions)] + debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest"); + + if sorted_values.is_empty() { + return self.clone(); + } + + let mut result = TDigest::new(self.max_size()); + result.count = OrderedFloat::from(self.count() + (sorted_values.len() as f64)); + + let maybe_min = *sorted_values.first().unwrap(); + let maybe_max = *sorted_values.last().unwrap(); + + if self.count() > 0.0 { + result.min = std::cmp::min(self.min, maybe_min); + result.max = std::cmp::max(self.max, maybe_max); + } else { + result.min = maybe_min; + result.max = maybe_max; + } + + let mut compressed: Vec = Vec::with_capacity(self.max_size); + + let mut k_limit: f64 = 1.0; + let mut q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + + let mut iter_centroids = self.centroids.iter().peekable(); + let mut iter_sorted_values = sorted_values.iter().peekable(); + + let mut curr: Centroid = if let Some(c) = iter_centroids.peek() { + let curr = **iter_sorted_values.peek().unwrap(); + if c.mean() < curr { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let mut weight_so_far = curr.weight(); + + let mut sums_to_merge = OrderedFloat::from(0.0); + let mut weights_to_merge = OrderedFloat::from(0.0); + + while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() { + let next: Centroid = if let Some(c) = iter_centroids.peek() { + if iter_sorted_values.peek().is_none() + || c.mean() < **iter_sorted_values.peek().unwrap() + { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let next_sum = next.mean() * next.weight(); + weight_so_far += next.weight(); + + if weight_so_far <= q_limit_times_count { + sums_to_merge += next_sum; + weights_to_merge += next.weight(); + } else { + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + sums_to_merge = 0.0.into(); + weights_to_merge = 0.0.into(); + + compressed.push(curr.clone()); + q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + curr = next; + } + } + + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + compressed.push(curr); + compressed.shrink_to_fit(); + compressed.sort(); + + result.centroids = compressed; + result + } + + fn external_merge( + centroids: &mut Vec, + first: usize, + middle: usize, + last: usize, + ) { + let mut result: Vec = Vec::with_capacity(centroids.len()); + + let mut i = first; + let mut j = middle; + + while i < middle && j < last { + match centroids[i].cmp(¢roids[j]) { + Ordering::Less => { + result.push(centroids[i].clone()); + i += 1; + } + Ordering::Greater => { + result.push(centroids[j].clone()); + j += 1; + } + Ordering::Equal => { + result.push(centroids[i].clone()); + i += 1; + } + } + } + + while i < middle { + result.push(centroids[i].clone()); + i += 1; + } + + while j < last { + result.push(centroids[j].clone()); + j += 1; + } + + i = first; + for centroid in result.into_iter() { + centroids[i] = centroid; + i += 1; + } + } + + // Merge multiple T-Digests + pub(crate) fn merge_digests(digests: &[TDigest]) -> TDigest { + let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum(); + if n_centroids == 0 { + return TDigest::default(); + } + + let max_size = digests.first().unwrap().max_size; + let mut centroids: Vec = Vec::with_capacity(n_centroids); + let mut starts: Vec = Vec::with_capacity(digests.len()); + + let mut count: f64 = 0.0; + let mut min = OrderedFloat::from(std::f64::INFINITY); + let mut max = OrderedFloat::from(std::f64::NEG_INFINITY); + + let mut start: usize = 0; + for digest in digests.iter() { + starts.push(start); + + let curr_count: f64 = digest.count(); + if curr_count > 0.0 { + min = std::cmp::min(min, digest.min); + max = std::cmp::max(max, digest.max); + count += curr_count; + for centroid in &digest.centroids { + centroids.push(centroid.clone()); + start += 1; + } + } + } + + let mut digests_per_block: usize = 1; + while digests_per_block < starts.len() { + for i in (0..starts.len()).step_by(digests_per_block * 2) { + if i + digests_per_block < starts.len() { + let first = starts[i]; + let middle = starts[i + digests_per_block]; + let last = if i + 2 * digests_per_block < starts.len() { + starts[i + 2 * digests_per_block] + } else { + centroids.len() + }; + + debug_assert!(first <= middle && middle <= last); + Self::external_merge(&mut centroids, first, middle, last); + } + } + + digests_per_block *= 2; + } + + let mut result = TDigest::new(max_size); + let mut compressed: Vec = Vec::with_capacity(max_size); + + let mut k_limit: f64 = 1.0; + let mut q_limit_times_count = + Self::k_to_q(k_limit, max_size as f64) * (count as f64); + + let mut iter_centroids = centroids.iter_mut(); + let mut curr = iter_centroids.next().unwrap(); + let mut weight_so_far = curr.weight(); + let mut sums_to_merge = OrderedFloat::from(0.0); + let mut weights_to_merge = OrderedFloat::from(0.0); + + for centroid in iter_centroids { + weight_so_far += centroid.weight(); + + if weight_so_far <= q_limit_times_count { + sums_to_merge += centroid.mean() * centroid.weight(); + weights_to_merge += centroid.weight(); + } else { + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + sums_to_merge = OrderedFloat::from(0.0); + weights_to_merge = OrderedFloat::from(0.0); + compressed.push(curr.clone()); + q_limit_times_count = + Self::k_to_q(k_limit, max_size as f64) * (count as f64); + k_limit += 1.0; + curr = centroid; + } + } + + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + compressed.push(curr.clone()); + compressed.shrink_to_fit(); + compressed.sort(); + + result.count = OrderedFloat::from(count as f64); + result.min = min; + result.max = max; + result.centroids = compressed; + result + } + + /// To estimate the value located at `q` quantile + pub(crate) fn estimate_quantile(&self, q: f64) -> f64 { + if self.centroids.is_empty() { + return 0.0; + } + + let count_ = self.count; + let rank = OrderedFloat::from(q) * count_; + + let mut pos: usize; + let mut t; + if q > 0.5 { + if q >= 1.0 { + return self.max(); + } + + pos = 0; + t = count_; + + for (k, centroid) in self.centroids.iter().enumerate().rev() { + t -= centroid.weight(); + + if rank >= t { + pos = k; + break; + } + } + } else { + if q <= 0.0 { + return self.min(); + } + + pos = self.centroids.len() - 1; + t = OrderedFloat::from(0.0); + + for (k, centroid) in self.centroids.iter().enumerate() { + if rank < t + centroid.weight() { + pos = k; + break; + } + + t += centroid.weight(); + } + } + + let mut delta = OrderedFloat::from(0.0); + let mut min = self.min; + let mut max = self.max; + + if self.centroids.len() > 1 { + if pos == 0 { + delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean(); + max = self.centroids[pos + 1].mean(); + } else if pos == (self.centroids.len() - 1) { + delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean(); + min = self.centroids[pos - 1].mean(); + } else { + delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean()) + / 2.0; + min = self.centroids[pos - 1].mean(); + max = self.centroids[pos + 1].mean(); + } + } + + let value = self.centroids[pos].mean() + + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta; + Self::clamp(value, min, max).into_inner() + } + + /// This method decomposes the [`TDigest`] and its [`Centroid`] instances + /// into a series of primitive scalar values. + /// + /// First the values of the TDigest are packed, followed by the variable + /// number of centroids packed into a [`ScalarValue::List`] of + /// [`ScalarValue::Float64`]: + /// + /// ```text + /// + /// ┌────────┬────────┬────────┬───────┬────────┬────────┐ + /// │max_size│ sum │ count │ max │ min │centroid│ + /// └────────┴────────┴────────┴───────┴────────┴────────┘ + /// │ + /// ┌─────────────────────┘ + /// ▼ + /// ┌ List ───┐ + /// │┌ ─ ─ ─ ┐│ + /// │ mean │ + /// │├ ─ ─ ─ ┼│─ ─ Centroid 1 + /// │ weight │ + /// │└ ─ ─ ─ ┘│ + /// │ │ + /// │┌ ─ ─ ─ ┐│ + /// │ mean │ + /// │├ ─ ─ ─ ┼│─ ─ Centroid 2 + /// │ weight │ + /// │└ ─ ─ ─ ┘│ + /// │ │ + /// ... + /// + /// ``` + /// + /// The [`TDigest::from_scalar_state()`] method reverses this processes, + /// consuming the output of this method and returning an unpacked + /// [`TDigest`]. + pub(crate) fn to_scalar_state(&self) -> Vec { + // Gather up all the centroids + let centroids: Vec<_> = self + .centroids + .iter() + .flat_map(|c| [c.mean().into_inner(), c.weight().into_inner()]) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + + vec![ + ScalarValue::UInt64(Some(self.max_size as u64)), + ScalarValue::Float64(Some(self.sum.into_inner())), + ScalarValue::Float64(Some(self.count.into_inner())), + ScalarValue::Float64(Some(self.max.into_inner())), + ScalarValue::Float64(Some(self.min.into_inner())), + ScalarValue::List(Some(Box::new(centroids)), Box::new(DataType::Float64)), + ] + } + + /// Unpack the serialised state of a [`TDigest`] produced by + /// [`Self::to_scalar_state()`]. + /// + /// # Correctness + /// + /// Providing input to this method that was not obtained from + /// [`Self::to_scalar_state()`] results in undefined behaviour and may + /// panic. + pub(crate) fn from_scalar_state(state: &[ScalarValue]) -> Self { + assert_eq!(state.len(), 6, "invalid TDigest state"); + + let max_size = match &state[0] { + ScalarValue::UInt64(Some(v)) => *v as usize, + v => panic!("invalid max_size type {:?}", v), + }; + + let centroids: Vec<_> = match &state[5] { + ScalarValue::List(Some(c), d) if **d == DataType::Float64 => c + .chunks(2) + .map(|v| Centroid::new(cast_scalar_f64!(v[0]), cast_scalar_f64!(v[1]))) + .collect(), + v => panic!("invalid centroids type {:?}", v), + }; + + let max = cast_scalar_f64!(&state[3]); + let min = cast_scalar_f64!(&state[4]); + assert!(max >= min); + + Self { + max_size, + sum: cast_scalar_f64!(state[1]), + count: cast_scalar_f64!(&state[2]), + max, + min, + centroids, + } + } +} + +#[cfg(debug_assertions)] +fn is_sorted(values: &[OrderedFloat]) -> bool { + values.windows(2).all(|w| w[0] <= w[1]) +} + +#[cfg(test)] +mod tests { + use std::iter; + + use super::*; + + // A macro to assert the specified `quantile` estimated by `t` is within the + // allowable relative error bound. + macro_rules! assert_error_bounds { + ($t:ident, quantile = $quantile:literal, want = $want:literal) => { + assert_error_bounds!( + $t, + quantile = $quantile, + want = $want, + allowable_error = 0.01 + ) + }; + ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => { + let ans = $t.estimate_quantile($quantile); + let expected: f64 = $want; + let percentage: f64 = (expected - ans).abs() / expected; + assert!( + percentage < $re, + "relative error {} is more than {}% (got quantile {}, want {})", + percentage, + $re, + ans, + expected + ); + }; + } + + macro_rules! assert_state_roundtrip { + ($t:ident) => { + let state = $t.to_scalar_state(); + let other = TDigest::from_scalar_state(&state); + assert_eq!($t, other); + }; + } + + #[test] + fn test_int64_uniform() { + let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v))); + + let t = TDigest::new(100); + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.1, want = 100.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_error_bounds!(t, quantile = 0.9, want = 900.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_int64_uniform_with_nulls() { + let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v))); + // Prepend some NULLs + let values = iter::repeat(ScalarValue::Int64(None)) + .take(10) + .chain(values); + // Append some more NULLs + let values = values.chain(iter::repeat(ScalarValue::Int64(None)).take(10)); + + let t = TDigest::new(100); + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.1, want = 100.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_error_bounds!(t, quantile = 0.9, want = 900.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_centroid_addition_regression() { + //https://github.com/MnO2/t-digest/pull/1 + + let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0]; + let mut t = TDigest::new(10); + + for v in vals { + t = t.merge_unsorted([ScalarValue::Float64(Some(v))]).unwrap(); + } + + assert_error_bounds!(t, quantile = 0.5, want = 1.0); + assert_error_bounds!(t, quantile = 0.95, want = 2.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_unsorted_against_uniform_distro() { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0); + assert_error_bounds!(t, quantile = 0.99, want = 990_000.0); + assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); + assert_error_bounds!(t, quantile = 0.0, want = 1.0); + assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_unsorted_against_skewed_distro() { + let t = TDigest::new(100); + let mut values: Vec<_> = (1..=600_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + for _ in 0..400_000 { + values.push(ScalarValue::Float64(Some(1_000_000.0))); + } + + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0); + assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); + assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_digests() { + let mut digests: Vec = Vec::new(); + + for _ in 1..=100 { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + let t = t.merge_unsorted(values).unwrap(); + digests.push(t) + } + + let t = TDigest::merge_digests(&digests); + + assert_error_bounds!(t, quantile = 1.0, want = 1000.0); + assert_error_bounds!(t, quantile = 0.99, want = 990.0); + assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2); + assert_error_bounds!(t, quantile = 0.0, want = 1.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_state_roundtrip!(t); + } +} diff --git a/datafusion/src/physical_plan/unicode_expressions.rs b/datafusion/src/physical_plan/unicode_expressions.rs index c55eb7e0e4df..7e9842247524 100644 --- a/datafusion/src/physical_plan/unicode_expressions.rs +++ b/datafusion/src/physical_plan/unicode_expressions.rs @@ -447,21 +447,29 @@ pub fn substr(args: &[ArrayRef]) -> Result { start, count ))) - } else if start <= 0 { - Ok(Some(string.to_string())) } else { let graphemes = string.graphemes(true).collect::>(); - let start_pos = start as usize - 1; - let count_usize = count as usize; - if graphemes.len() < start_pos { + let (start_pos, end_pos) = if start <= 0 { + let end_pos = start + count - 1; + ( + 0_usize, + if end_pos < 0 { + // we use 0 as workaround for usize to return empty string + 0 + } else { + end_pos as usize + }, + ) + } else { + ((start - 1) as usize, (start + count - 1) as usize) + }; + + if end_pos == 0 || graphemes.len() < start_pos { Ok(Some("".to_string())) - } else if graphemes.len() < start_pos + count_usize { + } else if graphemes.len() < end_pos { Ok(Some(graphemes[start_pos..].concat())) } else { - Ok(Some( - graphemes[start_pos..start_pos + count_usize] - .concat(), - )) + Ok(Some(graphemes[start_pos..end_pos].concat())) } } } diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index abc75829ea17..0aff006c7896 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -30,10 +30,10 @@ pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::execution::options::AvroReadOptions; pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions}; pub use crate::logical_plan::{ - array, ascii, avg, bit_length, btrim, character_length, chr, col, concat, concat_ws, - count, create_udf, date_part, date_trunc, digest, in_list, initcap, left, length, - lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, - sha512, split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, - Column, JoinType, Partitioning, + approx_percentile_cont, array, ascii, avg, bit_length, btrim, character_length, chr, + col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest, in_list, + initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now, octet_length, + random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, + sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, + translate, trim, upper, Column, JoinType, Partitioning, }; diff --git a/datafusion/src/pyarrow.rs b/datafusion/src/pyarrow.rs index d06e37f9e770..88ab2e4dade5 100644 --- a/datafusion/src/pyarrow.rs +++ b/datafusion/src/pyarrow.rs @@ -21,7 +21,6 @@ use pyo3::exceptions::{PyException, PyNotImplementedError}; use pyo3::ffi::Py_uintptr_t; use pyo3::prelude::*; use pyo3::types::PyList; -use pyo3::PyNativeType; use std::sync::Arc; use crate::error::DataFusionError; diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 0c0afc2f5b3e..847a9ddd65fd 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -2631,7 +2631,7 @@ mod tests { // Convert to length-2 array let array = scalar.to_array_of_size(2); let expected_vals = vec![ - (field_a.clone(), Int32Vec::from_slice(vec![23, 23]).as_arc()), + (field_a.clone(), Int32Vec::from_slice(&[23, 23]).as_arc()), ( field_b.clone(), Arc::new(BooleanArray::from_slice(&vec![false, false])) as ArrayRef, @@ -2645,8 +2645,8 @@ mod tests { Arc::new(StructArray::from_data( DataType::Struct(vec![field_e.clone(), field_f.clone()]), vec![ - Int16Vec::from_slice(vec![2, 2]).as_arc(), - Int64Vec::from_slice(vec![3, 3]).as_arc(), + Int16Vec::from_slice(&[2, 2]).as_arc(), + Int64Vec::from_slice(&[3, 3]).as_arc(), ], None, )) as ArrayRef, @@ -2722,7 +2722,7 @@ mod tests { let array: ArrayRef = ScalarValue::iter_to_array(scalars).unwrap(); let expected = Arc::new(struct_array_from(vec![ - (field_a, Int32Vec::from_slice(vec![23, 7, -1000]).as_arc()), + (field_a, Int32Vec::from_slice(&[23, 7, -1000]).as_arc()), ( field_b, Arc::new(BooleanArray::from_slice(&vec![false, true, true])) as ArrayRef, @@ -2737,8 +2737,8 @@ mod tests { Arc::new(StructArray::from_data( DataType::Struct(vec![field_e, field_f]), vec![ - Int16Vec::from_slice(vec![2, 4, 6]).as_arc(), - Int64Vec::from_slice(vec![3, 5, 7]).as_arc(), + Int16Vec::from_slice(&[2, 4, 6]).as_arc(), + Int64Vec::from_slice(&[3, 5, 7]).as_arc(), ], None, )) as ArrayRef, diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 4d99a056cc10..e07b41ff764c 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -700,14 +700,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - /// Generate a logic plan from an SQL select - fn select_to_plan( + /// Generate a logic plan from selection clause, the function contain optimization for cross join to inner join + /// Related PR: https://github.com/apache/arrow-datafusion/pull/1566 + fn plan_selection( &self, select: &Select, - ctes: &mut HashMap, - alias: Option, + plans: Vec, ) -> Result { - let plans = self.plan_from_tables(&select.from, ctes)?; let plan = match &select.selection { Some(predicate_expr) => { // build join schema @@ -825,9 +824,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } }; - let plan = plan?; + plan + } - // The SELECT expressions, with wildcards expanded. + /// Generate a logic plan from an SQL select + fn select_to_plan( + &self, + select: &Select, + ctes: &mut HashMap, + alias: Option, + ) -> Result { + // process `from` clause + let plans = self.plan_from_tables(&select.from, ctes)?; + + // process `where` clause + let plan = self.plan_selection(select, plans)?; + + // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs(&plan, select)?; // having and group by clause may reference aliases defined in select projection @@ -876,6 +889,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // All of the aggregate expressions (deduplicated). let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); + // All of the group by expressions let group_by_exprs = select .group_by .iter() @@ -894,6 +908,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; + // process group by, aggregation or having let (plan, select_exprs_post_aggr, having_expr_post_aggr_opt) = if !group_by_exprs .is_empty() || !aggr_exprs.is_empty() @@ -934,7 +949,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - // window function + // process window function let window_func_exprs = find_window_exprs(&select_exprs_post_aggr); let plan = if window_func_exprs.is_empty() { @@ -943,6 +958,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { LogicalPlanBuilder::window_plan(plan, window_func_exprs)? }; + // process distinct clause let plan = if select.distinct { return LogicalPlanBuilder::from(plan) .aggregate(select_exprs_post_aggr, iter::empty::())? @@ -950,6 +966,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { plan }; + + // generate the final projection plan project_with_alias(plan, select_exprs_post_aggr, alias) } @@ -1261,6 +1279,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), + BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL binary operator {:?}", op @@ -3292,7 +3311,7 @@ mod tests { JOIN orders \ ON id = customer_id"; let expected = "Projection: #person.id, #orders.order_id\ - \n Join: #person.id = #orders.customer_id\ + \n Inner Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -3306,7 +3325,7 @@ mod tests { ON id = customer_id AND order_id > 1 "; let expected = "Projection: #person.id, #orders.order_id\ \n Filter: #orders.order_id > Int64(1)\ - \n Join: #person.id = #orders.customer_id\ + \n Inner Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -3319,7 +3338,7 @@ mod tests { LEFT JOIN orders \ ON id = customer_id AND order_id > 1"; let expected = "Projection: #person.id, #orders.order_id\ - \n Join: #person.id = #orders.customer_id\ + \n Left Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n Filter: #orders.order_id > Int64(1)\ \n TableScan: orders projection=None"; @@ -3333,7 +3352,7 @@ mod tests { RIGHT JOIN orders \ ON id = customer_id AND id > 1"; let expected = "Projection: #person.id, #orders.order_id\ - \n Join: #person.id = #orders.customer_id\ + \n Right Join: #person.id = #orders.customer_id\ \n Filter: #person.id > Int64(1)\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; @@ -3347,7 +3366,7 @@ mod tests { JOIN orders \ ON person.id = orders.customer_id"; let expected = "Projection: #person.id, #orders.order_id\ - \n Join: #person.id = #orders.customer_id\ + \n Inner Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -3360,7 +3379,7 @@ mod tests { JOIN person as person2 \ USING (id)"; let expected = "Projection: #person.first_name, #person.id\ - \n Join: Using #person.id = #person2.id\ + \n Inner Join: Using #person.id = #person2.id\ \n TableScan: person projection=None\ \n TableScan: person2 projection=None"; quick_test(sql, expected); @@ -3373,7 +3392,7 @@ mod tests { JOIN lineitem as lineitem2 \ USING (l_item_id)"; let expected = "Projection: #lineitem.l_item_id, #lineitem.l_description, #lineitem.price, #lineitem2.l_description, #lineitem2.price\ - \n Join: Using #lineitem.l_item_id = #lineitem2.l_item_id\ + \n Inner Join: Using #lineitem.l_item_id = #lineitem2.l_item_id\ \n TableScan: lineitem projection=None\ \n TableScan: lineitem2 projection=None"; quick_test(sql, expected); @@ -3387,8 +3406,8 @@ mod tests { JOIN lineitem ON o_item_id = l_item_id"; let expected = "Projection: #person.id, #orders.order_id, #lineitem.l_description\ - \n Join: #orders.o_item_id = #lineitem.l_item_id\ - \n Join: #person.id = #orders.customer_id\ + \n Inner Join: #orders.o_item_id = #lineitem.l_item_id\ + \n Inner Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None\ \n TableScan: lineitem projection=None"; @@ -3921,8 +3940,8 @@ mod tests { fn cross_join_to_inner_join() { let sql = "select person.id from person, orders, lineitem where person.id = lineitem.l_item_id and orders.o_item_id = lineitem.l_description;"; let expected = "Projection: #person.id\ - \n Join: #lineitem.l_description = #orders.o_item_id\ - \n Join: #person.id = #lineitem.l_item_id\ + \n Inner Join: #lineitem.l_description = #orders.o_item_id\ + \n Inner Join: #person.id = #lineitem.l_item_id\ \n TableScan: person projection=None\ \n TableScan: lineitem projection=None\ \n TableScan: orders projection=None"; diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index 6e0e44e5e147..d845c01d63d1 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -17,6 +17,7 @@ //! Common unit test utility methods +use crate::arrow::array::UInt32Array; use crate::datasource::object_store::local::local_unpartitioned_file; use crate::datasource::{MemTable, PartitionedFile, TableProvider}; use crate::error::Result; @@ -183,14 +184,6 @@ pub fn make_partition(sz: i32) -> RecordBatch { RecordBatch::try_new(schema, vec![arr]).unwrap() } -/// Return a new table provider containing all of the supported timestamp types -pub fn table_with_timestamps() -> Arc { - let batch = make_timestamps(); - let schema = batch.schema().clone(); - let partitions = vec![vec![batch]]; - Arc::new(MemTable::try_new(schema, partitions).unwrap()) -} - /// Return a new table which provide this decimal column pub fn table_with_decimal() -> Arc { let batch_decimal = make_decimal(); @@ -212,87 +205,6 @@ fn make_decimal() -> RecordBatch { RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() } -/// Return record batch with all of the supported timestamp types -/// values -/// -/// Columns are named: -/// "nanos" --> TimestampNanosecondArray -/// "micros" --> TimestampMicrosecondArray -/// "millis" --> TimestampMillisecondArray -/// "secs" --> TimestampSecondArray -/// "names" --> StringArray -pub fn make_timestamps() -> RecordBatch { - let ts_strings = vec![ - Some("2018-11-13T17:11:10.011375885995"), - Some("2011-12-13T11:13:10.12345"), - None, - Some("2021-1-1T05:11:10.432"), - ]; - - let ts_nanos = ts_strings - .into_iter() - .map(|t| { - t.map(|t| { - t.parse::() - .unwrap() - .timestamp_nanos() - }) - }) - .collect::>(); - - let ts_micros = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000)) - .collect::>(); - - let ts_millis = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000)) - .collect::>(); - - let ts_secs = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000000)) - .collect::>(); - - let names = ts_nanos - .iter() - .enumerate() - .map(|(i, _)| format!("Row {}", i)); - - let arr_names = Utf8Array::::from_trusted_len_values_iter(names); - - let arr_nanos = - Int64Array::from(ts_nanos).to(DataType::Timestamp(TimeUnit::Nanosecond, None)); - let arr_micros = - Int64Array::from(ts_micros).to(DataType::Timestamp(TimeUnit::Microsecond, None)); - let arr_millis = - Int64Array::from(ts_millis).to(DataType::Timestamp(TimeUnit::Millisecond, None)); - let arr_secs = - Int64Array::from(ts_secs).to(DataType::Timestamp(TimeUnit::Second, None)); - - let schema = Schema::new(vec![ - Field::new("nanos", arr_nanos.data_type().clone(), true), - Field::new("micros", arr_micros.data_type().clone(), true), - Field::new("millis", arr_millis.data_type().clone(), true), - Field::new("secs", arr_secs.data_type().clone(), true), - Field::new("name", arr_names.data_type().clone(), true), - ]); - let schema = Arc::new(schema); - - RecordBatch::try_new( - schema, - vec![ - Arc::new(arr_nanos), - Arc::new(arr_micros), - Arc::new(arr_millis), - Arc::new(arr_secs), - Arc::new(arr_names), - ], - ) - .unwrap() -} - /// Asserts that given future is pending. pub fn assert_is_pending<'a, T>(fut: &mut Pin + Send + 'a>>) { let waker = futures::task::noop_waker(); @@ -302,6 +214,25 @@ pub fn assert_is_pending<'a, T>(fut: &mut Pin + Send assert!(poll.is_pending()); } +/// Create vector batches +pub fn create_vec_batches(schema: &Arc, n: usize) -> Vec { + let batch = create_batch(schema); + let mut vec = Vec::with_capacity(n); + for _ in 0..n { + vec.push(batch.clone()); + } + vec +} + +/// Create batch +fn create_batch(schema: &Arc) -> RecordBatch { + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(UInt32Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]))], + ) + .unwrap() +} + pub mod exec; pub mod object_store; pub mod user_defined; diff --git a/datafusion/src/test_util.rs b/datafusion/src/test_util.rs index 429539ec1f53..edb0f60c6d53 100644 --- a/datafusion/src/test_util.rs +++ b/datafusion/src/test_util.rs @@ -20,8 +20,8 @@ use std::collections::BTreeMap; use std::{env, error::Error, path::PathBuf, sync::Arc}; -use crate::field_util::SchemaExt; -use arrow::datatypes::{DataType, Field, Schema}; +use crate::field_util::{FieldExt, SchemaExt}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; /// Compares formatted output of a record batch with an expected /// vector of strings, with the result of pretty formatting record @@ -254,6 +254,32 @@ pub fn aggr_test_schema() -> Arc { Arc::new(schema) } +/// Get the schema for the aggregate_test_* csv files with an additional filed not present in the files. +pub fn aggr_test_schema_with_missing_col() -> SchemaRef { + let mut f1 = Field::new("c1", DataType::Utf8, false); + f1.set_metadata(Some(BTreeMap::from_iter( + vec![("testing".into(), "test".into())].into_iter(), + ))); + let schema = Schema::new(vec![ + f1, + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + Field::new("missing_col", DataType::Int64, true), + ]); + + Arc::new(schema) +} + #[cfg(test)] mod tests { use super::*; @@ -308,3 +334,40 @@ mod tests { assert!(PathBuf::from(res).is_dir()); } } + +#[cfg(test)] +pub fn create_decimal_array( + array: &[Option], + precision: usize, + scale: usize, +) -> crate::error::Result { + use arrow::array::{Int128Vec, TryPush}; + let mut decimal_builder = Int128Vec::from_data( + DataType::Decimal(precision, scale), + Vec::::with_capacity(array.len()), + None, + ); + + for value in array { + match value { + None => { + decimal_builder.push(None); + } + Some(v) => { + decimal_builder.try_push(Some(*v))?; + } + } + } + Ok(decimal_builder.into()) +} + +#[cfg(test)] +pub fn create_decimal_array_from_slice( + array: &[i128], + precision: usize, + scale: usize, +) -> crate::error::Result { + let decimal_array_values: Vec> = + array.into_iter().map(|v| Some(*v)).collect(); + create_decimal_array(&decimal_array_values, precision, scale) +} diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index d9a73c98a035..aadd0024892e 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -242,7 +242,7 @@ async fn custom_source_dataframe() -> Result<()> { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("c2", physical_plan.schema().field(0).name()); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let batches = collect(physical_plan, runtime).await?; let origin_rec_batch = TEST_CUSTOM_RECORD_BATCH!()?; assert_eq!(1, batches.len()); @@ -289,7 +289,7 @@ async fn optimizers_catch_all_statistics() { ) .unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let actual = collect(physical_plan, runtime).await.unwrap(); assert_eq!(actual.len(), 1); diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index 2437140197a1..9a502742ad89 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -44,13 +44,13 @@ fn create_test_table() -> Result> { let batch = RecordBatch::try_new( schema.clone(), vec![ - Arc::new(Utf8Array::::from_slice(vec![ + Arc::new(Utf8Array::::from_slice(&[ "abcDEF", "abc123", "CBAdef", "123AbcDef", ])), - Arc::new(Int32Array::from_slice(vec![1, 10, 10, 100])), + Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])), ], )?; @@ -152,6 +152,26 @@ async fn test_fn_btrim_with_chars() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_approx_percentile_cont() -> Result<()> { + let expr = approx_percentile_cont(col("b"), lit(0.5)); + + let expected = vec![ + "+-------------------------------------------+", + "| APPROXPERCENTILECONT(test.b,Float64(0.5)) |", + "+-------------------------------------------+", + "| 10 |", + "+-------------------------------------------+", + ]; + + let df = create_test_table()?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + + Ok(()) +} + #[tokio::test] async fn test_fn_character_length() -> Result<()> { let expr = character_length(col("a")); diff --git a/datafusion/tests/merge_fuzz.rs b/datafusion/tests/merge_fuzz.rs index cf8e66dbb116..f759c8d0f0fd 100644 --- a/datafusion/tests/merge_fuzz.rs +++ b/datafusion/tests/merge_fuzz.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Fuzz Test for various corner cases merging streams of RecordBatchs +//! Fuzz Test for various corner cases merging streams of RecordBatches use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array}; @@ -30,6 +30,7 @@ use datafusion::{ sorts::sort_preserving_merge::SortPreservingMergeExec, }, }; +use fuzz_utils::{add_empty_batches, batches_to_vec, partitions_to_sorted_vec}; use rand::{prelude::StdRng, Rng, SeedableRng}; #[tokio::test] @@ -145,36 +146,6 @@ async fn run_merge_test(input: Vec>) { } } -/// Extracts the i32 values from the set of batches and returns them as a single Vec -fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { - batches - .iter() - .map(|batch| { - assert_eq!(batch.num_columns(), 1); - batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap() - .iter() - }) - .flatten() - .map(|v| v.copied()) - .collect() -} - -// extract values from batches and sort them -fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec> { - let mut values: Vec<_> = partitions - .iter() - .map(|batches| batches_to_vec(batches).into_iter()) - .flatten() - .collect(); - - values.sort_unstable(); - values -} - /// Return the values `low..high` in order, in randomly sized /// record batches in a field named 'x' of type `Int32` fn make_staggered_batches(low: i32, high: i32, seed: u64) -> Vec { @@ -198,24 +169,6 @@ fn make_staggered_batches(low: i32, high: i32, seed: u64) -> Vec { add_empty_batches(batches, &mut rng) } -/// Adds a random number of empty record batches into the stream -fn add_empty_batches(batches: Vec, rng: &mut StdRng) -> Vec { - let schema = batches[0].schema().clone(); - - batches - .into_iter() - .map(|batch| { - // insert 0, or 1 empty batches before and after the current batch - let empty_batch = RecordBatch::new_empty(schema.clone()); - std::iter::repeat(empty_batch.clone()) - .take(rng.gen_range(0..2)) - .chain(std::iter::once(batch)) - .chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2))) - }) - .flatten() - .collect() -} - fn concat(mut v1: Vec, v2: Vec) -> Vec { v1.extend(v2); v1 diff --git a/datafusion/tests/order_spill_fuzz.rs b/datafusion/tests/order_spill_fuzz.rs new file mode 100644 index 000000000000..9fd38f1e5b4a --- /dev/null +++ b/datafusion/tests/order_spill_fuzz.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill + +use arrow::array::{ArrayRef, Int32Array}; +use arrow::compute::sort::SortOptions; +use datafusion::execution::memory_manager::MemoryManagerConfig; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::physical_plan::expressions::{col, PhysicalSortExpr}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::record_batch::RecordBatch; +use fuzz_utils::{add_empty_batches, batches_to_vec, partitions_to_sorted_vec}; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; + +#[tokio::test] +async fn test_sort_1k_mem() { + run_sort(1024, vec![(5, false), (2000, true), (1000000, true)]).await +} + +#[tokio::test] +async fn test_sort_100k_mem() { + run_sort(102400, vec![(5, false), (2000, false), (1000000, true)]).await +} + +#[tokio::test] +async fn test_sort_unlimited_mem() { + run_sort( + usize::MAX, + vec![(5, false), (2000, false), (1000000, false)], + ) + .await +} + +/// Sort the input using SortExec and ensure the results are correct according to `Vec::sort` +async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) { + for (size, spill) in size_spill { + let input = vec![make_staggered_batches(size)]; + let first_batch = input + .iter() + .map(|p| p.iter()) + .flatten() + .next() + .expect("at least one batch"); + let schema = first_batch.schema().clone(); + + let sort = vec![PhysicalSortExpr { + expr: col("x", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + + let exec = MemoryExec::try_new(&input, schema, None).unwrap(); + let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec)).unwrap()); + + let runtime_config = RuntimeConfig::new().with_memory_manager( + MemoryManagerConfig::try_new_limit(pool_size, 1.0).unwrap(), + ); + let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); + let collected = collect(sort.clone(), runtime).await.unwrap(); + + let expected = partitions_to_sorted_vec(&input); + let actual = batches_to_vec(&collected); + + if spill { + assert_ne!(sort.metrics().unwrap().spill_count().unwrap(), 0); + } else { + assert_eq!(sort.metrics().unwrap().spill_count().unwrap(), 0); + } + + assert_eq!(expected, actual, "failure in @ pool_size {}", pool_size); + } +} + +/// Return randomly sized record batches in a field named 'x' of type `Int32` +/// with randomized i32 content +fn make_staggered_batches(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input: Vec = vec![0; len]; + rng.fill(&mut input[..]); + let input = Int32Array::from_values(input.into_iter()); + + // split into several record batches + let mut remainder = + RecordBatch::try_from_iter(vec![("x", Arc::new(input) as ArrayRef)]).unwrap(); + + let mut batches = vec![]; + + // use a random number generator to pick a random sized output + let mut rng = StdRng::seed_from_u64(42); + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + + batches.push(remainder.slice(0, batch_size)); + remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + + add_empty_batches(batches, &mut rng) +} diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index abba09671cc9..c41095d0e9eb 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -540,7 +540,7 @@ impl ContextWithParquet { .await .expect("creating physical plan"); - let runtime = self.ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx.state.lock().runtime_env.clone(); let results = datafusion::physical_plan::collect(physical_plan.clone(), runtime) .await .expect("Running"); diff --git a/datafusion/tests/provider_filter_pushdown.rs b/datafusion/tests/provider_filter_pushdown.rs index c3da1f3544ea..5e13dbfa538e 100644 --- a/datafusion/tests/provider_filter_pushdown.rs +++ b/datafusion/tests/provider_filter_pushdown.rs @@ -25,7 +25,7 @@ use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::field_util::SchemaExt; use datafusion::logical_plan::Expr; use datafusion::physical_plan::common::SizedRecordBatchStream; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics}; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; @@ -85,11 +85,11 @@ impl ExecutionPlan for CustomPlan { _runtime: Arc, ) -> Result { let metrics = ExecutionPlanMetricsSet::new(); - let baseline_metrics = BaselineMetrics::new(&metrics, partition); + let tracking_metrics = MemTrackingMetrics::new(&metrics, partition); Ok(Box::pin(SizedRecordBatchStream::new( self.schema(), self.batches.clone(), - baseline_metrics, + tracking_metrics, ))) } diff --git a/datafusion/tests/simplification.rs b/datafusion/tests/simplification.rs new file mode 100644 index 000000000000..7bd62401d4fa --- /dev/null +++ b/datafusion/tests/simplification.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This program demonstrates the DataFusion expression simplification API. + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::field_util::SchemaExt; +use datafusion::{ + error::Result, + execution::context::ExecutionProps, + logical_plan::{DFSchema, Expr, SimplifyInfo}, + prelude::*, +}; + +/// In order to simplify expressions, DataFusion must have information +/// about the expressions. +/// +/// You can provide that information using DataFusion [DFSchema] +/// objects or from some other implemention +struct MyInfo { + /// The input schema + schema: DFSchema, + + /// Execution specific details needed for constant evaluation such + /// as the current time for `now()` and [VariableProviders] + execution_props: ExecutionProps, +} + +impl SimplifyInfo for MyInfo { + fn is_boolean_type(&self, expr: &Expr) -> Result { + Ok(matches!(expr.get_type(&self.schema)?, DataType::Boolean)) + } + + fn nullable(&self, expr: &Expr) -> Result { + expr.nullable(&self.schema) + } + + fn execution_props(&self) -> &ExecutionProps { + &self.execution_props + } +} + +impl From for MyInfo { + fn from(schema: DFSchema) -> Self { + Self { + schema, + execution_props: ExecutionProps::new(), + } + } +} + +/// A schema like: +/// +/// a: Int32 (possibly with nulls) +/// b: Int32 +/// s: Utf8 +fn schema() -> DFSchema { + Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, false), + Field::new("s", DataType::Utf8, false), + ]) + .try_into() + .unwrap() +} + +#[test] +fn basic() { + let info: MyInfo = schema().into(); + + // The `Expr` is a core concept in DataFusion, and DataFusion can + // help simplify it. + + // For example 'a < (2 + 3)' can be rewritten into the easier to + // optimize form `a < 5` automatically + let expr = col("a").lt(lit(2i32) + lit(3i32)); + + let simplified = expr.simplify(&info).unwrap(); + assert_eq!(simplified, col("a").lt(lit(5i32))); +} + +#[test] +fn fold_and_simplify() { + let info: MyInfo = schema().into(); + + // What will it do with the expression `concat('foo', 'bar') == 'foobar')`? + let expr = concat(&[lit("foo"), lit("bar")]).eq(lit("foobar")); + + // Since datafusion applies both simplification *and* rewriting + // some expressions can be entirely simplified + let simplified = expr.simplify(&info).unwrap(); + assert_eq!(simplified, lit(true)) +} diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 9d72752b091d..fd1d15cc0ca7 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -26,7 +26,7 @@ async fn csv_query_avg_multi_batch() -> Result<()> { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.unwrap(); let batch = &results[0]; let column = batch.column(0); @@ -354,6 +354,95 @@ async fn csv_query_approx_count() -> Result<()> { Ok(()) } +// This test executes the APPROX_PERCENTILE_CONT aggregation against the test +// data, asserting the estimated quantiles are ±5% their actual values. +// +// Actual quantiles calculated with: +// +// ```r +// read_csv("./testing/data/csv/aggregate_test_100.csv") |> +// select_if(is.numeric) |> +// summarise_all(~ quantile(., c(0.1, 0.5, 0.9))) +// ``` +// +// Giving: +// +// ```text +// c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 +// +// 1 1 -95.3 -22925. -1882606710 -7.25e18 18.9 2671. 472608672. 1.83e18 0.109 0.0714 +// 2 3 15.5 4599 377164262 1.13e18 134. 30634 2365817608. 9.30e18 0.491 0.551 +// 3 5 102. 25334. 1991374996. 7.37e18 231 57518. 3776538487. 1.61e19 0.834 0.946 +// ``` +// +// Column `c12` is omitted due to a large relative error (~10%) due to the small +// float values. +#[tokio::test] +async fn csv_query_approx_percentile_cont() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + + // Generate an assertion that the estimated $percentile value for $column is + // within 5% of the $actual percentile value. + macro_rules! percentile_test { + ($ctx:ident, column=$column:literal, percentile=$percentile:literal, actual=$actual:literal) => { + let sql = format!("SELECT (ABS(1 - CAST(approx_percentile_cont({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $percentile, $actual); + let actual = execute_to_batches(&mut ctx, &sql).await; + // + // "+------+", + // "| q |", + // "+------+", + // "| true |", + // "+------+", + // + let want = ["+------+", "| q |", "+------+", "| true |", "+------+"]; + assert_batches_eq!(want, &actual); + }; + } + + percentile_test!(ctx, column = "c2", percentile = 0.1, actual = 1.0); + percentile_test!(ctx, column = "c2", percentile = 0.5, actual = 3.0); + percentile_test!(ctx, column = "c2", percentile = 0.9, actual = 5.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c3", percentile = 0.1, actual = -95.3); + percentile_test!(ctx, column = "c3", percentile = 0.5, actual = 15.5); + percentile_test!(ctx, column = "c3", percentile = 0.9, actual = 102.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c4", percentile = 0.1, actual = -22925.0); + percentile_test!(ctx, column = "c4", percentile = 0.5, actual = 4599.0); + percentile_test!(ctx, column = "c4", percentile = 0.9, actual = 25334.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c5", percentile = 0.1, actual = -1882606710.0); + percentile_test!(ctx, column = "c5", percentile = 0.5, actual = 377164262.0); + percentile_test!(ctx, column = "c5", percentile = 0.9, actual = 1991374996.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c6", percentile = 0.1, actual = -7.25e18); + percentile_test!(ctx, column = "c6", percentile = 0.5, actual = 1.13e18); + percentile_test!(ctx, column = "c6", percentile = 0.9, actual = 7.37e18); + //////////////////////////////////// + percentile_test!(ctx, column = "c7", percentile = 0.1, actual = 18.9); + percentile_test!(ctx, column = "c7", percentile = 0.5, actual = 134.0); + percentile_test!(ctx, column = "c7", percentile = 0.9, actual = 231.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c8", percentile = 0.1, actual = 2671.0); + percentile_test!(ctx, column = "c8", percentile = 0.5, actual = 30634.0); + percentile_test!(ctx, column = "c8", percentile = 0.9, actual = 57518.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c9", percentile = 0.1, actual = 472608672.0); + percentile_test!(ctx, column = "c9", percentile = 0.5, actual = 2365817608.0); + percentile_test!(ctx, column = "c9", percentile = 0.9, actual = 3776538487.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c10", percentile = 0.1, actual = 1.83e18); + percentile_test!(ctx, column = "c10", percentile = 0.5, actual = 9.30e18); + percentile_test!(ctx, column = "c10", percentile = 0.9, actual = 1.61e19); + //////////////////////////////////// + percentile_test!(ctx, column = "c11", percentile = 0.1, actual = 0.109); + percentile_test!(ctx, column = "c11", percentile = 0.5, actual = 0.491); + percentile_test!(ctx, column = "c11", percentile = 0.9, actual = 0.834); + + Ok(()) +} + #[tokio::test] async fn query_count_without_from() -> Result<()> { let mut ctx = ExecutionContext::new(); @@ -473,3 +562,105 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn aggregate_timestamps_sum() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = plan_and_collect( + &mut ctx, + "SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t", + ) + .await + .unwrap_err(); + + assert_eq!(results.to_string(), "Error during planning: The function Sum does not support inputs of type Timestamp(Nanosecond, None)."); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_timestamps_count() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = execute_to_batches( + &mut ctx, + "SELECT count(nanos), count(micros), count(millis), count(secs) FROM t", + ) + .await; + + let expected = vec![ + "+----------------+-----------------+-----------------+---------------+", + "| COUNT(t.nanos) | COUNT(t.micros) | COUNT(t.millis) | COUNT(t.secs) |", + "+----------------+-----------------+-----------------+---------------+", + "| 3 | 3 | 3 | 3 |", + "+----------------+-----------------+-----------------+---------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_timestamps_min() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = execute_to_batches( + &mut ctx, + "SELECT min(nanos), min(micros), min(millis), min(secs) FROM t", + ) + .await; + + let expected = vec![ + "+----------------------------+----------------------------+-------------------------+---------------------+", + "| MIN(t.nanos) | MIN(t.micros) | MIN(t.millis) | MIN(t.secs) |", + "+----------------------------+----------------------------+-------------------------+---------------------+", + "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 |", + "+----------------------------+----------------------------+-------------------------+---------------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_timestamps_max() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = execute_to_batches( + &mut ctx, + "SELECT max(nanos), max(micros), max(millis), max(secs) FROM t", + ) + .await; + + let expected = vec![ + "+-------------------------+-------------------------+-------------------------+---------------------+", + "| MAX(t.nanos) | MAX(t.micros) | MAX(t.millis) | MAX(t.secs) |", + "+-------------------------+-------------------------+-------------------------+---------------------+", + "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 |", + "+-------------------------+-------------------------+-------------------------+---------------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_timestamps_avg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = plan_and_collect( + &mut ctx, + "SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t", + ) + .await + .unwrap_err(); + + assert_eq!(results.to_string(), "Error during planning: The function Avg does not support inputs of type Timestamp(Nanosecond, None)."); + Ok(()) +} diff --git a/datafusion/tests/sql/avro.rs b/datafusion/tests/sql/avro.rs index d0cdf71b0868..82d91a0bd481 100644 --- a/datafusion/tests/sql/avro.rs +++ b/datafusion/tests/sql/avro.rs @@ -124,7 +124,7 @@ async fn avro_single_nan_schema() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.unwrap(); for batch in results { assert_eq!(1, batch.num_rows()); diff --git a/datafusion/tests/sql/errors.rs b/datafusion/tests/sql/errors.rs index 05ca0642bae0..92b634dd5e96 100644 --- a/datafusion/tests/sql/errors.rs +++ b/datafusion/tests/sql/errors.rs @@ -37,7 +37,7 @@ async fn test_cast_expressions_error() -> Result<()> { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let result = collect(plan, runtime).await; match result { diff --git a/datafusion/tests/sql/explain_analyze.rs b/datafusion/tests/sql/explain_analyze.rs index 25e0cd6b0bda..a4371dbfc578 100644 --- a/datafusion/tests/sql/explain_analyze.rs +++ b/datafusion/tests/sql/explain_analyze.rs @@ -41,7 +41,7 @@ async fn explain_analyze_baseline_metrics() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(physical_plan.clone(), runtime).await.unwrap(); let formatted = print::write(&results); println!("Query Output:\n\n{}", formatted); @@ -327,7 +327,7 @@ async fn csv_explain_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string @@ -525,7 +525,7 @@ async fn csv_explain_verbose_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string @@ -610,9 +610,9 @@ order by Sort: #revenue DESC NULLS FIRST\ \n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\ \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Int64(1) - #lineitem.l_discount)]]\ - \n Join: #customer.c_nationkey = #nation.n_nationkey\ - \n Join: #orders.o_orderkey = #lineitem.l_orderkey\ - \n Join: #customer.c_custkey = #orders.o_custkey\ + \n Inner Join: #customer.c_nationkey = #nation.n_nationkey\ + \n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\ + \n Inner Join: #customer.c_custkey = #orders.o_custkey\ \n TableScan: customer projection=Some([0, 1, 2, 3, 4, 5, 7])\ \n Filter: #orders.o_orderdate >= Date32(\"8674\") AND #orders.o_orderdate < Date32(\"8766\")\ \n TableScan: orders projection=Some([0, 1, 4]), filters=[#orders.o_orderdate >= Date32(\"8674\"), #orders.o_orderdate < Date32(\"8766\")]\ diff --git a/datafusion/tests/sql/information_schema.rs b/datafusion/tests/sql/information_schema.rs new file mode 100644 index 000000000000..d93f0d7328d3 --- /dev/null +++ b/datafusion/tests/sql/information_schema.rs @@ -0,0 +1,502 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use async_trait::async_trait; +use datafusion::{ + catalog::{ + catalog::MemoryCatalogProvider, + schema::{MemorySchemaProvider, SchemaProvider}, + }, + datasource::{TableProvider, TableType}, + logical_plan::Expr, +}; + +use super::*; + +#[tokio::test] +async fn information_schema_tables_not_exist_by_default() { + let mut ctx = ExecutionContext::new(); + + let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Table or CTE with name 'information_schema.tables' not found" + ); +} + +#[tokio::test] +async fn information_schema_tables_no_tables() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+------------+", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | tables | VIEW |", + "+---------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_tables_tables_default_catalog() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + // Now, register an empty table + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+------------+", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | public | t | BASE TABLE |", + "+---------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Newly added tables should appear + ctx.register_table("t2", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+------------+", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | public | t | BASE TABLE |", + "| datafusion | public | t2 | BASE TABLE |", + "+---------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_tables_tables_with_multiple_catalogs() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + schema + .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) + .unwrap(); + schema + .register_table("t2".to_owned(), table_with_sequence(1, 1).unwrap()) + .unwrap(); + catalog.register_schema("my_schema", Arc::new(schema)); + ctx.register_catalog("my_catalog", Arc::new(catalog)); + + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + schema + .register_table("t3".to_owned(), table_with_sequence(1, 1).unwrap()) + .unwrap(); + catalog.register_schema("my_other_schema", Arc::new(schema)); + ctx.register_catalog("my_other_catalog", Arc::new(catalog)); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+------------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+------------------+--------------------+------------+------------+", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | tables | VIEW |", + "| my_catalog | information_schema | columns | VIEW |", + "| my_catalog | information_schema | tables | VIEW |", + "| my_catalog | my_schema | t1 | BASE TABLE |", + "| my_catalog | my_schema | t2 | BASE TABLE |", + "| my_other_catalog | information_schema | columns | VIEW |", + "| my_other_catalog | information_schema | tables | VIEW |", + "| my_other_catalog | my_other_schema | t3 | BASE TABLE |", + "+------------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_tables_table_types() { + struct TestTable(TableType); + + #[async_trait] + impl TableProvider for TestTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn table_type(&self) -> TableType { + self.0 + } + + fn schema(&self) -> SchemaRef { + unimplemented!() + } + + async fn scan( + &self, + _: &Option>, + _: &[Expr], + _: Option, + ) -> Result> { + unimplemented!() + } + } + + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("physical", Arc::new(TestTable(TableType::Base))) + .unwrap(); + ctx.register_table("query", Arc::new(TestTable(TableType::View))) + .unwrap(); + ctx.register_table("temp", Arc::new(TestTable(TableType::Temporary))) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+-----------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+-----------------+", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | public | physical | BASE TABLE |", + "| datafusion | public | query | VIEW |", + "| datafusion | public | temp | LOCAL TEMPORARY |", + "+---------------+--------------------+------------+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_show_tables_no_information_schema() { + let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + // use show tables alias + let err = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap_err(); + + assert_eq!(err.to_string(), "Error during planning: SHOW TABLES is not supported unless information_schema is enabled"); +} + +#[tokio::test] +async fn information_schema_show_tables() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + // use show tables alias + let result = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+------------+", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | public | t | BASE TABLE |", + "+---------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let result = plan_and_collect(&mut ctx, "SHOW tables").await.unwrap(); + + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_show_columns_no_information_schema() { + let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") + .await + .unwrap_err(); + + assert_eq!(err.to_string(), "Error during planning: SHOW COLUMNS is not supported unless information_schema is enabled"); +} + +#[tokio::test] +async fn information_schema_show_columns_like_where() { + let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let expected = + "Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported"; + + let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t LIKE 'f'") + .await + .unwrap_err(); + assert_eq!(err.to_string(), expected); + + let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t WHERE column_name = 'bar'") + .await + .unwrap_err(); + assert_eq!(err.to_string(), expected); +} + +#[tokio::test] +async fn information_schema_show_columns() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------+------------+-------------+-----------+-------------+", + "| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |", + "+---------------+--------------+------------+-------------+-----------+-------------+", + "| datafusion | public | t | i | Int32 | YES |", + "+---------------+--------------+------------+-------------+-----------+-------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let result = plan_and_collect(&mut ctx, "SHOW columns from t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &result); + + // This isn't ideal but it is consistent behavior for `SELECT * from T` + let err = plan_and_collect(&mut ctx, "SHOW columns from T") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Unknown relation for SHOW COLUMNS: T" + ); +} + +// test errors with WHERE and LIKE +#[tokio::test] +async fn information_schema_show_columns_full_extended() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SHOW FULL COLUMNS FROM t") + .await + .unwrap(); + let expected = vec![ + "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", + "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + "| datafusion | public | t | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", + "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let result = plan_and_collect(&mut ctx, "SHOW EXTENDED COLUMNS FROM t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_show_table_table_names() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM public.t") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------+------------+-------------+-----------+-------------+", + "| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |", + "+---------------+--------------+------------+-------------+-----------+-------------+", + "| datafusion | public | t | i | Int32 | YES |", + "+---------------+--------------+------------+-------------+-----------+-------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let result = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &result); + + let err = plan_and_collect(&mut ctx, "SHOW columns from t2") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Unknown relation for SHOW COLUMNS: t2" + ); + + let err = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t2") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Unknown relation for SHOW COLUMNS: datafusion.public.t2" + ); +} + +#[tokio::test] +async fn show_unsupported() { + let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + + let err = plan_and_collect(&mut ctx, "SHOW SOMETHING_UNKNOWN") + .await + .unwrap_err(); + + assert_eq!(err.to_string(), "This feature is not implemented: SHOW SOMETHING_UNKNOWN not implemented. Supported syntax: SHOW "); +} + +#[tokio::test] +async fn information_schema_columns_not_exist_by_default() { + let mut ctx = ExecutionContext::new(); + + let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Table or CTE with name 'information_schema.columns' not found" + ); +} + +fn table_with_many_types() -> Arc { + let schema = Schema::new(vec![ + Field::new("int32_col", DataType::Int32, false), + Field::new("float64_col", DataType::Float64, true), + Field::new("utf8_col", DataType::Utf8, true), + Field::new("large_utf8_col", DataType::LargeUtf8, false), + Field::new("binary_col", DataType::Binary, false), + Field::new("large_binary_col", DataType::LargeBinary, false), + Field::new( + "timestamp_nanos", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from_slice(&[1])), + Arc::new(Float64Array::from_slice(&[1.0])), + Arc::new(StringArray::from(vec![Some("foo")])), + Arc::new(LargeStringArray::from(vec![Some("bar")])), + Arc::new(BinaryArray::from_slice(&[b"foo" as &[u8]])), + Arc::new(LargeBinaryArray::from_slice(&[b"foo" as &[u8]])), + Arc::new(TimestampNanosecondArray::from_opt_vec( + vec![Some(123)], + None, + )), + ], + ) + .unwrap(); + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); + Arc::new(provider) +} + +#[tokio::test] +async fn information_schema_columns() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + + schema + .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) + .unwrap(); + + schema + .register_table("t2".to_owned(), table_with_many_types()) + .unwrap(); + catalog.register_schema("my_schema", Arc::new(schema)); + ctx.register_catalog("my_catalog", Arc::new(catalog)); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", + "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + "| my_catalog | my_schema | t1 | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", + "| my_catalog | my_schema | t2 | binary_col | 4 | | NO | Binary | | 2147483647 | | | | | |", + "| my_catalog | my_schema | t2 | float64_col | 1 | | YES | Float64 | | | 24 | 2 | | | |", + "| my_catalog | my_schema | t2 | int32_col | 0 | | NO | Int32 | | | 32 | 2 | | | |", + "| my_catalog | my_schema | t2 | large_binary_col | 5 | | NO | LargeBinary | | 9223372036854775807 | | | | | |", + "| my_catalog | my_schema | t2 | large_utf8_col | 3 | | NO | LargeUtf8 | | 9223372036854775807 | | | | | |", + "| my_catalog | my_schema | t2 | timestamp_nanos | 6 | | NO | Timestamp(Nanosecond, None) | | | | | | | |", + "| my_catalog | my_schema | t2 | utf8_col | 2 | | YES | Utf8 | | 2147483647 | | | | | |", + "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +/// Execute SQL and return results +async fn plan_and_collect( + ctx: &mut ExecutionContext, + sql: &str, +) -> Result> { + ctx.sql(sql).await?.collect().await +} diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs index 1f7599a80f16..38761d7570d7 100644 --- a/datafusion/tests/sql/joins.rs +++ b/datafusion/tests/sql/joins.rs @@ -16,7 +16,6 @@ // under the License. use super::*; -use datafusion::from_slice::FromSlice; #[tokio::test] async fn equijoin() -> Result<()> { @@ -881,3 +880,51 @@ async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Resul assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn join_timestamp() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let expected = vec![ + "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", + "| nanos | micros | millis | secs | name | nanos | micros | millis | secs | name |", + "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", + "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 | Row 1 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 | Row 1 |", + "| 2018-11-13 17:11:10.011375885 | 2018-11-13 17:11:10.011375 | 2018-11-13 17:11:10.011 | 2018-11-13 17:11:10 | Row 0 | 2018-11-13 17:11:10.011375885 | 2018-11-13 17:11:10.011375 | 2018-11-13 17:11:10.011 | 2018-11-13 17:11:10 | Row 0 |", + "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 | Row 3 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 | Row 3 |", + "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", + ]; + + let results = execute_to_batches( + &mut ctx, + "SELECT * FROM t as t1 \ + JOIN (SELECT * FROM t) as t2 \ + ON t1.nanos = t2.nanos", + ) + .await; + + assert_batches_sorted_eq!(expected, &results); + + let results = execute_to_batches( + &mut ctx, + "SELECT * FROM t as t1 \ + JOIN (SELECT * FROM t) as t2 \ + ON t1.micros = t2.micros", + ) + .await; + + assert_batches_sorted_eq!(expected, &results); + + let results = execute_to_batches( + &mut ctx, + "SELECT * FROM t as t1 \ + JOIN (SELECT * FROM t) as t2 \ + ON t1.millis = t2.millis", + ) + .await; + + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index 4685447258e7..26374cd5151b 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -25,7 +25,7 @@ use datafusion::assert_batches_eq; use datafusion::assert_batches_sorted_eq; use datafusion::assert_contains; use datafusion::assert_not_contains; -use datafusion::from_slice::FromSlice; +use datafusion::datasource::TableProvider; use datafusion::logical_plan::plan::{Aggregate, Projection}; use datafusion::logical_plan::LogicalPlan; use datafusion::logical_plan::TableScan; @@ -93,6 +93,7 @@ pub mod udf; pub mod union; pub mod window; +pub mod information_schema; #[cfg_attr(not(feature = "unicode_expressions"), ignore)] pub mod unicode; @@ -517,8 +518,15 @@ async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { Ok(()) } -/// Execute query and return result set as 2-d table of Vecs -/// `result[row][column]` +/// Execute SQL and return results as a RecordBatch +async fn plan_and_collect( + ctx: &mut ExecutionContext, + sql: &str, +) -> Result> { + ctx.sql(sql).await?.collect().await +} + +/// Execute query and return results as a Vec of RecordBatches async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec { let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx.create_logical_plan(sql).expect(&msg); @@ -532,7 +540,7 @@ async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec Result> { make_timestamp_table(TimeUnit::Nanosecond) } +/// Return a new table provider that has a single Int32 column with +/// values between `seq_start` and `seq_end` +pub fn table_with_sequence( + seq_start: i32, + seq_end: i32, +) -> Result> { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::>())); + let partitions = vec![vec![RecordBatch::try_new( + schema.clone(), + vec![arr as ArrayRef], + )?]]; + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) +} + // Normalizes parts of an explain plan that vary from run to run (such as path) fn normalize_for_explain(s: &str) -> String { // Convert things like /Users/alamb/Software/arrow/testing/data/csv/aggregate_test_100.csv @@ -666,6 +689,93 @@ fn normalize_vec_for_explain(v: Vec>) -> Vec> { .collect::>() } +/// Return a new table provider containing all of the supported timestamp types +pub fn table_with_timestamps() -> Arc { + let batch = make_timestamps(); + let schema = batch.schema(); + let partitions = vec![vec![batch]]; + Arc::new(MemTable::try_new(schema, partitions).unwrap()) +} + +/// Return record batch with all of the supported timestamp types +/// values +/// +/// Columns are named: +/// "nanos" --> TimestampNanosecondArray +/// "micros" --> TimestampMicrosecondArray +/// "millis" --> TimestampMillisecondArray +/// "secs" --> TimestampSecondArray +/// "names" --> StringArray +pub fn make_timestamps() -> RecordBatch { + let ts_strings = vec![ + Some("2018-11-13T17:11:10.011375885995"), + Some("2011-12-13T11:13:10.12345"), + None, + Some("2021-1-1T05:11:10.432"), + ]; + + let ts_nanos = ts_strings + .into_iter() + .map(|t| { + t.map(|t| { + t.parse::() + .unwrap() + .timestamp_nanos() + }) + }) + .collect::>(); + + let ts_micros = ts_nanos + .iter() + .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000)) + .collect::>(); + + let ts_millis = ts_nanos + .iter() + .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000)) + .collect::>(); + + let ts_secs = ts_nanos + .iter() + .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000000)) + .collect::>(); + + let names = ts_nanos + .iter() + .enumerate() + .map(|(i, _)| format!("Row {}", i)) + .collect::>(); + + let arr_nanos = TimestampNanosecondArray::from_opt_vec(ts_nanos, None); + let arr_micros = TimestampMicrosecondArray::from_opt_vec(ts_micros, None); + let arr_millis = TimestampMillisecondArray::from_opt_vec(ts_millis, None); + let arr_secs = TimestampSecondArray::from_opt_vec(ts_secs, None); + + let names = names.iter().map(|s| s.as_str()).collect::>(); + let arr_names = StringArray::from(names); + + let schema = Schema::new(vec![ + Field::new("nanos", arr_nanos.data_type().clone(), true), + Field::new("micros", arr_micros.data_type().clone(), true), + Field::new("millis", arr_millis.data_type().clone(), true), + Field::new("secs", arr_secs.data_type().clone(), true), + Field::new("name", arr_names.data_type().clone(), true), + ]); + let schema = Arc::new(schema); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(arr_nanos), + Arc::new(arr_micros), + Arc::new(arr_millis), + Arc::new(arr_secs), + Arc::new(arr_names), + ], + ) + .unwrap() +} + #[tokio::test] async fn nyc() -> Result<()> { // schema for nyxtaxi csv files diff --git a/datafusion/tests/sql/parquet.rs b/datafusion/tests/sql/parquet.rs index f3f798dc072c..9e6e419a6246 100644 --- a/datafusion/tests/sql/parquet.rs +++ b/datafusion/tests/sql/parquet.rs @@ -54,7 +54,7 @@ async fn parquet_single_nan_schema() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.unwrap(); for batch in results { assert_eq!(1, batch.num_rows()); @@ -91,7 +91,7 @@ async fn parquet_list_columns() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.unwrap(); // int64_list utf8_list diff --git a/datafusion/tests/sql/select.rs b/datafusion/tests/sql/select.rs index ca9a33b26810..89fd6f2b1571 100644 --- a/datafusion/tests/sql/select.rs +++ b/datafusion/tests/sql/select.rs @@ -16,7 +16,6 @@ // under the License. use super::*; -use datafusion::from_slice::FromSlice; #[tokio::test] async fn all_where_empty() -> Result<()> { diff --git a/datafusion/tests/sql/timestamp.rs b/datafusion/tests/sql/timestamp.rs index 28a5c5d09a2b..3dfcf552cabb 100644 --- a/datafusion/tests/sql/timestamp.rs +++ b/datafusion/tests/sql/timestamp.rs @@ -16,7 +16,6 @@ // under the License. use super::*; -use datafusion::from_slice::FromSlice; #[tokio::test] async fn query_cast_timestamp_millis() -> Result<()> { @@ -388,7 +387,7 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { let plan = ctx.create_physical_plan(&plan).await.expect(&msg); let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let res = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&res);