From 97fd9fff3353eff3d4fd4441c29cd91c19fce939 Mon Sep 17 00:00:00 2001 From: Gang Liao Date: Sat, 5 Jun 2021 10:52:14 -0400 Subject: [PATCH 01/25] Support modulus op --- .../src/physical_plan/expressions/binary.rs | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index 5c2d9ce02f51..5635ab8ae4b7 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -19,7 +19,7 @@ use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::arithmetic::{ - add, divide, divide_scalar, multiply, subtract, + add, divide, divide_scalar, multiply, subtract, modulus, modulus_scalar }; use arrow::compute::kernels::boolean::{and_kleene, or_kleene}; use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; @@ -341,14 +341,9 @@ fn common_binary_type( } // for math expressions, the final value of the coercion is also the return type // because coercion favours higher information types - Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => { + Operator::Plus | Operator::Minus | Operator::Modulus | Operator::Divide | Operator::Multiply => { numerical_coercion(lhs_type, rhs_type) } - Operator::Modulus => { - return Err(DataFusionError::NotImplemented( - "Modulus operator is still not supported".to_string(), - )) - } }; // re-write the error message of failed coercions to include the operator's information @@ -389,12 +384,9 @@ pub fn binary_operator_data_type( | Operator::GtEq | Operator::LtEq => Ok(DataType::Boolean), // math operations return the same value as the common coerced type - Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => { + Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply | Operator::Modulus => { Ok(common_type) } - Operator::Modulus => Err(DataFusionError::NotImplemented( - "Modulus operator is still not supported".to_string(), - )), } } @@ -454,6 +446,9 @@ impl PhysicalExpr for BinaryExpr { Operator::Divide => { binary_primitive_array_op_scalar!(array, scalar.clone(), divide) } + Operator::Modulus => { + binary_primitive_array_op_scalar!(array, scalar.clone(), modulus) + } // if scalar operation is not supported - fallback to array implementation _ => None, } @@ -503,6 +498,7 @@ impl PhysicalExpr for BinaryExpr { Operator::Minus => binary_primitive_array_op!(left, right, subtract), Operator::Multiply => binary_primitive_array_op!(left, right, multiply), Operator::Divide => binary_primitive_array_op!(left, right, divide), + Operator::Modulus => binary_primitive_array_op!(left, right, modulus), Operator::And => { if left_data_type == DataType::Boolean { boolean_op!(left, right, and_kleene) @@ -525,9 +521,6 @@ impl PhysicalExpr for BinaryExpr { ))); } } - Operator::Modulus => Err(DataFusionError::NotImplemented( - "Modulus operator is still not supported".to_string(), - )), }; result.map(|a| ColumnarValue::Array(a)) } @@ -964,6 +957,25 @@ mod tests { Ok(()) } + #[test] + fn modulus_op() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048])); + let b = Arc::new(Int32Array::from(vec![2, 4, 7, 14, 32])); + + apply_arithmetic::( + schema, + vec![a, b], + Operator::Modulus, + Int32Array::from(vec![0, 0, 2, 8, 0]), + )?; + + Ok(()) + } + fn apply_arithmetic( schema: SchemaRef, data: Vec, From b84789afc5a67e3f70cd8903bf96993b13414aaf Mon Sep 17 00:00:00 2001 From: sathis Date: Sun, 6 Jun 2021 03:04:24 +0530 Subject: [PATCH 02/25] Optimize cast function during planning stage (#513) Co-authored-by: Sathis Kumar --- datafusion/src/optimizer/constant_folding.rs | 58 ++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 97cc23264bda..d2ac5ce2f383 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -30,6 +30,7 @@ use crate::optimizer::utils; use crate::physical_plan::functions::BuiltinScalarFunction; use crate::scalar::ScalarValue; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; +use arrow::compute::{kernels, DEFAULT_CAST_OPTIONS}; /// Optimizer that simplifies comparison expressions involving boolean literals. /// @@ -247,6 +248,25 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { } } } + Expr::Cast { + expr: inner, + data_type, + } => match inner.as_ref() { + Expr::Literal(val) => { + let scalar_array = val.to_array(); + let cast_array = kernels::cast::cast_with_options( + &scalar_array, + &data_type, + &DEFAULT_CAST_OPTIONS, + )?; + let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; + Expr::Literal(cast_scalar) + } + _ => Expr::Cast { + expr: inner, + data_type, + }, + }, expr => { // no rewrite possible expr @@ -724,6 +744,44 @@ mod tests { assert_eq!(expected, actual); } + #[test] + fn cast_expr() { + let table_scan = test_table_scan().unwrap(); + let proj = vec![Expr::Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("0".to_string())))), + data_type: DataType::Int32, + }]; + let plan = LogicalPlanBuilder::from(&table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let expected = "Projection: Int32(0)\ + \n TableScan: test projection=None"; + let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); + assert_eq!(expected, actual); + } + + #[test] + fn cast_expr_wrong_arg() { + let table_scan = test_table_scan().unwrap(); + let proj = vec![Expr::Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("".to_string())))), + data_type: DataType::Int32, + }]; + let plan = LogicalPlanBuilder::from(&table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let expected = "Projection: Int32(NULL)\ + \n TableScan: test projection=None"; + let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); + assert_eq!(expected, actual); + } + #[test] fn single_now_expr() { let table_scan = test_table_scan().unwrap(); From ee2b9ef049954173231b987f86b4d8eace0d3e79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 6 Jun 2021 09:56:43 +0200 Subject: [PATCH 03/25] Fix display of execution time (#514) --- datafusion-cli/src/lib.rs | 7 +++---- datafusion-cli/src/main.rs | 5 ++++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs index 5bd16e333030..5b110d315364 100644 --- a/datafusion-cli/src/lib.rs +++ b/datafusion-cli/src/lib.rs @@ -29,17 +29,16 @@ pub struct PrintOptions { fn print_timing_info(row_count: usize, now: Instant) { println!( - "{} {} in set. Query took {} seconds.", + "{} {} in set. Query took {:.3} seconds.", row_count, if row_count == 1 { "row" } else { "rows" }, - now.elapsed().as_secs() + now.elapsed().as_secs_f64() ); } impl PrintOptions { /// print the batches to stdout using the specified format - pub fn print_batches(&self, batches: &[RecordBatch]) -> Result<()> { - let now = Instant::now(); + pub fn print_batches(&self, batches: &[RecordBatch], now: Instant) -> Result<()> { if batches.is_empty() { if !self.quiet { print_timing_info(0, now); diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 083710f6dd19..39ce02ffbfd8 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -30,6 +30,7 @@ use std::fs::File; use std::io::prelude::*; use std::io::BufReader; use std::path::Path; +use std::time::Instant; #[tokio::main] pub async fn main() { @@ -238,7 +239,9 @@ async fn exec_and_print( sql: String, ) -> Result<()> { let df = ctx.sql(&sql)?; + let now = Instant::now(); let results = df.collect().await?; - print_options.print_batches(&results)?; + + print_options.print_batches(&results, now)?; Ok(()) } From 767eeb0a8bf17916aafb9a88abd52e7350acb596 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 7 Jun 2021 18:14:25 +0800 Subject: [PATCH 04/25] closing up type checks (#506) --- ballista/rust/core/Cargo.toml | 2 +- ballista/rust/core/proto/ballista.proto | 6 +- .../core/src/serde/logical_plan/from_proto.rs | 49 +-- .../core/src/serde/logical_plan/to_proto.rs | 56 ++- .../src/serde/physical_plan/from_proto.rs | 1 + datafusion/src/logical_plan/expr.rs | 50 ++- datafusion/src/optimizer/utils.rs | 5 +- datafusion/src/physical_plan/mod.rs | 1 + datafusion/src/physical_plan/planner.rs | 3 +- datafusion/src/physical_plan/window_frames.rs | 337 ++++++++++++++++++ datafusion/src/sql/planner.rs | 52 ++- datafusion/src/sql/utils.rs | 12 + 12 files changed, 512 insertions(+), 62 deletions(-) create mode 100644 datafusion/src/physical_plan/window_frames.rs diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 99822cfe2aee..1f23a2a42e2a 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -35,7 +35,7 @@ futures = "0.3" log = "0.4" prost = "0.7" serde = {version = "1", features = ["derive"]} -sqlparser = "0.8" +sqlparser = "0.9.0" tokio = "1.0" tonic = "0.4" uuid = { version = "0.8", features = ["v4"] } diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 0ed9f243fd0a..38d87e934e5f 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -177,9 +177,9 @@ message WindowExprNode { // repeated LogicalExprNode partition_by = 5; repeated LogicalExprNode order_by = 6; // repeated LogicalExprNode filter = 7; - // oneof window_frame { - // WindowFrame frame = 8; - // } + oneof window_frame { + WindowFrame frame = 8; + } } message BetweenNode { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 662d9d0a929a..4a198174a2ba 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -20,12 +20,6 @@ use crate::error::BallistaError; use crate::serde::{proto_error, protobuf}; use crate::{convert_box_required, convert_required}; -use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -use std::{ - convert::{From, TryInto}, - unimplemented, -}; - use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, @@ -33,10 +27,17 @@ use datafusion::logical_plan::{ }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; +use datafusion::physical_plan::window_frames::{ + WindowFrame, WindowFrameBound, WindowFrameUnits, +}; use datafusion::physical_plan::window_functions::BuiltInWindowFunction; use datafusion::scalar::ScalarValue; use protobuf::logical_plan_node::LogicalPlanType; use protobuf::{logical_expr_node::ExprType, scalar_type}; +use std::{ + convert::{From, TryInto}, + unimplemented, +}; // use uuid::Uuid; @@ -83,20 +84,6 @@ impl TryInto for &protobuf::LogicalPlanNode { .iter() .map(|expr| expr.try_into()) .collect::, _>>()?; - - // let partition_by_expr = window - // .partition_by_expr - // .iter() - // .map(|expr| expr.try_into()) - // .collect::, _>>()?; - // let order_by_expr = window - // .order_by_expr - // .iter() - // .map(|expr| expr.try_into()) - // .collect::, _>>()?; - // // FIXME: add filter by expr - // // FIXME: parse the window_frame data - // let window_frame = None; LogicalPlanBuilder::from(&input) .window(window_expr)? .build() @@ -929,6 +916,15 @@ impl TryInto for &protobuf::LogicalExprNode { .map(|e| e.try_into()) .into_iter() .collect::, _>>()?; + let window_frame = expr + .window_frame + .as_ref() + .map::, _>(|e| match e { + window_expr_node::WindowFrame::Frame(frame) => { + frame.clone().try_into() + } + }) + .transpose()?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = protobuf::AggregateFunction::from_i32(*i) @@ -945,6 +941,7 @@ impl TryInto for &protobuf::LogicalExprNode { ), args: vec![parse_required_expr(&expr.expr)?], order_by, + window_frame, }) } window_expr_node::WindowFunction::BuiltInFunction(i) => { @@ -964,6 +961,7 @@ impl TryInto for &protobuf::LogicalExprNode { ), args: vec![parse_required_expr(&expr.expr)?], order_by, + window_frame, }) } } @@ -1333,8 +1331,15 @@ impl TryFrom for WindowFrame { ) })? .try_into()?; - // FIXME parse end bound - let end_bound = None; + let end_bound = window + .end_bound + .map(|end_bound| match end_bound { + protobuf::window_frame::EndBound::Bound(end_bound) => { + end_bound.try_into() + } + }) + .transpose()? + .unwrap_or(WindowFrameBound::CurrentRow); Ok(WindowFrame { units, start_bound, diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index d7734f05da56..56270030b59f 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -24,12 +24,17 @@ use std::{ convert::{TryFrom, TryInto}, }; +use super::super::proto_error; use crate::datasource::DfTableAdapter; use crate::serde::{protobuf, BallistaError}; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::datasource::CsvFile; use datafusion::logical_plan::{Expr, JoinType, LogicalPlan}; use datafusion::physical_plan::aggregates::AggregateFunction; +use datafusion::physical_plan::functions::BuiltinScalarFunction; +use datafusion::physical_plan::window_frames::{ + WindowFrame, WindowFrameBound, WindowFrameUnits, +}; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -38,10 +43,6 @@ use protobuf::{ arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, PrimitiveScalarType, ScalarListValue, ScalarType, }; -use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits}; - -use super::super::proto_error; -use datafusion::physical_plan::functions::BuiltinScalarFunction; impl protobuf::IntervalUnit { pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self { @@ -1007,6 +1008,7 @@ impl TryInto for &Expr { ref fun, ref args, ref order_by, + ref window_frame, .. } => { let window_function = match fun { @@ -1026,10 +1028,16 @@ impl TryInto for &Expr { .iter() .map(|e| e.try_into()) .collect::, _>>()?; + let window_frame = window_frame.map(|window_frame| { + protobuf::window_expr_node::WindowFrame::Frame( + window_frame.clone().into(), + ) + }); let window_expr = Box::new(protobuf::WindowExprNode { expr: Some(Box::new(arg.try_into()?)), window_function: Some(window_function), order_by, + window_frame, }); Ok(protobuf::LogicalExprNode { expr_type: Some(ExprType::WindowExpr(window_expr)), @@ -1256,23 +1264,35 @@ impl From for protobuf::WindowFrameUnits { } } -impl TryFrom for protobuf::WindowFrameBound { - type Error = BallistaError; - - fn try_from(_bound: WindowFrameBound) -> Result { - Err(BallistaError::NotImplemented( - "WindowFrameBound => protobuf::WindowFrameBound".to_owned(), - )) +impl From for protobuf::WindowFrameBound { + fn from(bound: WindowFrameBound) -> Self { + match bound { + WindowFrameBound::CurrentRow => protobuf::WindowFrameBound { + window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow + .into(), + bound_value: None, + }, + WindowFrameBound::Preceding(v) => protobuf::WindowFrameBound { + window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding.into(), + bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), + }, + WindowFrameBound::Following(v) => protobuf::WindowFrameBound { + window_frame_bound_type: protobuf::WindowFrameBoundType::Following.into(), + bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), + }, + } } } -impl TryFrom for protobuf::WindowFrame { - type Error = BallistaError; - - fn try_from(_window: WindowFrame) -> Result { - Err(BallistaError::NotImplemented( - "WindowFrame => protobuf::WindowFrame".to_owned(), - )) +impl From for protobuf::WindowFrame { + fn from(window: WindowFrame) -> Self { + protobuf::WindowFrame { + window_frame_units: protobuf::WindowFrameUnits::from(window.units).into(), + start_bound: Some(window.start_bound.into()), + end_bound: Some(protobuf::window_frame::EndBound::Bound( + window.end_bound.into(), + )), + } } } diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 22944313666f..5fcc971527c6 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -237,6 +237,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { fun, args, order_by, + .. } => { let arg = df_planner .create_physical_expr( diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 5103d5dc5051..bbc6ffabe928 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -19,22 +19,19 @@ //! such as `col = 5` or `SUM(col)`. See examples on the [`Expr`] struct. pub use super::Operator; - -use std::fmt; -use std::sync::Arc; - -use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; -use arrow::{compute::can_cast_types, datatypes::DataType}; - use crate::error::{DataFusionError, Result}; use crate::logical_plan::{DFField, DFSchema}; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, - window_functions, + window_frames, window_functions, }; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; +use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; +use arrow::{compute::can_cast_types, datatypes::DataType}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; use std::collections::HashSet; +use std::fmt; +use std::sync::Arc; /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS @@ -199,6 +196,8 @@ pub enum Expr { args: Vec, /// List of order by expressions order_by: Vec, + /// Window frame + window_frame: Option, }, /// aggregate function AggregateUDF { @@ -735,10 +734,12 @@ impl Expr { args, fun, order_by, + window_frame, } => Expr::WindowFunction { args: rewrite_vec(args, rewriter)?, fun, order_by: rewrite_vec(order_by, rewriter)?, + window_frame, }, Expr::AggregateFunction { args, @@ -1283,8 +1284,23 @@ impl fmt::Debug for Expr { Expr::ScalarUDF { fun, ref args, .. } => { fmt_function(f, &fun.name, false, args) } - Expr::WindowFunction { fun, ref args, .. } => { - fmt_function(f, &fun.to_string(), false, args) + Expr::WindowFunction { + fun, + ref args, + window_frame, + .. + } => { + fmt_function(f, &fun.to_string(), false, args)?; + if let Some(window_frame) = window_frame { + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + )?; + } + Ok(()) } Expr::AggregateFunction { fun, @@ -1401,8 +1417,18 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { Expr::ScalarUDF { fun, args, .. } => { create_function_name(&fun.name, false, args, input_schema) } - Expr::WindowFunction { fun, args, .. } => { - create_function_name(&fun.to_string(), false, args, input_schema) + Expr::WindowFunction { + fun, + args, + window_frame, + .. + } => { + let fun_name = + create_function_name(&fun.to_string(), false, args, input_schema)?; + Ok(match window_frame { + Some(window_frame) => format!("{} {}", fun_name, window_frame), + None => fun_name, + }) } Expr::AggregateFunction { fun, diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 2cb65066feb9..65c95bee20d4 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -337,7 +337,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions.to_vec(), }), - Expr::WindowFunction { fun, .. } => { + Expr::WindowFunction { + fun, window_frame, .. + } => { let index = expressions .iter() .position(|expr| { @@ -353,6 +355,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions[..index].to_vec(), order_by: expressions[index + 1..].to_vec(), + window_frame: *window_frame, }) } Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index af6969c43cbd..490e02875c42 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -617,5 +617,6 @@ pub mod udf; #[cfg(feature = "unicode_expressions")] pub mod unicode_expressions; pub mod union; +pub mod window_frames; pub mod window_functions; pub mod windows; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 754ace08de6a..d7451c787096 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -17,8 +17,6 @@ //! Physical query planner -use std::sync::Arc; - use super::{ aggregates, cross_join::CrossJoinExec, empty::EmptyExec, expressions::binary, functions, hash_join::PartitionMode, udaf, union::UnionExec, windows, @@ -56,6 +54,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::{compute::can_cast_types, datatypes::DataType}; use expressions::col; use log::debug; +use std::sync::Arc; /// This trait exposes the ability to plan an [`ExecutionPlan`] out of a [`LogicalPlan`]. pub trait ExtensionPlanner { diff --git a/datafusion/src/physical_plan/window_frames.rs b/datafusion/src/physical_plan/window_frames.rs new file mode 100644 index 000000000000..f0be5a221fbf --- /dev/null +++ b/datafusion/src/physical_plan/window_frames.rs @@ -0,0 +1,337 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Window frame +//! +//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: +//! - A frame type - either ROWS, RANGE or GROUPS, +//! - A starting frame boundary, +//! - An ending frame boundary, +//! - An EXCLUDE clause. + +use crate::error::{DataFusionError, Result}; +use sqlparser::ast; +use std::cmp::Ordering; +use std::convert::{From, TryFrom}; +use std::fmt; + +/// The frame-spec determines which output rows are read by an aggregate window function. +/// +/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the +/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to +/// CURRENT ROW. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct WindowFrame { + /// A frame type - either ROWS, RANGE or GROUPS + pub units: WindowFrameUnits, + /// A starting frame boundary + pub start_bound: WindowFrameBound, + /// An ending frame boundary + pub end_bound: WindowFrameBound, +} + +impl fmt::Display for WindowFrame { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} BETWEEN {} AND {}", + self.units, self.start_bound, self.end_bound + )?; + Ok(()) + } +} + +impl TryFrom for WindowFrame { + type Error = DataFusionError; + + fn try_from(value: ast::WindowFrame) -> Result { + let start_bound = value.start_bound.into(); + let end_bound = value + .end_bound + .map(WindowFrameBound::from) + .unwrap_or(WindowFrameBound::CurrentRow); + + if let WindowFrameBound::Following(None) = start_bound { + Err(DataFusionError::Execution( + "Invalid window frame: start bound cannot be unbounded following" + .to_owned(), + )) + } else if let WindowFrameBound::Preceding(None) = end_bound { + Err(DataFusionError::Execution( + "Invalid window frame: end bound cannot be unbounded preceding" + .to_owned(), + )) + } else if start_bound > end_bound { + Err(DataFusionError::Execution(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + start_bound, end_bound + ))) + } else { + let units = value.units.into(); + Ok(Self { + units, + start_bound, + end_bound, + }) + } + } +} + +impl Default for WindowFrame { + fn default() -> Self { + WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(None), + end_bound: WindowFrameBound::CurrentRow, + } + } +} + +/// There are five ways to describe starting and ending frame boundaries: +/// +/// 1. UNBOUNDED PRECEDING +/// 2. PRECEDING +/// 3. CURRENT ROW +/// 4. FOLLOWING +/// 5. UNBOUNDED FOLLOWING +/// +/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) +#[derive(Debug, Clone, Copy, Eq)] +pub enum WindowFrameBound { + /// 1. UNBOUNDED PRECEDING + /// The frame boundary is the first row in the partition. + /// + /// 2. PRECEDING + /// must be a non-negative constant numeric expression. The boundary is a row that + /// is "units" prior to the current row. + Preceding(Option), + /// 3. The current row. + /// + /// For RANGE and GROUPS frame types, peers of the current row are also + /// included in the frame, unless specifically excluded by the EXCLUDE clause. + /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame + /// boundary. + CurrentRow, + /// 4. This is the same as " PRECEDING" except that the boundary is units after the + /// current rather than before the current row. + /// + /// 5. UNBOUNDED FOLLOWING + /// The frame boundary is the last row in the partition. + Following(Option), +} + +impl From for WindowFrameBound { + fn from(value: ast::WindowFrameBound) -> Self { + match value { + ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), + ast::WindowFrameBound::Following(v) => Self::Following(v), + ast::WindowFrameBound::CurrentRow => Self::CurrentRow, + } + } +} + +impl fmt::Display for WindowFrameBound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), + WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), + WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), + WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), + WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), + } + } +} + +impl PartialEq for WindowFrameBound { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for WindowFrameBound { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for WindowFrameBound { + fn cmp(&self, other: &Self) -> Ordering { + self.get_rank().cmp(&other.get_rank()) + } +} + +impl WindowFrameBound { + /// get the rank of this window frame bound. + /// + /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value + /// which requires special handling e.g. with preceding the larger the value the smaller the + /// rank and also for 0 preceding / following it is the same as current row + fn get_rank(&self) -> (u8, u64) { + match self { + WindowFrameBound::Preceding(None) => (0, 0), + WindowFrameBound::Following(None) => (4, 0), + WindowFrameBound::Preceding(Some(0)) + | WindowFrameBound::CurrentRow + | WindowFrameBound::Following(Some(0)) => (2, 0), + WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), + WindowFrameBound::Following(Some(v)) => (3, *v), + } + } +} + +/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the +/// starting and ending boundaries of the frame are measured. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WindowFrameUnits { + /// The ROWS frame type means that the starting and ending boundaries for the frame are + /// determined by counting individual rows relative to the current row. + Rows, + /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one + /// term. Call that term "X". With the RANGE frame type, the elements of the frame are + /// determined by computing the value of expression X for all rows in the partition and framing + /// those rows for which the value of X is within a certain range of the value of X for the + /// current row. + Range, + /// The GROUPS frame type means that the starting and ending boundaries are determine + /// by counting "groups" relative to the current group. A "group" is a set of rows that all have + /// equivalent values for all all terms of the window ORDER BY clause. + Groups, +} + +impl fmt::Display for WindowFrameUnits { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + WindowFrameUnits::Rows => "ROWS", + WindowFrameUnits::Range => "RANGE", + WindowFrameUnits::Groups => "GROUPS", + }) + } +} + +impl From for WindowFrameUnits { + fn from(value: ast::WindowFrameUnits) -> Self { + match value { + ast::WindowFrameUnits::Range => Self::Range, + ast::WindowFrameUnits::Groups => Self::Groups, + ast::WindowFrameUnits::Rows => Self::Rows, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_window_frame_creation() -> Result<()> { + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Following(None), + end_bound: None, + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(None), + end_bound: Some(ast::WindowFrameBound::Preceding(None)), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(Some(1)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() + ); + Ok(()) + } + + #[test] + fn test_eq() { + assert_eq!( + WindowFrameBound::Preceding(Some(0)), + WindowFrameBound::CurrentRow + ); + assert_eq!( + WindowFrameBound::CurrentRow, + WindowFrameBound::Following(Some(0)) + ); + assert_eq!( + WindowFrameBound::Following(Some(2)), + WindowFrameBound::Following(Some(2)) + ); + assert_eq!( + WindowFrameBound::Following(None), + WindowFrameBound::Following(None) + ); + assert_eq!( + WindowFrameBound::Preceding(Some(2)), + WindowFrameBound::Preceding(Some(2)) + ); + assert_eq!( + WindowFrameBound::Preceding(None), + WindowFrameBound::Preceding(None) + ); + } + + #[test] + fn test_ord() { + assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); + // ! yes this is correct! + assert!( + WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) + ); + assert!( + WindowFrameBound::Preceding(Some(u64::MAX)) + < WindowFrameBound::Preceding(Some(u64::MAX - 1)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(1000000)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(u64::MAX)) + ); + assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); + assert!( + WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) + ); + assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); + assert!( + WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) + ); + assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); + assert!( + WindowFrameBound::Following(Some(u64::MAX)) + < WindowFrameBound::Following(None) + ); + } +} diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index aa6b5a93f483..6bf7b776c8db 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1121,13 +1121,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // then, window function if let Some(window) = &function.over { - if window.partition_by.is_empty() && window.window_frame.is_none() { + if window.partition_by.is_empty() { let order_by = window .order_by .iter() .map(|e| self.order_by_to_sort_expr(e)) .into_iter() .collect::>>()?; + let window_frame = window + .window_frame + .as_ref() + .map(|window_frame| window_frame.clone().try_into()) + .transpose()?; let fun = window_functions::WindowFunction::from_str(&name); if let Ok(window_functions::WindowFunction::AggregateFunction( aggregate_fun, @@ -1140,6 +1145,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { args: self .aggregate_fn_to_expr(&aggregate_fun, function)?, order_by, + window_frame, }); } else if let Ok( window_functions::WindowFunction::BuiltInWindowFunction( @@ -1151,8 +1157,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: window_functions::WindowFunction::BuiltInWindowFunction( window_fun, ), - args:self.function_args_to_expr(function)?, - order_by + args: self.function_args_to_expr(function)?, + order_by, + window_frame, }); } } @@ -2806,6 +2813,45 @@ mod tests { quick_test(sql, expected); } + #[test] + fn over_order_by_with_window_frame_double_end() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn over_order_by_with_window_frame_single_end() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn over_order_by_with_window_frame_single_end_groups() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + /// psql result /// ``` /// QUERY PLAN diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 80a25d04468f..7a5dc0da1b53 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -239,6 +239,7 @@ where fun, args, order_by, + window_frame, } => Ok(Expr::WindowFunction { fun: fun.clone(), args: args @@ -249,6 +250,7 @@ where .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, + window_frame: *window_frame, }), Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF { fun: fun.clone(), @@ -453,21 +455,25 @@ mod tests { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], order_by: vec![], + window_frame: None, }; // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; @@ -500,21 +506,25 @@ mod tests { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![age_asc.clone(), name_desc.clone()], + window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], order_by: vec![age_asc.clone(), name_desc.clone()], + window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], + window_frame: None, }; // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; @@ -551,6 +561,7 @@ mod tests { nulls_first: true, }, ], + window_frame: None, }, Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), @@ -572,6 +583,7 @@ mod tests { nulls_first: true, }, ], + window_frame: None, }, ]; let expected = vec![ From 5773a03fe6f03f00d5aa78b219cc46009611cca7 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 7 Jun 2021 22:51:04 +0800 Subject: [PATCH 05/25] refactor sort exec stream and combine batches (#515) --- datafusion/src/physical_plan/common.rs | 96 +++++++++++++++++++++++--- datafusion/src/physical_plan/sort.rs | 86 +++++++++-------------- integration-tests/sqls/simple_sort.sql | 22 ++++++ integration-tests/test_psql_parity.py | 2 +- 4 files changed, 140 insertions(+), 66 deletions(-) create mode 100644 integration-tests/sqls/simple_sort.sql diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index e60963bbb5b7..2482bfc0872c 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -17,24 +17,22 @@ //! Defines common code used in execution plans -use std::fs; -use std::fs::metadata; -use std::sync::Arc; -use std::task::{Context, Poll}; - +use super::{RecordBatchStream, SendableRecordBatchStream}; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::ExecutionPlan; +use arrow::compute::concat; use arrow::datatypes::SchemaRef; +use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use futures::channel::mpsc; use futures::{SinkExt, Stream, StreamExt, TryStreamExt}; +use std::fs; +use std::fs::metadata; +use std::sync::Arc; +use std::task::{Context, Poll}; use tokio::task::JoinHandle; -use crate::arrow::error::ArrowError; -use crate::error::{DataFusionError, Result}; -use crate::physical_plan::ExecutionPlan; - -use super::{RecordBatchStream, SendableRecordBatchStream}; - /// Stream of record batches pub struct SizedRecordBatchStream { schema: SchemaRef, @@ -83,6 +81,32 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result ArrowResult> { + if batches.is_empty() { + Ok(None) + } else { + let columns = schema + .fields() + .iter() + .enumerate() + .map(|(i, _)| { + concat( + &batches + .iter() + .map(|batch| batch.column(i).as_ref()) + .collect::>(), + ) + }) + .collect::>>()?; + Ok(Some(RecordBatch::try_new(schema.clone(), columns)?)) + } +} + /// Recursively builds a list of files in a directory with a given extension pub fn build_file_list(dir: &str, ext: &str) -> Result> { let mut filenames: Vec = Vec::new(); @@ -144,3 +168,53 @@ pub(crate) fn spawn_execution( } }) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::{Float32Array, Float64Array}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; + + #[test] + fn test_combine_batches_empty() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("f32", DataType::Float32, false), + Field::new("f64", DataType::Float64, false), + ])); + let result = combine_batches(&[], schema)?; + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn test_combine_batches() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("f32", DataType::Float32, false), + Field::new("f64", DataType::Float64, false), + ])); + + let batch_count = 1000; + let batch_size = 10; + let batches = (0..batch_count) + .map(|i| { + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Float32Array::from(vec![i as f32; batch_size])), + Arc::new(Float64Array::from(vec![i as f64; batch_size])), + ], + ) + .unwrap() + }) + .collect::>(); + + let result = combine_batches(&batches, schema)?; + assert!(result.is_some()); + let result = result.unwrap(); + assert_eq!(batch_count * batch_size, result.num_rows()); + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index c5b838c6e84b..7747030d8a93 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -17,32 +17,28 @@ //! Defines the SORT plan -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::time::Instant; - -use async_trait::async_trait; -use futures::stream::Stream; -use futures::Future; -use hashbrown::HashMap; - -use pin_project_lite::pin_project; - -pub use arrow::compute::SortOptions; -use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions}; -use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; -use arrow::{array::ArrayRef, error::ArrowError}; - use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, SQLMetric, }; +pub use arrow::compute::SortOptions; +use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions}; +use arrow::datatypes::SchemaRef; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; +use arrow::{array::ArrayRef, error::ArrowError}; +use async_trait::async_trait; +use futures::stream::Stream; +use futures::Future; +use hashbrown::HashMap; +use pin_project_lite::pin_project; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Instant; /// Sort execution plan #[derive(Debug)] @@ -190,47 +186,25 @@ impl ExecutionPlan for SortExec { } } -fn sort_batches( - batches: &[RecordBatch], - schema: &SchemaRef, +fn sort_batch( + batch: RecordBatch, + schema: SchemaRef, expr: &[PhysicalSortExpr], -) -> ArrowResult> { - if batches.is_empty() { - return Ok(None); - } - // combine all record batches into one for each column - let combined_batch = RecordBatch::try_new( - schema.clone(), - schema - .fields() - .iter() - .enumerate() - .map(|(i, _)| { - concat( - &batches - .iter() - .map(|batch| batch.column(i).as_ref()) - .collect::>(), - ) - }) - .collect::>>()?, - )?; - - // sort combined record batch +) -> ArrowResult { // TODO: pushup the limit expression to sort let indices = lexsort_to_indices( &expr .iter() - .map(|e| e.evaluate_to_sort_column(&combined_batch)) + .map(|e| e.evaluate_to_sort_column(&batch)) .collect::>>() .map_err(DataFusionError::into_arrow_external_error)?, None, )?; // reorder all rows based on sorted indices - let sorted_batch = RecordBatch::try_new( - schema.clone(), - combined_batch + RecordBatch::try_new( + schema, + batch .columns() .iter() .map(|column| { @@ -245,8 +219,7 @@ fn sort_batches( ) }) .collect::>>()?, - ); - sorted_batch.map(Some) + ) } pin_project! { @@ -277,9 +250,14 @@ impl SortStream { .map_err(DataFusionError::into_arrow_external_error) .and_then(move |batches| { let now = Instant::now(); - let result = sort_batches(&batches, &schema, &expr); + // combine all record batches into one for each column + let combined = common::combine_batches(&batches, schema.clone())?; + // sort combined record batch + let result = combined + .map(|batch| sort_batch(batch, schema, &expr)) + .transpose()?; sort_time.add(now.elapsed().as_nanos() as usize); - result + Ok(result) }); tx.send(sorted_batch) diff --git a/integration-tests/sqls/simple_sort.sql b/integration-tests/sqls/simple_sort.sql new file mode 100644 index 000000000000..50fb12dfdc70 --- /dev/null +++ b/integration-tests/sqls/simple_sort.sql @@ -0,0 +1,22 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +SELECT + c2, + c3, + c10 +FROM test +ORDER BY c2 ASC, c3 DESC, c10; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index 5bd308180e59..51861c583f8a 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase): def test_parity(self): root = Path(os.path.dirname(__file__)) / "sqls" files = set(root.glob("*.sql")) - self.assertEqual(len(files), 5, msg="tests are missed") + self.assertEqual(len(files), 6, msg="tests are missed") for fname in files: with self.subTest(fname=fname): datafusion_output = pd.read_csv( From 63accf8630e734cd96ba11baa9a89b437703acc5 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 7 Jun 2021 22:56:39 +0800 Subject: [PATCH 06/25] closing up type checks (#518) --- .../core/src/serde/logical_plan/from_proto.rs | 6 +- .../core/src/serde/logical_plan/to_proto.rs | 17 +- datafusion/src/logical_plan/expr.rs | 4 +- datafusion/src/logical_plan/mod.rs | 1 + datafusion/src/logical_plan/window_frames.rs | 337 ++++++++++++++++++ 5 files changed, 351 insertions(+), 14 deletions(-) create mode 100644 datafusion/src/logical_plan/window_frames.rs diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 4a198174a2ba..36a37a1e472c 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -21,15 +21,15 @@ use crate::error::BallistaError; use crate::serde::{proto_error, protobuf}; use crate::{convert_box_required, convert_required}; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use datafusion::logical_plan::window_frames::{ + WindowFrame, WindowFrameBound, WindowFrameUnits, +}; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; -use datafusion::physical_plan::window_frames::{ - WindowFrame, WindowFrameBound, WindowFrameUnits, -}; use datafusion::physical_plan::window_functions::BuiltInWindowFunction; use datafusion::scalar::ScalarValue; use protobuf::logical_plan_node::LogicalPlanType; diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 56270030b59f..fb1383daab3a 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -19,22 +19,17 @@ //! buffer format, allowing DataFusion logical plans to be serialized and transmitted between //! processes. -use std::{ - boxed, - convert::{TryFrom, TryInto}, -}; - use super::super::proto_error; use crate::datasource::DfTableAdapter; use crate::serde::{protobuf, BallistaError}; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::datasource::CsvFile; -use datafusion::logical_plan::{Expr, JoinType, LogicalPlan}; +use datafusion::logical_plan::{ + window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits}, + Expr, JoinType, LogicalPlan, +}; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::functions::BuiltinScalarFunction; -use datafusion::physical_plan::window_frames::{ - WindowFrame, WindowFrameBound, WindowFrameUnits, -}; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -43,6 +38,10 @@ use protobuf::{ arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, PrimitiveScalarType, ScalarListValue, ScalarType, }; +use std::{ + boxed, + convert::{TryFrom, TryInto}, +}; impl protobuf::IntervalUnit { pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self { diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index bbc6ffabe928..d5c92dbd2143 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,10 +20,10 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; -use crate::logical_plan::{DFField, DFSchema}; +use crate::logical_plan::{window_frames, DFField, DFSchema}; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, - window_frames, window_functions, + window_functions, }; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index f948770e6437..4a39e114d53f 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -29,6 +29,7 @@ mod extension; mod operators; mod plan; mod registry; +pub mod window_frames; pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; diff --git a/datafusion/src/logical_plan/window_frames.rs b/datafusion/src/logical_plan/window_frames.rs new file mode 100644 index 000000000000..f0be5a221fbf --- /dev/null +++ b/datafusion/src/logical_plan/window_frames.rs @@ -0,0 +1,337 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Window frame +//! +//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: +//! - A frame type - either ROWS, RANGE or GROUPS, +//! - A starting frame boundary, +//! - An ending frame boundary, +//! - An EXCLUDE clause. + +use crate::error::{DataFusionError, Result}; +use sqlparser::ast; +use std::cmp::Ordering; +use std::convert::{From, TryFrom}; +use std::fmt; + +/// The frame-spec determines which output rows are read by an aggregate window function. +/// +/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the +/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to +/// CURRENT ROW. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct WindowFrame { + /// A frame type - either ROWS, RANGE or GROUPS + pub units: WindowFrameUnits, + /// A starting frame boundary + pub start_bound: WindowFrameBound, + /// An ending frame boundary + pub end_bound: WindowFrameBound, +} + +impl fmt::Display for WindowFrame { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} BETWEEN {} AND {}", + self.units, self.start_bound, self.end_bound + )?; + Ok(()) + } +} + +impl TryFrom for WindowFrame { + type Error = DataFusionError; + + fn try_from(value: ast::WindowFrame) -> Result { + let start_bound = value.start_bound.into(); + let end_bound = value + .end_bound + .map(WindowFrameBound::from) + .unwrap_or(WindowFrameBound::CurrentRow); + + if let WindowFrameBound::Following(None) = start_bound { + Err(DataFusionError::Execution( + "Invalid window frame: start bound cannot be unbounded following" + .to_owned(), + )) + } else if let WindowFrameBound::Preceding(None) = end_bound { + Err(DataFusionError::Execution( + "Invalid window frame: end bound cannot be unbounded preceding" + .to_owned(), + )) + } else if start_bound > end_bound { + Err(DataFusionError::Execution(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + start_bound, end_bound + ))) + } else { + let units = value.units.into(); + Ok(Self { + units, + start_bound, + end_bound, + }) + } + } +} + +impl Default for WindowFrame { + fn default() -> Self { + WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(None), + end_bound: WindowFrameBound::CurrentRow, + } + } +} + +/// There are five ways to describe starting and ending frame boundaries: +/// +/// 1. UNBOUNDED PRECEDING +/// 2. PRECEDING +/// 3. CURRENT ROW +/// 4. FOLLOWING +/// 5. UNBOUNDED FOLLOWING +/// +/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) +#[derive(Debug, Clone, Copy, Eq)] +pub enum WindowFrameBound { + /// 1. UNBOUNDED PRECEDING + /// The frame boundary is the first row in the partition. + /// + /// 2. PRECEDING + /// must be a non-negative constant numeric expression. The boundary is a row that + /// is "units" prior to the current row. + Preceding(Option), + /// 3. The current row. + /// + /// For RANGE and GROUPS frame types, peers of the current row are also + /// included in the frame, unless specifically excluded by the EXCLUDE clause. + /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame + /// boundary. + CurrentRow, + /// 4. This is the same as " PRECEDING" except that the boundary is units after the + /// current rather than before the current row. + /// + /// 5. UNBOUNDED FOLLOWING + /// The frame boundary is the last row in the partition. + Following(Option), +} + +impl From for WindowFrameBound { + fn from(value: ast::WindowFrameBound) -> Self { + match value { + ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), + ast::WindowFrameBound::Following(v) => Self::Following(v), + ast::WindowFrameBound::CurrentRow => Self::CurrentRow, + } + } +} + +impl fmt::Display for WindowFrameBound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), + WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), + WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), + WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), + WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), + } + } +} + +impl PartialEq for WindowFrameBound { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for WindowFrameBound { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for WindowFrameBound { + fn cmp(&self, other: &Self) -> Ordering { + self.get_rank().cmp(&other.get_rank()) + } +} + +impl WindowFrameBound { + /// get the rank of this window frame bound. + /// + /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value + /// which requires special handling e.g. with preceding the larger the value the smaller the + /// rank and also for 0 preceding / following it is the same as current row + fn get_rank(&self) -> (u8, u64) { + match self { + WindowFrameBound::Preceding(None) => (0, 0), + WindowFrameBound::Following(None) => (4, 0), + WindowFrameBound::Preceding(Some(0)) + | WindowFrameBound::CurrentRow + | WindowFrameBound::Following(Some(0)) => (2, 0), + WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), + WindowFrameBound::Following(Some(v)) => (3, *v), + } + } +} + +/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the +/// starting and ending boundaries of the frame are measured. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WindowFrameUnits { + /// The ROWS frame type means that the starting and ending boundaries for the frame are + /// determined by counting individual rows relative to the current row. + Rows, + /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one + /// term. Call that term "X". With the RANGE frame type, the elements of the frame are + /// determined by computing the value of expression X for all rows in the partition and framing + /// those rows for which the value of X is within a certain range of the value of X for the + /// current row. + Range, + /// The GROUPS frame type means that the starting and ending boundaries are determine + /// by counting "groups" relative to the current group. A "group" is a set of rows that all have + /// equivalent values for all all terms of the window ORDER BY clause. + Groups, +} + +impl fmt::Display for WindowFrameUnits { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + WindowFrameUnits::Rows => "ROWS", + WindowFrameUnits::Range => "RANGE", + WindowFrameUnits::Groups => "GROUPS", + }) + } +} + +impl From for WindowFrameUnits { + fn from(value: ast::WindowFrameUnits) -> Self { + match value { + ast::WindowFrameUnits::Range => Self::Range, + ast::WindowFrameUnits::Groups => Self::Groups, + ast::WindowFrameUnits::Rows => Self::Rows, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_window_frame_creation() -> Result<()> { + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Following(None), + end_bound: None, + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(None), + end_bound: Some(ast::WindowFrameBound::Preceding(None)), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(Some(1)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() + ); + Ok(()) + } + + #[test] + fn test_eq() { + assert_eq!( + WindowFrameBound::Preceding(Some(0)), + WindowFrameBound::CurrentRow + ); + assert_eq!( + WindowFrameBound::CurrentRow, + WindowFrameBound::Following(Some(0)) + ); + assert_eq!( + WindowFrameBound::Following(Some(2)), + WindowFrameBound::Following(Some(2)) + ); + assert_eq!( + WindowFrameBound::Following(None), + WindowFrameBound::Following(None) + ); + assert_eq!( + WindowFrameBound::Preceding(Some(2)), + WindowFrameBound::Preceding(Some(2)) + ); + assert_eq!( + WindowFrameBound::Preceding(None), + WindowFrameBound::Preceding(None) + ); + } + + #[test] + fn test_ord() { + assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); + // ! yes this is correct! + assert!( + WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) + ); + assert!( + WindowFrameBound::Preceding(Some(u64::MAX)) + < WindowFrameBound::Preceding(Some(u64::MAX - 1)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(1000000)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(u64::MAX)) + ); + assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); + assert!( + WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) + ); + assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); + assert!( + WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) + ); + assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); + assert!( + WindowFrameBound::Following(Some(u64::MAX)) + < WindowFrameBound::Following(None) + ); + } +} From 2f73e795d3ae68638d6509bfa02388bfa3727381 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Tue, 8 Jun 2021 00:43:09 +0800 Subject: [PATCH 07/25] Refactor window aggregation, simplify batch processing logic (#516) * refactor sort exec stream and combine batches * refactor async function --- datafusion/src/physical_plan/sort.rs | 1 - datafusion/src/physical_plan/windows.rs | 149 +++++++++++------------- 2 files changed, 71 insertions(+), 79 deletions(-) diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index 7747030d8a93..437519a7d2a2 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -241,7 +241,6 @@ impl SortStream { sort_time: Arc, ) -> Self { let (tx, rx) = futures::channel::oneshot::channel(); - let schema = input.schema(); tokio::spawn(async move { let schema = input.schema(); diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 659d2183819d..7eb14943facf 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -19,7 +19,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ - aggregates, + aggregates, common, expressions::{Literal, NthValue, RowNumber}, type_coercion::coerce, window_functions::signature_for_built_in, @@ -29,20 +29,18 @@ use crate::physical_plan::{ RecordBatchStream, SendableRecordBatchStream, WindowAccumulator, WindowExpr, }; use crate::scalar::ScalarValue; -use arrow::compute::concat; use arrow::{ - array::{Array, ArrayRef}, + array::ArrayRef, datatypes::{Field, Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; use async_trait::async_trait; -use futures::stream::{Stream, StreamExt}; +use futures::stream::Stream; use futures::Future; use pin_project_lite::pin_project; use std::any::Any; use std::convert::TryInto; -use std::iter; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -339,22 +337,15 @@ fn window_aggregate_batch( window_accumulators: &mut [WindowAccumulatorItem], expressions: &[Vec>], ) -> Result>> { - // 1.1 iterate accumulators and respective expressions together - // 1.2 evaluate expressions - // 1.3 update / merge window accumulators with the expressions' values - - // 1.1 window_accumulators .iter_mut() .zip(expressions) .map(|(window_acc, expr)| { - // 1.2 let values = &expr .iter() - .map(|e| e.evaluate(batch)) + .map(|e| e.evaluate(&batch)) .map(|r| r.map(|v| v.into_array(batch.num_rows()))) .collect::>>()?; - window_acc.scan_batch(batch.num_rows(), values) }) .into_iter() @@ -380,60 +371,50 @@ fn create_window_accumulators( .collect::>>() } -async fn compute_window_aggregate( - schema: SchemaRef, +/// Compute the window aggregate columns +/// +/// 1. get a list of window accumulators +/// 2. evaluate the args +/// 3. scan args with window functions +/// 4. concat with final aggregations +/// +/// FIXME so far this fn does not support: +/// 1. partition by +/// 2. order by +/// 3. window frame +/// +/// which will require further work: +/// 1. inter-partition order by using vec partition-point (https://github.com/apache/arrow-datafusion/issues/360) +/// 2. inter-partition parallelism using one-shot channel (https://github.com/apache/arrow-datafusion/issues/299) +/// 3. convert aggregation based window functions to be self-contain so that: (https://github.com/apache/arrow-datafusion/issues/361) +/// a. some can be grow-only window-accumulating +/// b. some can be grow-and-shrink window-accumulating +/// c. some can be based on segment tree +fn compute_window_aggregates( window_expr: Vec>, - mut input: SendableRecordBatchStream, -) -> ArrowResult { - let mut window_accumulators = create_window_accumulators(&window_expr) - .map_err(DataFusionError::into_arrow_external_error)?; - - let expressions = window_expressions(&window_expr) - .map_err(DataFusionError::into_arrow_external_error)?; - - let expressions = Arc::new(expressions); - - // TODO each element shall have some size hint - let mut accumulator: Vec> = - iter::repeat(vec![]).take(window_expr.len()).collect(); - - let mut original_batches: Vec = vec![]; - - let mut total_num_rows = 0; - - while let Some(batch) = input.next().await { - let batch = batch?; - total_num_rows += batch.num_rows(); - original_batches.push(batch.clone()); - - let batch_aggregated = - window_aggregate_batch(&batch, &mut window_accumulators, &expressions) - .map_err(DataFusionError::into_arrow_external_error)?; - accumulator.iter_mut().zip(batch_aggregated).for_each( - |(acc_for_window, window_batch)| { - if let Some(data) = window_batch { - acc_for_window.push(data); - } - }, - ); + batch: &RecordBatch, +) -> Result> { + let mut window_accumulators = create_window_accumulators(&window_expr)?; + let expressions = Arc::new(window_expressions(&window_expr)?); + let num_rows = batch.num_rows(); + let window_aggregates = + window_aggregate_batch(batch, &mut window_accumulators, &expressions)?; + let final_aggregates = finalize_window_aggregation(&window_accumulators)?; + + // both must equal to window_expr.len() + if window_aggregates.len() != final_aggregates.len() { + return Err(DataFusionError::Internal( + "Impossibly got len mismatch".to_owned(), + )); } - let aggregated_mapped = finalize_window_aggregation(&window_accumulators) - .map_err(DataFusionError::into_arrow_external_error)?; - - let mut columns: Vec = accumulator + window_aggregates .iter() - .zip(aggregated_mapped) - .map(|(acc, agg)| { - Ok(match (acc, agg) { - (acc, Some(scalar_value)) if acc.is_empty() => { - scalar_value.to_array_of_size(total_num_rows) - } - (acc, None) if !acc.is_empty() => { - let vec_array: Vec<&dyn Array> = - acc.iter().map(|arc| arc.as_ref()).collect(); - concat(&vec_array)? - } + .zip(final_aggregates) + .map(|(wa, fa)| { + Ok(match (wa, fa) { + (None, Some(fa)) => fa.to_array_of_size(num_rows), + (Some(wa), None) if wa.len() == num_rows => wa.clone(), _ => { return Err(DataFusionError::Execution( "Invalid window function behavior".to_owned(), @@ -441,20 +422,7 @@ async fn compute_window_aggregate( } }) }) - .collect::>>() - .map_err(DataFusionError::into_arrow_external_error)?; - - for i in 0..(schema.fields().len() - window_expr.len()) { - let col = concat( - &original_batches - .iter() - .map(|batch| batch.column(i).as_ref()) - .collect::>(), - )?; - columns.push(col); - } - - RecordBatch::try_new(schema.clone(), columns) + .collect() } impl WindowAggStream { @@ -467,7 +435,8 @@ impl WindowAggStream { let (tx, rx) = futures::channel::oneshot::channel(); let schema_clone = schema.clone(); tokio::spawn(async move { - let result = compute_window_aggregate(schema_clone, window_expr, input).await; + let schema = schema_clone.clone(); + let result = WindowAggStream::process(input, window_expr, schema).await; tx.send(result) }); @@ -477,6 +446,30 @@ impl WindowAggStream { schema, } } + + async fn process( + input: SendableRecordBatchStream, + window_expr: Vec>, + schema: SchemaRef, + ) -> ArrowResult { + let input_schema = input.schema(); + let batches = common::collect(input) + .await + .map_err(DataFusionError::into_arrow_external_error)?; + let batch = common::combine_batches(&batches, input_schema.clone())?; + if let Some(batch) = batch { + // calculate window cols + let mut columns = compute_window_aggregates(window_expr, &batch) + .map_err(DataFusionError::into_arrow_external_error)?; + // combine with the original cols + // note the setup of window aggregates is that they newly calculated window + // expressions are always prepended to the columns + columns.extend_from_slice(batch.columns()); + RecordBatch::try_new(schema, columns) + } else { + Ok(RecordBatch::new_empty(schema)) + } + } } impl Stream for WindowAggStream { From e39f3116684b836829a2e02c9013d8a84d87b82e Mon Sep 17 00:00:00 2001 From: Rich Date: Tue, 8 Jun 2021 23:26:03 +0800 Subject: [PATCH 08/25] 110 support group by positions (#519) * 110 support group by positions * try resolve positions via array, not map * Add comment for i64 and simplify the pattern match * combine match and if condition, add more test cases * replace '0 as i64' with 0_i64 --- datafusion/src/sql/planner.rs | 42 +++++++++++++++++++++++++++++++---- datafusion/src/sql/utils.rs | 22 ++++++++++++++++++ 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 6bf7b776c8db..7df0068c5f54 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -56,6 +56,7 @@ use super::{ can_columns_satisfy_exprs, expand_wildcard, expr_as_column_expr, extract_aliases, find_aggregate_exprs, find_column_exprs, find_window_exprs, group_window_expr_by_sort_keys, rebase_expr, resolve_aliases_to_exprs, + resolve_positions_to_exprs, }, }; @@ -582,15 +583,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // All of the aggregate expressions (deduplicated). let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); + let alias_map = extract_aliases(&select_exprs); let group_by_exprs = select .group_by .iter() .map(|e| { let group_by_expr = self.sql_expr_to_logical_expr(e)?; - let group_by_expr = resolve_aliases_to_exprs( - &group_by_expr, - &extract_aliases(&select_exprs), - )?; + let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?; + let group_by_expr = + resolve_positions_to_exprs(&group_by_expr, &select_exprs)?; self.validate_schema_satisfies_exprs( plan.schema(), &[group_by_expr.clone()], @@ -2326,6 +2327,39 @@ mod tests { ); } + #[test] + fn select_simple_aggregate_with_groupby_can_use_positions() { + quick_test( + "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 1, 2", + "Projection: #state, #age AS b, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#state, #age]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None", + ); + quick_test( + "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 2, 1", + "Projection: #state, #age AS b, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#age, #state]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None", + ); + } + + #[test] + fn select_simple_aggregate_with_groupby_position_out_of_range() { + let sql = "SELECT state, MIN(age) FROM person GROUP BY 0"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "Plan(\"Projection references non-aggregate values\")", + format!("{:?}", err) + ); + + let sql2 = "SELECT state, MIN(age) FROM person GROUP BY 5"; + let err2 = logical_plan(sql2).expect_err("query should have failed"); + assert_eq!( + "Plan(\"Projection references non-aggregate values\")", + format!("{:?}", err2) + ); + } + #[test] fn select_simple_aggregate_with_groupby_can_use_alias() { quick_test( diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 7a5dc0da1b53..848fb3ee31fc 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -18,6 +18,7 @@ //! SQL Utility Functions use crate::logical_plan::{DFSchema, Expr, LogicalPlan}; +use crate::scalar::ScalarValue; use crate::{ error::{DataFusionError, Result}, logical_plan::{ExpressionVisitor, Recursion}, @@ -392,6 +393,27 @@ pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap { .collect::>() } +pub(crate) fn resolve_positions_to_exprs( + expr: &Expr, + select_exprs: &[Expr], +) -> Result { + match expr { + // sql_expr_to_logical_expr maps number to i64 + // https://github.com/apache/arrow-datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887 + Expr::Literal(ScalarValue::Int64(Some(position))) + if position > &0_i64 && position <= &(select_exprs.len() as i64) => + { + let index = (position - 1) as usize; + let select_expr = &select_exprs[index]; + match select_expr { + Expr::Alias(nested_expr, _alias_name) => Ok(*nested_expr.clone()), + _ => Ok(select_expr.clone()), + } + } + _ => Ok(expr.clone()), + } +} + /// Rebuilds an `Expr` with columns that refer to aliases replaced by the /// alias' underlying `Expr`. pub(crate) fn resolve_aliases_to_exprs( From 8495f95d7b510109c70cf2b4b606ba020bffd27a Mon Sep 17 00:00:00 2001 From: Javier Goday Date: Tue, 8 Jun 2021 23:32:19 +0200 Subject: [PATCH 09/25] Wrong aggregation arguments error. (#505) * Fix aggregate fn with invalid column * Fix error message * Fix error message * fix clippy * fix message and test * avoid unwrap in test_aggregation_with_bad_arguments * Update datafusion/tests/sql.rs Co-authored-by: Andrew Lamb * Fix test_aggregation_with_bad_arguments Co-authored-by: Andrew Lamb --- datafusion/src/physical_plan/aggregates.rs | 9 ++++++++- datafusion/tests/sql.rs | 12 ++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 3607f29debba..60025a316228 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -113,7 +113,14 @@ pub fn create_aggregate_expr( name: String, ) -> Result> { // coerce - let arg = coerce(args, input_schema, &signature(fun))?[0].clone(); + let arg = coerce(args, input_schema, &signature(fun))?; + if arg.is_empty() { + return Err(DataFusionError::Plan(format!( + "Invalid or wrong number of arguments passed to aggregate: '{}'", + name, + ))); + } + let arg = arg[0].clone(); let arg_types = args .iter() diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index d77671e7f4ff..5ce1884049d8 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -3437,3 +3437,15 @@ async fn test_physical_plan_display_indent_multi_children() { expected, actual ); } + +#[tokio::test] +async fn test_aggregation_with_bad_arguments() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; + let logical_plan = ctx.create_logical_plan(&sql)?; + let physical_plan = ctx.create_physical_plan(&logical_plan); + let err = physical_plan.unwrap_err(); + assert_eq!(err.to_string(), "Error during planning: Invalid or wrong number of arguments passed to aggregate: 'COUNT(DISTINCT )'"); + Ok(()) +} From 42f908e2b5d2bd2abd4e396ade1a94fb0ff28ba1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 9 Jun 2021 20:23:23 +0200 Subject: [PATCH 10/25] Define the unittests using pytest (#493) * Use pytest * Formatting * Update GHA conf * Remove TODO note * Format * Test requirements file * Update workflow file * Merge requirements file * Update workflow file --- .github/workflows/python_test.yaml | 2 +- dev/release/rat_exclude_files.txt | 1 + python/requirements.in | 1 + python/requirements.txt | 47 ++-- python/tests/generic.py | 51 ++-- python/tests/test_df.py | 136 +++++----- python/tests/test_sql.py | 416 +++++++++++------------------ python/tests/test_udaf.py | 86 +++--- 8 files changed, 328 insertions(+), 412 deletions(-) diff --git a/.github/workflows/python_test.yaml b/.github/workflows/python_test.yaml index 13516ff699da..e689396b5dcd 100644 --- a/.github/workflows/python_test.yaml +++ b/.github/workflows/python_test.yaml @@ -53,7 +53,7 @@ jobs: pip install -r requirements.txt maturin develop - python -m unittest discover tests + pytest -v . env: CARGO_HOME: "/home/runner/.cargo" CARGO_TARGET_DIR: "/home/runner/target" diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 6126699bbc1f..96beccd0af81 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -105,3 +105,4 @@ benchmarks/queries/q*.sql ballista/rust/scheduler/testdata/* ballista/ui/scheduler/yarn.lock python/rust-toolchain +python/requirements*.txt diff --git a/python/requirements.in b/python/requirements.in index 3ef9f18966d4..4ff7f4ee618b 100644 --- a/python/requirements.in +++ b/python/requirements.in @@ -17,3 +17,4 @@ maturin toml pyarrow +pytest diff --git a/python/requirements.txt b/python/requirements.txt index ff02b80cf6fc..f7ede1ebd58e 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,25 +1,17 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. # # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --generate-hashes +# pip-compile --generate-hashes requirements.in # +attrs==21.2.0 \ + --hash=sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1 \ + --hash=sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb + # via pytest +iniconfig==1.1.1 \ + --hash=sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3 \ + --hash=sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32 + # via pytest maturin==0.10.6 \ --hash=sha256:0e81496f70a4805e6ea7dda7b0425246c111ccb119a2e22c64abeff131f4dd21 \ --hash=sha256:3b5d5429bc05a816824420d99973f0cab39d8e274f6c3647bfd9afd95a030304 \ @@ -59,6 +51,18 @@ numpy==1.20.3 \ --hash=sha256:f1452578d0516283c87608a5a5548b0cdde15b99650efdfd85182102ef7a7c17 \ --hash=sha256:f39a995e47cb8649673cfa0579fbdd1cdd33ea497d1728a6cb194d6252268e48 # via pyarrow +packaging==20.9 \ + --hash=sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5 \ + --hash=sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a + # via pytest +pluggy==0.13.1 \ + --hash=sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0 \ + --hash=sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d + # via pytest +py==1.10.0 \ + --hash=sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3 \ + --hash=sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a + # via pytest pyarrow==4.0.1 \ --hash=sha256:04be0f7cb9090bd029b5b53bed628548fef569e5d0b5c6cd7f6d0106dbbc782d \ --hash=sha256:0fde9c7a3d5d37f3fe5d18c4ed015e8f585b68b26d72a10d7012cad61afe43ff \ @@ -86,9 +90,18 @@ pyarrow==4.0.1 \ --hash=sha256:fa7b165cfa97158c1e6d15c68428317b4f4ae786d1dc2dbab43f1328c1eb43aa \ --hash=sha256:fe976695318560a97c6d31bba828eeca28c44c6f6401005e54ba476a28ac0a10 # via -r requirements.in +pyparsing==2.4.7 \ + --hash=sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1 \ + --hash=sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b + # via packaging +pytest==6.2.4 \ + --hash=sha256:50bcad0a0b9c5a72c8e4e7c9855a3ad496ca6a881a3641b4260605450772c54b \ + --hash=sha256:91ef2131a9bd6be8f76f1f08eac5c5317221d6ad1e143ae03894b862e8976890 + # via -r requirements.in toml==0.10.2 \ --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f # via # -r requirements.in # maturin + # pytest diff --git a/python/tests/generic.py b/python/tests/generic.py index 267d6f656ce0..e61542e6ab37 100644 --- a/python/tests/generic.py +++ b/python/tests/generic.py @@ -16,24 +16,30 @@ # under the License. import datetime -import numpy -import pyarrow + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq # used to write parquet files -import pyarrow.parquet def data(): - data = numpy.concatenate( - [numpy.random.normal(0, 0.01, size=50), numpy.random.normal(50, 0.01, size=50)] + np.random.seed(1) + data = np.concatenate( + [ + np.random.normal(0, 0.01, size=50), + np.random.normal(50, 0.01, size=50), + ] ) - return pyarrow.array(data) + return pa.array(data) def data_with_nans(): - data = numpy.random.normal(0, 0.01, size=50) - mask = numpy.random.randint(0, 2, size=50) - data[mask == 0] = numpy.NaN + np.random.seed(0) + data = np.random.normal(0, 0.01, size=50) + mask = np.random.randint(0, 2, size=50) + data[mask == 0] = np.NaN return data @@ -43,8 +49,19 @@ def data_datetime(f): datetime.datetime.now() - datetime.timedelta(days=1), datetime.datetime.now() + datetime.timedelta(days=1), ] - return pyarrow.array( - data, type=pyarrow.timestamp(f), mask=numpy.array([False, True, False]) + return pa.array( + data, type=pa.timestamp(f), mask=np.array([False, True, False]) + ) + + +def data_date32(): + data = [ + datetime.date(2000, 1, 1), + datetime.date(1980, 1, 1), + datetime.date(2030, 1, 1), + ] + return pa.array( + data, type=pa.date32(), mask=np.array([False, True, False]) ) @@ -54,16 +71,16 @@ def data_timedelta(f): datetime.timedelta(days=1), datetime.timedelta(seconds=1), ] - return pyarrow.array( - data, type=pyarrow.duration(f), mask=numpy.array([False, True, False]) + return pa.array( + data, type=pa.duration(f), mask=np.array([False, True, False]) ) def data_binary_other(): - return numpy.array([1, 0, 0], dtype="u4") + return np.array([1, 0, 0], dtype="u4") def write_parquet(path, data): - table = pyarrow.Table.from_arrays([data], names=["a"]) - pyarrow.parquet.write_table(table, path) - return path + table = pa.Table.from_arrays([data], names=["a"]) + pq.write_table(table, path) + return str(path) diff --git a/python/tests/test_df.py b/python/tests/test_df.py index fdafdfa7f509..5b6cbddbd74b 100644 --- a/python/tests/test_df.py +++ b/python/tests/test_df.py @@ -15,100 +15,98 @@ # specific language governing permissions and limitations # under the License. -import unittest - import pyarrow as pa -import datafusion +import pytest +from datafusion import ExecutionContext +from datafusion import functions as f + + +@pytest.fixture +def df(): + ctx = ExecutionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) -f = datafusion.functions + return ctx.create_dataframe([[batch]]) -class TestCase(unittest.TestCase): - def _prepare(self): - ctx = datafusion.ExecutionContext() +def test_select(df): + df = df.select( + f.col("a") + f.col("b"), + f.col("a") - f.col("b"), + ) - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) + # execute and collect the first (and only) batch + result = df.collect()[0] - def test_select(self): - df = self._prepare() + assert result.column(0) == pa.array([5, 7, 9]) + assert result.column(1) == pa.array([-3, -3, -3]) - df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), - ) - # execute and collect the first (and only) batch - result = df.collect()[0] +def test_filter(df): + df = df.select( + f.col("a") + f.col("b"), + f.col("a") - f.col("b"), + ).filter(f.col("a") > f.lit(2)) - self.assertEqual(result.column(0), pa.array([5, 7, 9])) - self.assertEqual(result.column(1), pa.array([-3, -3, -3])) + # execute and collect the first (and only) batch + result = df.collect()[0] - def test_filter(self): - df = self._prepare() + assert result.column(0) == pa.array([9]) + assert result.column(1) == pa.array([-3]) - df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), - ).filter(f.col("a") > f.lit(2)) - # execute and collect the first (and only) batch - result = df.collect()[0] +def test_sort(df): + df = df.sort([f.col("b").sort(ascending=False)]) - self.assertEqual(result.column(0), pa.array([9])) - self.assertEqual(result.column(1), pa.array([-3])) + table = pa.Table.from_batches(df.collect()) + expected = {"a": [3, 2, 1], "b": [6, 5, 4]} - def test_sort(self): - df = self._prepare() - df = df.sort([f.col("b").sort(ascending=False)]) + assert table.to_pydict() == expected - table = pa.Table.from_batches(df.collect()) - expected = {"a": [3, 2, 1], "b": [6, 5, 4]} - self.assertEqual(table.to_pydict(), expected) - def test_limit(self): - df = self._prepare() +def test_limit(df): + df = df.limit(1) - df = df.limit(1) + # execute and collect the first (and only) batch + result = df.collect()[0] - # execute and collect the first (and only) batch - result = df.collect()[0] + assert len(result.column(0)) == 1 + assert len(result.column(1)) == 1 - self.assertEqual(len(result.column(0)), 1) - self.assertEqual(len(result.column(1)), 1) - def test_udf(self): - df = self._prepare() +def test_udf(df): + # is_null is a pa function over arrays + udf = f.udf(lambda x: x.is_null(), [pa.int64()], pa.bool_()) - # is_null is a pa function over arrays - udf = f.udf(lambda x: x.is_null(), [pa.int64()], pa.bool_()) + df = df.select(udf(f.col("a"))) + result = df.collect()[0].column(0) - df = df.select(udf(f.col("a"))) + assert result == pa.array([False, False, False]) - self.assertEqual(df.collect()[0].column(0), pa.array([False, False, False])) - def test_join(self): - ctx = datafusion.ExecutionContext() +def test_join(): + ctx = ExecutionContext() - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - df = ctx.create_dataframe([[batch]]) + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]]) - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2]), pa.array([8, 10])], - names=["a", "c"], - ) - df1 = ctx.create_dataframe([[batch]]) + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([8, 10])], + names=["a", "c"], + ) + df1 = ctx.create_dataframe([[batch]]) - df = df.join(df1, on="a", how="inner") - df = df.sort([f.col("a").sort(ascending=True)]) - table = pa.Table.from_batches(df.collect()) + df = df.join(df1, on="a", how="inner") + df = df.sort([f.col("a").sort(ascending=True)]) + table = pa.Table.from_batches(df.collect()) - expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} - self.assertEqual(table.to_pydict(), expected) + expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + assert table.to_pydict() == expected diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 117284973fb7..361526d06970 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -15,286 +15,182 @@ # specific language governing permissions and limitations # under the License. -import unittest -import tempfile -import datetime -import os.path -import shutil +import numpy as np +import pyarrow as pa +import pytest +from datafusion import ExecutionContext -import numpy -import pyarrow -import datafusion +from . import generic as helpers -# used to write parquet files -import pyarrow.parquet -from tests.generic import * +@pytest.fixture +def ctx(): + return ExecutionContext() -class TestCase(unittest.TestCase): - def setUp(self): - # Create a temporary directory - self.test_dir = tempfile.mkdtemp() - numpy.random.seed(1) +def test_no_table(ctx): + with pytest.raises(Exception, match="DataFusion error"): + ctx.sql("SELECT a FROM b").collect() - def tearDown(self): - # Remove the directory after the test - shutil.rmtree(self.test_dir) - def test_no_table(self): - with self.assertRaises(Exception): - datafusion.Context().sql("SELECT a FROM b").collect() +def test_register(ctx, tmp_path): + path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + ctx.register_parquet("t", path) - def test_register(self): - ctx = datafusion.ExecutionContext() + assert ctx.tables() == {"t"} - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data()) - ctx.register_parquet("t", path) +def test_execute(ctx, tmp_path): + data = [1, 1, 2, 2, 3, 11, 12] - self.assertEqual(ctx.tables(), {"t"}) + # single column, "a" + path = helpers.write_parquet(tmp_path / "a.parquet", pa.array(data)) + ctx.register_parquet("t", path) - def test_execute(self): - data = [1, 1, 2, 2, 3, 11, 12] + assert ctx.tables() == {"t"} - ctx = datafusion.ExecutionContext() + # count + result = ctx.sql("SELECT COUNT(a) FROM t").collect() - # single column, "a" - path = write_parquet( - os.path.join(self.test_dir, "a.parquet"), pyarrow.array(data) - ) - ctx.register_parquet("t", path) + expected = pa.array([7], pa.uint64()) + expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])] + assert result == expected - self.assertEqual(ctx.tables(), {"t"}) + # where + expected = pa.array([2], pa.uint64()) + expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])] + result = ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect() + assert result == expected - # count - result = ctx.sql("SELECT COUNT(a) FROM t").collect() + # group by + results = ctx.sql( + "SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)" + ).collect() - expected = pyarrow.array([7], pyarrow.uint64()) - expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])] - self.assertEqual(expected, result) + # group by returns batches + result_keys = [] + result_values = [] + for result in results: + pydict = result.to_pydict() + result_keys.extend(pydict["CAST(a AS Int32)"]) + result_values.extend(pydict["COUNT(a)"]) - # where - expected = pyarrow.array([2], pyarrow.uint64()) - expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])] - self.assertEqual( - expected, ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect() - ) + result_keys, result_values = ( + list(t) for t in zip(*sorted(zip(result_keys, result_values))) + ) - # group by - results = ctx.sql( - "SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)" - ).collect() - - # group by returns batches - result_keys = [] - result_values = [] - for result in results: - pydict = result.to_pydict() - result_keys.extend(pydict["CAST(a AS Int32)"]) - result_values.extend(pydict["COUNT(a)"]) - - result_keys, result_values = ( - list(t) for t in zip(*sorted(zip(result_keys, result_values))) - ) + assert result_keys == [1, 2, 3, 11, 12] + assert result_values == [2, 2, 1, 1, 1] - self.assertEqual(result_keys, [1, 2, 3, 11, 12]) - self.assertEqual(result_values, [2, 2, 1, 1, 1]) - - # order by - result = ctx.sql( - "SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2" - ).collect() - expected_a = pyarrow.array([50.0219, 50.0152], pyarrow.float64()) - expected_cast = pyarrow.array([50, 50], pyarrow.int32()) - expected = [ - pyarrow.RecordBatch.from_arrays( - [expected_a, expected_cast], ["a", "CAST(a AS Int32)"] - ) - ] - numpy.testing.assert_equal(expected[0].column(1), expected[0].column(1)) - - def test_cast(self): - """ - Verify that we can cast - """ - ctx = datafusion.ExecutionContext() - - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data()) - ctx.register_parquet("t", path) - - valid_types = [ - "smallint", - "int", - "bigint", - "float(32)", - "float(64)", - "float", - ] - - select = ", ".join( - [f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)] + # order by + result = ctx.sql( + "SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2" + ).collect() + expected_a = pa.array([50.0219, 50.0152], pa.float64()) + expected_cast = pa.array([50, 50], pa.int32()) + expected = [ + pa.RecordBatch.from_arrays( + [expected_a, expected_cast], ["a", "CAST(a AS Int32)"] ) - - # can execute, which implies that we can cast - ctx.sql(f"SELECT {select} FROM t").collect() - - def _test_udf(self, udf, args, return_type, array, expected): - ctx = datafusion.ExecutionContext() - - # write to disk - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), array) - ctx.register_parquet("t", path) - - ctx.register_udf("udf", udf, args, return_type) - - batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect() - - result = batches[0].column(0) - - self.assertEqual(expected, result) - - def test_udf_identity(self): - self._test_udf( + ] + np.testing.assert_equal(expected[0].column(1), expected[0].column(1)) + + +def test_cast(ctx, tmp_path): + """ + Verify that we can cast + """ + path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + ctx.register_parquet("t", path) + + valid_types = [ + "smallint", + "int", + "bigint", + "float(32)", + "float(64)", + "float", + ] + + select = ", ".join( + [f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)] + ) + + # can execute, which implies that we can cast + ctx.sql(f"SELECT {select} FROM t").collect() + + +@pytest.mark.parametrize( + ("fn", "input_types", "output_type", "input_values", "expected_values"), + [ + ( lambda x: x, - [pyarrow.float64()], - pyarrow.float64(), - pyarrow.array([-1.2, None, 1.2]), - pyarrow.array([-1.2, None, 1.2]), - ) - - def test_udf(self): - self._test_udf( + [pa.float64()], + pa.float64(), + [-1.2, None, 1.2], + [-1.2, None, 1.2], + ), + ( lambda x: x.is_null(), - [pyarrow.float64()], - pyarrow.bool_(), - pyarrow.array([-1.2, None, 1.2]), - pyarrow.array([False, True, False]), - ) - - -class TestIO(unittest.TestCase): - def setUp(self): - # Create a temporary directory - self.test_dir = tempfile.mkdtemp() - - def tearDown(self): - # Remove the directory after the test - shutil.rmtree(self.test_dir) - - def _test_data(self, data): - ctx = datafusion.ExecutionContext() - - # write to disk - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data) - ctx.register_parquet("t", path) - - batches = ctx.sql("SELECT a AS tt FROM t").collect() - - result = batches[0].column(0) - - numpy.testing.assert_equal(data, result) - - def test_nans(self): - self._test_data(data_with_nans()) - - def test_utf8(self): - array = pyarrow.array( - ["a", "b", "c"], pyarrow.utf8(), numpy.array([False, True, False]) - ) - self._test_data(array) - - def test_large_utf8(self): - array = pyarrow.array( - ["a", "b", "c"], pyarrow.large_utf8(), numpy.array([False, True, False]) - ) - self._test_data(array) - - # Error from Arrow - @unittest.expectedFailure - def test_datetime_s(self): - self._test_data(data_datetime("s")) - - # C data interface missing - @unittest.expectedFailure - def test_datetime_ms(self): - self._test_data(data_datetime("ms")) - - # C data interface missing - @unittest.expectedFailure - def test_datetime_us(self): - self._test_data(data_datetime("us")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_datetime_ns(self): - self._test_data(data_datetime("ns")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_s(self): - self._test_data(data_timedelta("s")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_ms(self): - self._test_data(data_timedelta("ms")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_us(self): - self._test_data(data_timedelta("us")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_ns(self): - self._test_data(data_timedelta("ns")) - - def test_date32(self): - array = pyarrow.array( - [ - datetime.date(2000, 1, 1), - datetime.date(1980, 1, 1), - datetime.date(2030, 1, 1), - ], - pyarrow.date32(), - numpy.array([False, True, False]), - ) - self._test_data(array) - - def test_binary_variable(self): - array = pyarrow.array( - [b"1", b"2", b"3"], pyarrow.binary(), numpy.array([False, True, False]) - ) - self._test_data(array) - - # C data interface missing - @unittest.expectedFailure - def test_binary_fixed(self): - array = pyarrow.array( - [b"1111", b"2222", b"3333"], - pyarrow.binary(4), - numpy.array([False, True, False]), - ) - self._test_data(array) - - def test_large_binary(self): - array = pyarrow.array( - [b"1111", b"2222", b"3333"], - pyarrow.large_binary(), - numpy.array([False, True, False]), - ) - self._test_data(array) - - def test_binary_other(self): - self._test_data(data_binary_other()) - - def test_bool(self): - array = pyarrow.array( - [False, True, True], None, numpy.array([False, True, False]) - ) - self._test_data(array) - - def test_u32(self): - array = pyarrow.array([0, 1, 2], None, numpy.array([False, True, False])) - self._test_data(array) + [pa.float64()], + pa.bool_(), + [-1.2, None, 1.2], + [False, True, False], + ), + ], +) +def test_udf( + ctx, tmp_path, fn, input_types, output_type, input_values, expected_values +): + # write to disk + path = helpers.write_parquet( + tmp_path / "a.parquet", pa.array(input_values) + ) + ctx.register_parquet("t", path) + ctx.register_udf("udf", fn, input_types, output_type) + + batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect() + result = batches[0].column(0) + + assert result == pa.array(expected_values) + + +_null_mask = np.array([False, True, False]) + + +@pytest.mark.parametrize( + "arr", + [ + pa.array(["a", "b", "c"], pa.utf8(), _null_mask), + pa.array(["a", "b", "c"], pa.large_utf8(), _null_mask), + pa.array([b"1", b"2", b"3"], pa.binary(), _null_mask), + pa.array([b"1111", b"2222", b"3333"], pa.large_binary(), _null_mask), + pa.array([False, True, True], None, _null_mask), + pa.array([0, 1, 2], None), + helpers.data_binary_other(), + helpers.data_date32(), + helpers.data_with_nans(), + # C data interface missing + pytest.param( + pa.array([b"1111", b"2222", b"3333"], pa.binary(4), _null_mask), + marks=pytest.mark.xfail, + ), + pytest.param(helpers.data_datetime("s"), marks=pytest.mark.xfail), + pytest.param(helpers.data_datetime("ms"), marks=pytest.mark.xfail), + pytest.param(helpers.data_datetime("us"), marks=pytest.mark.xfail), + pytest.param(helpers.data_datetime("ns"), marks=pytest.mark.xfail), + # Not writtable to parquet + pytest.param(helpers.data_timedelta("s"), marks=pytest.mark.xfail), + pytest.param(helpers.data_timedelta("ms"), marks=pytest.mark.xfail), + pytest.param(helpers.data_timedelta("us"), marks=pytest.mark.xfail), + pytest.param(helpers.data_timedelta("ns"), marks=pytest.mark.xfail), + ], +) +def test_simple_select(ctx, tmp_path, arr): + path = helpers.write_parquet(tmp_path / "a.parquet", arr) + ctx.register_parquet("t", path) + + batches = ctx.sql("SELECT a AS tt FROM t").collect() + result = batches[0].column(0) + + np.testing.assert_equal(result, arr) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index e1e4f933a9b4..b24c08dbc867 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -15,12 +15,11 @@ # specific language governing permissions and limitations # under the License. -import unittest -import pyarrow -import pyarrow.compute -import datafusion - -f = datafusion.functions +import pyarrow as pa +import pyarrow.compute as pc +import pytest +from datafusion import ExecutionContext +from datafusion import functions as f class Accumulator: @@ -29,63 +28,54 @@ class Accumulator: """ def __init__(self): - self._sum = pyarrow.scalar(0.0) + self._sum = pa.scalar(0.0) - def to_scalars(self) -> [pyarrow.Scalar]: + def to_scalars(self) -> [pa.Scalar]: return [self._sum] - def update(self, values: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar( - self._sum.as_py() + pyarrow.compute.sum(values).as_py() - ) + def update(self, values: pa.Array) -> None: + # Not nice since pyarrow scalars can't be summed yet. + # This breaks on `None` + self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - def merge(self, states: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar( - self._sum.as_py() + pyarrow.compute.sum(states).as_py() - ) + def merge(self, states: pa.Array) -> None: + # Not nice since pyarrow scalars can't be summed yet. + # This breaks on `None` + self._sum = pa.scalar(self._sum.as_py() + pc.sum(states).as_py()) - def evaluate(self) -> pyarrow.Scalar: + def evaluate(self) -> pa.Scalar: return self._sum -class TestCase(unittest.TestCase): - def _prepare(self): - ctx = datafusion.ExecutionContext() +@pytest.fixture +def df(): + ctx = ExecutionContext() - # create a RecordBatch and a new DataFrame from it - batch = pyarrow.RecordBatch.from_arrays( - [pyarrow.array([1, 2, 3]), pyarrow.array([4, 4, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]]) - def test_aggregate(self): - df = self._prepare() - udaf = f.udaf( - Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()] - ) +def test_aggregate(df): + udaf = f.udaf(Accumulator, pa.float64(), pa.float64(), [pa.float64()]) - df = df.aggregate([], [udaf(f.col("a"))]) + df = df.aggregate([], [udaf(f.col("a"))]) - # execute and collect the first (and only) batch - result = df.collect()[0] + # execute and collect the first (and only) batch + result = df.collect()[0] - self.assertEqual(result.column(0), pyarrow.array([1.0 + 2.0 + 3.0])) + assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) - def test_group_by(self): - df = self._prepare() - udaf = f.udaf( - Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()] - ) +def test_group_by(df): + udaf = f.udaf(Accumulator, pa.float64(), pa.float64(), [pa.float64()]) - df = df.aggregate([f.col("b")], [udaf(f.col("a"))]) + df = df.aggregate([f.col("b")], [udaf(f.col("a"))]) - # execute and collect the first (and only) batch - batches = df.collect() - arrays = [batch.column(1) for batch in batches] - joined = pyarrow.concat_arrays(arrays) - self.assertEqual(joined, pyarrow.array([1.0 + 2.0, 3.0])) + batches = df.collect() + arrays = [batch.column(1) for batch in batches] + joined = pa.concat_arrays(arrays) + assert joined == pa.array([1.0 + 2.0, 3.0]) From d5bca0e350d94a1e1063bed8a0da0cb09c6e3e1c Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 10 Jun 2021 02:26:01 +0800 Subject: [PATCH 11/25] Add `partition by` constructs in window functions and modify logical planning (#501) * closing up type checks * add fmt --- ballista/rust/core/proto/ballista.proto | 2 +- .../core/src/serde/logical_plan/from_proto.rs | 8 + .../core/src/serde/logical_plan/to_proto.rs | 6 + .../src/serde/physical_plan/from_proto.rs | 8 + datafusion/src/logical_plan/expr.rs | 14 +- datafusion/src/logical_plan/plan.rs | 6 +- datafusion/src/optimizer/utils.rs | 46 +++- datafusion/src/sql/planner.rs | 217 +++++++++++++----- datafusion/src/sql/utils.rs | 57 ++++- 9 files changed, 280 insertions(+), 84 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 38d87e934e5f..85af9023fb46 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -174,7 +174,7 @@ message WindowExprNode { // udaf = 3 } LogicalExprNode expr = 4; - // repeated LogicalExprNode partition_by = 5; + repeated LogicalExprNode partition_by = 5; repeated LogicalExprNode order_by = 6; // repeated LogicalExprNode filter = 7; oneof window_frame { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 36a37a1e472c..86daeb063c47 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -910,6 +910,12 @@ impl TryInto for &protobuf::LogicalExprNode { .window_function .as_ref() .ok_or_else(|| proto_error("Received empty window function"))?; + let partition_by = expr + .partition_by + .iter() + .map(|e| e.try_into()) + .into_iter() + .collect::, _>>()?; let order_by = expr .order_by .iter() @@ -940,6 +946,7 @@ impl TryInto for &protobuf::LogicalExprNode { AggregateFunction::from(aggr_function), ), args: vec![parse_required_expr(&expr.expr)?], + partition_by, order_by, window_frame, }) @@ -960,6 +967,7 @@ impl TryInto for &protobuf::LogicalExprNode { BuiltInWindowFunction::from(built_in_function), ), args: vec![parse_required_expr(&expr.expr)?], + partition_by, order_by, window_frame, }) diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index fb1383daab3a..5d996843d624 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1006,6 +1006,7 @@ impl TryInto for &Expr { Expr::WindowFunction { ref fun, ref args, + ref partition_by, ref order_by, ref window_frame, .. @@ -1023,6 +1024,10 @@ impl TryInto for &Expr { } }; let arg = &args[0]; + let partition_by = partition_by + .iter() + .map(|e| e.try_into()) + .collect::, _>>()?; let order_by = order_by .iter() .map(|e| e.try_into()) @@ -1035,6 +1040,7 @@ impl TryInto for &Expr { let window_expr = Box::new(protobuf::WindowExprNode { expr: Some(Box::new(arg.try_into()?)), window_function: Some(window_function), + partition_by, order_by, window_frame, }); diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 5fcc971527c6..b319d5b25f12 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -236,7 +236,9 @@ impl TryInto> for &protobuf::PhysicalPlanNode { Expr::WindowFunction { fun, args, + partition_by, order_by, + window_frame, .. } => { let arg = df_planner @@ -248,9 +250,15 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .map_err(|e| { BallistaError::General(format!("{:?}", e)) })?; + if !partition_by.is_empty() { + return Err(BallistaError::NotImplemented("Window function with partition by is not yet implemented".to_owned())); + } if !order_by.is_empty() { return Err(BallistaError::NotImplemented("Window function with order by is not yet implemented".to_owned())); } + if window_frame.is_some() { + return Err(BallistaError::NotImplemented("Window function with window frame is not yet implemented".to_owned())); + } let window_expr = create_window_expr( &fun, &[arg], diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index d5c92dbd2143..58dba16f02ef 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -194,6 +194,8 @@ pub enum Expr { fun: window_functions::WindowFunction, /// List of expressions to feed to the functions as arguments args: Vec, + /// List of partition by expressions + partition_by: Vec, /// List of order by expressions order_by: Vec, /// Window frame @@ -588,10 +590,18 @@ impl Expr { Expr::ScalarUDF { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::WindowFunction { args, order_by, .. } => { + Expr::WindowFunction { + args, + partition_by, + order_by, + .. + } => { let visitor = args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = partition_by + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; let visitor = order_by .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; @@ -733,11 +743,13 @@ impl Expr { Expr::WindowFunction { args, fun, + partition_by, order_by, window_frame, } => Expr::WindowFunction { args: rewrite_vec(args, rewriter)?, fun, + partition_by: rewrite_vec(partition_by, rewriter)?, order_by: rewrite_vec(order_by, rewriter)?, window_frame, }, diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 25cf9e33d2ca..3344dce1d81d 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -687,11 +687,7 @@ impl LogicalPlan { LogicalPlan::Window { ref window_expr, .. } => { - write!( - f, - "WindowAggr: windowExpr=[{:?}] partitionBy=[]", - window_expr - ) + write!(f, "WindowAggr: windowExpr=[{:?}]", window_expr) } LogicalPlan::Aggregate { ref group_expr, diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 65c95bee20d4..e707d30bc9ac 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -36,6 +36,7 @@ use crate::{ const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__"; const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__"; +const WINDOW_PARTITION_MARKER: &str = "__DATAFUSION_WINDOW_PARTITION__"; const WINDOW_SORT_MARKER: &str = "__DATAFUSION_WINDOW_SORT__"; /// Recursively walk a list of expression trees, collecting the unique set of column @@ -258,9 +259,16 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::IsNotNull(e) => Ok(vec![e.as_ref().to_owned()]), Expr::ScalarFunction { args, .. } => Ok(args.clone()), Expr::ScalarUDF { args, .. } => Ok(args.clone()), - Expr::WindowFunction { args, order_by, .. } => { + Expr::WindowFunction { + args, + partition_by, + order_by, + .. + } => { let mut expr_list: Vec = vec![]; expr_list.extend(args.clone()); + expr_list.push(lit(WINDOW_PARTITION_MARKER)); + expr_list.extend(partition_by.clone()); expr_list.push(lit(WINDOW_SORT_MARKER)); expr_list.extend(order_by.clone()); Ok(expr_list) @@ -340,7 +348,20 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { Expr::WindowFunction { fun, window_frame, .. } => { - let index = expressions + let partition_index = expressions + .iter() + .position(|expr| { + matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str))) + if str == WINDOW_PARTITION_MARKER) + }) + .ok_or_else(|| { + DataFusionError::Internal( + "Ill-formed window function expressions: unexpected marker" + .to_owned(), + ) + })?; + + let sort_index = expressions .iter() .position(|expr| { matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str))) @@ -351,12 +372,21 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { "Ill-formed window function expressions".to_owned(), ) })?; - Ok(Expr::WindowFunction { - fun: fun.clone(), - args: expressions[..index].to_vec(), - order_by: expressions[index + 1..].to_vec(), - window_frame: *window_frame, - }) + + if partition_index >= sort_index { + Err(DataFusionError::Internal( + "Ill-formed window function expressions: partition index too large" + .to_owned(), + )) + } else { + Ok(Expr::WindowFunction { + fun: fun.clone(), + args: expressions[..partition_index].to_vec(), + partition_by: expressions[partition_index + 1..sort_index].to_vec(), + order_by: expressions[sort_index + 1..].to_vec(), + window_frame: *window_frame, + }) + } } Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { fun: fun.clone(), diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 7df0068c5f54..53f22ecaf3f2 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1122,52 +1122,53 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // then, window function if let Some(window) = &function.over { - if window.partition_by.is_empty() { - let order_by = window - .order_by - .iter() - .map(|e| self.order_by_to_sort_expr(e)) - .into_iter() - .collect::>>()?; - let window_frame = window - .window_frame - .as_ref() - .map(|window_frame| window_frame.clone().try_into()) - .transpose()?; - let fun = window_functions::WindowFunction::from_str(&name); - if let Ok(window_functions::WindowFunction::AggregateFunction( + let partition_by = window + .partition_by + .iter() + .map(|e| self.sql_expr_to_logical_expr(e)) + .into_iter() + .collect::>>()?; + let order_by = window + .order_by + .iter() + .map(|e| self.order_by_to_sort_expr(e)) + .into_iter() + .collect::>>()?; + let window_frame = window + .window_frame + .as_ref() + .map(|window_frame| window_frame.clone().try_into()) + .transpose()?; + let fun = window_functions::WindowFunction::from_str(&name)?; + match fun { + window_functions::WindowFunction::AggregateFunction( aggregate_fun, - )) = fun - { + ) => { return Ok(Expr::WindowFunction { fun: window_functions::WindowFunction::AggregateFunction( aggregate_fun.clone(), ), args: self .aggregate_fn_to_expr(&aggregate_fun, function)?, + partition_by, order_by, window_frame, }); - } else if let Ok( - window_functions::WindowFunction::BuiltInWindowFunction( - window_fun, - ), - ) = fun - { + } + window_functions::WindowFunction::BuiltInWindowFunction( + window_fun, + ) => { return Ok(Expr::WindowFunction { fun: window_functions::WindowFunction::BuiltInWindowFunction( window_fun, ), args: self.function_args_to_expr(function)?, + partition_by, order_by, window_frame, }); } } - return Err(DataFusionError::NotImplemented(format!( - "Unsupported OVER clause ({})", - window - ))); } // next, aggregate built-ins @@ -2775,7 +2776,7 @@ mod tests { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; let expected = "\ Projection: #order_id, #MAX(order_id)\ - \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#order_id)]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2785,7 +2786,7 @@ mod tests { let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders"; let expected = "\ Projection: #order_id AS oid, #MAX(order_id) AS max_oid\ - \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#order_id)]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2795,7 +2796,7 @@ mod tests { let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders"; let expected = "\ Projection: #order_id, #MAX(qty Multiply Float64(1.1))\ - \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2806,20 +2807,29 @@ mod tests { "SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #MIN(qty), #AVG(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------- + /// WindowAgg (cost=69.83..87.33 rows=1000 width=8) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` #[test] - fn over_partition_by_not_supported() { - let sql = - "SELECT order_id, MAX(delivered) OVER (PARTITION BY order_id) from orders"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "NotImplemented(\"Unsupported OVER clause (PARTITION BY order_id)\")", - format!("{:?}", err) - ); + fn over_partition_by() { + let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); } /// psql result @@ -2839,9 +2849,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2852,9 +2862,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2865,9 +2875,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2878,9 +2888,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ Projection: #order_id, #MAX(qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2903,9 +2913,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY (order_id + 1)) from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id Plus Int64(1) ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2929,10 +2939,10 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[SUM(#qty)]]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2956,10 +2966,10 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[SUM(#qty)]]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2987,15 +2997,108 @@ mod tests { let expected = "\ Sort: #order_id ASC NULLS FIRST\ \n Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[SUM(#qty)]]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------- + /// WindowAgg (cost=69.83..89.83 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id, qty + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + #[test] + fn over_partition_by_order_by() { + let sql = + "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------- + /// WindowAgg (cost=69.83..89.83 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id, qty + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + #[test] + fn over_partition_by_order_by_no_dup() { + let sql = + "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------------------- + /// WindowAgg (cost=142.16..162.16 rows=1000 width=16) + /// -> Sort (cost=142.16..144.66 rows=1000 width=12) + /// Sort Key: qty, order_id + /// -> WindowAgg (cost=69.83..92.33 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id, qty + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + #[test] + fn over_partition_by_order_by_mix_up() { + let sql = + "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty), MIN(qty) OVER (PARTITION BY qty ORDER BY order_id) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ + \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ----------------------------------------------------------------------------- + /// WindowAgg (cost=69.83..109.83 rows=1000 width=24) + /// -> WindowAgg (cost=69.83..92.33 rows=1000 width=20) + /// -> Sort (cost=69.83..72.33 rows=1000 width=16) + /// Sort Key: order_id, qty, price + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=16) + /// ``` + /// FIXME: for now we are not detecting prefix of sorting keys in order to save one sort exec phase + #[test] + fn over_partition_by_order_by_mix_up_prefix() { + let sql = + "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty), MIN(qty) OVER (PARTITION BY order_id, qty ORDER BY price) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST, #price ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + #[test] fn only_union_all_supported() { let sql = "SELECT order_id from orders EXCEPT SELECT order_id FROM orders"; diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 848fb3ee31fc..5e9b9526ea83 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -239,6 +239,7 @@ where Expr::WindowFunction { fun, args, + partition_by, order_by, window_frame, } => Ok(Expr::WindowFunction { @@ -247,6 +248,10 @@ where .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, + partition_by: partition_by + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, order_by: order_by .iter() .map(|e| clone_with_replacement(e, replacement_fn)) @@ -432,19 +437,38 @@ pub(crate) fn resolve_aliases_to_exprs( }) } +type WindowSortKey = Vec; + +fn generate_sort_key(partition_by: &[Expr], order_by: &[Expr]) -> WindowSortKey { + let mut sort_key = vec![]; + partition_by.iter().for_each(|e| { + let e = e.clone().sort(true, true); + if !sort_key.contains(&e) { + sort_key.push(e); + } + }); + order_by.iter().for_each(|e| { + if !sort_key.contains(&e) { + sort_key.push(e.clone()); + } + }); + sort_key +} + /// group a slice of window expression expr by their order by expressions pub(crate) fn group_window_expr_by_sort_keys( window_expr: &[Expr], -) -> Result)>> { +) -> Result)>> { let mut result = vec![]; window_expr.iter().try_for_each(|expr| match expr { - Expr::WindowFunction { order_by, .. } => { + Expr::WindowFunction { partition_by, order_by, .. } => { + let sort_key = generate_sort_key(partition_by, order_by); if let Some((_, values)) = result.iter_mut().find( - |group: &&mut (&[Expr], Vec<&Expr>)| matches!(group, (key, _) if key == order_by), + |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key), ) { values.push(expr); } else { - result.push((order_by, vec![expr])) + result.push((sort_key, vec![expr])) } Ok(()) } @@ -466,7 +490,7 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { let result = group_window_expr_by_sort_keys(&[])?; - let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![]; + let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![]; assert_eq!(expected, result); Ok(()) } @@ -476,32 +500,35 @@ mod tests { let max1 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![], window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![], window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], + partition_by: vec![], order_by: vec![], window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], + partition_by: vec![], order_by: vec![], window_frame: None, }; - // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs)?; - let key = &[]; - let expected: Vec<(&[Expr], Vec<&Expr>)> = + let key = vec![]; + let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![(key, vec![&max1, &max2, &min3, &sum4])]; assert_eq!(expected, result); Ok(()) @@ -527,24 +554,28 @@ mod tests { let max1 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![age_asc.clone(), name_desc.clone()], window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![], window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], + partition_by: vec![], order_by: vec![age_asc.clone(), name_desc.clone()], window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], + partition_by: vec![], order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], window_frame: None, }; @@ -552,11 +583,11 @@ mod tests { let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs)?; - let key1 = &[age_asc.clone(), name_desc.clone()]; - let key2 = &[]; - let key3 = &[name_desc, age_asc, created_at_desc]; + let key1 = vec![age_asc.clone(), name_desc.clone()]; + let key2 = vec![]; + let key3 = vec![name_desc, age_asc, created_at_desc]; - let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![ + let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![ (key1, vec![&max1, &min3]), (key2, vec![&max2]), (key3, vec![&sum4]), @@ -571,6 +602,7 @@ mod tests { Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![ Expr::Sort { expr: Box::new(col("age")), @@ -588,6 +620,7 @@ mod tests { Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], + partition_by: vec![], order_by: vec![ Expr::Sort { expr: Box::new(col("name")), From 8f84564edab1679163d91691f63381f38907d515 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Jun 2021 09:18:15 -0400 Subject: [PATCH 12/25] Return errors properly from RepartitionExec (#521) --- datafusion/src/physical_plan/repartition.rs | 205 ++++++++++++++++++-- datafusion/src/test/exec.rs | 183 ++++++++++++++++- 2 files changed, 372 insertions(+), 16 deletions(-) diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index e5747dda88b7..37d98c7d118b 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -147,12 +147,13 @@ impl ExecutionPlan for RepartitionExec { let fetch_time = self.fetch_time_nanos.clone(); let repart_time = self.repart_time_nanos.clone(); let send_time = self.send_time_nanos.clone(); - let mut txs: HashMap<_, _> = channels + let txs: HashMap<_, _> = channels .iter() .map(|(partition, (tx, _rx))| (*partition, tx.clone())) .collect(); let partitioning = self.partitioning.clone(); - let _: JoinHandle> = tokio::spawn(async move { + let mut txs_captured = txs.clone(); + let input_task: JoinHandle> = tokio::spawn(async move { // execute the child operator let now = Instant::now(); let mut stream = input.execute(i).await?; @@ -170,13 +171,13 @@ impl ExecutionPlan for RepartitionExec { if result.is_none() { break; } - let result = result.unwrap(); + let result: ArrowResult = result.unwrap(); match &partitioning { Partitioning::RoundRobinBatch(_) => { let now = Instant::now(); let output_partition = counter % num_output_partitions; - let tx = txs.get_mut(&output_partition).unwrap(); + let tx = txs_captured.get_mut(&output_partition).unwrap(); tx.send(Some(result)).map_err(|e| { DataFusionError::Execution(e.to_string()) })?; @@ -230,7 +231,9 @@ impl ExecutionPlan for RepartitionExec { ); repart_time.add(now.elapsed().as_nanos() as usize); let now = Instant::now(); - let tx = txs.get_mut(&num_output_partition).unwrap(); + let tx = txs_captured + .get_mut(&num_output_partition) + .unwrap(); tx.send(Some(output_batch)).map_err(|e| { DataFusionError::Execution(e.to_string()) })?; @@ -249,13 +252,12 @@ impl ExecutionPlan for RepartitionExec { counter += 1; } - // notify each output partition that this input partition has no more data - for (_, tx) in txs { - tx.send(None) - .map_err(|e| DataFusionError::Execution(e.to_string()))?; - } Ok(()) }); + + // In a separate task, wait for each input to be done + // (and pass along any errors) + tokio::spawn(async move { Self::wait_for_task(input_task, txs).await }); } } @@ -308,6 +310,45 @@ impl RepartitionExec { send_time_nanos: SQLMetric::time_nanos(), }) } + + /// Waits for `input_task` which is consuming one of the inputs to + /// complete. Upon each successful completion, sends a `None` to + /// each of the output tx channels to signal one of the inputs is + /// complete. Upon error, propagates the errors to all output tx + /// channels. + async fn wait_for_task( + input_task: JoinHandle>, + txs: HashMap>>>, + ) { + // wait for completion, and propagate error + // note we ignore errors on send (.ok) as that means the receiver has already shutdown. + match input_task.await { + // Error in joining task + Err(e) => { + for (_, tx) in txs { + let err = DataFusionError::Execution(format!("Join Error: {}", e)); + let err = Err(err.into_arrow_external_error()); + tx.send(Some(err)).ok(); + } + } + // Error from running input task + Ok(Err(e)) => { + for (_, tx) in txs { + // wrap it because need to send error to all output partitions + let err = DataFusionError::Execution(e.to_string()); + let err = Err(err.into_arrow_external_error()); + tx.send(Some(err)).ok(); + } + } + // Input task completed successfully + Ok(Ok(())) => { + // notify each output partition that this input partition has no more data + for (_, tx) in txs { + tx.send(None).ok(); + } + } + } + } } struct RepartitionStream { @@ -356,10 +397,17 @@ impl RecordBatchStream for RepartitionStream { #[cfg(test)] mod tests { use super::*; - use crate::physical_plan::memory::MemoryExec; - use arrow::array::UInt32Array; + use crate::{ + assert_batches_sorted_eq, + physical_plan::memory::MemoryExec, + test::exec::{ErrorExec, MockExec}, + }; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow::{ + array::{ArrayRef, StringArray, UInt32Array}, + error::ArrowError, + }; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -517,4 +565,137 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn unsupported_partitioning() { + // have to send at least one batch through to provoke error + let batch = RecordBatch::try_from_iter(vec![( + "my_awesome_field", + Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef, + )]) + .unwrap(); + + let schema = batch.schema(); + let input = MockExec::new(vec![Ok(batch)], schema); + // This generates an error (partitioning type not supported) + // but only after the plan is executed. The error should be + // returned and no results produced + let partitioning = Partitioning::UnknownPartitioning(1); + let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); + let output_stream = exec.execute(0).await.unwrap(); + + // Expect that an error is returned + let result_string = crate::physical_plan::common::collect(output_stream) + .await + .unwrap_err() + .to_string(); + assert!( + result_string + .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"), + "actual: {}", + result_string + ); + } + + #[tokio::test] + async fn error_for_input_exec() { + // This generates an error on a call to execute. The error + // should be returned and no results produced. + + let input = ErrorExec::new(); + let partitioning = Partitioning::RoundRobinBatch(1); + let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); + + // Note: this should pass (the stream can be created) but the + // error when the input is executed should get passed back + let output_stream = exec.execute(0).await.unwrap(); + + // Expect that an error is returned + let result_string = crate::physical_plan::common::collect(output_stream) + .await + .unwrap_err() + .to_string(); + assert!( + result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"), + "actual: {}", + result_string + ); + } + + #[tokio::test] + async fn repartition_with_error_in_stream() { + let batch = RecordBatch::try_from_iter(vec![( + "my_awesome_field", + Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef, + )]) + .unwrap(); + + // input stream returns one good batch and then one error. The + // error should be returned. + let err = Err(ArrowError::ComputeError("bad data error".to_string())); + + let schema = batch.schema(); + let input = MockExec::new(vec![Ok(batch), err], schema); + let partitioning = Partitioning::RoundRobinBatch(1); + let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); + + // Note: this should pass (the stream can be created) but the + // error when the input is executed should get passed back + let output_stream = exec.execute(0).await.unwrap(); + + // Expect that an error is returned + let result_string = crate::physical_plan::common::collect(output_stream) + .await + .unwrap_err() + .to_string(); + assert!( + result_string.contains("bad data error"), + "actual: {}", + result_string + ); + } + + #[tokio::test] + async fn repartition_with_delayed_stream() { + let batch1 = RecordBatch::try_from_iter(vec![( + "my_awesome_field", + Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef, + )]) + .unwrap(); + + let batch2 = RecordBatch::try_from_iter(vec![( + "my_awesome_field", + Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef, + )]) + .unwrap(); + + // The mock exec doesn't return immediately (instead it + // requires the input to wait at least once) + let schema = batch1.schema(); + let expected_batches = vec![batch1.clone(), batch2.clone()]; + let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema); + let partitioning = Partitioning::RoundRobinBatch(1); + + let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); + + let expected = vec![ + "+------------------+", + "| my_awesome_field |", + "+------------------+", + "| foo |", + "| bar |", + "| frob |", + "| baz |", + "+------------------+", + ]; + + assert_batches_sorted_eq!(&expected, &expected_batches); + + let output_stream = exec.execute(0).await.unwrap(); + let batches = crate::physical_plan::common::collect(output_stream) + .await + .unwrap(); + + assert_batches_sorted_eq!(&expected, &batches); + } } diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs index 04cd29530c01..bcd94dd6d639 100644 --- a/datafusion/src/test/exec.rs +++ b/datafusion/src/test/exec.rs @@ -17,14 +17,25 @@ //! Simple iterator over batches for use in testing -use std::task::{Context, Poll}; +use async_trait::async_trait; +use std::{ + any::Any, + sync::Arc, + task::{Context, Poll}, +}; use arrow::{ - datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, + datatypes::{DataType, Field, Schema, SchemaRef}, + error::{ArrowError, Result as ArrowResult}, + record_batch::RecordBatch, }; -use futures::Stream; +use futures::{Stream, StreamExt}; +use tokio_stream::wrappers::ReceiverStream; -use crate::physical_plan::RecordBatchStream; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, +}; /// Index into the data that has been returned so far #[derive(Debug, Default, Clone)] @@ -100,3 +111,167 @@ impl RecordBatchStream for TestStream { self.data[0].schema() } } + +/// A Mock ExecutionPlan that can be used for writing tests of other ExecutionPlans +/// +#[derive(Debug)] +pub struct MockExec { + /// the results to send back + data: Vec>, + schema: SchemaRef, +} + +impl MockExec { + /// Create a new exec with a single partition that returns the + /// record batches in this Exec. Note the batches are not produced + /// immediately (the caller has to actually yield and another task + /// must run) to ensure any poll loops are correct. + pub fn new(data: Vec>, schema: SchemaRef) -> Self { + Self { data, schema } + } +} + +#[async_trait] +impl ExecutionPlan for MockExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn children(&self) -> Vec> { + unimplemented!() + } + + fn with_new_children( + &self, + _children: Vec>, + ) -> Result> { + unimplemented!() + } + + /// Returns a stream which yields data + async fn execute(&self, partition: usize) -> Result { + assert_eq!(partition, 0); + + let schema = self.schema(); + + // Result doesn't implement clone, so do it ourself + let data: Vec<_> = self + .data + .iter() + .map(|r| match r { + Ok(batch) => Ok(batch.clone()), + Err(e) => Err(clone_error(e)), + }) + .collect(); + + let (tx, rx) = tokio::sync::mpsc::channel(2); + + // task simply sends data in order but in a separate + // thread (to ensure the batches are not available without the + // DelayedStream yielding). + tokio::task::spawn(async move { + for batch in data { + println!("Sending batch via delayed stream"); + if let Err(e) = tx.send(batch).await { + println!("ERROR batch via delayed stream: {}", e); + } + } + }); + + // returned stream simply reads off the rx stream + let stream = DelayedStream { + schema, + inner: ReceiverStream::new(rx), + }; + Ok(Box::pin(stream)) + } +} + +fn clone_error(e: &ArrowError) -> ArrowError { + use ArrowError::*; + match e { + ComputeError(msg) => ComputeError(msg.to_string()), + _ => unimplemented!(), + } +} + +#[derive(Debug)] +pub struct DelayedStream { + schema: SchemaRef, + inner: ReceiverStream>, +} + +impl Stream for DelayedStream { + type Item = ArrowResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.inner.poll_next_unpin(cx) + } +} + +impl RecordBatchStream for DelayedStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +/// A mock execution plan that errors on a call to execute +#[derive(Debug)] +pub struct ErrorExec { + schema: SchemaRef, +} +impl ErrorExec { + pub fn new() -> Self { + let schema = Arc::new(Schema::new(vec![Field::new( + "dummy", + DataType::Int64, + true, + )])); + Self { schema } + } +} + +#[async_trait] +impl ExecutionPlan for ErrorExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn children(&self) -> Vec> { + unimplemented!() + } + + fn with_new_children( + &self, + _children: Vec>, + ) -> Result> { + unimplemented!() + } + + /// Returns a stream which yields data + async fn execute(&self, partition: usize) -> Result { + Err(DataFusionError::Internal(format!( + "ErrorExec, unsurprisingly, errored in partition {}", + partition + ))) + } +} From 77775b77967a1912b2a423618e4eaa44192bdc23 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 10 Jun 2021 22:55:59 +0800 Subject: [PATCH 13/25] add boundary check (#530) --- .../core/src/serde/logical_plan/from_proto.rs | 10 +- datafusion/src/logical_plan/window_frames.rs | 35 ++ datafusion/src/physical_plan/mod.rs | 1 - datafusion/src/physical_plan/window_frames.rs | 337 ------------------ datafusion/src/sql/planner.rs | 58 ++- 5 files changed, 95 insertions(+), 346 deletions(-) delete mode 100644 datafusion/src/physical_plan/window_frames.rs diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 86daeb063c47..894a5f0a7d98 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -927,10 +927,18 @@ impl TryInto for &protobuf::LogicalExprNode { .as_ref() .map::, _>(|e| match e { window_expr_node::WindowFrame::Frame(frame) => { - frame.clone().try_into() + let window_frame: WindowFrame = frame.clone().try_into()?; + if WindowFrameUnits::Range == window_frame.units + && order_by.len() != 1 + { + Err(proto_error("With window frame of type RANGE, the order by expression must be of length 1")) + } else { + Ok(window_frame) + } } }) .transpose()?; + match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = protobuf::AggregateFunction::from_i32(*i) diff --git a/datafusion/src/logical_plan/window_frames.rs b/datafusion/src/logical_plan/window_frames.rs index f0be5a221fbf..8aaebd3155c1 100644 --- a/datafusion/src/logical_plan/window_frames.rs +++ b/datafusion/src/logical_plan/window_frames.rs @@ -82,6 +82,22 @@ impl TryFrom for WindowFrame { ))) } else { let units = value.units.into(); + if units == WindowFrameUnits::Range { + for bound in &[start_bound, end_bound] { + match bound { + WindowFrameBound::Preceding(Some(v)) + | WindowFrameBound::Following(Some(v)) + if *v > 0 => + { + Err(DataFusionError::NotImplemented(format!( + "With WindowFrameUnits={}, the bound cannot be {} PRECEDING or FOLLOWING at the moment", + units, v + ))) + } + _ => Ok(()), + }?; + } + } Ok(Self { units, start_bound, @@ -270,6 +286,25 @@ mod tests { result.err().unwrap().to_string(), "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(Some(2)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "This feature is not implemented: With WindowFrameUnits=RANGE, the bound cannot be 2 PRECEDING or FOLLOWING at the moment".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Rows, + start_bound: ast::WindowFrameBound::Preceding(Some(2)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + }; + let result = WindowFrame::try_from(window_frame); + assert!(result.is_ok()); Ok(()) } diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 490e02875c42..af6969c43cbd 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -617,6 +617,5 @@ pub mod udf; #[cfg(feature = "unicode_expressions")] pub mod unicode_expressions; pub mod union; -pub mod window_frames; pub mod window_functions; pub mod windows; diff --git a/datafusion/src/physical_plan/window_frames.rs b/datafusion/src/physical_plan/window_frames.rs deleted file mode 100644 index f0be5a221fbf..000000000000 --- a/datafusion/src/physical_plan/window_frames.rs +++ /dev/null @@ -1,337 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Window frame -//! -//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: -//! - A frame type - either ROWS, RANGE or GROUPS, -//! - A starting frame boundary, -//! - An ending frame boundary, -//! - An EXCLUDE clause. - -use crate::error::{DataFusionError, Result}; -use sqlparser::ast; -use std::cmp::Ordering; -use std::convert::{From, TryFrom}; -use std::fmt; - -/// The frame-spec determines which output rows are read by an aggregate window function. -/// -/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the -/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to -/// CURRENT ROW. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct WindowFrame { - /// A frame type - either ROWS, RANGE or GROUPS - pub units: WindowFrameUnits, - /// A starting frame boundary - pub start_bound: WindowFrameBound, - /// An ending frame boundary - pub end_bound: WindowFrameBound, -} - -impl fmt::Display for WindowFrame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{} BETWEEN {} AND {}", - self.units, self.start_bound, self.end_bound - )?; - Ok(()) - } -} - -impl TryFrom for WindowFrame { - type Error = DataFusionError; - - fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.into(); - let end_bound = value - .end_bound - .map(WindowFrameBound::from) - .unwrap_or(WindowFrameBound::CurrentRow); - - if let WindowFrameBound::Following(None) = start_bound { - Err(DataFusionError::Execution( - "Invalid window frame: start bound cannot be unbounded following" - .to_owned(), - )) - } else if let WindowFrameBound::Preceding(None) = end_bound { - Err(DataFusionError::Execution( - "Invalid window frame: end bound cannot be unbounded preceding" - .to_owned(), - )) - } else if start_bound > end_bound { - Err(DataFusionError::Execution(format!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - start_bound, end_bound - ))) - } else { - let units = value.units.into(); - Ok(Self { - units, - start_bound, - end_bound, - }) - } - } -} - -impl Default for WindowFrame { - fn default() -> Self { - WindowFrame { - units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(None), - end_bound: WindowFrameBound::CurrentRow, - } - } -} - -/// There are five ways to describe starting and ending frame boundaries: -/// -/// 1. UNBOUNDED PRECEDING -/// 2. PRECEDING -/// 3. CURRENT ROW -/// 4. FOLLOWING -/// 5. UNBOUNDED FOLLOWING -/// -/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) -#[derive(Debug, Clone, Copy, Eq)] -pub enum WindowFrameBound { - /// 1. UNBOUNDED PRECEDING - /// The frame boundary is the first row in the partition. - /// - /// 2. PRECEDING - /// must be a non-negative constant numeric expression. The boundary is a row that - /// is "units" prior to the current row. - Preceding(Option), - /// 3. The current row. - /// - /// For RANGE and GROUPS frame types, peers of the current row are also - /// included in the frame, unless specifically excluded by the EXCLUDE clause. - /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame - /// boundary. - CurrentRow, - /// 4. This is the same as " PRECEDING" except that the boundary is units after the - /// current rather than before the current row. - /// - /// 5. UNBOUNDED FOLLOWING - /// The frame boundary is the last row in the partition. - Following(Option), -} - -impl From for WindowFrameBound { - fn from(value: ast::WindowFrameBound) -> Self { - match value { - ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), - ast::WindowFrameBound::Following(v) => Self::Following(v), - ast::WindowFrameBound::CurrentRow => Self::CurrentRow, - } - } -} - -impl fmt::Display for WindowFrameBound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), - WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), - WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), - WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), - WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), - } - } -} - -impl PartialEq for WindowFrameBound { - fn eq(&self, other: &Self) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl PartialOrd for WindowFrameBound { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for WindowFrameBound { - fn cmp(&self, other: &Self) -> Ordering { - self.get_rank().cmp(&other.get_rank()) - } -} - -impl WindowFrameBound { - /// get the rank of this window frame bound. - /// - /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value - /// which requires special handling e.g. with preceding the larger the value the smaller the - /// rank and also for 0 preceding / following it is the same as current row - fn get_rank(&self) -> (u8, u64) { - match self { - WindowFrameBound::Preceding(None) => (0, 0), - WindowFrameBound::Following(None) => (4, 0), - WindowFrameBound::Preceding(Some(0)) - | WindowFrameBound::CurrentRow - | WindowFrameBound::Following(Some(0)) => (2, 0), - WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), - WindowFrameBound::Following(Some(v)) => (3, *v), - } - } -} - -/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the -/// starting and ending boundaries of the frame are measured. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WindowFrameUnits { - /// The ROWS frame type means that the starting and ending boundaries for the frame are - /// determined by counting individual rows relative to the current row. - Rows, - /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one - /// term. Call that term "X". With the RANGE frame type, the elements of the frame are - /// determined by computing the value of expression X for all rows in the partition and framing - /// those rows for which the value of X is within a certain range of the value of X for the - /// current row. - Range, - /// The GROUPS frame type means that the starting and ending boundaries are determine - /// by counting "groups" relative to the current group. A "group" is a set of rows that all have - /// equivalent values for all all terms of the window ORDER BY clause. - Groups, -} - -impl fmt::Display for WindowFrameUnits { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(match self { - WindowFrameUnits::Rows => "ROWS", - WindowFrameUnits::Range => "RANGE", - WindowFrameUnits::Groups => "GROUPS", - }) - } -} - -impl From for WindowFrameUnits { - fn from(value: ast::WindowFrameUnits) -> Self { - match value { - ast::WindowFrameUnits::Range => Self::Range, - ast::WindowFrameUnits::Groups => Self::Groups, - ast::WindowFrameUnits::Rows => Self::Rows, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_window_frame_creation() -> Result<()> { - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Following(None), - end_bound: None, - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(None), - end_bound: Some(ast::WindowFrameBound::Preceding(None)), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(Some(1)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() - ); - Ok(()) - } - - #[test] - fn test_eq() { - assert_eq!( - WindowFrameBound::Preceding(Some(0)), - WindowFrameBound::CurrentRow - ); - assert_eq!( - WindowFrameBound::CurrentRow, - WindowFrameBound::Following(Some(0)) - ); - assert_eq!( - WindowFrameBound::Following(Some(2)), - WindowFrameBound::Following(Some(2)) - ); - assert_eq!( - WindowFrameBound::Following(None), - WindowFrameBound::Following(None) - ); - assert_eq!( - WindowFrameBound::Preceding(Some(2)), - WindowFrameBound::Preceding(Some(2)) - ); - assert_eq!( - WindowFrameBound::Preceding(None), - WindowFrameBound::Preceding(None) - ); - } - - #[test] - fn test_ord() { - assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); - // ! yes this is correct! - assert!( - WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) - ); - assert!( - WindowFrameBound::Preceding(Some(u64::MAX)) - < WindowFrameBound::Preceding(Some(u64::MAX - 1)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(1000000)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(u64::MAX)) - ); - assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); - assert!( - WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) - ); - assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); - assert!( - WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) - ); - assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); - assert!( - WindowFrameBound::Following(Some(u64::MAX)) - < WindowFrameBound::Following(None) - ); - } -} diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 53f22ecaf3f2..c128634091a0 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -19,6 +19,7 @@ use crate::catalog::TableReference; use crate::datasource::TableProvider; +use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ and, lit, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, Operator, PlanType, @@ -1137,7 +1138,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let window_frame = window .window_frame .as_ref() - .map(|window_frame| window_frame.clone().try_into()) + .map(|window_frame| { + let window_frame: WindowFrame = window_frame.clone().try_into()?; + if WindowFrameUnits::Range == window_frame.units + && order_by.len() != 1 + { + Err(DataFusionError::Plan(format!( + "With window frame of type RANGE, the order by expression must be of length 1, got {}", order_by.len()))) + } else { + Ok(window_frame) + } + + }) .transpose()?; let fun = window_functions::WindowFunction::from_str(&name)?; match fun { @@ -2859,10 +2871,10 @@ mod tests { #[test] fn over_order_by_with_window_frame_double_end() { - let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ + Projection: #order_id, #MAX(qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ \n Sort: #order_id ASC NULLS FIRST\ \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ @@ -2872,10 +2884,10 @@ mod tests { #[test] fn over_order_by_with_window_frame_single_end() { - let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]]\ + Projection: #order_id, #MAX(qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ \n Sort: #order_id ASC NULLS FIRST\ \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ @@ -2883,6 +2895,38 @@ mod tests { quick_test(sql, expected); } + #[test] + fn over_order_by_with_window_frame_range_value_check() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING) from orders"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "NotImplemented(\"With WindowFrameUnits=RANGE, the bound cannot be 3 PRECEDING or FOLLOWING at the moment\")", + format!("{:?}", err) + ); + } + + #[test] + fn over_order_by_with_window_frame_range_order_by_check() { + let sql = + "SELECT order_id, MAX(qty) OVER (RANGE UNBOUNDED PRECEDING) from orders"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "Plan(\"With window frame of type RANGE, the order by expression must be of length 1, got 0\")", + format!("{:?}", err) + ); + } + + #[test] + fn over_order_by_with_window_frame_range_order_by_check_2() { + let sql = + "SELECT order_id, MAX(qty) OVER (ORDER BY order_id, qty RANGE UNBOUNDED PRECEDING) from orders"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "Plan(\"With window frame of type RANGE, the order by expression must be of length 1, got 2\")", + format!("{:?}", err) + ); + } + #[test] fn over_order_by_with_window_frame_single_end_groups() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; From 5c88450a0286c98cdd4b0679f6b09b7eee1c3570 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 10 Jun 2021 22:58:19 +0800 Subject: [PATCH 14/25] remove redundant into_iter() calls (#527) --- ballista/rust/core/src/serde/logical_plan/from_proto.rs | 2 -- datafusion/src/physical_plan/windows.rs | 1 - datafusion/src/sql/planner.rs | 4 ---- 3 files changed, 7 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 894a5f0a7d98..c2c1001b939c 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -914,13 +914,11 @@ impl TryInto for &protobuf::LogicalExprNode { .partition_by .iter() .map(|e| e.try_into()) - .into_iter() .collect::, _>>()?; let order_by = expr .order_by .iter() .map(|e| e.try_into()) - .into_iter() .collect::, _>>()?; let window_frame = expr .window_frame diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 7eb14943facf..9a6b92985b51 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -348,7 +348,6 @@ fn window_aggregate_batch( .collect::>>()?; window_acc.scan_batch(batch.num_rows(), values) }) - .into_iter() .collect::>>() } diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index c128634091a0..860d21714ec6 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -714,7 +714,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let select_exprs = select_exprs .iter() .map(|expr| rebase_expr(expr, &window_exprs, &plan)) - .into_iter() .collect::>>()?; Ok((plan, select_exprs)) } @@ -811,7 +810,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by_rex = order_by .iter() .map(|e| self.order_by_to_sort_expr(e)) - .into_iter() .collect::>>()?; LogicalPlanBuilder::from(&plan).sort(order_by_rex)?.build() @@ -1127,13 +1125,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .partition_by .iter() .map(|e| self.sql_expr_to_logical_expr(e)) - .into_iter() .collect::>>()?; let order_by = window .order_by .iter() .map(|e| self.order_by_to_sort_expr(e)) - .into_iter() .collect::>>()?; let window_frame = window .window_frame From 3ef7f3495b9501f9a14db64a6ae4d923f681c649 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 11 Jun 2021 06:16:21 +0800 Subject: [PATCH 15/25] use nightly nightly-2021-05-10 (#536) --- .env | 2 +- .github/workflows/python_build.yml | 2 +- .github/workflows/python_test.yaml | 4 ++-- python/rust-toolchain | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.env b/.env index 4fb24bed40a1..05517d00f8e4 100644 --- a/.env +++ b/.env @@ -47,7 +47,7 @@ FEDORA=33 PYTHON=3.6 LLVM=11 CLANG_TOOLS=8 -RUST=nightly-2021-03-24 +RUST=nightly-2021-05-10 GO=1.15 NODE=14 MAVEN=3.5.4 diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml index eba11b8e3a41..1f083de7827f 100644 --- a/.github/workflows/python_build.yml +++ b/.github/workflows/python_build.yml @@ -39,7 +39,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: - toolchain: nightly-2021-01-06 + toolchain: nightly-2021-05-10 - name: Install dependencies run: | diff --git a/.github/workflows/python_test.yaml b/.github/workflows/python_test.yaml index e689396b5dcd..ebf5e9f594c0 100644 --- a/.github/workflows/python_test.yaml +++ b/.github/workflows/python_test.yaml @@ -25,8 +25,8 @@ jobs: - uses: actions/checkout@v2 - name: Setup Rust toolchain run: | - rustup toolchain install nightly-2021-01-06 - rustup default nightly-2021-01-06 + rustup toolchain install nightly-2021-05-10 + rustup default nightly-2021-05-10 rustup component add rustfmt - name: Cache Cargo uses: actions/cache@v2 diff --git a/python/rust-toolchain b/python/rust-toolchain index 9d0cf79d367d..6231a95e3036 100644 --- a/python/rust-toolchain +++ b/python/rust-toolchain @@ -1 +1 @@ -nightly-2021-01-06 +nightly-2021-05-10 From 63e3045c9e0dd0579ec2be92bb174401f898833f Mon Sep 17 00:00:00 2001 From: Ximo Guanter Date: Fri, 11 Jun 2021 17:45:27 +0200 Subject: [PATCH 16/25] Make BallistaContext::collect streaming (#535) --- ballista/rust/client/src/context.rs | 113 ++++++++++++++++++---------- 1 file changed, 72 insertions(+), 41 deletions(-) diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 4c0ab4244be3..4e5cc1a7a76b 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -24,21 +24,27 @@ use std::{collections::HashMap, convert::TryInto}; use std::{fs, time::Duration}; use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient; +use ballista_core::serde::protobuf::PartitionLocation; use ballista_core::serde::protobuf::{ execute_query_params::Query, job_status, ExecuteQueryParams, GetJobStatusParams, GetJobStatusResult, }; use ballista_core::{ - client::BallistaClient, datasource::DfTableAdapter, memory_stream::MemoryStream, - utils::create_datafusion_context, + client::BallistaClient, datasource::DfTableAdapter, utils::create_datafusion_context, }; use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::error::Result as ArrowResult; +use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::TableReference; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::csv::CsvReadOptions; use datafusion::{dataframe::DataFrame, physical_plan::RecordBatchStream}; +use futures::future; +use futures::Stream; +use futures::StreamExt; use log::{error, info}; #[allow(dead_code)] @@ -68,6 +74,32 @@ impl BallistaContextState { } } +struct WrappedStream { + stream: Pin> + Send + Sync>>, + schema: SchemaRef, +} + +impl RecordBatchStream for WrappedStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for WrappedStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.stream.poll_next_unpin(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + #[allow(dead_code)] pub struct BallistaContext { @@ -155,6 +187,29 @@ impl BallistaContext { ctx.sql(sql) } + async fn fetch_partition( + location: PartitionLocation, + ) -> Result>> { + let metadata = location.executor_meta.ok_or_else(|| { + DataFusionError::Internal("Received empty executor metadata".to_owned()) + })?; + let partition_id = location.partition_id.ok_or_else(|| { + DataFusionError::Internal("Received empty partition id".to_owned()) + })?; + let mut ballista_client = + BallistaClient::try_new(metadata.host.as_str(), metadata.port as u16) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + Ok(ballista_client + .fetch_partition( + &partition_id.job_id, + partition_id.stage_id as usize, + partition_id.partition_id as usize, + ) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?) + } + pub async fn collect( &self, plan: &LogicalPlan, @@ -222,45 +277,21 @@ impl BallistaContext { break Err(DataFusionError::Execution(msg)); } job_status::Status::Completed(completed) => { - // TODO: use streaming. Probably need to change the signature of fetch_partition to achieve that - let mut result = vec![]; - for location in completed.partition_location { - let metadata = location.executor_meta.ok_or_else(|| { - DataFusionError::Internal( - "Received empty executor metadata".to_owned(), - ) - })?; - let partition_id = location.partition_id.ok_or_else(|| { - DataFusionError::Internal( - "Received empty partition id".to_owned(), - ) - })?; - let mut ballista_client = BallistaClient::try_new( - metadata.host.as_str(), - metadata.port as u16, - ) - .await - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; - let stream = ballista_client - .fetch_partition( - &partition_id.job_id, - partition_id.stage_id as usize, - partition_id.partition_id as usize, - ) - .await - .map_err(|e| { - DataFusionError::Execution(format!("{:?}", e)) - })?; - result.append( - &mut datafusion::physical_plan::common::collect(stream) - .await?, - ); - } - break Ok(Box::pin(MemoryStream::try_new( - result, - Arc::new(schema), - None, - )?)); + let result = future::join_all( + completed + .partition_location + .into_iter() + .map(BallistaContext::fetch_partition), + ) + .await + .into_iter() + .collect::>>()?; + + let result = WrappedStream { + stream: Box::pin(futures::stream::iter(result).flatten()), + schema: Arc::new(schema), + }; + break Ok(Box::pin(result)); } }; } From ad70a1e91667174436f2110a70e3e557c7069e9a Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sat, 12 Jun 2021 02:50:23 +0800 Subject: [PATCH 17/25] reuse datafusion physical planner in ballista building from protobuf (#532) * use logical planner in ballista building * simplify statement * fix unit test * fix per comment --- .../src/serde/physical_plan/from_proto.rs | 142 ++++-------------- datafusion/src/physical_plan/planner.rs | 116 +++++++++++--- datafusion/src/physical_plan/windows.rs | 44 ++++-- 3 files changed, 153 insertions(+), 149 deletions(-) diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index b319d5b25f12..d49d53cf8d85 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -36,7 +36,7 @@ use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; use datafusion::logical_plan::{DFSchema, Expr}; -use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction}; +use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::expressions::col; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::PartitionMode; @@ -45,7 +45,6 @@ use datafusion::physical_plan::planner::DefaultPhysicalPlanner; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; -use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::windows::WindowAggExec; use datafusion::physical_plan::{ coalesce_batches::CoalesceBatchesExec, @@ -205,76 +204,27 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ) })? .clone(); - let physical_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); - - let catalog_list = - Arc::new(MemoryCatalogList::new()) as Arc; - let ctx_state = ExecutionContextState { - catalog_list, - scalar_functions: Default::default(), - var_provider: Default::default(), - aggregate_functions: Default::default(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - }; - + let ctx_state = ExecutionContextState::new(); let window_agg_expr: Vec<(Expr, String)> = window_agg .window_expr .iter() .zip(window_agg.window_expr_name.iter()) .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone()))) .collect::, _>>()?; - - let mut physical_window_expr = vec![]; - let df_planner = DefaultPhysicalPlanner::default(); - - for (expr, name) in &window_agg_expr { - match expr { - Expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - .. - } => { - let arg = df_planner - .create_physical_expr( - &args[0], - &physical_schema, - &ctx_state, - ) - .map_err(|e| { - BallistaError::General(format!("{:?}", e)) - })?; - if !partition_by.is_empty() { - return Err(BallistaError::NotImplemented("Window function with partition by is not yet implemented".to_owned())); - } - if !order_by.is_empty() { - return Err(BallistaError::NotImplemented("Window function with order by is not yet implemented".to_owned())); - } - if window_frame.is_some() { - return Err(BallistaError::NotImplemented("Window function with window frame is not yet implemented".to_owned())); - } - let window_expr = create_window_expr( - &fun, - &[arg], - &physical_schema, - name.to_owned(), - )?; - physical_window_expr.push(window_expr); - } - _ => { - return Err(BallistaError::General( - "Invalid expression for WindowAggrExec".to_string(), - )); - } - } - } - + let physical_window_expr = window_agg_expr + .iter() + .map(|(expr, name)| { + df_planner.create_window_expr_with_name( + expr, + name.to_string(), + &physical_schema, + &ctx_state, + ) + }) + .collect::, _>>()?; Ok(Arc::new(WindowAggExec::try_new( physical_window_expr, input, @@ -297,7 +247,6 @@ impl TryInto> for &protobuf::PhysicalPlanNode { AggregateMode::FinalPartitioned } }; - let group = hash_agg .group_expr .iter() @@ -306,25 +255,13 @@ impl TryInto> for &protobuf::PhysicalPlanNode { compile_expr(expr, &input.schema()).map(|e| (e, name.to_string())) }) .collect::, _>>()?; - let logical_agg_expr: Vec<(Expr, String)> = hash_agg .aggr_expr .iter() .zip(hash_agg.aggr_expr_name.iter()) .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone()))) .collect::, _>>()?; - - let catalog_list = - Arc::new(MemoryCatalogList::new()) as Arc; - let ctx_state = ExecutionContextState { - catalog_list, - scalar_functions: Default::default(), - var_provider: Default::default(), - aggregate_functions: Default::default(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - }; - + let ctx_state = ExecutionContextState::new(); let input_schema = hash_agg .input_schema .as_ref() @@ -336,37 +273,18 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .clone(); let physical_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); - - let mut physical_aggr_expr = vec![]; - let df_planner = DefaultPhysicalPlanner::default(); - for (expr, name) in &logical_agg_expr { - match expr { - Expr::AggregateFunction { fun, args, .. } => { - let arg = df_planner - .create_physical_expr( - &args[0], - &physical_schema, - &ctx_state, - ) - .map_err(|e| { - BallistaError::General(format!("{:?}", e)) - })?; - physical_aggr_expr.push(create_aggregate_expr( - &fun, - false, - &[arg], - &physical_schema, - name.to_string(), - )?); - } - _ => { - return Err(BallistaError::General( - "Invalid expression for HashAggregateExec".to_string(), - )) - } - } - } + let physical_aggr_expr = logical_agg_expr + .iter() + .map(|(expr, name)| { + df_planner.create_aggregate_expr_with_name( + expr, + name.to_string(), + &physical_schema, + &ctx_state, + ) + }) + .collect::, _>>()?; Ok(Arc::new(HashAggregateExec::try_new( agg_mode, group, @@ -484,15 +402,7 @@ fn compile_expr( schema: &Schema, ) -> Result, BallistaError> { let df_planner = DefaultPhysicalPlanner::default(); - let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; - let state = ExecutionContextState { - catalog_list, - scalar_functions: HashMap::new(), - var_provider: HashMap::new(), - aggregate_functions: HashMap::new(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - }; + let state = ExecutionContextState::new(); let expr: Expr = expr.try_into()?; df_planner .create_physical_expr(&expr, schema, &state) diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index d7451c787096..d42948a8666c 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -731,34 +731,82 @@ impl DefaultPhysicalPlanner { } } - /// Create a window expression from a logical expression - pub fn create_window_expr( + /// Create a window expression with a name from a logical expression + pub fn create_window_expr_with_name( &self, e: &Expr, - logical_input_schema: &DFSchema, + name: String, physical_input_schema: &Schema, ctx_state: &ExecutionContextState, ) -> Result> { - // unpack aliased logical expressions, e.g. "sum(col) over () as total" - let (name, e) = match e { - Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (e.name(logical_input_schema)?, e), - }; - match e { - Expr::WindowFunction { fun, args, .. } => { + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { let args = args .iter() .map(|e| { self.create_physical_expr(e, physical_input_schema, ctx_state) }) .collect::>>()?; - // if !order_by.is_empty() { - // return Err(DataFusionError::NotImplemented( - // "Window function with order by is not yet implemented".to_owned(), - // )); - // } - windows::create_window_expr(fun, &args, physical_input_schema, name) + let partition_by = partition_by + .iter() + .map(|e| { + self.create_physical_expr(e, physical_input_schema, ctx_state) + }) + .collect::>>()?; + let order_by = order_by + .iter() + .map(|e| match e { + Expr::Sort { + expr, + asc, + nulls_first, + } => self.create_physical_sort_expr( + expr, + &physical_input_schema, + SortOptions { + descending: !*asc, + nulls_first: *nulls_first, + }, + &ctx_state, + ), + _ => Err(DataFusionError::Plan( + "Sort only accepts sort expressions".to_string(), + )), + }) + .collect::>>()?; + if !partition_by.is_empty() { + return Err(DataFusionError::NotImplemented( + "window expression with non-empty partition by clause is not yet supported" + .to_owned(), + )); + } + if !order_by.is_empty() { + return Err(DataFusionError::NotImplemented( + "window expression with non-empty order by clause is not yet supported" + .to_owned(), + )); + } + if window_frame.is_some() { + return Err(DataFusionError::NotImplemented( + "window expression with window frame definition is not yet supported" + .to_owned(), + )); + } + windows::create_window_expr( + fun, + name, + &args, + &partition_by, + &order_by, + *window_frame, + physical_input_schema, + ) } other => Err(DataFusionError::Internal(format!( "Invalid window expression '{:?}'", @@ -767,20 +815,30 @@ impl DefaultPhysicalPlanner { } } - /// Create an aggregate expression from a logical expression - pub fn create_aggregate_expr( + /// Create a window expression from a logical expression or an alias + pub fn create_window_expr( &self, e: &Expr, logical_input_schema: &DFSchema, physical_input_schema: &Schema, ctx_state: &ExecutionContextState, - ) -> Result> { - // unpack aliased logical expressions, e.g. "sum(col) as total" + ) -> Result> { + // unpack aliased logical expressions, e.g. "sum(col) over () as total" let (name, e) = match e { Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), _ => (e.name(logical_input_schema)?, e), }; + self.create_window_expr_with_name(e, name, physical_input_schema, ctx_state) + } + /// Create an aggregate expression with a name from a logical expression + pub fn create_aggregate_expr_with_name( + &self, + e: &Expr, + name: String, + physical_input_schema: &Schema, + ctx_state: &ExecutionContextState, + ) -> Result> { match e { Expr::AggregateFunction { fun, @@ -819,7 +877,23 @@ impl DefaultPhysicalPlanner { } } - /// Create an aggregate expression from a logical expression + /// Create an aggregate expression from a logical expression or an alias + pub fn create_aggregate_expr( + &self, + e: &Expr, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + ctx_state: &ExecutionContextState, + ) -> Result> { + // unpack aliased logical expressions, e.g. "sum(col) as total" + let (name, e) = match e { + Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), + _ => (e.name(logical_input_schema)?, e), + }; + self.create_aggregate_expr_with_name(e, name, physical_input_schema, ctx_state) + } + + /// Create a physical sort expression from a logical expression pub fn create_physical_sort_expr( &self, e: &Expr, diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 9a6b92985b51..565a9eef2857 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -18,9 +18,11 @@ //! Execution plan for window functions use crate::error::{DataFusionError, Result}; + +use crate::logical_plan::window_frames::WindowFrame; use crate::physical_plan::{ aggregates, common, - expressions::{Literal, NthValue, RowNumber}, + expressions::{Literal, NthValue, PhysicalSortExpr, RowNumber}, type_coercion::coerce, window_functions::signature_for_built_in, window_functions::BuiltInWindowFunctionExpr, @@ -61,12 +63,18 @@ pub struct WindowAggExec { /// Create a physical expression for window function pub fn create_window_expr( fun: &WindowFunction, + name: String, args: &[Arc], + // https://github.com/apache/arrow-datafusion/issues/299 + _partition_by: &[Arc], + // https://github.com/apache/arrow-datafusion/issues/360 + _order_by: &[PhysicalSortExpr], + // https://github.com/apache/arrow-datafusion/issues/361 + _window_frame: Option, input_schema: &Schema, - name: String, ) -> Result> { - match fun { - WindowFunction::AggregateFunction(fun) => Ok(Arc::new(AggregateWindowExpr { + Ok(match fun { + WindowFunction::AggregateFunction(fun) => Arc::new(AggregateWindowExpr { aggregate: aggregates::create_aggregate_expr( fun, false, @@ -74,11 +82,11 @@ pub fn create_window_expr( input_schema, name, )?, - })), - WindowFunction::BuiltInWindowFunction(fun) => Ok(Arc::new(BuiltInWindowExpr { + }), + WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr { window: create_built_in_window_expr(fun, args, input_schema, name)?, - })), - } + }), + }) } fn create_built_in_window_expr( @@ -537,9 +545,12 @@ mod tests { let window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), &[col("c3")], + &[], + &[], + Some(WindowFrame::default()), schema.as_ref(), - "count".to_owned(), )?], input, schema.clone(), @@ -567,21 +578,30 @@ mod tests { vec![ create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), &[col("c3")], + &[], + &[], + Some(WindowFrame::default()), schema.as_ref(), - "count".to_owned(), )?, create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Max), + "max".to_owned(), &[col("c3")], + &[], + &[], + Some(WindowFrame::default()), schema.as_ref(), - "max".to_owned(), )?, create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Min), + "min".to_owned(), &[col("c3")], + &[], + &[], + Some(WindowFrame::default()), schema.as_ref(), - "min".to_owned(), )?, ], input, From 8f4078d83f7ea0348fa43906d26156bf8a95de4c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 12 Jun 2021 06:45:06 -0600 Subject: [PATCH 18/25] ShuffleReaderExec now supports multiple locations per partition (#541) * ShuffleReaderExec now supports multiple locations per partition * Remove TODO * avoid clone --- ballista/rust/client/src/context.rs | 39 +------- ballista/rust/core/proto/ballista.proto | 7 +- .../src/execution_plans/shuffle_reader.rs | 94 +++++++++++-------- .../src/serde/physical_plan/from_proto.rs | 12 ++- .../core/src/serde/physical_plan/to_proto.rs | 18 ++-- ballista/rust/core/src/utils.rs | 40 +++++++- ballista/rust/scheduler/src/planner.rs | 2 +- ballista/rust/scheduler/src/state/mod.rs | 6 +- 8 files changed, 130 insertions(+), 88 deletions(-) diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 4e5cc1a7a76b..695045d220d0 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -29,21 +29,18 @@ use ballista_core::serde::protobuf::{ execute_query_params::Query, job_status, ExecuteQueryParams, GetJobStatusParams, GetJobStatusResult, }; +use ballista_core::utils::WrappedStream; use ballista_core::{ client::BallistaClient, datasource::DfTableAdapter, utils::create_datafusion_context, }; use datafusion::arrow::datatypes::Schema; -use datafusion::arrow::datatypes::SchemaRef; -use datafusion::arrow::error::Result as ArrowResult; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::TableReference; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::csv::CsvReadOptions; use datafusion::{dataframe::DataFrame, physical_plan::RecordBatchStream}; use futures::future; -use futures::Stream; use futures::StreamExt; use log::{error, info}; @@ -74,32 +71,6 @@ impl BallistaContextState { } } -struct WrappedStream { - stream: Pin> + Send + Sync>>, - schema: SchemaRef, -} - -impl RecordBatchStream for WrappedStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl Stream for WrappedStream { - type Item = ArrowResult; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.stream.poll_next_unpin(cx) - } - - fn size_hint(&self) -> (usize, Option) { - self.stream.size_hint() - } -} - #[allow(dead_code)] pub struct BallistaContext { @@ -287,10 +258,10 @@ impl BallistaContext { .into_iter() .collect::>>()?; - let result = WrappedStream { - stream: Box::pin(futures::stream::iter(result).flatten()), - schema: Arc::new(schema), - }; + let result = WrappedStream::new( + Box::pin(futures::stream::iter(result).flatten()), + Arc::new(schema), + ); break Ok(Box::pin(result)); } }; diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 85af9023fb46..5aafd00cf1b0 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -489,10 +489,15 @@ message HashAggregateExecNode { } message ShuffleReaderExecNode { - repeated PartitionLocation partition_location = 1; + repeated ShuffleReaderPartition partition = 1; Schema schema = 2; } +message ShuffleReaderPartition { + // each partition of a shuffle read can read data from multiple locations + repeated PartitionLocation location = 1; +} + message GlobalLimitExecNode { PhysicalPlanNode input = 1; uint32 limit = 2; diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs index db29cf13b5fe..3a7f795f1a7f 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::fmt::Formatter; use std::sync::Arc; use std::{any::Any, pin::Pin}; @@ -22,35 +23,35 @@ use crate::client::BallistaClient; use crate::memory_stream::MemoryStream; use crate::serde::scheduler::PartitionLocation; +use crate::utils::WrappedStream; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::error::Result as ArrowResult; +use datafusion::arrow::record_batch::RecordBatch; use datafusion::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; use datafusion::{ error::{DataFusionError, Result}, physical_plan::RecordBatchStream, }; +use futures::{future, Stream, StreamExt}; use log::info; -use std::fmt::Formatter; -/// ShuffleReaderExec reads partitions that have already been materialized by an executor. +/// ShuffleReaderExec reads partitions that have already been materialized by a query stage +/// being executed by an executor #[derive(Debug, Clone)] pub struct ShuffleReaderExec { - // The query stage that is responsible for producing the shuffle partitions that - // this operator will read - pub(crate) partition_location: Vec, + /// Each partition of a shuffle can read data from multiple locations + pub(crate) partition: Vec>, pub(crate) schema: SchemaRef, } impl ShuffleReaderExec { /// Create a new ShuffleReaderExec pub fn try_new( - partition_meta: Vec, + partition: Vec>, schema: SchemaRef, ) -> Result { - Ok(Self { - partition_location: partition_meta, - schema, - }) + Ok(Self { partition, schema }) } } @@ -65,7 +66,7 @@ impl ExecutionPlan for ShuffleReaderExec { } fn output_partitioning(&self) -> Partitioning { - Partitioning::UnknownPartitioning(self.partition_location.len()) + Partitioning::UnknownPartitioning(self.partition.len()) } fn children(&self) -> Vec> { @@ -86,23 +87,18 @@ impl ExecutionPlan for ShuffleReaderExec { partition: usize, ) -> Result>> { info!("ShuffleReaderExec::execute({})", partition); - let partition_location = &self.partition_location[partition]; - - let mut client = BallistaClient::try_new( - &partition_location.executor_meta.host, - partition_location.executor_meta.port, - ) - .await - .map_err(|e| DataFusionError::Execution(format!("Ballista Error: {:?}", e)))?; - client - .fetch_partition( - &partition_location.partition_id.job_id, - partition_location.partition_id.stage_id, - partition, - ) + let partition_locations = &self.partition[partition]; + let result = future::join_all(partition_locations.iter().map(fetch_partition)) .await - .map_err(|e| DataFusionError::Execution(format!("Ballista Error: {:?}", e))) + .into_iter() + .collect::>>()?; + + let result = WrappedStream::new( + Box::pin(futures::stream::iter(result).flatten()), + Arc::new(self.schema.as_ref().clone()), + ); + Ok(Box::pin(result)) } fn fmt_as( @@ -113,22 +109,46 @@ impl ExecutionPlan for ShuffleReaderExec { match t { DisplayFormatType::Default => { let loc_str = self - .partition_location + .partition .iter() - .map(|l| { - format!( - "[executor={} part={}:{}:{} stats={:?}]", - l.executor_meta.id, - l.partition_id.job_id, - l.partition_id.stage_id, - l.partition_id.partition_id, - l.partition_stats - ) + .map(|x| { + x.iter() + .map(|l| { + format!( + "[executor={} part={}:{}:{} stats={:?}]", + l.executor_meta.id, + l.partition_id.job_id, + l.partition_id.stage_id, + l.partition_id.partition_id, + l.partition_stats + ) + }) + .collect::>() + .join(",") }) .collect::>() - .join(","); + .join("\n"); write!(f, "ShuffleReaderExec: partition_locations={}", loc_str) } } } } + +async fn fetch_partition( + location: &PartitionLocation, +) -> Result>> { + let metadata = &location.executor_meta; + let partition_id = &location.partition_id; + let mut ballista_client = + BallistaClient::try_new(metadata.host.as_str(), metadata.port as u16) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + Ok(ballista_client + .fetch_partition( + &partition_id.job_id, + partition_id.stage_id as usize, + partition_id.partition_id as usize, + ) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?) +} diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index d49d53cf8d85..a2c9db9ecafb 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -25,6 +25,7 @@ use crate::error::BallistaError; use crate::execution_plans::{ShuffleReaderExec, UnresolvedShuffleExec}; use crate::serde::protobuf::repartition_exec_node::PartitionMethod; use crate::serde::protobuf::LogicalExprNode; +use crate::serde::protobuf::ShuffleReaderPartition; use crate::serde::scheduler::PartitionLocation; use crate::serde::{proto_error, protobuf}; use crate::{convert_box_required, convert_required}; @@ -327,10 +328,15 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } PhysicalPlanType::ShuffleReader(shuffle_reader) => { let schema = Arc::new(convert_required!(shuffle_reader.schema)?); - let partition_location: Vec = shuffle_reader - .partition_location + let partition_location: Vec> = shuffle_reader + .partition .iter() - .map(|p| p.clone().try_into()) + .map(|p| { + p.location + .iter() + .map(|l| l.clone().try_into()) + .collect::, _>>() + }) .collect::, BallistaError>>()?; let shuffle_reader = ShuffleReaderExec::try_new(partition_location, schema)?; diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 26092e74a096..15d5d4b931ff 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -57,6 +57,7 @@ use protobuf::physical_plan_node::PhysicalPlanType; use crate::execution_plans::{ShuffleReaderExec, UnresolvedShuffleExec}; use crate::serde::protobuf::repartition_exec_node::PartitionMethod; +use crate::serde::scheduler::PartitionLocation; use crate::serde::{protobuf, BallistaError}; use datafusion::physical_plan::functions::{BuiltinScalarFunction, ScalarFunctionExpr}; use datafusion::physical_plan::merge::MergeExec; @@ -268,16 +269,19 @@ impl TryInto for Arc { )), }) } else if let Some(exec) = plan.downcast_ref::() { - let partition_location = exec - .partition_location - .iter() - .map(|l| l.clone().try_into()) - .collect::>()?; - + let mut partition = vec![]; + for location in &exec.partition { + partition.push(protobuf::ShuffleReaderPartition { + location: location + .iter() + .map(|l| l.clone().try_into()) + .collect::, _>>()?, + }); + } Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ShuffleReader( protobuf::ShuffleReaderExecNode { - partition_location, + partition, schema: Some(exec.schema().as_ref().into()), }, )), diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index 4ba6ec40fec9..b58be2800f7b 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -27,11 +27,12 @@ use crate::execution_plans::{QueryStageExec, UnresolvedShuffleExec}; use crate::memory_stream::MemoryStream; use crate::serde::scheduler::PartitionStats; +use datafusion::arrow::error::Result as ArrowResult; use datafusion::arrow::{ array::{ ArrayBuilder, ArrayRef, StructArray, StructBuilder, UInt64Array, UInt64Builder, }, - datatypes::{DataType, Field}, + datatypes::{DataType, Field, SchemaRef}, ipc::reader::FileReader, ipc::writer::FileWriter, record_batch::RecordBatch, @@ -54,7 +55,7 @@ use datafusion::physical_plan::sort::SortExec; use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, PhysicalExpr, RecordBatchStream, }; -use futures::StreamExt; +use futures::{future, Stream, StreamExt}; /// Stream data to disk in Arrow IPC format @@ -234,3 +235,38 @@ pub fn create_datafusion_context() -> ExecutionContext { .with_physical_optimizer_rules(rules); ExecutionContext::with_config(config) } + +pub struct WrappedStream { + stream: Pin> + Send + Sync>>, + schema: SchemaRef, +} + +impl WrappedStream { + pub fn new( + stream: Pin> + Send + Sync>>, + schema: SchemaRef, + ) -> Self { + Self { stream, schema } + } +} + +impl RecordBatchStream for WrappedStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for WrappedStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.stream.poll_next_unpin(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index 445ef9a07787..2ac9f6121e00 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -186,7 +186,7 @@ impl DistributedPlanner { pub fn remove_unresolved_shuffles( stage: &dyn ExecutionPlan, - partition_locations: &HashMap>, + partition_locations: &HashMap>>, ) -> Result> { let mut new_children: Vec> = vec![]; for child in stage.children() { diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index a15efd618ff1..506fd1c0db98 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -234,7 +234,7 @@ impl SchedulerState { let unresolved_shuffles = find_unresolved_shuffles(&plan)?; let mut partition_locations: HashMap< usize, - Vec, + Vec>, > = HashMap::new(); for unresolved_shuffle in unresolved_shuffles { for stage_id in unresolved_shuffle.query_stage_ids { @@ -256,7 +256,7 @@ impl SchedulerState { let empty = vec![]; let locations = partition_locations.entry(stage_id).or_insert(empty); - locations.push( + locations.push(vec![ ballista_core::serde::scheduler::PartitionLocation { partition_id: ballista_core::serde::scheduler::PartitionId { @@ -271,7 +271,7 @@ impl SchedulerState { .clone(), partition_stats: PartitionStats::default(), }, - ); + ]); } else { continue 'tasks; } From 519698a0cd792a9c263d96079d341816f746c6ec Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 13 Jun 2021 19:12:26 +0800 Subject: [PATCH 19/25] Refactor hash aggregates's planner building code (#539) * refactor hash aggregates * remove stale comments --- datafusion/src/physical_plan/mod.rs | 5 +-- datafusion/src/physical_plan/planner.rs | 54 +++++++++++-------------- 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index af6969c43cbd..ebc6fd6ce94a 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -341,9 +341,8 @@ pub async fn collect_partitioned( pub enum Partitioning { /// Allocate batches using a round-robin algorithm and the specified number of partitions RoundRobinBatch(usize), - /// Allocate rows based on a hash of one of more expressions and the specified - /// number of partitions - /// This partitioning scheme is not yet fully supported. See [ARROW-11011](https://issues.apache.org/jira/browse/ARROW-11011) + /// Allocate rows based on a hash of one of more expressions and the specified number of + /// partitions Hash(Vec>, usize), /// Unknown partitioning scheme with a known number of partitions UnknownPartitioning(usize), diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index d42948a8666c..adae9224a19a 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -222,11 +222,15 @@ impl DefaultPhysicalPlanner { .flat_map(|x| x.0.data_type(physical_input_schema.as_ref())) .any(|x| matches!(x, DataType::Dictionary(_, _))); - if !groups.is_empty() + let can_repartition = !groups.is_empty() && ctx_state.config.concurrency > 1 && ctx_state.config.repartition_aggregations - && !contains_dict - { + && !contains_dict; + + let (initial_aggr, next_partition_mode): ( + Arc, + AggregateMode, + ) = if can_repartition { // Divide partial hash aggregates into multiple partitions by hash key let hash_repartition = Arc::new(RepartitionExec::try_new( initial_aggr, @@ -235,35 +239,25 @@ impl DefaultPhysicalPlanner { ctx_state.config.concurrency, ), )?); - - // Combine hashaggregates within the partition - Ok(Arc::new(HashAggregateExec::try_new( - AggregateMode::FinalPartitioned, - final_group - .iter() - .enumerate() - .map(|(i, expr)| (expr.clone(), groups[i].1.clone())) - .collect(), - aggregates, - hash_repartition, - input_schema, - )?)) + // Combine hash aggregates within the partition + (hash_repartition, AggregateMode::FinalPartitioned) } else { - // construct a second aggregation, keeping the final column name equal to the first aggregation - // and the expressions corresponding to the respective aggregate + // construct a second aggregation, keeping the final column name equal to the + // first aggregation and the expressions corresponding to the respective aggregate + (initial_aggr, AggregateMode::Final) + }; - Ok(Arc::new(HashAggregateExec::try_new( - AggregateMode::Final, - final_group - .iter() - .enumerate() - .map(|(i, expr)| (expr.clone(), groups[i].1.clone())) - .collect(), - aggregates, - initial_aggr, - input_schema, - )?)) - } + Ok(Arc::new(HashAggregateExec::try_new( + next_partition_mode, + final_group + .iter() + .enumerate() + .map(|(i, expr)| (expr.clone(), groups[i].1.clone())) + .collect(), + aggregates, + initial_aggr, + input_schema, + )?)) } LogicalPlan::Projection { input, expr, .. } => { let input_exec = self.create_initial_plan(input, ctx_state)?; From 738f13b39de21224396ab447572d9ef573d06bc8 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 13 Jun 2021 19:13:04 +0800 Subject: [PATCH 20/25] turn on clippy rule for needless borrow (#545) * turn on clippy rule for needless borrow * do a format round * use warn not deny --- .../core/src/execution_plans/query_stage.rs | 2 +- .../core/src/serde/logical_plan/to_proto.rs | 4 +- ballista/rust/executor/src/flight_service.rs | 2 +- ballista/rust/scheduler/src/state/mod.rs | 2 +- benchmarks/src/bin/tpch.rs | 16 ++--- datafusion/benches/aggregate_query_sql.rs | 2 +- datafusion/benches/filter_query_sql.rs | 2 +- datafusion/benches/math_query_sql.rs | 2 +- datafusion/benches/sort_limit_query_sql.rs | 2 +- datafusion/src/datasource/csv.rs | 2 +- datafusion/src/datasource/json.rs | 2 +- datafusion/src/execution/context.rs | 8 +-- datafusion/src/execution/dataframe_impl.rs | 2 +- datafusion/src/lib.rs | 2 +- datafusion/src/logical_plan/dfschema.rs | 4 +- datafusion/src/logical_plan/plan.rs | 32 ++++----- datafusion/src/optimizer/filter_push_down.rs | 20 +++--- .../src/optimizer/projection_push_down.rs | 14 ++-- .../src/optimizer/simplify_expressions.rs | 4 +- datafusion/src/physical_optimizer/pruning.rs | 15 ++--- datafusion/src/physical_plan/aggregates.rs | 2 +- .../src/physical_plan/expressions/case.rs | 2 +- .../physical_plan/expressions/row_number.rs | 2 +- datafusion/src/physical_plan/filter.rs | 2 +- datafusion/src/physical_plan/functions.rs | 6 +- .../src/physical_plan/hash_aggregate.rs | 8 +-- datafusion/src/physical_plan/hash_join.rs | 24 +++---- datafusion/src/physical_plan/planner.rs | 26 ++++---- datafusion/src/physical_plan/projection.rs | 2 +- datafusion/src/physical_plan/repartition.rs | 2 +- .../physical_plan/sort_preserving_merge.rs | 2 +- .../src/physical_plan/string_expressions.rs | 6 +- datafusion/src/physical_plan/type_coercion.rs | 6 +- datafusion/src/physical_plan/windows.rs | 6 +- datafusion/src/sql/planner.rs | 66 +++++++++---------- datafusion/src/sql/utils.rs | 4 +- datafusion/tests/sql.rs | 28 ++++---- 37 files changed, 160 insertions(+), 173 deletions(-) diff --git a/ballista/rust/core/src/execution_plans/query_stage.rs b/ballista/rust/core/src/execution_plans/query_stage.rs index 233dee5b9b52..264c44dc43dc 100644 --- a/ballista/rust/core/src/execution_plans/query_stage.rs +++ b/ballista/rust/core/src/execution_plans/query_stage.rs @@ -139,7 +139,7 @@ impl ExecutionPlan for QueryStageExec { info!("Writing results to {}", path); // stream results to disk - let stats = utils::write_stream_to_disk(&mut stream, &path) + let stats = utils::write_stream_to_disk(&mut stream, path) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 5d996843d624..c454d03257f0 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1033,9 +1033,7 @@ impl TryInto for &Expr { .map(|e| e.try_into()) .collect::, _>>()?; let window_frame = window_frame.map(|window_frame| { - protobuf::window_expr_node::WindowFrame::Frame( - window_frame.clone().into(), - ) + protobuf::window_expr_node::WindowFrame::Frame(window_frame.into()) }); let window_expr = Box::new(protobuf::WindowExprNode { expr: Some(Box::new(arg.try_into()?)), diff --git a/ballista/rust/executor/src/flight_service.rs b/ballista/rust/executor/src/flight_service.rs index d4eb1229c294..99424b6e8db4 100644 --- a/ballista/rust/executor/src/flight_service.rs +++ b/ballista/rust/executor/src/flight_service.rs @@ -279,7 +279,7 @@ fn create_flight_iter( options: &IpcWriteOptions, ) -> Box>> { let (flight_dictionaries, flight_batch) = - arrow_flight::utils::flight_data_from_arrow_batch(batch, &options); + arrow_flight::utils::flight_data_from_arrow_batch(batch, options); Box::new( flight_dictionaries .into_iter() diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index 506fd1c0db98..75f1574ef125 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -223,7 +223,7 @@ impl SchedulerState { .collect(); let executors = self.get_executors_metadata().await?; 'tasks: for (_key, value) in kvs.iter() { - let mut status: TaskStatus = decode_protobuf(&value)?; + let mut status: TaskStatus = decode_protobuf(value)?; if status.status.is_none() { let partition = status.partition_id.as_ref().unwrap(); let plan = self diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 9ac66e136dbd..34b8d3a27b19 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -350,7 +350,7 @@ async fn execute_query( if debug { println!("Logical plan:\n{:?}", plan); } - let plan = ctx.optimize(&plan)?; + let plan = ctx.optimize(plan)?; if debug { println!("Optimized logical plan:\n{:?}", plan); } @@ -921,9 +921,9 @@ mod tests { .iter() .map(|field| { Field::new( - Field::name(&field), + Field::name(field), DataType::Utf8, - Field::is_nullable(&field), + Field::is_nullable(field), ) }) .collect::>(), @@ -939,8 +939,8 @@ mod tests { .iter() .map(|field| { Field::new( - Field::name(&field), - Field::data_type(&field).to_owned(), + Field::name(field), + Field::data_type(field).to_owned(), true, ) }) @@ -990,10 +990,10 @@ mod tests { .map(|field| { Expr::Alias( Box::new(Cast { - expr: Box::new(trim(col(Field::name(&field)))), - data_type: Field::data_type(&field).to_owned(), + expr: Box::new(trim(col(Field::name(field)))), + data_type: Field::data_type(field).to_owned(), }), - Field::name(&field).to_string(), + Field::name(field).to_string(), ) }) .collect::>(), diff --git a/datafusion/benches/aggregate_query_sql.rs b/datafusion/benches/aggregate_query_sql.rs index 8f1a97e198d3..74798ae572cd 100644 --- a/datafusion/benches/aggregate_query_sql.rs +++ b/datafusion/benches/aggregate_query_sql.rs @@ -47,7 +47,7 @@ fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query - let df = ctx.lock().unwrap().sql(&sql).unwrap(); + let df = ctx.lock().unwrap().sql(sql).unwrap(); rt.block_on(df.collect()).unwrap(); } diff --git a/datafusion/benches/filter_query_sql.rs b/datafusion/benches/filter_query_sql.rs index 8600bdc88c6a..253ef455f5af 100644 --- a/datafusion/benches/filter_query_sql.rs +++ b/datafusion/benches/filter_query_sql.rs @@ -28,7 +28,7 @@ use std::sync::Arc; async fn query(ctx: &mut ExecutionContext, sql: &str) { // execute the query - let df = ctx.sql(&sql).unwrap(); + let df = ctx.sql(sql).unwrap(); let results = df.collect().await.unwrap(); // display the relation diff --git a/datafusion/benches/math_query_sql.rs b/datafusion/benches/math_query_sql.rs index 1aaa2d3403cf..51e52e8acddb 100644 --- a/datafusion/benches/math_query_sql.rs +++ b/datafusion/benches/math_query_sql.rs @@ -40,7 +40,7 @@ fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query - let df = ctx.lock().unwrap().sql(&sql).unwrap(); + let df = ctx.lock().unwrap().sql(sql).unwrap(); rt.block_on(df.collect()).unwrap(); } diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs index 1e8339ea31eb..5a875d3d8799 100644 --- a/datafusion/benches/sort_limit_query_sql.rs +++ b/datafusion/benches/sort_limit_query_sql.rs @@ -35,7 +35,7 @@ fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query - let df = ctx.lock().unwrap().sql(&sql).unwrap(); + let df = ctx.lock().unwrap().sql(sql).unwrap(); rt.block_on(df.collect()).unwrap(); } diff --git a/datafusion/src/datasource/csv.rs b/datafusion/src/datasource/csv.rs index e1a61595f2ee..906a1ce415f6 100644 --- a/datafusion/src/datasource/csv.rs +++ b/datafusion/src/datasource/csv.rs @@ -204,7 +204,7 @@ impl TableProvider for CsvFile { } } Source::Path(p) => { - CsvExec::try_new(&p, opts, projection.clone(), batch_size, limit)? + CsvExec::try_new(p, opts, projection.clone(), batch_size, limit)? } }; Ok(Arc::new(exec)) diff --git a/datafusion/src/datasource/json.rs b/datafusion/src/datasource/json.rs index f916f6c1e382..90fedfd6f528 100644 --- a/datafusion/src/datasource/json.rs +++ b/datafusion/src/datasource/json.rs @@ -149,7 +149,7 @@ impl TableProvider for NdJsonFile { } } Source::Path(p) => { - NdJsonExec::try_new(&p, opts, projection.clone(), batch_size, limit)? + NdJsonExec::try_new(p, opts, projection.clone(), batch_size, limit)? } }; Ok(Arc::new(exec)) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 950ba2b88691..f09d7f4f90c9 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -275,7 +275,7 @@ impl ExecutionContext { ) -> Result> { Ok(Arc::new(DataFrameImpl::new( self.state.clone(), - &LogicalPlanBuilder::scan_csv(&filename, options, None)?.build()?, + &LogicalPlanBuilder::scan_csv(filename, options, None)?.build()?, ))) } @@ -284,7 +284,7 @@ impl ExecutionContext { Ok(Arc::new(DataFrameImpl::new( self.state.clone(), &LogicalPlanBuilder::scan_parquet( - &filename, + filename, None, self.state.lock().unwrap().config.concurrency, )? @@ -328,7 +328,7 @@ impl ExecutionContext { /// executed against this context. pub fn register_parquet(&mut self, name: &str, filename: &str) -> Result<()> { let table = ParquetTable::try_new( - &filename, + filename, self.state.lock().unwrap().config.concurrency, )?; self.register_table(name, Arc::new(table))?; @@ -3205,7 +3205,7 @@ mod tests { .expect("Executing CREATE EXTERNAL TABLE"); let sql = "SELECT * from csv_with_timestamps"; - let result = plan_and_collect(&mut ctx, &sql).await.unwrap(); + let result = plan_and_collect(&mut ctx, sql).await.unwrap(); let expected = vec![ "+--------+-------------------------+", "| name | ts |", diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 19f71eb79268..a674e3cdb0f1 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -373,7 +373,7 @@ mod tests { ctx.register_csv( "aggregate_test_100", &format!("{}/csv/aggregate_test_100.csv", testdata), - CsvReadOptions::new().schema(&schema.as_ref()), + CsvReadOptions::new().schema(schema.as_ref()), )?; Ok(()) } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index e4501a78ada4..64cc0a1349a2 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -14,7 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -#![warn(missing_docs)] +#![warn(missing_docs, clippy::needless_borrow)] // Clippy lints, some should be disabled incrementally #![allow( clippy::float_cmp, diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 9adb22b43d07..5a9167e58b05 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -325,12 +325,12 @@ impl DFField { /// Returns an immutable reference to the `DFField`'s unqualified name pub fn name(&self) -> &String { - &self.field.name() + self.field.name() } /// Returns an immutable reference to the `DFField`'s data-type pub fn data_type(&self) -> &DataType { - &self.field.data_type() + self.field.data_type() } /// Indicates whether this `DFField` supports null values diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 3344dce1d81d..a80bc54b4a2f 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -221,23 +221,23 @@ impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { match self { - LogicalPlan::EmptyRelation { schema, .. } => &schema, + LogicalPlan::EmptyRelation { schema, .. } => schema, LogicalPlan::TableScan { projected_schema, .. - } => &projected_schema, - LogicalPlan::Projection { schema, .. } => &schema, + } => projected_schema, + LogicalPlan::Projection { schema, .. } => schema, LogicalPlan::Filter { input, .. } => input.schema(), - LogicalPlan::Window { schema, .. } => &schema, - LogicalPlan::Aggregate { schema, .. } => &schema, + LogicalPlan::Window { schema, .. } => schema, + LogicalPlan::Aggregate { schema, .. } => schema, LogicalPlan::Sort { input, .. } => input.schema(), - LogicalPlan::Join { schema, .. } => &schema, - LogicalPlan::CrossJoin { schema, .. } => &schema, + LogicalPlan::Join { schema, .. } => schema, + LogicalPlan::CrossJoin { schema, .. } => schema, LogicalPlan::Repartition { input, .. } => input.schema(), LogicalPlan::Limit { input, .. } => input.schema(), - LogicalPlan::CreateExternalTable { schema, .. } => &schema, - LogicalPlan::Explain { schema, .. } => &schema, - LogicalPlan::Extension { node } => &node.schema(), - LogicalPlan::Union { schema, .. } => &schema, + LogicalPlan::CreateExternalTable { schema, .. } => schema, + LogicalPlan::Explain { schema, .. } => schema, + LogicalPlan::Extension { node } => node.schema(), + LogicalPlan::Union { schema, .. } => schema, } } @@ -246,12 +246,12 @@ impl LogicalPlan { match self { LogicalPlan::TableScan { projected_schema, .. - } => vec![&projected_schema], + } => vec![projected_schema], LogicalPlan::Window { input, schema, .. } | LogicalPlan::Aggregate { input, schema, .. } | LogicalPlan::Projection { input, schema, .. } => { let mut schemas = input.all_schemas(); - schemas.insert(0, &schema); + schemas.insert(0, schema); schemas } LogicalPlan::Join { @@ -267,16 +267,16 @@ impl LogicalPlan { } => { let mut schemas = left.all_schemas(); schemas.extend(right.all_schemas()); - schemas.insert(0, &schema); + schemas.insert(0, schema); schemas } LogicalPlan::Union { schema, .. } => { vec![schema] } - LogicalPlan::Extension { node } => vec![&node.schema()], + LogicalPlan::Extension { node } => vec![node.schema()], LogicalPlan::Explain { schema, .. } | LogicalPlan::EmptyRelation { schema, .. } - | LogicalPlan::CreateExternalTable { schema, .. } => vec![&schema], + | LogicalPlan::CreateExternalTable { schema, .. } => vec![schema], LogicalPlan::Limit { input, .. } | LogicalPlan::Repartition { input, .. } | LogicalPlan::Sort { input, .. } diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index 4b1ae76927b4..85d1f812f41a 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -137,7 +137,7 @@ fn get_join_predicates<'a>( let all_in_right = right.len() == columns.len(); !all_in_left && !all_in_right }) - .map(|((ref a, ref b), _)| (a, b)) + .map(|((a, b), _)| (a, b)) .unzip(); (pushable_to_left, pushable_to_right, keep) } @@ -151,7 +151,7 @@ fn push_down(state: &State, plan: &LogicalPlan) -> Result { .collect::>>()?; let expr = plan.expressions(); - utils::from_plan(&plan, &expr, &new_inputs) + utils::from_plan(plan, &expr, &new_inputs) } /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with @@ -225,8 +225,8 @@ fn split_members<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>) { op: Operator::And, left, } => { - split_members(&left, predicates); - split_members(&right, predicates); + split_members(left, predicates); + split_members(right, predicates); } other => predicates.push(other), } @@ -297,7 +297,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { // optimize inner let new_input = optimize(input, state)?; - utils::from_plan(&plan, &expr, &[new_input]) + utils::from_plan(plan, expr, &[new_input]) } LogicalPlan::Aggregate { input, aggr_expr, .. @@ -335,7 +335,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { LogicalPlan::Join { left, right, .. } | LogicalPlan::CrossJoin { left, right, .. } => { let (pushable_to_left, pushable_to_right, keep) = - get_join_predicates(&state, &left.schema(), &right.schema()); + get_join_predicates(&state, left.schema(), right.schema()); let mut left_state = state.clone(); left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); @@ -347,7 +347,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { // create a new Join with the new `left` and `right` let expr = plan.expressions(); - let plan = utils::from_plan(&plan, &expr, &[left, right])?; + let plan = utils::from_plan(plan, &expr, &[left, right])?; if keep.0.is_empty() { Ok(plan) @@ -437,11 +437,11 @@ impl FilterPushDown { /// replaces columns by its name on the projection. fn rewrite(expr: &Expr, projection: &HashMap) -> Result { - let expressions = utils::expr_sub_expressions(&expr)?; + let expressions = utils::expr_sub_expressions(expr)?; let expressions = expressions .iter() - .map(|e| rewrite(e, &projection)) + .map(|e| rewrite(e, projection)) .collect::>>()?; if let Expr::Column(name) = expr { @@ -450,7 +450,7 @@ fn rewrite(expr: &Expr, projection: &HashMap) -> Result { } } - utils::rewrite_expression(&expr, &expressions) + utils::rewrite_expression(expr, &expressions) } #[cfg(test)] diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index f0b364ab9852..ad795f5f5dd5 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -146,7 +146,7 @@ fn optimize_plan( let new_input = optimize_plan( optimizer, - &input, + input, &new_required_columns, true, execution_props, @@ -176,14 +176,14 @@ fn optimize_plan( Ok(LogicalPlan::Join { left: Arc::new(optimize_plan( optimizer, - &left, + left, &new_required_columns, true, execution_props, )?), right: Arc::new(optimize_plan( optimizer, - &right, + right, &new_required_columns, true, execution_props, @@ -204,7 +204,7 @@ fn optimize_plan( let mut new_window_expr = Vec::new(); { window_expr.iter().try_for_each(|expr| { - let name = &expr.name(&schema)?; + let name = &expr.name(schema)?; if required_columns.contains(name) { new_window_expr.push(expr.clone()); new_required_columns.insert(name.clone()); @@ -235,7 +235,7 @@ fn optimize_plan( window_expr: new_window_expr, input: Arc::new(optimize_plan( optimizer, - &input, + input, &new_required_columns, true, execution_props, @@ -259,7 +259,7 @@ fn optimize_plan( // Gather all columns needed for expressions in this Aggregate let mut new_aggr_expr = Vec::new(); aggr_expr.iter().try_for_each(|expr| { - let name = &expr.name(&schema)?; + let name = &expr.name(schema)?; if required_columns.contains(name) { new_aggr_expr.push(expr.clone()); @@ -286,7 +286,7 @@ fn optimize_plan( aggr_expr: new_aggr_expr, input: Arc::new(optimize_plan( optimizer, - &input, + input, &new_required_columns, true, execution_props, diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 0697d689c401..9ad7a94d8bfe 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -248,7 +248,7 @@ fn simplify(expr: &Expr) -> Expr { }) .unwrap_or_else(|| expr.clone()), Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { - left: Box::new(simplify(&left)), + left: Box::new(simplify(left)), op: *op, right: Box::new(simplify(right)), }, @@ -267,7 +267,7 @@ fn optimize(plan: &LogicalPlan) -> Result { .into_iter() .map(|x| simplify(&x)) .collect::>(); - utils::from_plan(&plan, &expr, &new_inputs) + utils::from_plan(plan, &expr, &new_inputs) } impl OptimizerRule for SimplifyExpressions { diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index c65733bd7526..da82d53871a8 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -420,7 +420,7 @@ impl<'a> PruningExpressionBuilder<'a> { fn min_column_expr(&mut self) -> Result { self.required_columns.min_column_expr( &self.column_name, - &self.column_expr, + self.column_expr, self.field, ) } @@ -428,7 +428,7 @@ impl<'a> PruningExpressionBuilder<'a> { fn max_column_expr(&mut self) -> Result { self.required_columns.max_column_expr( &self.column_name, - &self.column_expr, + self.column_expr, self.field, ) } @@ -440,7 +440,7 @@ fn rewrite_column_expr( column_old_name: &str, column_new_name: &str, ) -> Result { - let expressions = utils::expr_sub_expressions(&expr)?; + let expressions = utils::expr_sub_expressions(expr)?; let expressions = expressions .iter() .map(|e| rewrite_column_expr(e, column_old_name, column_new_name)) @@ -451,7 +451,7 @@ fn rewrite_column_expr( return Ok(Expr::Column(column_new_name.to_string())); } } - utils::rewrite_expression(&expr, &expressions) + utils::rewrite_expression(expr, &expressions) } /// Given a column reference to `column_name`, returns a pruning @@ -515,16 +515,15 @@ fn build_predicate_expression( let (left, op, right) = match expr { Expr::BinaryExpr { left, op, right } => (left, *op, right), Expr::Column(name) => { - let expr = build_single_column_expr(&name, schema, required_columns, false) + let expr = build_single_column_expr(name, schema, required_columns, false) .unwrap_or(unhandled); return Ok(expr); } // match !col (don't do so recursively) Expr::Not(input) => { if let Expr::Column(name) = input.as_ref() { - let expr = - build_single_column_expr(&name, schema, required_columns, true) - .unwrap_or(unhandled); + let expr = build_single_column_expr(name, schema, required_columns, true) + .unwrap_or(unhandled); return Ok(expr); } else { return Ok(unhandled); diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 60025a316228..897c78fd46ff 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -127,7 +127,7 @@ pub fn create_aggregate_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; - let return_type = return_type(&fun, &arg_types)?; + let return_type = return_type(fun, &arg_types)?; Ok(match (fun, distinct) { (AggregateFunction::Count, false) => { diff --git a/datafusion/src/physical_plan/expressions/case.rs b/datafusion/src/physical_plan/expressions/case.rs index 95ae5325af11..f89ea8d1e296 100644 --- a/datafusion/src/physical_plan/expressions/case.rs +++ b/datafusion/src/physical_plan/expressions/case.rs @@ -377,7 +377,7 @@ impl CaseExpr { let then_value = then_value.into_array(batch.num_rows()); current_value = Some(if_then_else( - &when_value, + when_value, then_value, current_value.unwrap(), &return_type, diff --git a/datafusion/src/physical_plan/expressions/row_number.rs b/datafusion/src/physical_plan/expressions/row_number.rs index f399995461f7..eaf9b21cbc64 100644 --- a/datafusion/src/physical_plan/expressions/row_number.rs +++ b/datafusion/src/physical_plan/expressions/row_number.rs @@ -49,7 +49,7 @@ impl BuiltInWindowFunctionExpr for RowNumber { fn field(&self) -> Result { let nullable = false; let data_type = DataType::UInt64; - Ok(Field::new(&self.name(), data_type, nullable)) + Ok(Field::new(self.name(), data_type, nullable)) } fn expressions(&self) -> Vec> { diff --git a/datafusion/src/physical_plan/filter.rs b/datafusion/src/physical_plan/filter.rs index bc2b17aa4f47..0a8c825aba1a 100644 --- a/datafusion/src/physical_plan/filter.rs +++ b/datafusion/src/physical_plan/filter.rs @@ -151,7 +151,7 @@ fn batch_filter( predicate: &Arc, ) -> ArrowResult { predicate - .evaluate(&batch) + .evaluate(batch) .map(|v| v.into_array(batch.num_rows())) .map_err(DataFusionError::into_arrow_external_error) .and_then(|array| { diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index eb312cabd7f0..49ca79a00496 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -344,7 +344,7 @@ pub fn return_type( // or the execution panics. // verify that this is a valid set of data types for this function - data_types(&arg_types, &signature(fun))?; + data_types(arg_types, &signature(fun))?; // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. @@ -624,7 +624,7 @@ pub fn create_physical_expr( &format!("{}", fun), fun_expr, args, - &return_type(&fun, &arg_types)?, + &return_type(fun, &arg_types)?, ))); } BuiltinScalarFunction::InitCap => |args| match args[0].data_type() { @@ -953,7 +953,7 @@ pub fn create_physical_expr( &format!("{}", fun), fun_expr, args, - &return_type(&fun, &arg_types)?, + &return_type(fun, &arg_types)?, ))) } diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index ffb51b2e8a1f..453d500e98bd 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -120,8 +120,8 @@ fn create_schema( for (expr, name) in group_expr { fields.push(Field::new( name, - expr.data_type(&input_schema)?, - expr.nullable(&input_schema)?, + expr.data_type(input_schema)?, + expr.nullable(input_schema)?, )) } @@ -413,7 +413,7 @@ fn group_aggregate_batch( let mut offset_so_far = 0; for key in batch_keys.iter() { let (_, _, indices) = accumulators.get_mut(key).unwrap(); - batch_indices.append_slice(&indices)?; + batch_indices.append_slice(indices)?; offset_so_far += indices.len(); offsets.push(offset_so_far); } @@ -779,7 +779,7 @@ fn evaluate( batch: &RecordBatch, ) -> Result> { expr.iter() - .map(|expr| expr.evaluate(&batch)) + .map(|expr| expr.evaluate(batch)) .map(|r| r.map(|v| v.into_array(batch.num_rows()))) .collect::>>() } diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index d12e249cbe34..1b0322b521a5 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -133,13 +133,13 @@ impl HashJoinExec { ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); - check_join_is_valid(&left_schema, &right_schema, &on)?; + check_join_is_valid(&left_schema, &right_schema, on)?; let schema = Arc::new(build_join_schema( &left_schema, &right_schema, on, - &join_type, + join_type, )); let on = on @@ -289,7 +289,7 @@ impl ExecutionPlan for HashJoinExec { hashes_buffer.resize(batch.num_rows(), 0); update_hash( &on_left, - &batch, + batch, &mut hashmap, offset, &self.random_state, @@ -342,7 +342,7 @@ impl ExecutionPlan for HashJoinExec { hashes_buffer.resize(batch.num_rows(), 0); update_hash( &on_left, - &batch, + batch, &mut hashmap, offset, &self.random_state, @@ -436,7 +436,7 @@ fn update_hash( .collect::>>()?; // calculate the hash values - let hash_values = create_hashes(&keys_values, &random_state, hashes_buffer)?; + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; // insert hashes to key of the hashmap for (row, hash_value) in hash_values.iter().enumerate() { @@ -538,15 +538,9 @@ fn build_batch( column_indices: &[ColumnIndex], random_state: &RandomState, ) -> ArrowResult<(RecordBatch, UInt64Array)> { - let (left_indices, right_indices) = build_join_indexes( - &left_data, - &batch, - join_type, - on_left, - on_right, - random_state, - ) - .unwrap(); + let (left_indices, right_indices) = + build_join_indexes(left_data, batch, join_type, on_left, on_right, random_state) + .unwrap(); if matches!(join_type, JoinType::Semi | JoinType::Anti) { return Ok(( @@ -613,7 +607,7 @@ fn build_join_indexes( }) .collect::>>()?; let hashes_buffer = &mut vec![0; keys_values[0].len()]; - let hash_values = create_hashes(&keys_values, &random_state, hashes_buffer)?; + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; let left = &left_data.0; match join_type { diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index adae9224a19a..31b3749dd354 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -155,7 +155,7 @@ impl DefaultPhysicalPlanner { .map(|e| { self.create_window_expr( e, - &logical_input_schema, + logical_input_schema, &physical_input_schema, ctx_state, ) @@ -189,7 +189,7 @@ impl DefaultPhysicalPlanner { &physical_input_schema, ctx_state, ), - e.name(&logical_input_schema), + e.name(logical_input_schema), )) }) .collect::>>()?; @@ -198,7 +198,7 @@ impl DefaultPhysicalPlanner { .map(|e| { self.create_aggregate_expr( e, - &logical_input_schema, + logical_input_schema, &physical_input_schema, ctx_state, ) @@ -266,12 +266,8 @@ impl DefaultPhysicalPlanner { .iter() .map(|e| { tuple_err(( - self.create_physical_expr( - e, - &input_exec.schema(), - &ctx_state, - ), - e.name(&input_schema), + self.create_physical_expr(e, &input_exec.schema(), ctx_state), + e.name(input_schema), )) }) .collect::>>()?; @@ -307,7 +303,7 @@ impl DefaultPhysicalPlanner { let runtime_expr = expr .iter() .map(|e| { - self.create_physical_expr(e, &input_schema, &ctx_state) + self.create_physical_expr(e, &input_schema, ctx_state) }) .collect::>>()?; Partitioning::Hash(runtime_expr, *n) @@ -378,7 +374,7 @@ impl DefaultPhysicalPlanner { right, Partitioning::Hash(right_expr, ctx_state.config.concurrency), )?), - &keys, + keys, &physical_join_type, PartitionMode::Partitioned, )?)) @@ -386,7 +382,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(HashJoinExec::try_new( left, right, - &keys, + keys, &physical_join_type, PartitionMode::CollectLeft, )?)) @@ -504,7 +500,7 @@ impl DefaultPhysicalPlanner { } Expr::Column(name) => { // check that name exists - input_schema.field_with_name(&name)?; + input_schema.field_with_name(name)?; Ok(Arc::new(Column::new(name))) } Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), @@ -762,12 +758,12 @@ impl DefaultPhysicalPlanner { nulls_first, } => self.create_physical_sort_expr( expr, - &physical_input_schema, + physical_input_schema, SortOptions { descending: !*asc, nulls_first: *nulls_first, }, - &ctx_state, + ctx_state, ), _ => Err(DataFusionError::Plan( "Sort only accepts sort expressions".to_string(), diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index c0d78ff7168b..d4c0459c211b 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -166,7 +166,7 @@ fn batch_project( ) -> ArrowResult { expressions .iter() - .map(|expr| expr.evaluate(&batch)) + .map(|expr| expr.evaluate(batch)) .map(|r| r.map(|v| v.into_array(batch.num_rows()))) .collect::>>() .map_or_else( diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 37d98c7d118b..5d1f8d7760cf 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -479,7 +479,7 @@ mod tests { partitions, Partitioning::Hash( vec![Arc::new(crate::physical_plan::expressions::Column::new( - &"c0", + "c0", ))], 8, ), diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sort_preserving_merge.rs index 283294a43ec7..c39acc474d31 100644 --- a/datafusion/src/physical_plan/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sort_preserving_merge.rs @@ -376,7 +376,7 @@ impl SortPreservingMergeStream { match min_cursor { None => min_cursor = Some((idx, candidate)), - Some((_, ref min)) => { + Some((_, min)) => { if min.compare(candidate, &self.sort_options)? == Ordering::Greater { diff --git a/datafusion/src/physical_plan/string_expressions.rs b/datafusion/src/physical_plan/string_expressions.rs index 882fe30502fd..09e19c4dfa47 100644 --- a/datafusion/src/physical_plan/string_expressions.rs +++ b/datafusion/src/physical_plan/string_expressions.rs @@ -299,7 +299,7 @@ pub fn concat(args: &[ColumnarValue]) -> Result { ColumnarValue::Array(v) => { if v.is_valid(index) { let v = v.as_any().downcast_ref::().unwrap(); - owned_string.push_str(&v.value(index)); + owned_string.push_str(v.value(index)); } } _ => unreachable!(), @@ -353,10 +353,10 @@ pub fn concat_ws(args: &[ArrayRef]) -> Result { for arg_index in 1..args.len() { let arg = &args[arg_index]; if !arg.is_null(index) { - owned_string.push_str(&arg.value(index)); + owned_string.push_str(arg.value(index)); // if not last push separator if arg_index != args.len() - 1 { - owned_string.push_str(&sep); + owned_string.push_str(sep); } } } diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index 06d3739b53b2..fe87ecda872c 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -60,7 +60,7 @@ pub fn coerce( expressions .iter() .enumerate() - .map(|(i, expr)| try_cast(expr.clone(), &schema, new_types[i].clone())) + .map(|(i, expr)| try_cast(expr.clone(), schema, new_types[i].clone())) .collect::>>() } @@ -85,7 +85,7 @@ pub fn data_types( } for valid_types in valid_types { - if let Some(types) = maybe_data_types(&valid_types, ¤t_types) { + if let Some(types) = maybe_data_types(&valid_types, current_types) { return Ok(types); } } @@ -157,7 +157,7 @@ fn maybe_data_types( new_type.push(current_type.clone()) } else { // attempt to coerce - if can_coerce_from(valid_type, ¤t_type) { + if can_coerce_from(valid_type, current_type) { new_type.push(valid_type.clone()) } else { // not possible diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 565a9eef2857..f95dd446844d 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -145,7 +145,7 @@ impl WindowExpr for BuiltInWindowExpr { } fn name(&self) -> &str { - &self.window.name() + self.window.name() } fn field(&self) -> Result { @@ -191,7 +191,7 @@ impl WindowExpr for AggregateWindowExpr { } fn name(&self) -> &str { - &self.aggregate.name() + self.aggregate.name() } fn field(&self) -> Result { @@ -351,7 +351,7 @@ fn window_aggregate_batch( .map(|(window_acc, expr)| { let values = &expr .iter() - .map(|e| e.evaluate(&batch)) + .map(|e| e.evaluate(batch)) .map(|r| r.map(|v| v.into_array(batch.num_rows()))) .collect::>>()?; window_acc.scan_batch(batch.num_rows(), values) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 860d21714ec6..7e7462ef390e 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -86,8 +86,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a logical plan from an DataFusion SQL statement pub fn statement_to_plan(&self, statement: &DFStatement) -> Result { match statement { - DFStatement::CreateExternalTable(s) => self.external_table_to_plan(&s), - DFStatement::Statement(s) => self.sql_statement_to_plan(&s), + DFStatement::CreateExternalTable(s) => self.external_table_to_plan(s), + DFStatement::Statement(s) => self.sql_statement_to_plan(s), } } @@ -98,9 +98,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { verbose, statement, analyze: _, - } => self.explain_statement_to_plan(*verbose, &statement), - Statement::Query(query) => self.query_to_plan(&query), - Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable), + } => self.explain_statement_to_plan(*verbose, statement), + Statement::Query(query) => self.query_to_plan(query), + Statement::ShowVariable { variable } => self.show_variable_to_plan(variable), Statement::ShowColumns { extended, full, @@ -232,7 +232,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FileType::NdJson => {} }; - let schema = self.build_schema(&columns)?; + let schema = self.build_schema(columns)?; Ok(LogicalPlan::CreateExternalTable { schema: schema.to_dfschema_ref()?, @@ -250,7 +250,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { verbose: bool, statement: &Statement, ) -> Result { - let plan = self.sql_statement_to_plan(&statement)?; + let plan = self.sql_statement_to_plan(statement)?; let stringified_plans = vec![StringifiedPlan::new( PlanType::LogicalPlan, @@ -370,7 +370,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { left: &LogicalPlan, right: &LogicalPlan, ) -> Result { - LogicalPlanBuilder::from(&left).cross_join(&right)?.build() + LogicalPlanBuilder::from(left).cross_join(right)?.build() } fn parse_join( @@ -383,7 +383,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match constraint { JoinConstraint::On(sql_expr) => { let mut keys: Vec<(String, String)> = vec![]; - let join_schema = left.schema().join(&right.schema())?; + let join_schema = left.schema().join(right.schema())?; // parse ON expression let expr = self.sql_to_rex(sql_expr, &join_schema)?; @@ -396,14 +396,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { keys.iter().map(|pair| pair.1.as_str()).collect(); // return the logical plan representing the join - LogicalPlanBuilder::from(&left) - .join(&right, join_type, &left_keys, &right_keys)? + LogicalPlanBuilder::from(left) + .join(right, join_type, &left_keys, &right_keys)? .build() } JoinConstraint::Using(idents) => { let keys: Vec<&str> = idents.iter().map(|x| x.value.as_str()).collect(); - LogicalPlanBuilder::from(&left) - .join(&right, join_type, &keys, &keys)? + LogicalPlanBuilder::from(left) + .join(right, join_type, &keys, &keys)? .build() } JoinConstraint::Natural => { @@ -472,7 +472,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // build join schema let mut fields = vec![]; for plan in &plans { - fields.extend_from_slice(&plan.schema().fields()); + fields.extend_from_slice(plan.schema().fields()); } let join_schema = DFSchema::new(fields)?; @@ -673,16 +673,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(projection .iter() - .map(|expr| self.sql_select_to_rex(&expr, &input_schema)) + .map(|expr| self.sql_select_to_rex(expr, input_schema)) .collect::>>()? .iter() - .flat_map(|expr| expand_wildcard(&expr, &input_schema)) + .flat_map(|expr| expand_wildcard(expr, input_schema)) .collect::>()) } /// Wrap a plan in a projection fn project(&self, input: &LogicalPlan, expr: Vec) -> Result { - self.validate_schema_satisfies_exprs(&input.schema(), &expr)?; + self.validate_schema_satisfies_exprs(input.schema(), &expr)?; LogicalPlanBuilder::from(input).project(expr)?.build() } @@ -733,7 +733,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .cloned() .collect::>(); - let plan = LogicalPlanBuilder::from(&input) + let plan = LogicalPlanBuilder::from(input) .aggregate(group_by_exprs, aggr_exprs)? .build()?; @@ -784,14 +784,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn limit(&self, input: &LogicalPlan, limit: &Option) -> Result { match *limit { Some(ref limit_expr) => { - let n = match self.sql_to_rex(&limit_expr, &input.schema())? { + let n = match self.sql_to_rex(limit_expr, input.schema())? { Expr::Literal(ScalarValue::Int64(Some(n))) => Ok(n as usize), _ => Err(DataFusionError::Plan( "Unexpected expression for LIMIT clause".to_string(), )), }?; - LogicalPlanBuilder::from(&input).limit(n)?.build() + LogicalPlanBuilder::from(input).limit(n)?.build() } _ => Ok(input.clone()), } @@ -812,7 +812,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|e| self.order_by_to_sort_expr(e)) .collect::>>()?; - LogicalPlanBuilder::from(&plan).sort(order_by_rex)?.build() + LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() } /// convert sql OrderByExpr to Expr::Sort @@ -836,7 +836,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .iter() .try_for_each(|col| match col { Expr::Column(name) => { - schema.field_with_unqualified_name(&name).map_err(|_| { + schema.field_with_unqualified_name(name).map_err(|_| { DataFusionError::Plan(format!( "Invalid identifier '{}' for schema {}", name, @@ -854,7 +854,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match sql { SelectItem::UnnamedExpr(expr) => self.sql_to_rex(expr, schema), SelectItem::ExprWithAlias { expr, alias } => Ok(Alias( - Box::new(self.sql_to_rex(&expr, schema)?), + Box::new(self.sql_to_rex(expr, schema)?), alias.value.clone(), )), SelectItem::Wildcard => Ok(Expr::Wildcard), @@ -977,7 +977,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref expr, ref data_type, } => Ok(Expr::Cast { - expr: Box::new(self.sql_expr_to_logical_expr(&expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(expr)?), data_type: convert_data_type(data_type)?, }), @@ -985,7 +985,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref expr, ref data_type, } => Ok(Expr::TryCast { - expr: Box::new(self.sql_expr_to_logical_expr(&expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(expr)?), data_type: convert_data_type(data_type)?, }), @@ -1040,10 +1040,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref low, ref high, } => Ok(Expr::Between { - expr: Box::new(self.sql_expr_to_logical_expr(&expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(expr)?), negated: *negated, - low: Box::new(self.sql_expr_to_logical_expr(&low)?), - high: Box::new(self.sql_expr_to_logical_expr(&high)?), + low: Box::new(self.sql_expr_to_logical_expr(low)?), + high: Box::new(self.sql_expr_to_logical_expr(high)?), }), SQLExpr::InList { @@ -1057,7 +1057,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()?; Ok(Expr::InList { - expr: Box::new(self.sql_expr_to_logical_expr(&expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(expr)?), list: list_expr, negated: *negated, }) @@ -1091,9 +1091,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }?; Ok(Expr::BinaryExpr { - left: Box::new(self.sql_expr_to_logical_expr(&left)?), + left: Box::new(self.sql_expr_to_logical_expr(left)?), op: operator, - right: Box::new(self.sql_expr_to_logical_expr(&right)?), + right: Box::new(self.sql_expr_to_logical_expr(right)?), }) } @@ -1209,7 +1209,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(&e), + SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(e), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported ast node {:?} in sqltorel", @@ -3167,7 +3167,7 @@ mod tests { fn logical_plan(sql: &str) -> Result { let planner = SqlToRel::new(&MockContextProvider {}); - let result = DFParser::parse_sql(&sql); + let result = DFParser::parse_sql(sql); let ast = result.unwrap(); planner.statement_to_plan(&ast[0]) } diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 5e9b9526ea83..82431c2314ab 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -146,7 +146,7 @@ where pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { match expr { Expr::Column(_) => Ok(expr.clone()), - _ => Ok(Expr::Column(expr.name(&plan.schema())?)), + _ => Ok(Expr::Column(expr.name(plan.schema())?)), } } @@ -448,7 +448,7 @@ fn generate_sort_key(partition_by: &[Expr], order_by: &[Expr]) -> WindowSortKey } }); order_by.iter().for_each(|e| { - if !sort_key.contains(&e) { + if !sort_key.contains(e) { sort_key.push(e.clone()); } }); diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 5ce1884049d8..d9d77648c742 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -130,7 +130,7 @@ async fn parquet_single_nan_schema() { ctx.register_parquet("single_nan", &format!("{}/single_nan.parquet", testdata)) .unwrap(); let sql = "SELECT mycol FROM single_nan"; - let plan = ctx.create_logical_plan(&sql).unwrap(); + let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).unwrap(); let results = collect(plan).await.unwrap(); @@ -165,7 +165,7 @@ async fn parquet_list_columns() { ])); let sql = "SELECT int64_list, utf8_list FROM list_columns"; - let plan = ctx.create_logical_plan(&sql).unwrap(); + let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).unwrap(); let results = collect(plan).await.unwrap(); @@ -647,7 +647,7 @@ async fn csv_query_error() -> Result<()> { let mut ctx = create_ctx()?; register_aggregate_csv(&mut ctx)?; let sql = "SELECT sin(c1) FROM aggregate_test_100"; - let plan = ctx.create_logical_plan(&sql); + let plan = ctx.create_logical_plan(sql); assert!(plan.is_err()); Ok(()) } @@ -748,7 +748,7 @@ async fn csv_query_avg_multi_batch() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT avg(c12) FROM aggregate_test_100"; - let plan = ctx.create_logical_plan(&sql).unwrap(); + let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).unwrap(); let results = collect(plan).await.unwrap(); @@ -1615,7 +1615,7 @@ async fn csv_explain_plans() { // Logical plan // Create plan let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(&sql).expect(&msg); + let plan = ctx.create_logical_plan(sql).expect(&msg); let logical_schema = plan.schema(); // println!("SQL: {}", sql); @@ -1812,7 +1812,7 @@ async fn csv_explain_verbose_plans() { // Logical plan // Create plan let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(&sql).expect(&msg); + let plan = ctx.create_logical_plan(sql).expect(&msg); let logical_schema = plan.schema(); // println!("SQL: {}", sql); @@ -2088,7 +2088,7 @@ fn register_alltypes_parquet(ctx: &mut ExecutionContext) { /// `result[row][column]` async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec> { let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(&sql).expect(&msg); + let plan = ctx.create_logical_plan(sql).expect(&msg); let logical_schema = plan.schema(); let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); @@ -2561,7 +2561,7 @@ async fn query_cte_incorrect() -> Result<()> { // self reference let sql = "WITH t AS (SELECT * FROM t) SELECT * from u"; - let plan = ctx.create_logical_plan(&sql); + let plan = ctx.create_logical_plan(sql); assert!(plan.is_err()); assert_eq!( format!("{}", plan.unwrap_err()), @@ -2570,7 +2570,7 @@ async fn query_cte_incorrect() -> Result<()> { // forward referencing let sql = "WITH t AS (SELECT * FROM u), u AS (SELECT 1) SELECT * from u"; - let plan = ctx.create_logical_plan(&sql); + let plan = ctx.create_logical_plan(sql); assert!(plan.is_err()); assert_eq!( format!("{}", plan.unwrap_err()), @@ -2579,7 +2579,7 @@ async fn query_cte_incorrect() -> Result<()> { // wrapping should hide u let sql = "WITH t AS (WITH u as (SELECT 1) SELECT 1) SELECT * from u"; - let plan = ctx.create_logical_plan(&sql); + let plan = ctx.create_logical_plan(sql); assert!(plan.is_err()); assert_eq!( format!("{}", plan.unwrap_err()), @@ -3326,7 +3326,7 @@ async fn test_cast_expressions_error() -> Result<()> { let mut ctx = create_ctx()?; register_aggregate_csv(&mut ctx)?; let sql = "SELECT CAST(c1 AS INT) FROM aggregate_test_100"; - let plan = ctx.create_logical_plan(&sql).unwrap(); + let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).unwrap(); let result = collect(plan).await; @@ -3355,7 +3355,7 @@ async fn test_physical_plan_display_indent() { GROUP BY c1 \ ORDER BY the_min DESC \ LIMIT 10"; - let plan = ctx.create_logical_plan(&sql).unwrap(); + let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let physical_plan = ctx.create_physical_plan(&plan).unwrap(); @@ -3403,7 +3403,7 @@ async fn test_physical_plan_display_indent_multi_children() { ON c1=c2\ "; - let plan = ctx.create_logical_plan(&sql).unwrap(); + let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let physical_plan = ctx.create_physical_plan(&plan).unwrap(); @@ -3443,7 +3443,7 @@ async fn test_aggregation_with_bad_arguments() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; - let logical_plan = ctx.create_logical_plan(&sql)?; + let logical_plan = ctx.create_logical_plan(sql)?; let physical_plan = ctx.create_physical_plan(&logical_plan); let err = physical_plan.unwrap_err(); assert_eq!(err.to_string(), "Error during planning: Invalid or wrong number of arguments passed to aggregate: 'COUNT(DISTINCT )'"); From 2568323dbd85e05f2bf3e6e484f7cc39983ff26c Mon Sep 17 00:00:00 2001 From: Javier Goday Date: Sun, 13 Jun 2021 13:15:24 +0200 Subject: [PATCH 21/25] #420: Support for not_eq predicate in pruning predicates (#544) --- datafusion/src/physical_optimizer/pruning.rs | 28 ++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index da82d53871a8..a7e1fb00c230 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -552,6 +552,14 @@ fn build_predicate_expression( }; let corrected_op = expr_builder.correct_operator(op); let statistics_expr = match corrected_op { + Operator::NotEq => { + // column != literal => (min, max) = literal => min > literal || literal > max + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + min_column_expr + .gt(expr_builder.scalar_expr().clone()) + .or(expr_builder.scalar_expr().clone().gt(max_column_expr)) + } Operator::Eq => { // column = literal => (min, max) = literal => min <= literal && literal <= max // (column / 2) = 4 => (column_min / 2) <= 4 && 4 <= (column_max / 2) @@ -929,6 +937,26 @@ mod tests { Ok(()) } + #[test] + fn row_group_predicate_not_eq() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let expected_expr = "#c1_min Gt Int32(1) Or Int32(1) Gt #c1_max"; + + // test column on the left + let expr = col("c1").not_eq(lit(1)); + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + // test column on the right + let expr = lit(1).not_eq(col("c1")); + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + #[test] fn row_group_predicate_gt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); From d3828541a61b5681b93590a47e22d63715949136 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 13 Jun 2021 07:34:07 -0400 Subject: [PATCH 22/25] Cleanup Repartition Exec code (#538) * Cleanup RepartitionExec code * cleanup metric handling * Add elapsed_nanos --- datafusion/src/physical_plan/mod.rs | 5 + datafusion/src/physical_plan/repartition.rs | 279 +++++++++++--------- 2 files changed, 157 insertions(+), 127 deletions(-) diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index ebc6fd6ce94a..2dcba802560a 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -107,6 +107,11 @@ impl SQLMetric { self.value.fetch_add(n, Ordering::Relaxed); } + /// Add elapsed nanoseconds since `start`to self + pub fn add_elapsed(&self, start: std::time::Instant) { + self.add(start.elapsed().as_nanos() as usize) + } + /// Get the current value pub fn value(&self) -> usize { self.value.load(Ordering::Relaxed) diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 5d1f8d7760cf..7ef194849074 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -38,7 +38,7 @@ use futures::stream::Stream; use futures::StreamExt; use hashbrown::HashMap; use tokio::sync::{ - mpsc::{UnboundedReceiver, UnboundedSender}, + mpsc::{self, UnboundedReceiver, UnboundedSender}, Mutex, }; use tokio::task::JoinHandle; @@ -60,12 +60,40 @@ pub struct RepartitionExec { HashMap, UnboundedReceiver)>, >, >, + + /// Execution metrics + metrics: RepartitionMetrics, +} + +#[derive(Debug, Clone)] +struct RepartitionMetrics { /// Time in nanos to execute child operator and fetch batches - fetch_time_nanos: Arc, + fetch_nanos: Arc, /// Time in nanos to perform repartitioning - repart_time_nanos: Arc, + repart_nanos: Arc, /// Time in nanos for sending resulting batches to channels - send_time_nanos: Arc, + send_nanos: Arc, +} + +impl RepartitionMetrics { + fn new() -> Self { + Self { + fetch_nanos: SQLMetric::time_nanos(), + repart_nanos: SQLMetric::time_nanos(), + send_nanos: SQLMetric::time_nanos(), + } + } + /// Convert into the external metrics form + fn to_hashmap(&self) -> HashMap { + let mut metrics = HashMap::new(); + metrics.insert("fetchTime".to_owned(), self.fetch_nanos.as_ref().clone()); + metrics.insert( + "repartitionTime".to_owned(), + self.repart_nanos.as_ref().clone(), + ); + metrics.insert("sendTime".to_owned(), self.send_nanos.as_ref().clone()); + metrics + } } impl RepartitionExec { @@ -132,9 +160,8 @@ impl ExecutionPlan for RepartitionExec { // being read yet. This may cause high memory usage if the next operator is // reading output partitions in order rather than concurrently. One workaround // for this would be to add spill-to-disk capabilities. - let (sender, receiver) = tokio::sync::mpsc::unbounded_channel::< - Option>, - >(); + let (sender, receiver) = + mpsc::unbounded_channel::>>(); channels.insert(partition, (sender, receiver)); } // Use fixed random state @@ -142,122 +169,24 @@ impl ExecutionPlan for RepartitionExec { // launch one async task per *input* partition for i in 0..num_input_partitions { - let random_state = random.clone(); - let input = self.input.clone(); - let fetch_time = self.fetch_time_nanos.clone(); - let repart_time = self.repart_time_nanos.clone(); - let send_time = self.send_time_nanos.clone(); let txs: HashMap<_, _> = channels .iter() .map(|(partition, (tx, _rx))| (*partition, tx.clone())) .collect(); - let partitioning = self.partitioning.clone(); - let mut txs_captured = txs.clone(); - let input_task: JoinHandle> = tokio::spawn(async move { - // execute the child operator - let now = Instant::now(); - let mut stream = input.execute(i).await?; - fetch_time.add(now.elapsed().as_nanos() as usize); - let mut counter = 0; - let hashes_buf = &mut vec![]; - - loop { - // fetch the next batch - let now = Instant::now(); - let result = stream.next().await; - fetch_time.add(now.elapsed().as_nanos() as usize); - - if result.is_none() { - break; - } - let result: ArrowResult = result.unwrap(); - - match &partitioning { - Partitioning::RoundRobinBatch(_) => { - let now = Instant::now(); - let output_partition = counter % num_output_partitions; - let tx = txs_captured.get_mut(&output_partition).unwrap(); - tx.send(Some(result)).map_err(|e| { - DataFusionError::Execution(e.to_string()) - })?; - send_time.add(now.elapsed().as_nanos() as usize); - } - Partitioning::Hash(exprs, _) => { - let now = Instant::now(); - let input_batch = result?; - let arrays = exprs - .iter() - .map(|expr| { - Ok(expr - .evaluate(&input_batch)? - .into_array(input_batch.num_rows())) - }) - .collect::>>()?; - hashes_buf.clear(); - hashes_buf.resize(arrays[0].len(), 0); - // Hash arrays and compute buckets based on number of partitions - let hashes = - create_hashes(&arrays, &random_state, hashes_buf)?; - let mut indices = vec![vec![]; num_output_partitions]; - for (index, hash) in hashes.iter().enumerate() { - indices - [(*hash % num_output_partitions as u64) as usize] - .push(index as u64) - } - repart_time.add(now.elapsed().as_nanos() as usize); - for (num_output_partition, partition_indices) in - indices.into_iter().enumerate() - { - let now = Instant::now(); - let indices = partition_indices.into(); - // Produce batches based on indices - let columns = input_batch - .columns() - .iter() - .map(|c| { - take(c.as_ref(), &indices, None).map_err( - |e| { - DataFusionError::Execution( - e.to_string(), - ) - }, - ) - }) - .collect::>>>()?; - let output_batch = RecordBatch::try_new( - input_batch.schema(), - columns, - ); - repart_time.add(now.elapsed().as_nanos() as usize); - let now = Instant::now(); - let tx = txs_captured - .get_mut(&num_output_partition) - .unwrap(); - tx.send(Some(output_batch)).map_err(|e| { - DataFusionError::Execution(e.to_string()) - })?; - send_time.add(now.elapsed().as_nanos() as usize); - } - } - other => { - // this should be unreachable as long as the validation logic - // in the constructor is kept up-to-date - return Err(DataFusionError::NotImplemented(format!( - "Unsupported repartitioning scheme {:?}", - other - ))); - } - } - counter += 1; - } - - Ok(()) - }); + let input_task: JoinHandle> = + tokio::spawn(Self::pull_from_input( + random.clone(), + self.input.clone(), + i, + txs.clone(), + self.partitioning.clone(), + self.metrics.clone(), + )); // In a separate task, wait for each input to be done - // (and pass along any errors) - tokio::spawn(async move { Self::wait_for_task(input_task, txs).await }); + // (and pass along any errors, including panic!s) + tokio::spawn(Self::wait_for_task(input_task, txs)); } } @@ -272,14 +201,7 @@ impl ExecutionPlan for RepartitionExec { } fn metrics(&self) -> HashMap { - let mut metrics = HashMap::new(); - metrics.insert("fetchTime".to_owned(), (*self.fetch_time_nanos).clone()); - metrics.insert( - "repartitionTime".to_owned(), - (*self.repart_time_nanos).clone(), - ); - metrics.insert("sendTime".to_owned(), (*self.send_time_nanos).clone()); - metrics + self.metrics.to_hashmap() } fn fmt_as( @@ -305,12 +227,115 @@ impl RepartitionExec { input, partitioning, channels: Arc::new(Mutex::new(HashMap::new())), - fetch_time_nanos: SQLMetric::time_nanos(), - repart_time_nanos: SQLMetric::time_nanos(), - send_time_nanos: SQLMetric::time_nanos(), + metrics: RepartitionMetrics::new(), }) } + /// Pulls data from the specified input plan, feeding it to the + /// output partitions based on the desired partitioning + /// + /// i is the input partition index + /// + /// txs hold the output sending channels for each output partition + async fn pull_from_input( + random_state: ahash::RandomState, + input: Arc, + i: usize, + mut txs: HashMap>>>, + partitioning: Partitioning, + metrics: RepartitionMetrics, + ) -> Result<()> { + let num_output_partitions = txs.len(); + + // execute the child operator + let now = Instant::now(); + let mut stream = input.execute(i).await?; + metrics.fetch_nanos.add_elapsed(now); + + let mut counter = 0; + let hashes_buf = &mut vec![]; + + loop { + // fetch the next batch + let now = Instant::now(); + let result = stream.next().await; + metrics.fetch_nanos.add_elapsed(now); + + if result.is_none() { + break; + } + let result: ArrowResult = result.unwrap(); + + match &partitioning { + Partitioning::RoundRobinBatch(_) => { + let now = Instant::now(); + let output_partition = counter % num_output_partitions; + let tx = txs.get_mut(&output_partition).unwrap(); + tx.send(Some(result)) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + metrics.send_nanos.add_elapsed(now); + } + Partitioning::Hash(exprs, _) => { + let now = Instant::now(); + let input_batch = result?; + let arrays = exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&input_batch)? + .into_array(input_batch.num_rows())) + }) + .collect::>>()?; + hashes_buf.clear(); + hashes_buf.resize(arrays[0].len(), 0); + // Hash arrays and compute buckets based on number of partitions + let hashes = create_hashes(&arrays, &random_state, hashes_buf)?; + let mut indices = vec![vec![]; num_output_partitions]; + for (index, hash) in hashes.iter().enumerate() { + indices[(*hash % num_output_partitions as u64) as usize] + .push(index as u64) + } + metrics.repart_nanos.add_elapsed(now); + for (num_output_partition, partition_indices) in + indices.into_iter().enumerate() + { + let now = Instant::now(); + let indices = partition_indices.into(); + // Produce batches based on indices + let columns = input_batch + .columns() + .iter() + .map(|c| { + take(c.as_ref(), &indices, None).map_err(|e| { + DataFusionError::Execution(e.to_string()) + }) + }) + .collect::>>>()?; + let output_batch = + RecordBatch::try_new(input_batch.schema(), columns); + metrics.repart_nanos.add_elapsed(now); + let now = Instant::now(); + let tx = txs.get_mut(&num_output_partition).unwrap(); + tx.send(Some(output_batch)) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + metrics.send_nanos.add_elapsed(now); + } + } + other => { + // this should be unreachable as long as the validation logic + // in the constructor is kept up-to-date + return Err(DataFusionError::NotImplemented(format!( + "Unsupported repartitioning scheme {:?}", + other + ))); + } + } + counter += 1; + } + + Ok(()) + } + /// Waits for `input_task` which is consuming one of the inputs to /// complete. Upon each successful completion, sends a `None` to /// each of the output tx channels to signal one of the inputs is From 91af8203faa80e959e30a6350b1486c9ddc25247 Mon Sep 17 00:00:00 2001 From: QP Hou Date: Mon, 14 Jun 2021 00:02:39 -0700 Subject: [PATCH 23/25] support table alias in join clause (#547) * support table alias in join clause * Update datafusion/src/sql/planner.rs Co-authored-by: Andrew Lamb Co-authored-by: Andrew Lamb --- datafusion/src/sql/planner.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 7e7462ef390e..e860bd74641d 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -424,7 +424,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ctes: &mut HashMap, ) -> Result { match relation { - TableFactor::Table { name, .. } => { + TableFactor::Table { name, alias, .. } => { let table_name = name.to_string(); let cte = ctes.get(&table_name); match ( @@ -432,10 +432,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.schema_provider.get_table_provider(name.try_into()?), ) { (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Some(provider)) => { - LogicalPlanBuilder::scan(&table_name, provider, None)?.build() - } - (_, None) => Err(DataFusionError::Plan(format!( + (_, Some(provider)) => LogicalPlanBuilder::scan( + // take alias into account to support `JOIN table1 as table2` + alias + .as_ref() + .map(|a| a.name.value.as_str()) + .unwrap_or(&table_name), + provider, + None, + )? + .build(), + (None, None) => Err(DataFusionError::Plan(format!( "Table or CTE with name '{}' not found", name ))), From fe810e980834db2582b530188823e308ed9f097c Mon Sep 17 00:00:00 2001 From: QP Hou Date: Mon, 14 Jun 2021 06:35:09 -0700 Subject: [PATCH 24/25] reuse code for now function expr creation (#548) --- datafusion/src/physical_plan/functions.rs | 332 +++++++++++----------- 1 file changed, 168 insertions(+), 164 deletions(-) diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 49ca79a00496..1e423c367cd8 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -512,39 +512,35 @@ macro_rules! invoke_if_unicode_expressions_feature_flag { }; } -/// Create a physical (function) expression. -/// This function errors when `args`' can't be coerced to a valid argument type of the function. -pub fn create_physical_expr( +/// Create a physical scalar function. +pub fn create_physical_fun( fun: &BuiltinScalarFunction, - args: &[Arc], - input_schema: &Schema, ctx_state: &ExecutionContextState, -) -> Result> { - let fun_expr: ScalarFunctionImplementation = Arc::new(match fun { +) -> Result { + Ok(match fun { // math functions - BuiltinScalarFunction::Abs => math_expressions::abs, - BuiltinScalarFunction::Acos => math_expressions::acos, - BuiltinScalarFunction::Asin => math_expressions::asin, - BuiltinScalarFunction::Atan => math_expressions::atan, - BuiltinScalarFunction::Ceil => math_expressions::ceil, - BuiltinScalarFunction::Cos => math_expressions::cos, - BuiltinScalarFunction::Exp => math_expressions::exp, - BuiltinScalarFunction::Floor => math_expressions::floor, - BuiltinScalarFunction::Log => math_expressions::log10, - BuiltinScalarFunction::Ln => math_expressions::ln, - BuiltinScalarFunction::Log10 => math_expressions::log10, - BuiltinScalarFunction::Log2 => math_expressions::log2, - BuiltinScalarFunction::Random => math_expressions::random, - BuiltinScalarFunction::Round => math_expressions::round, - BuiltinScalarFunction::Signum => math_expressions::signum, - BuiltinScalarFunction::Sin => math_expressions::sin, - BuiltinScalarFunction::Sqrt => math_expressions::sqrt, - BuiltinScalarFunction::Tan => math_expressions::tan, - BuiltinScalarFunction::Trunc => math_expressions::trunc, - + BuiltinScalarFunction::Abs => Arc::new(math_expressions::abs), + BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), + BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), + BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), + BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), + BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos), + BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp), + BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), + BuiltinScalarFunction::Log => Arc::new(math_expressions::log10), + BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), + BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), + BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), + BuiltinScalarFunction::Random => Arc::new(math_expressions::random), + BuiltinScalarFunction::Round => Arc::new(math_expressions::round), + BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum), + BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin), + BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), + BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), + BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc), // string functions - BuiltinScalarFunction::Array => array_expressions::array, - BuiltinScalarFunction::Ascii => |args| match args[0].data_type() { + BuiltinScalarFunction::Array => Arc::new(array_expressions::array), + BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::ascii::)(args) } @@ -555,8 +551,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function ascii", other, ))), - }, - BuiltinScalarFunction::BitLength => |args| match &args[0] { + }), + BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( @@ -567,8 +563,8 @@ pub fn create_physical_expr( )), _ => unreachable!(), }, - }, - BuiltinScalarFunction::Btrim => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Btrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::btrim::)(args) } @@ -579,55 +575,47 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function btrim", other, ))), - }, - BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int32Type, - "character_length" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int64Type, - "character_length" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function character_length", - other, - ))), - }, + }), + BuiltinScalarFunction::CharacterLength => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int32Type, + "character_length" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int64Type, + "character_length" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function character_length", + other, + ))), + }) + } BuiltinScalarFunction::Chr => { - |args| make_scalar_function(string_expressions::chr)(args) + Arc::new(|args| make_scalar_function(string_expressions::chr)(args)) } - BuiltinScalarFunction::Concat => string_expressions::concat, + BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), BuiltinScalarFunction::ConcatWithSeparator => { - |args| make_scalar_function(string_expressions::concat_ws)(args) + Arc::new(|args| make_scalar_function(string_expressions::concat_ws)(args)) } - BuiltinScalarFunction::DatePart => datetime_expressions::date_part, - BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::DatePart => Arc::new(datetime_expressions::date_part), + BuiltinScalarFunction::DateTrunc => Arc::new(datetime_expressions::date_trunc), BuiltinScalarFunction::Now => { // bind value for now at plan time - let fun_expr = Arc::new(datetime_expressions::make_now( + Arc::new(datetime_expressions::make_now( ctx_state.execution_props.query_execution_start_time, - )); - - // TODO refactor code to not return here, but instead fall through below - let args = vec![]; - let arg_types = vec![]; // has no args - return Ok(Arc::new(ScalarFunctionExpr::new( - &format!("{}", fun), - fun_expr, - args, - &return_type(fun, &arg_types)?, - ))); + )) } - BuiltinScalarFunction::InitCap => |args| match args[0].data_type() { + BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::initcap::)(args) } @@ -638,8 +626,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function initcap", other, ))), - }, - BuiltinScalarFunction::Left => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); make_scalar_function(func)(args) @@ -652,9 +640,9 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function left", other, ))), - }, - BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Lpad => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Lower => Arc::new(string_expressions::lower), + BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); make_scalar_function(func)(args) @@ -667,8 +655,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function lpad", other, ))), - }, - BuiltinScalarFunction::Ltrim => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::ltrim::)(args) } @@ -679,12 +667,12 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function ltrim", other, ))), - }, + }), BuiltinScalarFunction::MD5 => { - invoke_if_crypto_expressions_feature_flag!(md5, "md5") + Arc::new(invoke_if_crypto_expressions_feature_flag!(md5, "md5")) } - BuiltinScalarFunction::NullIf => nullif_func, - BuiltinScalarFunction::OctetLength => |args| match &args[0] { + BuiltinScalarFunction::NullIf => Arc::new(nullif_func), + BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( @@ -695,52 +683,56 @@ pub fn create_physical_expr( )), _ => unreachable!(), }, - }, - BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_match, - i32, - "regexp_match" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_match, - i64, - "regexp_match" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_match", - other - ))), - }, - BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, - i32, - "regexp_replace" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, - i64, - "regexp_replace" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_replace", - other, - ))), - }, - BuiltinScalarFunction::Repeat => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::RegexpMatch => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i32, + "regexp_match" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i64, + "regexp_match" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_match", + other + ))), + }) + } + BuiltinScalarFunction::RegexpReplace => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i32, + "regexp_replace" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i64, + "regexp_replace" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_replace", + other, + ))), + }) + } + BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::repeat::)(args) } @@ -751,8 +743,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function repeat", other, ))), - }, - BuiltinScalarFunction::Replace => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Replace => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::replace::)(args) } @@ -763,8 +755,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function replace", other, ))), - }, - BuiltinScalarFunction::Reverse => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); @@ -779,8 +771,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function reverse", other, ))), - }, - BuiltinScalarFunction::Right => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Right => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); @@ -795,8 +787,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function right", other, ))), - }, - BuiltinScalarFunction::Rpad => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); make_scalar_function(func)(args) @@ -809,8 +801,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function rpad", other, ))), - }, - BuiltinScalarFunction::Rtrim => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::rtrim::)(args) } @@ -821,20 +813,20 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function rtrim", other, ))), - }, + }), BuiltinScalarFunction::SHA224 => { - invoke_if_crypto_expressions_feature_flag!(sha224, "sha224") + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha224, "sha224")) } BuiltinScalarFunction::SHA256 => { - invoke_if_crypto_expressions_feature_flag!(sha256, "sha256") + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha256, "sha256")) } BuiltinScalarFunction::SHA384 => { - invoke_if_crypto_expressions_feature_flag!(sha384, "sha384") + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha384, "sha384")) } BuiltinScalarFunction::SHA512 => { - invoke_if_crypto_expressions_feature_flag!(sha512, "sha512") + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha512, "sha512")) } - BuiltinScalarFunction::SplitPart => |args| match args[0].data_type() { + BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::split_part::)(args) } @@ -845,8 +837,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function split_part", other, ))), - }, - BuiltinScalarFunction::StartsWith => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::starts_with::)(args) } @@ -857,8 +849,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function starts_with", other, ))), - }, - BuiltinScalarFunction::Strpos => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( strpos, Int32Type, "strpos" @@ -875,8 +867,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function strpos", other, ))), - }, - BuiltinScalarFunction::Substr => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Substr => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); @@ -891,8 +883,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function substr", other, ))), - }, - BuiltinScalarFunction::ToHex => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { DataType::Int32 => { make_scalar_function(string_expressions::to_hex::)(args) } @@ -903,9 +895,11 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function to_hex", other, ))), - }, - BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, - BuiltinScalarFunction::Translate => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::ToTimestamp => { + Arc::new(datetime_expressions::to_timestamp) + } + BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( translate, @@ -926,8 +920,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function translate", other, ))), - }, - BuiltinScalarFunction::Trim => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::btrim::)(args) } @@ -938,10 +932,20 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function trim", other, ))), - }, - BuiltinScalarFunction::Upper => string_expressions::upper, - }); - // coerce + }), + BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), + }) +} + +/// Create a physical (function) expression. +/// This function errors when `args`' can't be coerced to a valid argument type of the function. +pub fn create_physical_expr( + fun: &BuiltinScalarFunction, + args: &[Arc], + input_schema: &Schema, + ctx_state: &ExecutionContextState, +) -> Result> { + let fun_expr = create_physical_fun(fun, ctx_state)?; let args = coerce(args, input_schema, &signature(fun))?; let arg_types = args From 9c23de0fb5ed3965a1b2357e2614cf81834b0319 Mon Sep 17 00:00:00 2001 From: Gang Liao Date: Sat, 5 Jun 2021 10:52:14 -0400 Subject: [PATCH 25/25] Support modulus op --- .../src/physical_plan/expressions/binary.rs | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index 5c2d9ce02f51..5635ab8ae4b7 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -19,7 +19,7 @@ use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::arithmetic::{ - add, divide, divide_scalar, multiply, subtract, + add, divide, divide_scalar, multiply, subtract, modulus, modulus_scalar }; use arrow::compute::kernels::boolean::{and_kleene, or_kleene}; use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; @@ -341,14 +341,9 @@ fn common_binary_type( } // for math expressions, the final value of the coercion is also the return type // because coercion favours higher information types - Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => { + Operator::Plus | Operator::Minus | Operator::Modulus | Operator::Divide | Operator::Multiply => { numerical_coercion(lhs_type, rhs_type) } - Operator::Modulus => { - return Err(DataFusionError::NotImplemented( - "Modulus operator is still not supported".to_string(), - )) - } }; // re-write the error message of failed coercions to include the operator's information @@ -389,12 +384,9 @@ pub fn binary_operator_data_type( | Operator::GtEq | Operator::LtEq => Ok(DataType::Boolean), // math operations return the same value as the common coerced type - Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => { + Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply | Operator::Modulus => { Ok(common_type) } - Operator::Modulus => Err(DataFusionError::NotImplemented( - "Modulus operator is still not supported".to_string(), - )), } } @@ -454,6 +446,9 @@ impl PhysicalExpr for BinaryExpr { Operator::Divide => { binary_primitive_array_op_scalar!(array, scalar.clone(), divide) } + Operator::Modulus => { + binary_primitive_array_op_scalar!(array, scalar.clone(), modulus) + } // if scalar operation is not supported - fallback to array implementation _ => None, } @@ -503,6 +498,7 @@ impl PhysicalExpr for BinaryExpr { Operator::Minus => binary_primitive_array_op!(left, right, subtract), Operator::Multiply => binary_primitive_array_op!(left, right, multiply), Operator::Divide => binary_primitive_array_op!(left, right, divide), + Operator::Modulus => binary_primitive_array_op!(left, right, modulus), Operator::And => { if left_data_type == DataType::Boolean { boolean_op!(left, right, and_kleene) @@ -525,9 +521,6 @@ impl PhysicalExpr for BinaryExpr { ))); } } - Operator::Modulus => Err(DataFusionError::NotImplemented( - "Modulus operator is still not supported".to_string(), - )), }; result.map(|a| ColumnarValue::Array(a)) } @@ -964,6 +957,25 @@ mod tests { Ok(()) } + #[test] + fn modulus_op() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048])); + let b = Arc::new(Int32Array::from(vec![2, 4, 7, 14, 32])); + + apply_arithmetic::( + schema, + vec![a, b], + Operator::Modulus, + Int32Array::from(vec![0, 0, 2, 8, 0]), + )?; + + Ok(()) + } + fn apply_arithmetic( schema: SchemaRef, data: Vec,