From 14d807a7fde85a02b58abeede58ba8d5b8b7a64a Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Sun, 26 Jun 2022 15:47:27 -0600 Subject: [PATCH 01/40] Failing tests --- datafusion/core/tests/sql/timestamp.rs | 80 ++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index 1e475fb175bd..912b22fbd7c5 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -814,3 +814,83 @@ async fn group_by_timestamp_millis() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn interval_year() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-01' + interval '1' year as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1995-01-01 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn add_interval_month() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-31' + interval '1' month as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1994-02-28 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn sub_interval_month() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-03-31' - interval '1' month as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1994-02-28 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn sub_month_wrap() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-15' - interval '1' month as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1993-12-15 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} From 88f5d7fa8f07c012da3c2ad3c286599bc9477b4b Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Sun, 26 Jun 2022 15:58:09 -0600 Subject: [PATCH 02/40] Add month/year arithmetic --- datafusion/common/src/scalar.rs | 5 +- .../physical-expr/src/expressions/datetime.rs | 168 +++++++++++------- 2 files changed, 106 insertions(+), 67 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 553dc5eae570..ccd3b6ca37ff 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -37,6 +37,7 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part of arrow’s `Array`. +/// https://arrow.apache.org/docs/python/api/datatypes.html #[derive(Clone)] pub enum ScalarValue { /// represents `DataType::Null` (castable to/from any other type) @@ -75,9 +76,9 @@ pub enum ScalarValue { LargeBinary(Option>), /// list of nested ScalarValue List(Option>, Box), - /// Date stored as a signed 32bit int + /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), - /// Date stored as a signed 64bit int + /// Date stored as a signed 64bit int days since UNIX epoch 1970-01-01 Date64(Option), /// Timestamp Second TimestampSecond(Option, Option), diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index 3d84e79f2cc9..fa1abdc3be70 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -18,11 +18,14 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; +use chrono::{Datelike, NaiveDate}; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, Operator}; use std::any::Any; +use std::cmp::min; use std::fmt::{Display, Formatter}; +use std::ops::{Add, Sub}; use std::sync::Arc; /// Perform DATE +/ INTERVAL math @@ -86,76 +89,111 @@ impl PhysicalExpr for DateIntervalExpr { let dates = self.lhs.evaluate(batch)?; let intervals = self.rhs.evaluate(batch)?; - let interval = match intervals { - ColumnarValue::Scalar(interval) => match interval { - ScalarValue::IntervalDayTime(Some(interval)) => interval as i32, - ScalarValue::IntervalYearMonth(Some(_)) => { - return Err(DataFusionError::Execution( - "DateIntervalExpr does not support IntervalYearMonth".to_string(), - )) - } - ScalarValue::IntervalMonthDayNano(Some(_)) => { - return Err(DataFusionError::Execution( - "DateIntervalExpr does not support IntervalMonthDayNano" - .to_string(), - )) - } - other => { - return Err(DataFusionError::Execution(format!( - "DateIntervalExpr does not support non-interval type {:?}", - other - ))) - } - }, - _ => { - return Err(DataFusionError::Execution( - "Columnar execution is not yet supported for DateIntervalExpr" - .to_string(), - )) + // Unwrap days since epoch + let operand = match dates { + ColumnarValue::Scalar(scalar) => scalar, + _ => Err(DataFusionError::Execution( + "Columnar execution is not yet supported for DateIntervalExpr" + .to_string(), + ))?, + }; + + // Convert to NaiveDate + let epoch = NaiveDate::from_ymd(1970, 1, 1); + let prior = match operand { + ScalarValue::Date32(Some(date)) => { + epoch.add(chrono::Duration::days(date as i64)) } + ScalarValue::Date64(Some(date)) => epoch.add(chrono::Duration::days(date)), + _ => Err(DataFusionError::Execution(format!( + "Invalid lhs type for DateIntervalExpr: {:?}", + operand + )))?, }; - match dates { - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Date32(Some(date)) => match &self.op { - Operator::Plus => Ok(ColumnarValue::Scalar(ScalarValue::Date32( - Some(date + interval), - ))), - Operator::Minus => Ok(ColumnarValue::Scalar(ScalarValue::Date32( - Some(date - interval), - ))), - _ => { - // this should be unreachable because we check the operators in `try_new` - Err(DataFusionError::Execution( - "Invalid operator for DateIntervalExpr".to_string(), - )) - } - }, - ScalarValue::Date64(Some(date)) => match &self.op { - Operator::Plus => Ok(ColumnarValue::Scalar(ScalarValue::Date64( - Some(date + interval as i64), - ))), - Operator::Minus => Ok(ColumnarValue::Scalar(ScalarValue::Date64( - Some(date - interval as i64), - ))), - _ => { - // this should be unreachable because we check the operators in `try_new` - Err(DataFusionError::Execution( - "Invalid operator for DateIntervalExpr".to_string(), - )) - } - }, - _ => { - // this should be unreachable because we check the types in `try_new` - Err(DataFusionError::Execution( - "Invalid lhs type for DateIntervalExpr".to_string(), - )) - } - }, + // Unwrap interval to add + let scalar = match &intervals { + ColumnarValue::Scalar(interval) => interval, _ => Err(DataFusionError::Execution( "Columnar execution is not yet supported for DateIntervalExpr" .to_string(), - )), - } + ))?, + }; + + // Negate for subtraction + let interval = match &scalar { + ScalarValue::IntervalDayTime(Some(interval)) => *interval, + ScalarValue::IntervalYearMonth(Some(interval)) => *interval as i64, + ScalarValue::IntervalMonthDayNano(Some(_interval)) => { + Err(DataFusionError::Execution( + "DateIntervalExpr does not support IntervalMonthDayNano".to_string(), + ))? + } + other => Err(DataFusionError::Execution(format!( + "DateIntervalExpr does not support non-interval type {:?}", + other + )))?, + }; + let interval = match &self.op { + Operator::Plus => interval, + Operator::Minus => interval * -1, + _ => { + // this should be unreachable because we check the operators in `try_new` + Err(DataFusionError::Execution( + "Invalid operator for DateIntervalExpr".to_string(), + ))? + } + }; + + // Add interval + let posterior = match scalar { + ScalarValue::IntervalDayTime(Some(_)) => { + prior.add(chrono::Duration::days(interval)) + } + ScalarValue::IntervalYearMonth(Some(_)) => { + let target = add_months(prior, interval); + let target_plus = add_months(target, 1); + let last_day = target_plus.sub(chrono::Duration::days(1)); + let day = min(prior.day(), last_day.day()); + NaiveDate::from_ymd(target.year(), target.month(), day) + } + ScalarValue::IntervalMonthDayNano(Some(_)) => { + Err(DataFusionError::Execution( + "DateIntervalExpr does not support IntervalMonthDayNano".to_string(), + ))? + } + other => Err(DataFusionError::Execution(format!( + "DateIntervalExpr does not support non-interval type {:?}", + other + )))?, + }; + + // convert back + let posterior = posterior.sub(epoch).num_days(); + let res = match operand { + ScalarValue::Date32(Some(_)) => { + let casted = + i32::try_from(posterior).context("Date arithmetic out of bounds!")?; + ColumnarValue::Scalar(ScalarValue::Date32(Some(casted))) + } + ScalarValue::Date64(Some(_)) => { + ColumnarValue::Scalar(ScalarValue::Date64(Some(posterior))) + } + _ => Err(DataFusionError::Execution(format!( + "Invalid lhs type for DateIntervalExpr: {}", + scalar + )))?, + }; + Ok(res) } } + +fn add_months(dt: NaiveDate, delta: i64) -> NaiveDate { + let ay = dt.year(); + let am = dt.month() as i32 - 1; // zero-based for modulo operations + let bm = am + delta as i32; + let by = ay + if bm < 0 { bm / 12 - 1 } else { bm / 12 }; + let cm = bm % 12; + let dm = if cm < 0 { cm + 12 } else { cm }; + return NaiveDate::from_ymd(by, dm as u32 + 1, 1); +} From d2f43c9471b26548f226ddc36805c588456790e5 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Sun, 26 Jun 2022 16:56:04 -0600 Subject: [PATCH 03/40] Fix tests? --- datafusion/physical-expr/src/expressions/datetime.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index fa1abdc3be70..a76c1d963137 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -172,8 +172,11 @@ impl PhysicalExpr for DateIntervalExpr { let posterior = posterior.sub(epoch).num_days(); let res = match operand { ScalarValue::Date32(Some(_)) => { - let casted = - i32::try_from(posterior).context("Date arithmetic out of bounds!")?; + let casted = i32::try_from(posterior).map_err(|_| { + DataFusionError::Execution( + "Date arithmetic out of bounds!".to_string(), + ) + })?; ColumnarValue::Scalar(ScalarValue::Date32(Some(casted))) } ScalarValue::Date64(Some(_)) => { From e34705e8683c5f6dfdef182caf4a993e30e631d0 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Sun, 26 Jun 2022 17:29:02 -0600 Subject: [PATCH 04/40] Fix clippy? --- datafusion/physical-expr/src/expressions/datetime.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index a76c1d963137..f7eb1e5236e1 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -136,7 +136,7 @@ impl PhysicalExpr for DateIntervalExpr { }; let interval = match &self.op { Operator::Plus => interval, - Operator::Minus => interval * -1, + Operator::Minus => -interval, _ => { // this should be unreachable because we check the operators in `try_new` Err(DataFusionError::Execution( @@ -198,5 +198,5 @@ fn add_months(dt: NaiveDate, delta: i64) -> NaiveDate { let by = ay + if bm < 0 { bm / 12 - 1 } else { bm / 12 }; let cm = bm % 12; let dm = if cm < 0 { cm + 12 } else { cm }; - return NaiveDate::from_ymd(by, dm as u32 + 1, 1); + NaiveDate::from_ymd(by, dm as u32 + 1, 1) } From c37d29e579a5c96d6a26209ebc1c61bab921e100 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 13:02:40 -0600 Subject: [PATCH 05/40] Update datafusion/common/src/scalar.rs Co-authored-by: Andrew Lamb --- datafusion/common/src/scalar.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index ccd3b6ca37ff..cb6680b89c57 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -78,7 +78,7 @@ pub enum ScalarValue { List(Option>, Box), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), - /// Date stored as a signed 64bit int days since UNIX epoch 1970-01-01 + /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 Date64(Option), /// Timestamp Second TimestampSecond(Option, Option), From 874a5edd8e0ff5492166295864d927977a0ba079 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 09:45:05 -0600 Subject: [PATCH 06/40] Add support for all types, fix math --- datafusion/common/src/scalar.rs | 10 +- .../physical-expr/src/expressions/datetime.rs | 97 +++++++++---------- 2 files changed, 54 insertions(+), 53 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index cb6680b89c57..a579a7e74aec 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -38,6 +38,7 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part of arrow’s `Array`. /// https://arrow.apache.org/docs/python/api/datatypes.html +/// https://github.com/apache/arrow/blob/master/format/Schema.fbs#L354-L375 #[derive(Clone)] pub enum ScalarValue { /// represents `DataType::Null` (castable to/from any other type) @@ -88,11 +89,14 @@ pub enum ScalarValue { TimestampMicrosecond(Option, Option), /// Timestamp Nanoseconds TimestampNanosecond(Option, Option), - /// Interval with YearMonth unit + /// Number of elapsed whole months since epoch IntervalYearMonth(Option), - /// Interval with DayTime unit + /// Number of elapsed days and milliseconds since epoch (no leap seconds) + /// stored as 2 contiguous 32-bit signed integers IntervalDayTime(Option), - /// Interval with MonthDayNano unit + /// A triple of the number of elapsed months, days, and nanoseconds. + /// Months and days are encoded as 32-bit signed integers. + /// Nanoseconds is encoded as a 64-bit signed integer (no leap seconds). IntervalMonthDayNano(Option), /// struct of nested ScalarValue Struct(Option>, Box>), diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index f7eb1e5236e1..f735eede5233 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -18,7 +18,7 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use chrono::{Datelike, NaiveDate}; +use chrono::{Datelike, Duration, NaiveDate}; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, Operator}; @@ -77,15 +77,15 @@ impl PhysicalExpr for DateIntervalExpr { self } - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + fn data_type(&self, input_schema: &Schema) -> Result { self.lhs.data_type(input_schema) } - fn nullable(&self, input_schema: &Schema) -> datafusion_common::Result { + fn nullable(&self, input_schema: &Schema) -> Result { self.lhs.nullable(input_schema) } - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { let dates = self.lhs.evaluate(batch)?; let intervals = self.rhs.evaluate(batch)?; @@ -101,10 +101,8 @@ impl PhysicalExpr for DateIntervalExpr { // Convert to NaiveDate let epoch = NaiveDate::from_ymd(1970, 1, 1); let prior = match operand { - ScalarValue::Date32(Some(date)) => { - epoch.add(chrono::Duration::days(date as i64)) - } - ScalarValue::Date64(Some(date)) => epoch.add(chrono::Duration::days(date)), + ScalarValue::Date32(Some(d)) => epoch.add(Duration::days(d as i64)), + ScalarValue::Date64(Some(ms)) => epoch.add(Duration::milliseconds(ms)), _ => Err(DataFusionError::Execution(format!( "Invalid lhs type for DateIntervalExpr: {:?}", operand @@ -120,23 +118,10 @@ impl PhysicalExpr for DateIntervalExpr { ))?, }; - // Negate for subtraction - let interval = match &scalar { - ScalarValue::IntervalDayTime(Some(interval)) => *interval, - ScalarValue::IntervalYearMonth(Some(interval)) => *interval as i64, - ScalarValue::IntervalMonthDayNano(Some(_interval)) => { - Err(DataFusionError::Execution( - "DateIntervalExpr does not support IntervalMonthDayNano".to_string(), - ))? - } - other => Err(DataFusionError::Execution(format!( - "DateIntervalExpr does not support non-interval type {:?}", - other - )))?, - }; - let interval = match &self.op { - Operator::Plus => interval, - Operator::Minus => -interval, + // Invert sign for subtraction + let sign = match &self.op { + Operator::Plus => 1, + Operator::Minus => -1, _ => { // this should be unreachable because we check the operators in `try_new` Err(DataFusionError::Execution( @@ -145,23 +130,11 @@ impl PhysicalExpr for DateIntervalExpr { } }; - // Add interval + // Do math let posterior = match scalar { - ScalarValue::IntervalDayTime(Some(_)) => { - prior.add(chrono::Duration::days(interval)) - } - ScalarValue::IntervalYearMonth(Some(_)) => { - let target = add_months(prior, interval); - let target_plus = add_months(target, 1); - let last_day = target_plus.sub(chrono::Duration::days(1)); - let day = min(prior.day(), last_day.day()); - NaiveDate::from_ymd(target.year(), target.month(), day) - } - ScalarValue::IntervalMonthDayNano(Some(_)) => { - Err(DataFusionError::Execution( - "DateIntervalExpr does not support IntervalMonthDayNano".to_string(), - ))? - } + ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign), + ScalarValue::IntervalYearMonth(Some(i)) => add_months(prior, *i * sign), + ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign), other => Err(DataFusionError::Execution(format!( "DateIntervalExpr does not support non-interval type {:?}", other @@ -169,18 +142,14 @@ impl PhysicalExpr for DateIntervalExpr { }; // convert back - let posterior = posterior.sub(epoch).num_days(); let res = match operand { ScalarValue::Date32(Some(_)) => { - let casted = i32::try_from(posterior).map_err(|_| { - DataFusionError::Execution( - "Date arithmetic out of bounds!".to_string(), - ) - })?; - ColumnarValue::Scalar(ScalarValue::Date32(Some(casted))) + let days = posterior.sub(epoch).num_days() as i32; + ColumnarValue::Scalar(ScalarValue::Date32(Some(days))) } ScalarValue::Date64(Some(_)) => { - ColumnarValue::Scalar(ScalarValue::Date64(Some(posterior))) + let ms = posterior.sub(epoch).num_milliseconds(); + ColumnarValue::Scalar(ScalarValue::Date64(Some(ms))) } _ => Err(DataFusionError::Execution(format!( "Invalid lhs type for DateIntervalExpr: {}", @@ -191,7 +160,35 @@ impl PhysicalExpr for DateIntervalExpr { } } -fn add_months(dt: NaiveDate, delta: i64) -> NaiveDate { +fn add_m_d_nano(prior: NaiveDate, interval: i128, sign: i32) -> NaiveDate { + let interval = interval as u128; + let months = (interval >> 96) as i32 * sign; + let days = (interval >> 64) as i32 * sign; + let nanos = interval as i64 * sign as i64; + let a = add_months(prior, months); + let b = a.add(Duration::days(days as i64)); + let c = b.add(Duration::nanoseconds(nanos)); + c +} + +fn add_day_time(prior: NaiveDate, interval: i64, sign: i32) -> NaiveDate { + let interval = interval as u64; + let ms = (interval >> 32) as i32 * sign; + let days = interval as i32 * sign; + let intermediate = prior.add(Duration::days(days as i64)); + let posterior = intermediate.add(Duration::milliseconds(ms as i64)); + posterior +} + +fn add_months(prior: NaiveDate, interval: i32) -> NaiveDate { + let target = chrono_add_months(prior, interval); + let target_plus = chrono_add_months(target, 1); + let last_day = target_plus.sub(chrono::Duration::days(1)); + let day = min(prior.day(), last_day.day()); + NaiveDate::from_ymd(target.year(), target.month(), day) +} + +fn chrono_add_months(dt: NaiveDate, delta: i32) -> NaiveDate { let ay = dt.year(); let am = dt.month() as i32 - 1; // zero-based for modulo operations let bm = am + delta as i32; From ee1c7565a8b7a06e38e5ab8a6238953e361a8438 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 09:48:11 -0600 Subject: [PATCH 07/40] Fix doc --- datafusion/common/src/scalar.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index a579a7e74aec..e3442d749a49 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -89,9 +89,9 @@ pub enum ScalarValue { TimestampMicrosecond(Option, Option), /// Timestamp Nanoseconds TimestampNanosecond(Option, Option), - /// Number of elapsed whole months since epoch + /// Number of elapsed whole months IntervalYearMonth(Option), - /// Number of elapsed days and milliseconds since epoch (no leap seconds) + /// Number of elapsed days and milliseconds (no leap seconds) /// stored as 2 contiguous 32-bit signed integers IntervalDayTime(Option), /// A triple of the number of elapsed months, days, and nanoseconds. From 5ea1c282aa713651a2f54ccaa6bf421886086049 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 10:32:17 -0600 Subject: [PATCH 08/40] Fix test that relied on previous flawed implementation --- datafusion/core/tests/sql/timestamp.rs | 40 +++++++++++++++++++ .../optimizer/src/simplify_expressions.rs | 8 ++-- .../physical-expr/src/expressions/datetime.rs | 4 +- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index 912b22fbd7c5..9acc3f3cbe28 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -894,3 +894,43 @@ async fn sub_month_wrap() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn add_interval_day() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-15' + interval '1' day as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1994-01-16 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn sub_interval_day() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-01' - interval '1' day as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1993-12-31 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index aa089a00a6a9..da4dfa9eece1 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -1951,7 +1951,7 @@ mod tests { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema) .unwrap() - + Expr::Literal(ScalarValue::IntervalDayTime(Some(123))); + + Expr::Literal(ScalarValue::IntervalDayTime(Some(123i64 << 32))); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![date_plus_interval_expr]) @@ -1963,10 +1963,10 @@ mod tests { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = "Projection: Date32(\"18636\") AS CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Date32) + IntervalDayTime(\"123\")\ - \n TableScan: test"; + let expected = r#"Projection: Date32("18636") AS CAST(totimestamp(Utf8("2020-09-08T12:05:00+00:00")) AS Date32) + IntervalDayTime("528280977408") + TableScan: test"#; let actual = get_optimized_plan_formatted(&plan, &time); - assert_eq!(expected, actual); + assert_eq!(actual, expected); } } diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index f735eede5233..036d0012c4dc 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -173,8 +173,8 @@ fn add_m_d_nano(prior: NaiveDate, interval: i128, sign: i32) -> NaiveDate { fn add_day_time(prior: NaiveDate, interval: i64, sign: i32) -> NaiveDate { let interval = interval as u64; - let ms = (interval >> 32) as i32 * sign; - let days = interval as i32 * sign; + let days = (interval >> 32) as i32 * sign; + let ms = interval as i32 * sign; let intermediate = prior.add(Duration::days(days as i64)); let posterior = intermediate.add(Duration::milliseconds(ms as i64)); posterior From 83484707952caf125f70c695f6cefe2d766f9a1c Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 10:57:35 -0600 Subject: [PATCH 09/40] Appease clippy --- datafusion/physical-expr/src/expressions/datetime.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index 036d0012c4dc..6b4b2e571fe1 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -167,8 +167,7 @@ fn add_m_d_nano(prior: NaiveDate, interval: i128, sign: i32) -> NaiveDate { let nanos = interval as i64 * sign as i64; let a = add_months(prior, months); let b = a.add(Duration::days(days as i64)); - let c = b.add(Duration::nanoseconds(nanos)); - c + b.add(Duration::nanoseconds(nanos)) } fn add_day_time(prior: NaiveDate, interval: i64, sign: i32) -> NaiveDate { @@ -176,8 +175,7 @@ fn add_day_time(prior: NaiveDate, interval: i64, sign: i32) -> NaiveDate { let days = (interval >> 32) as i32 * sign; let ms = interval as i32 * sign; let intermediate = prior.add(Duration::days(days as i64)); - let posterior = intermediate.add(Duration::milliseconds(ms as i64)); - posterior + intermediate.add(Duration::milliseconds(ms as i64)) } fn add_months(prior: NaiveDate, interval: i32) -> NaiveDate { From cd999c755b3737772d882be823a5b30b273394e6 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Sat, 25 Jun 2022 15:14:45 -0600 Subject: [PATCH 10/40] Failing test case for TPC-H query 20 --- datafusion/core/tests/sql/mod.rs | 33 +++++++++++++++- datafusion/core/tests/sql/subqueries.rs | 42 +++++++++++++++++++++ datafusion/core/tests/tpch-csv/part.csv | 2 + datafusion/core/tests/tpch-csv/partsupp.csv | 2 + datafusion/core/tests/tpch-csv/supplier.csv | 2 + 5 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 datafusion/core/tests/sql/subqueries.rs create mode 100644 datafusion/core/tests/tpch-csv/part.csv create mode 100644 datafusion/core/tests/tpch-csv/partsupp.csv create mode 100644 datafusion/core/tests/tpch-csv/supplier.csv diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 0e3e08873cce..95259c9ca63a 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -110,6 +110,7 @@ pub mod information_schema; mod partitioned_csv; #[cfg(feature = "unicode_expressions")] pub mod unicode; +mod subqueries; fn assert_float_eq(expected: &[Vec], received: &[Vec]) where @@ -483,7 +484,37 @@ fn get_tpch_table_schema(table: &str) -> Schema { Field::new("n_comment", DataType::Utf8, false), ]), - _ => unimplemented!(), + "supplier" => Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_name", DataType::Utf8, false), + Field::new("s_address", DataType::Utf8, false), + Field::new("s_nationkey", DataType::Int64, false), + Field::new("s_phone", DataType::Utf8, false), + Field::new("s_acctbal", DataType::Float64, false), + Field::new("s_comment", DataType::Utf8, false), + ]), + + "partsupp" => Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), + Field::new("ps_availqty", DataType::Int32, false), + Field::new("ps_supplycost", DataType::Float64, false), + Field::new("ps_comment", DataType::Utf8, false), + ]), + + "part" => Schema::new(vec![ + Field::new("p_partkey", DataType::Int64, false), + Field::new("p_name", DataType::Utf8, false), + Field::new("p_mfgr", DataType::Utf8, false), + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("p_container", DataType::Utf8, false), + Field::new("p_retailprice", DataType::Float64, false), + Field::new("p_comment", DataType::Utf8, false), + ]), + + _ => unimplemented!("Table: {}", table), } } diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs new file mode 100644 index 000000000000..9f39f09b3ad5 --- /dev/null +++ b/datafusion/core/tests/sql/subqueries.rs @@ -0,0 +1,42 @@ +use super::*; +use datafusion::assert_batches_eq; +use datafusion::prelude::SessionContext; +use crate::sql::{execute_to_batches}; + +#[tokio::test] +async fn select_all() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "supplier").await?; + register_tpch_csv(&ctx, "nation").await?; + register_tpch_csv(&ctx, "partsupp").await?; + register_tpch_csv(&ctx, "part").await?; + register_tpch_csv(&ctx, "lineitem").await?; + + let sql = r#" + select s_name, s_address + from supplier, nation + where s_suppkey in ( + select ps_suppkey from partsupp + where ps_partkey in ( select p_partkey from part where p_name like 'forest%' ) + and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem + where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate >= date '1994-01-01' + and l_shipdate < 'date 1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey and n_name = 'CANADA' + order by s_name; + "#; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00005 |", + "+---------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/tpch-csv/part.csv b/datafusion/core/tests/tpch-csv/part.csv new file mode 100644 index 000000000000..f790f07bc2fe --- /dev/null +++ b/datafusion/core/tests/tpch-csv/part.csv @@ -0,0 +1,2 @@ +p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,p_retailprice,p_comment +1,goldenrod lavender spring chocolate lace,Manufacturer#1,Brand#13,PROMO BURNISHED COPPER,7,JUMBO PKG,901.00,ly. slyly ironi diff --git a/datafusion/core/tests/tpch-csv/partsupp.csv b/datafusion/core/tests/tpch-csv/partsupp.csv new file mode 100644 index 000000000000..833789312fc1 --- /dev/null +++ b/datafusion/core/tests/tpch-csv/partsupp.csv @@ -0,0 +1,2 @@ +ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment +1,252,8076,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff diff --git a/datafusion/core/tests/tpch-csv/supplier.csv b/datafusion/core/tests/tpch-csv/supplier.csv new file mode 100644 index 000000000000..768096c7ffa6 --- /dev/null +++ b/datafusion/core/tests/tpch-csv/supplier.csv @@ -0,0 +1,2 @@ +s_suppkey,s_name,s_address,s_nationkey,s_phone,s_acctbal,s_comment +1,Supplier#000000001, N kD4on9OM Ipw3,gf0JBoQDd7tgrzrddZ,17,27-918-335-1736,5755.94,each slyly above the careful From ccdb98fabe19276ff78913140ab8ce6f0f6a7d71 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Sat, 25 Jun 2022 15:16:05 -0600 Subject: [PATCH 11/40] Fix name --- datafusion/core/tests/sql/subqueries.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 9f39f09b3ad5..ef202bf3ceae 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -3,8 +3,9 @@ use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use crate::sql::{execute_to_batches}; +/// https://github.com/apache/arrow-datafusion/issues/171 #[tokio::test] -async fn select_all() -> Result<()> { +async fn tpch_q20() -> Result<()> { let ctx = SessionContext::new(); register_tpch_csv(&ctx, "supplier").await?; register_tpch_csv(&ctx, "nation").await?; From e7fcb2f8211ec85f9637d7c72e1e8243f2f794cb Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Sat, 25 Jun 2022 16:51:24 -0600 Subject: [PATCH 12/40] Broken test for adding intervals to dates --- datafusion/core/tests/sql/mod.rs | 2 ++ datafusion/core/tests/sql/subqueries.rs | 40 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 95259c9ca63a..92926fbbd712 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -649,6 +649,7 @@ async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec .map_err(|e| format!("{:?} at {}", e, msg)) .unwrap(); let logical_schema = plan.schema(); + println!("Logical {}", plan.display_indent()); // TODO: remove let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); let plan = ctx @@ -657,6 +658,7 @@ async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec .unwrap(); let optimized_logical_schema = plan.schema(); + println!("Optimized {}", plan.display_indent()); // TODO: remove let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); let plan = ctx .create_physical_plan(&plan) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index ef202bf3ceae..176fb98ba536 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -41,3 +41,43 @@ async fn tpch_q20() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn scalar_subquery() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select * from (values (1)) where column1 > ( select 0.5 );"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00005 |", + "+---------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn interval_year() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-01' + interval '1' year;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00005 |", + "+---------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} From 9b51e46013f926002c8eb6311e18d14c7ff2d17a Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Sun, 26 Jun 2022 14:43:33 -0600 Subject: [PATCH 13/40] Tests pass --- datafusion/core/tests/sql/subqueries.rs | 20 ------------------- .../physical-expr/src/expressions/datetime.rs | 1 + 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 176fb98ba536..19d01ebaf59f 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -61,23 +61,3 @@ async fn scalar_subquery() -> Result<()> { Ok(()) } - -#[tokio::test] -async fn interval_year() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select date '1994-01-01' + interval '1' year;"; - let results = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+---------+", - "| c1 |", - "+---------+", - "| 0.00005 |", - "+---------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) -} diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index 6b4b2e571fe1..6469962b7a94 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -27,6 +27,7 @@ use std::cmp::min; use std::fmt::{Display, Formatter}; use std::ops::{Add, Sub}; use std::sync::Arc; +use chrono::{Datelike, NaiveDate}; /// Perform DATE +/ INTERVAL math #[derive(Debug)] From de8ae11b3a22320fa8f033195c517efd0fbad28a Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 10:55:41 -0600 Subject: [PATCH 14/40] Fix rebase --- datafusion/physical-expr/src/expressions/datetime.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index 6469962b7a94..6b4b2e571fe1 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -27,7 +27,6 @@ use std::cmp::min; use std::fmt::{Display, Formatter}; use std::ops::{Add, Sub}; use std::sync::Arc; -use chrono::{Datelike, NaiveDate}; /// Perform DATE +/ INTERVAL math #[derive(Debug)] From 8dd2b160ed67d865787b10719e188a775396ae43 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 12:03:05 -0600 Subject: [PATCH 15/40] Fix query --- datafusion/core/tests/sql/subqueries.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 19d01ebaf59f..00e09f08177b 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -21,7 +21,7 @@ async fn tpch_q20() -> Result<()> { where ps_partkey in ( select p_partkey from part where p_name like 'forest%' ) and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate >= date '1994-01-01' - and l_shipdate < 'date 1994-01-01' + interval '1' year + and l_shipdate < date '1994-01-01' + interval '1' year ) ) and s_nationkey = n_nationkey and n_name = 'CANADA' From 34b29088603494266e80c081ba7e61c06e24ca36 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 14:13:28 -0600 Subject: [PATCH 16/40] Additional tests --- datafusion/core/tests/sql/subqueries.rs | 68 +++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 00e09f08177b..c0e184e46ed3 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -42,6 +42,74 @@ async fn tpch_q20() -> Result<()> { Ok(()) } +#[tokio::test] +async fn tpch_q20_correlated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "supplier").await?; + register_tpch_csv(&ctx, "nation").await?; + register_tpch_csv(&ctx, "partsupp").await?; + register_tpch_csv(&ctx, "part").await?; + register_tpch_csv(&ctx, "lineitem").await?; + + let sql = r#" + select ps_suppkey from partsupp + where ps_partkey in ( select p_partkey from part where p_name like 'forest%' ) + and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem + where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) order by ps_suppkey; + "#; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00005 |", + "+---------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q20_decorrelated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "supplier").await?; + register_tpch_csv(&ctx, "nation").await?; + register_tpch_csv(&ctx, "partsupp").await?; + register_tpch_csv(&ctx, "part").await?; + register_tpch_csv(&ctx, "lineitem").await?; + + let sql = r#" + select ps_suppkey + from partsupp ps + inner join part on ps.ps_partkey = part.p_partkey and p_name like 'forest%' + inner join ( + select l_partkey, l_suppkey, 0.5 * sum(l_quantity) as threshold from lineitem + where l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + group by l_partkey, l_suppkey + ) av on av.l_suppkey=ps.ps_suppkey and av.l_partkey=ps.ps_partkey and ps.ps_availqty > av.threshold + order by ps_suppkey; + "#; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00005 |", + "+---------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + #[tokio::test] async fn scalar_subquery() -> Result<()> { let ctx = SessionContext::new(); From 6a759ce407ac5a54c3de13d1e0163af2b542d127 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 14:17:46 -0600 Subject: [PATCH 17/40] Reduce to minimum failing (and passing) cases --- datafusion/core/tests/sql/subqueries.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index c0e184e46ed3..ad621557f173 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -53,10 +53,8 @@ async fn tpch_q20_correlated() -> Result<()> { let sql = r#" select ps_suppkey from partsupp - where ps_partkey in ( select p_partkey from part where p_name like 'forest%' ) - and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem - where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate >= date '1994-01-01' - and l_shipdate < date '1994-01-01' + interval '1' year + where ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem + where l_partkey = ps_partkey and l_suppkey = ps_suppkey ) order by ps_suppkey; "#; let results = execute_to_batches(&ctx, sql).await; @@ -86,11 +84,8 @@ async fn tpch_q20_decorrelated() -> Result<()> { let sql = r#" select ps_suppkey from partsupp ps - inner join part on ps.ps_partkey = part.p_partkey and p_name like 'forest%' inner join ( select l_partkey, l_suppkey, 0.5 * sum(l_quantity) as threshold from lineitem - where l_shipdate >= date '1994-01-01' - and l_shipdate < date '1994-01-01' + interval '1' year group by l_partkey, l_suppkey ) av on av.l_suppkey=ps.ps_suppkey and av.l_partkey=ps.ps_partkey and ps.ps_availqty > av.threshold order by ps_suppkey; From 37a73c217464c32edba081b2e59ab94f8653e8bb Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 14:43:10 -0600 Subject: [PATCH 18/40] Adjust so data _should_ be returned, but see none --- datafusion/core/tests/sql/mod.rs | 4 ++-- datafusion/core/tests/sql/subqueries.rs | 16 ++++++++++++++-- datafusion/core/tests/tpch-csv/partsupp.csv | 2 +- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 92926fbbd712..6704d0553b51 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -108,9 +108,9 @@ mod explain; mod idenfifers; pub mod information_schema; mod partitioned_csv; +mod subqueries; #[cfg(feature = "unicode_expressions")] pub mod unicode; -mod subqueries; fn assert_float_eq(expected: &[Vec], received: &[Vec]) where @@ -513,7 +513,7 @@ fn get_tpch_table_schema(table: &str) -> Schema { Field::new("p_retailprice", DataType::Float64, false), Field::new("p_comment", DataType::Utf8, false), ]), - + _ => unimplemented!("Table: {}", table), } } diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index ad621557f173..d175791236b8 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -1,7 +1,7 @@ use super::*; +use crate::sql::execute_to_batches; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; -use crate::sql::{execute_to_batches}; /// https://github.com/apache/arrow-datafusion/issues/171 #[tokio::test] @@ -81,8 +81,20 @@ async fn tpch_q20_decorrelated() -> Result<()> { register_tpch_csv(&ctx, "part").await?; register_tpch_csv(&ctx, "lineitem").await?; + /* + #suppkey + Sort: #ps.ps_suppkey ASC NULLS LAST + Projection: #ps.ps_suppkey AS suppkey, #ps.ps_suppkey + Inner Join: #ps.ps_suppkey = #av.l_suppkey, #ps.ps_partkey = #av.l_partkey Filter: #ps.ps_availqty > #av.threshold + SubqueryAlias: ps + TableScan: partsupp projection=Some([ps_partkey, ps_suppkey, ps_availqty]) + Projection: #av.l_partkey, #av.l_suppkey, #av.threshold, alias=av + Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) * #SUM(lineitem.l_quantity) AS threshold, alias=av + Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]] + TableScan: lineitem projection=Some([l_partkey, l_suppkey, l_quantity]) + */ let sql = r#" - select ps_suppkey + select ps_suppkey as suppkey from partsupp ps inner join ( select l_partkey, l_suppkey, 0.5 * sum(l_quantity) as threshold from lineitem diff --git a/datafusion/core/tests/tpch-csv/partsupp.csv b/datafusion/core/tests/tpch-csv/partsupp.csv index 833789312fc1..66392afa0641 100644 --- a/datafusion/core/tests/tpch-csv/partsupp.csv +++ b/datafusion/core/tests/tpch-csv/partsupp.csv @@ -1,2 +1,2 @@ ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment -1,252,8076,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff +1,67310,7311,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff From 1db5c8d33b2f2f06e5ff11f841389d62a25819f1 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 14:47:18 -0600 Subject: [PATCH 19/40] Fixed data, decorrelated test passes --- datafusion/core/tests/sql/subqueries.rs | 4 ++-- datafusion/core/tests/tpch-csv/partsupp.csv | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index d175791236b8..b0e8ce02f726 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -106,9 +106,9 @@ async fn tpch_q20_decorrelated() -> Result<()> { let expected = vec![ "+---------+", - "| c1 |", + "| suppkey |", "+---------+", - "| 0.00005 |", + "| 7311 |", "+---------+", ]; diff --git a/datafusion/core/tests/tpch-csv/partsupp.csv b/datafusion/core/tests/tpch-csv/partsupp.csv index 66392afa0641..d7db83d03042 100644 --- a/datafusion/core/tests/tpch-csv/partsupp.csv +++ b/datafusion/core/tests/tpch-csv/partsupp.csv @@ -1,2 +1,2 @@ ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment -1,67310,7311,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff +67310,7311,100,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff From f3ee70c65c7de0bbceeff3d8c04b9df4e5bbb5e0 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 14:55:47 -0600 Subject: [PATCH 20/40] Check in plans --- datafusion/core/tests/sql/subqueries.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index b0e8ce02f726..ed5f293a80d9 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -51,6 +51,15 @@ async fn tpch_q20_correlated() -> Result<()> { register_tpch_csv(&ctx, "part").await?; register_tpch_csv(&ctx, "lineitem").await?; + /* +#partsupp.ps_suppkey ASC NULLS LAST + Projection: #partsupp.ps_suppkey + Filter: #partsupp.ps_availqty > (Subquery: Projection: Float64(0.5) * #SUM(lineitem.l_quantity) + Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_quantity)]] + Filter: #lineitem.l_partkey = #partsupp.ps_partkey AND #lineitem.l_suppkey = #partsupp.ps_suppkey + TableScan: lineitem projection=None) + TableScan: partsupp projection=None + */ let sql = r#" select ps_suppkey from partsupp where ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem @@ -82,7 +91,7 @@ async fn tpch_q20_decorrelated() -> Result<()> { register_tpch_csv(&ctx, "lineitem").await?; /* - #suppkey + #suppkey Sort: #ps.ps_suppkey ASC NULLS LAST Projection: #ps.ps_suppkey AS suppkey, #ps.ps_suppkey Inner Join: #ps.ps_suppkey = #av.l_suppkey, #ps.ps_partkey = #av.l_partkey Filter: #ps.ps_availqty > #av.threshold From b08da9737f01f450fbe896cea117c682ffa84037 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 14:58:30 -0600 Subject: [PATCH 21/40] Put real assertion in place --- datafusion/core/tests/sql/subqueries.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index ed5f293a80d9..43f1208d9072 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -70,9 +70,9 @@ async fn tpch_q20_correlated() -> Result<()> { let expected = vec![ "+---------+", - "| c1 |", + "| suppkey |", "+---------+", - "| 0.00005 |", + "| 7311 |", "+---------+", ]; From f22c0791168b3dc49b81bbd80c1b92583588b1e7 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 15:07:22 -0600 Subject: [PATCH 22/40] Add test for already working subquery optimizer --- datafusion/core/tests/sql/subqueries.rs | 46 +++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 43f1208d9072..24b4290adb31 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -145,3 +145,49 @@ async fn scalar_subquery() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn filter_to_join() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "customer").await?; + register_tpch_csv(&ctx, "nation").await?; + + /* +Sort: #customer.c_custkey ASC NULLS LAST + Projection: #customer.c_custkey + Filter: #customer.c_nationkey IN (Subquery: Projection: #nation.n_nationkey + TableScan: nation projection=None) + TableScan: customer projection=None + */ + let sql = r#" + select c_custkey from customer + where c_nationkey in (select n_nationkey from nation) + order by c_custkey; + "#; + let results = execute_to_batches(&ctx, sql).await; + /* +Sort: #customer.c_custkey ASC NULLS LAST + Projection: #customer.c_custkey + Semi Join: #customer.c_nationkey = #nation.n_nationkey + TableScan: customer projection=Some([c_custkey, c_nationkey]) + Projection: #nation.n_nationkey + TableScan: nation projection=Some([n_nationkey]) + */ + + let expected = vec![ + "+-----------+", + "| c_custkey |", + "+-----------+", + "| 3 |", + "| 4 |", + "| 5 |", + "| 9 |", + "| 10 |", + "+-----------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + From 0e0e0c72a4e22f60334458c6baeeb0dd5a752f0f Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 15:49:29 -0600 Subject: [PATCH 23/40] Add decorellator --- datafusion/core/src/execution/context.rs | 2 + datafusion/core/tests/sql/subqueries.rs | 66 +++++++++---------- datafusion/optimizer/src/lib.rs | 1 + .../optimizer/src/subquery_decorrelate.rs | 28 ++++++++ 4 files changed, 64 insertions(+), 33 deletions(-) create mode 100644 datafusion/optimizer/src/subquery_decorrelate.rs diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 458b915260a6..393e8aef34ec 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -99,6 +99,7 @@ use chrono::{DateTime, Utc}; use datafusion_common::ScalarValue; use datafusion_expr::TableSource; use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; +use datafusion_optimizer::subquery_decorrelate::SubqueryDecorrelate; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -1239,6 +1240,7 @@ impl SessionState { // of applying other optimizations Arc::new(SimplifyExpressions::new()), Arc::new(SubqueryFilterToJoin::new()), + Arc::new(SubqueryDecorrelate::new()), Arc::new(EliminateFilter::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 24b4290adb31..9a523ac5b3f6 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -52,14 +52,16 @@ async fn tpch_q20_correlated() -> Result<()> { register_tpch_csv(&ctx, "lineitem").await?; /* -#partsupp.ps_suppkey ASC NULLS LAST - Projection: #partsupp.ps_suppkey - Filter: #partsupp.ps_availqty > (Subquery: Projection: Float64(0.5) * #SUM(lineitem.l_quantity) - Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_quantity)]] - Filter: #lineitem.l_partkey = #partsupp.ps_partkey AND #lineitem.l_suppkey = #partsupp.ps_suppkey - TableScan: lineitem projection=None) - TableScan: partsupp projection=None - */ + #partsupp.ps_suppkey ASC NULLS LAST + Projection: #partsupp.ps_suppkey + Filter: #partsupp.ps_availqty > ( + Subquery: Projection: Float64(0.5) * #SUM(lineitem.l_quantity) + Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_quantity)]] + Filter: #lineitem.l_partkey = #partsupp.ps_partkey AND #lineitem.l_suppkey = #partsupp.ps_suppkey + TableScan: lineitem projection=None + ) + TableScan: partsupp projection=None + */ let sql = r#" select ps_suppkey from partsupp where ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem @@ -91,17 +93,17 @@ async fn tpch_q20_decorrelated() -> Result<()> { register_tpch_csv(&ctx, "lineitem").await?; /* - #suppkey - Sort: #ps.ps_suppkey ASC NULLS LAST - Projection: #ps.ps_suppkey AS suppkey, #ps.ps_suppkey - Inner Join: #ps.ps_suppkey = #av.l_suppkey, #ps.ps_partkey = #av.l_partkey Filter: #ps.ps_availqty > #av.threshold - SubqueryAlias: ps - TableScan: partsupp projection=Some([ps_partkey, ps_suppkey, ps_availqty]) - Projection: #av.l_partkey, #av.l_suppkey, #av.threshold, alias=av - Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) * #SUM(lineitem.l_quantity) AS threshold, alias=av - Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]] - TableScan: lineitem projection=Some([l_partkey, l_suppkey, l_quantity]) - */ + #suppkey + Sort: #ps.ps_suppkey ASC NULLS LAST + Projection: #ps.ps_suppkey AS suppkey, #ps.ps_suppkey + Inner Join: #ps.ps_suppkey = #av.l_suppkey, #ps.ps_partkey = #av.l_partkey Filter: #ps.ps_availqty > #av.threshold + SubqueryAlias: ps + TableScan: partsupp projection=Some([ps_partkey, ps_suppkey, ps_availqty]) + Projection: #av.l_partkey, #av.l_suppkey, #av.threshold, alias=av + Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) * #SUM(lineitem.l_quantity) AS threshold, alias=av + Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]] + TableScan: lineitem projection=Some([l_partkey, l_suppkey, l_quantity]) + */ let sql = r#" select ps_suppkey as suppkey from partsupp ps @@ -153,12 +155,11 @@ async fn filter_to_join() -> Result<()> { register_tpch_csv(&ctx, "nation").await?; /* -Sort: #customer.c_custkey ASC NULLS LAST - Projection: #customer.c_custkey - Filter: #customer.c_nationkey IN (Subquery: Projection: #nation.n_nationkey - TableScan: nation projection=None) - TableScan: customer projection=None - */ + Sort: #customer.c_custkey ASC NULLS LAST + Projection: #customer.c_custkey + Filter: #customer.c_nationkey IN (Subquery: Projection: #nation.n_nationkey TableScan: nation projection=None) + TableScan: customer projection=None + */ let sql = r#" select c_custkey from customer where c_nationkey in (select n_nationkey from nation) @@ -166,13 +167,13 @@ Sort: #customer.c_custkey ASC NULLS LAST "#; let results = execute_to_batches(&ctx, sql).await; /* -Sort: #customer.c_custkey ASC NULLS LAST - Projection: #customer.c_custkey - Semi Join: #customer.c_nationkey = #nation.n_nationkey - TableScan: customer projection=Some([c_custkey, c_nationkey]) - Projection: #nation.n_nationkey - TableScan: nation projection=Some([n_nationkey]) - */ + Sort: #customer.c_custkey ASC NULLS LAST + Projection: #customer.c_custkey + Semi Join: #customer.c_nationkey = #nation.n_nationkey + TableScan: customer projection=Some([c_custkey, c_nationkey]) + Projection: #nation.n_nationkey + TableScan: nation projection=Some([n_nationkey]) + */ let expected = vec![ "+-----------+", @@ -190,4 +191,3 @@ Sort: #customer.c_custkey ASC NULLS LAST Ok(()) } - diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index a6b7cfcbb8fb..9afe6af0ca13 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -27,6 +27,7 @@ pub mod projection_push_down; pub mod reduce_outer_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; +pub mod subquery_decorrelate; pub mod subquery_filter_to_join; pub mod utils; diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs new file mode 100644 index 000000000000..41c9b328fe10 --- /dev/null +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -0,0 +1,28 @@ +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_expr::LogicalPlan; + +/// Optimizer rule for rewriting subquery filters to joins +#[derive(Default)] +pub struct SubqueryDecorrelate {} + +impl SubqueryDecorrelate { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for SubqueryDecorrelate { + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &OptimizerConfig, + ) -> datafusion_common::Result { + println!("{}", plan.display_indent()); + return Ok(plan.clone()); + } + + fn name(&self) -> &str { + "subquery_decorrelate" + } +} From 308b67c7ac19fae7860e7fbd77c62e4a054ce5df Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 27 Jun 2022 16:05:36 -0600 Subject: [PATCH 24/40] Check in broken test --- .../optimizer/src/subquery_decorrelate.rs | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index 41c9b328fe10..9a17a82a2b38 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -16,7 +16,7 @@ impl OptimizerRule for SubqueryDecorrelate { fn optimize( &self, plan: &LogicalPlan, - optimizer_config: &OptimizerConfig, + _optimizer_config: &OptimizerConfig, ) -> datafusion_common::Result { println!("{}", plan.display_indent()); return Ok(plan.clone()); @@ -26,3 +26,52 @@ impl OptimizerRule for SubqueryDecorrelate { "subquery_decorrelate" } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use super::*; + use crate::test::*; + use datafusion_expr::{ + col, in_subquery, logical_plan::LogicalPlanBuilder, + }; + + #[test] + fn in_subquery_simple() -> datafusion_common::Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("c"), test_subquery_with_name("sq")?))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: #test.b [b:UInt32]\ + \n Semi Join: #test.c = #sq.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: #sq.c [c:UInt32]\ + \n TableScan: sq projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + // TODO: deduplicate with subquery_filter_to_join + fn test_subquery_with_name(name: &str) -> datafusion_common::Result> { + let table_scan = test_table_scan_with_name(name)?; + Ok(Arc::new( + LogicalPlanBuilder::from(table_scan) + .project(vec![col("c")])? + .build()?, + )) + } + + // TODO: deduplicate with subquery_filter_to_join + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { + let rule = SubqueryDecorrelate::new(); + let optimized_plan = rule + .optimize(plan, &OptimizerConfig::new()) + .expect("failed to optimize plan"); + let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); + assert_eq!(formatted_plan, expected); + } + +} From 0c5ed1a9147023afbe8b8dfafe72b2c0344b93c0 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 28 Jun 2022 11:00:37 -0600 Subject: [PATCH 25/40] Add some passing and failing tests to see scope of problem --- datafusion/core/tests/sql/subqueries.rs | 179 +++++++++++++++++++++++- 1 file changed, 172 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 9a523ac5b3f6..7faae3c366dd 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -45,10 +45,7 @@ async fn tpch_q20() -> Result<()> { #[tokio::test] async fn tpch_q20_correlated() -> Result<()> { let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "supplier").await?; - register_tpch_csv(&ctx, "nation").await?; register_tpch_csv(&ctx, "partsupp").await?; - register_tpch_csv(&ctx, "part").await?; register_tpch_csv(&ctx, "lineitem").await?; /* @@ -83,13 +80,14 @@ async fn tpch_q20_correlated() -> Result<()> { Ok(()) } +// 0. recurse down to most deeply nested subquery +// 1. find references to outer scope (ps_partkey, ps_suppkey), if none, bail - not correlated +// 2. remove correlated fields from filter +// 3. add correlated fields as group by & to projection #[tokio::test] async fn tpch_q20_decorrelated() -> Result<()> { let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "supplier").await?; - register_tpch_csv(&ctx, "nation").await?; register_tpch_csv(&ctx, "partsupp").await?; - register_tpch_csv(&ctx, "part").await?; register_tpch_csv(&ctx, "lineitem").await?; /* @@ -128,6 +126,171 @@ async fn tpch_q20_decorrelated() -> Result<()> { Ok(()) } +#[tokio::test] +async fn tpch_q4_correlated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "orders").await?; + register_tpch_csv(&ctx, "lineitem").await?; + + /* +#orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Filter: EXISTS ( + Subquery: Projection: #lineitem.l_orderkey, #lineitem.l_partkey, #lineitem.l_suppkey, #lineitem.l_linenumber, #lineitem.l_quantity, #lineitem.l_extendedprice, #lineitem.l_discount, #lineitem.l_tax, + #lineitem.l_returnflag, #lineitem.l_linestatus, #lineitem.l_shipdate, #lineitem.l_commitdate, #lineitem.l_receiptdate, #lineitem.l_shipinstruct, #lineitem.l_shipmode, #lineitem.l_comment + Filter: #lineitem.l_orderkey = #orders.o_orderkey + TableScan: lineitem projection=None + ) + TableScan: orders projection=None + */ + let sql = r#" + select o_orderpriority, count(*) as order_count + from orders + where exists ( select * from lineitem where l_orderkey = o_orderkey ) + group by o_orderpriority + order by o_orderpriority; + "#; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| suppkey |", + "+---------+", + "| 7311 |", + "+---------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q4_decorrelated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "orders").await?; + register_tpch_csv(&ctx, "lineitem").await?; + + /* +#o.o_orderpriority ASC NULLS LAST + Projection: #o.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#o.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Inner Join: #o.o_orderkey = #l.l_orderkey + SubqueryAlias: o + TableScan: orders projection=Some([o_orderkey, o_orderpriority]) + Projection: #l.l_orderkey, alias=l + Projection: #lineitem.l_orderkey, alias=l + Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] + TableScan: lineitem projection=Some([l_orderkey]) + */ + let sql = r#" + select o_orderpriority, count(*) as order_count + from orders o + inner join ( select l_orderkey from lineitem group by l_orderkey ) l on l.l_orderkey = o_orderkey + group by o_orderpriority + order by o_orderpriority; + "#; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+-----------------+-------------+", + "| o_orderpriority | order_count |", + "+-----------------+-------------+", + "| 1-URGENT | 1 |", + "| 5-LOW | 1 |", + "+-----------------+-------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q17_correlated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "lineitem").await?; + register_tpch_csv(&ctx, "part").await?; + + /* +#SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly + Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]] + Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX") AND #lineitem.l_quantity < ( + Subquery: Projection: Float64(0.2) * #AVG(lineitem.l_quantity) + Aggregate: groupBy=[[]], aggr=[[AVG(#lineitem.l_quantity)]] + Filter: #lineitem.l_partkey = #part.p_partkey + TableScan: lineitem projection=None + ) + Inner Join: #lineitem.l_partkey = #part.p_partkey + TableScan: lineitem projection=None + TableScan: part projection=None + */ + let sql = r#" + select sum(l_extendedprice) / 7.0 as avg_yearly + from lineitem, part + where p_partkey = l_partkey and p_brand = 'Brand#23' and p_container = 'MED BOX' + and l_quantity < ( select 0.2 * avg(l_quantity) from lineitem where l_partkey = p_partkey + ); + "#; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| suppkey |", + "+---------+", + "| 7311 |", + "+---------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q17_decorrelated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "lineitem").await?; + register_tpch_csv(&ctx, "part").await?; + + /* + #SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly + Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]] + Filter: #lineitem.l_quantity < #li.qty + Inner Join: #part.p_partkey = #li.l_partkey + Inner Join: #lineitem.l_partkey = #part.p_partkey + TableScan: lineitem projection=Some([l_partkey, l_quantity, l_extendedprice]) + Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX") + TableScan: part projection=Some([p_partkey, p_brand, p_container]), partial_filters=[#part.p_brand = Utf8("Brand#23"), #part.p_container = Utf8("MED BOX")] + Projection: #li.l_partkey, #li.qty, alias=li + Projection: #lineitem.l_partkey, Float64(0.2) * #AVG(lineitem.l_quantity) AS qty, alias=li + Aggregate: groupBy=[[#lineitem.l_partkey]], aggr=[[AVG(#lineitem.l_quantity)]] + TableScan: lineitem projection=Some([l_partkey, l_quantity, l_extendedprice]) + */ + let sql = r#" + select sum(l_extendedprice) / 7.0 as avg_yearly + from lineitem + inner join part on p_partkey = l_partkey + inner join ( select l_partkey, 0.2 * avg(l_quantity) as qty from lineitem group by l_partkey + ) li on li.l_partkey = p_partkey + where p_brand = 'Brand#23' and p_container = 'MED BOX' and l_quantity < li.qty; + "#; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| avg_yearly |", + "+------------+", + "| |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + #[tokio::test] async fn scalar_subquery() -> Result<()> { let ctx = SessionContext::new(); @@ -157,7 +320,9 @@ async fn filter_to_join() -> Result<()> { /* Sort: #customer.c_custkey ASC NULLS LAST Projection: #customer.c_custkey - Filter: #customer.c_nationkey IN (Subquery: Projection: #nation.n_nationkey TableScan: nation projection=None) + Filter: #customer.c_nationkey IN ( + Subquery: Projection: #nation.n_nationkey TableScan: nation projection=None + ) TableScan: customer projection=None */ let sql = r#" From d11d7f95e107ac2a7e4b1ff82e8b30cd87bcd7b6 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 28 Jun 2022 12:57:08 -0600 Subject: [PATCH 26/40] Have almost all inputs needed for optimization, but need to catch 1 level earlier in tree --- datafusion/core/tests/sql/subqueries.rs | 17 +++-- .../optimizer/src/subquery_decorrelate.rs | 66 +++++++++++++++++-- 2 files changed, 69 insertions(+), 14 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 7faae3c366dd..5ed091711bd1 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -134,15 +134,14 @@ async fn tpch_q4_correlated() -> Result<()> { /* #orders.o_orderpriority ASC NULLS LAST - Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count - Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] - Filter: EXISTS ( - Subquery: Projection: #lineitem.l_orderkey, #lineitem.l_partkey, #lineitem.l_suppkey, #lineitem.l_linenumber, #lineitem.l_quantity, #lineitem.l_extendedprice, #lineitem.l_discount, #lineitem.l_tax, - #lineitem.l_returnflag, #lineitem.l_linestatus, #lineitem.l_shipdate, #lineitem.l_commitdate, #lineitem.l_receiptdate, #lineitem.l_shipinstruct, #lineitem.l_shipmode, #lineitem.l_comment - Filter: #lineitem.l_orderkey = #orders.o_orderkey - TableScan: lineitem projection=None - ) - TableScan: orders projection=None + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Filter: EXISTS ( -- plan + Subquery: Projection: * -- proj + Filter: #lineitem.l_orderkey = #orders.o_orderkey -- filter + TableScan: lineitem projection=None -- filter.input + ) + TableScan: orders projection=None */ let sql = r#" select o_orderpriority, count(*) as order_count diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index 9a17a82a2b38..c66189ddb2e5 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,5 +1,9 @@ -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_expr::LogicalPlan; +use std::hash::Hash; +use std::sync::Arc; +use hashbrown::HashSet; +use datafusion_expr::logical_plan::Filter; +use crate::{OptimizerConfig, OptimizerRule, utils}; +use datafusion_expr::{Expr, LogicalPlan}; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] @@ -16,10 +20,62 @@ impl OptimizerRule for SubqueryDecorrelate { fn optimize( &self, plan: &LogicalPlan, - _optimizer_config: &OptimizerConfig, + optimizer_config: &OptimizerConfig, ) -> datafusion_common::Result { - println!("{}", plan.display_indent()); - return Ok(plan.clone()); + match plan { + LogicalPlan::Filter(Filter { predicate, input }) => { + // Apply optimizer rule to current input + let optimized_input = self.optimize(input, optimizer_config)?; + + match predicate { + // TODO: arbitrary expression trees, Expr::InSubQuery + Expr::Exists { subquery,negated } => { + let text = format!("{:?}", subquery); + match &*subquery.subquery { + LogicalPlan::Projection(proj) => { + println!("proj"); + } + _ => return Ok(plan.clone()) + } + for input in subquery.subquery.inputs() { + match input { + LogicalPlan::Filter(filter) => { + match &filter.predicate { + Expr::BinaryExpr { left, op, right } => { + let lcol = match &**left { + Expr::Column(col) => col, + _ => return Ok(plan.clone()) + }; + let rcol = match &**right { + Expr::Column(col) => col, + _ => return Ok(plan.clone()) + }; + let cols = vec![lcol, rcol]; + let cols: HashSet<_> = cols.iter().map(|c| &c.name).collect(); + let fields: HashSet<_> = input.schema().fields().iter().map(|f| f.name()).collect(); + + let found: Vec<_> = cols.intersection(&fields).map(|it| (*it).clone()).collect(); + let closed_upon: Vec<_> = cols.difference(&fields).map(|it| (*it).clone()).collect(); + + println!("{:?} {:?}", found, closed_upon); + }, + _ => return Ok(plan.clone()) + } + }, + _ => return Ok(plan.clone()) + } + } + }, + _ => return Ok(plan.clone()) + } + + return Ok(plan.clone()) + }, + _ => { + // Apply the optimization to all inputs of the plan + utils::optimize_children(self, plan, optimizer_config) + } + } } fn name(&self) -> &str { From 6ab68941803e10627ef211cf6a2dec914837ae60 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 28 Jun 2022 13:05:30 -0600 Subject: [PATCH 27/40] Collected all inputs, now we just need to optimize --- datafusion/core/tests/sql/subqueries.rs | 2 +- datafusion/optimizer/src/subquery_decorrelate.rs | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 5ed091711bd1..88ba43641cf2 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -141,7 +141,7 @@ async fn tpch_q4_correlated() -> Result<()> { Filter: #lineitem.l_orderkey = #orders.o_orderkey -- filter TableScan: lineitem projection=None -- filter.input ) - TableScan: orders projection=None + TableScan: orders projection=None -- plan.inputs */ let sql = r#" select o_orderpriority, count(*) as order_count diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index c66189ddb2e5..d16e6cace3a3 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -27,6 +27,10 @@ impl OptimizerRule for SubqueryDecorrelate { // Apply optimizer rule to current input let optimized_input = self.optimize(input, optimizer_config)?; + for input in plan.inputs() { + println!("{}", input.display_indent()); + } + match predicate { // TODO: arbitrary expression trees, Expr::InSubQuery Expr::Exists { subquery,negated } => { From b281c8c8c3cefb09cfc3cf0c2d6826c5b132e8bd Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 28 Jun 2022 14:54:20 -0600 Subject: [PATCH 28/40] Successfully decorrelated query 4 --- datafusion/core/tests/sql/subqueries.rs | 58 +++++++++++++------ .../optimizer/src/subquery_decorrelate.rs | 39 ++++++++++--- 2 files changed, 71 insertions(+), 26 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 88ba43641cf2..372b15605cce 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -150,16 +150,39 @@ async fn tpch_q4_correlated() -> Result<()> { group by o_orderpriority order by o_orderpriority; "#; - let results = execute_to_batches(&ctx, sql).await; + // assert plan + let plan = ctx + .create_logical_plan(sql) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let plan = ctx + .optimize(&plan) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let actual = format!("{}", plan.display_indent()); + let expected = +r#"Sort: #orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Inner Join: #orders.o_orderkey = #lineitem.l_orderkey + TableScan: orders projection=Some([o_orderkey, o_orderpriority]) + Projection: #lineitem.l_orderkey + Projection: #lineitem.l_orderkey + Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] + TableScan: lineitem projection=Some([l_orderkey])"#.to_string(); + assert_eq!(expected, actual); + + // assert data + let results = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+---------+", - "| suppkey |", - "+---------+", - "| 7311 |", - "+---------+", + "+-----------------+-------------+", + "| o_orderpriority | order_count |", + "+-----------------+-------------+", + "| 1-URGENT | 1 |", + "| 5-LOW | 1 |", + "+-----------------+-------------+", ]; - assert_batches_eq!(expected, &results); Ok(()) @@ -172,21 +195,20 @@ async fn tpch_q4_decorrelated() -> Result<()> { register_tpch_csv(&ctx, "lineitem").await?; /* -#o.o_orderpriority ASC NULLS LAST - Projection: #o.o_orderpriority, #COUNT(UInt8(1)) AS order_count - Aggregate: groupBy=[[#o.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] - Inner Join: #o.o_orderkey = #l.l_orderkey - SubqueryAlias: o - TableScan: orders projection=Some([o_orderkey, o_orderpriority]) - Projection: #l.l_orderkey, alias=l - Projection: #lineitem.l_orderkey, alias=l +Sort: #orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Inner Join: #orders.o_orderkey = #lineitem.l_orderkey + TableScan: orders projection=Some([o_orderkey, o_orderpriority]) + Projection: #lineitem.l_orderkey + Projection: #lineitem.l_orderkey Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] TableScan: lineitem projection=Some([l_orderkey]) - */ + */ let sql = r#" select o_orderpriority, count(*) as order_count - from orders o - inner join ( select l_orderkey from lineitem group by l_orderkey ) l on l.l_orderkey = o_orderkey + from orders + inner join ( select l_orderkey from lineitem group by l_orderkey ) on l_orderkey = o_orderkey group by o_orderpriority order by o_orderpriority; "#; diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index d16e6cace3a3..4e47c32eb396 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,9 +1,9 @@ -use std::hash::Hash; -use std::sync::Arc; +use std::borrow::Borrow; use hashbrown::HashSet; -use datafusion_expr::logical_plan::Filter; +use datafusion_common::Column; +use datafusion_expr::logical_plan::{Filter, JoinType}; use crate::{OptimizerConfig, OptimizerRule, utils}; -use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] @@ -25,7 +25,16 @@ impl OptimizerRule for SubqueryDecorrelate { match plan { LogicalPlan::Filter(Filter { predicate, input }) => { // Apply optimizer rule to current input - let optimized_input = self.optimize(input, optimizer_config)?; + // let optimized_input = self.optimize(input, optimizer_config)?; + + if plan.inputs().len() != 1 { + return Ok(plan.clone()); + } + let first = if let Some(f) = plan.inputs().get(0) { + (*f).clone() + } else { + return Ok(plan.clone()); + }; for input in plan.inputs() { println!("{}", input.display_indent()); @@ -34,9 +43,9 @@ impl OptimizerRule for SubqueryDecorrelate { match predicate { // TODO: arbitrary expression trees, Expr::InSubQuery Expr::Exists { subquery,negated } => { - let text = format!("{:?}", subquery); + let _text = format!("{:?}", subquery); match &*subquery.subquery { - LogicalPlan::Projection(proj) => { + LogicalPlan::Projection(_proj) => { println!("proj"); } _ => return Ok(plan.clone()) @@ -61,7 +70,21 @@ impl OptimizerRule for SubqueryDecorrelate { let found: Vec<_> = cols.intersection(&fields).map(|it| (*it).clone()).collect(); let closed_upon: Vec<_> = cols.difference(&fields).map(|it| (*it).clone()).collect(); - println!("{:?} {:?}", found, closed_upon); + println!("------{:?} {:?}", found, closed_upon); + + let a = vec![Column::from_qualified_name(closed_upon.get(0).unwrap())]; + let b = vec![Column::from_qualified_name(found.get(0).unwrap())]; + let c = found.get(0).unwrap().clone(); + let d = vec![Expr::Column(c.as_str().into())]; + let e: Vec = vec![]; + let r = LogicalPlanBuilder::from((*filter.input).clone()) + .aggregate(d, e).unwrap() + .project(vec![Expr::Column(c.as_str().into())]).unwrap() + .project(vec![Expr::Column(c.as_str().into())]).unwrap() + .build().unwrap(); + return LogicalPlanBuilder::from(first) + .join(&r, JoinType::Inner, (a.clone(), b.clone()), None).unwrap() + .build(); }, _ => return Ok(plan.clone()) } From 6a08eb1ab3924eb45766e7f3652991f053831a2c Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 13:10:29 -0600 Subject: [PATCH 29/40] refactor --- datafusion/core/tests/sql/mod.rs | 2 - datafusion/core/tests/sql/subqueries.rs | 82 ++--------- .../optimizer/src/subquery_decorrelate.rs | 135 ++++++++---------- 3 files changed, 71 insertions(+), 148 deletions(-) diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 6704d0553b51..fd5a189a9f66 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -649,7 +649,6 @@ async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec .map_err(|e| format!("{:?} at {}", e, msg)) .unwrap(); let logical_schema = plan.schema(); - println!("Logical {}", plan.display_indent()); // TODO: remove let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); let plan = ctx @@ -658,7 +657,6 @@ async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec .unwrap(); let optimized_logical_schema = plan.schema(); - println!("Optimized {}", plan.display_indent()); // TODO: remove let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); let plan = ctx .create_physical_plan(&plan) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 372b15605cce..af388dfd8d28 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -4,7 +4,7 @@ use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; /// https://github.com/apache/arrow-datafusion/issues/171 -#[tokio::test] +// #[tokio::test] async fn tpch_q20() -> Result<()> { let ctx = SessionContext::new(); register_tpch_csv(&ctx, "supplier").await?; @@ -42,7 +42,7 @@ async fn tpch_q20() -> Result<()> { Ok(()) } -#[tokio::test] +// #[tokio::test] async fn tpch_q20_correlated() -> Result<()> { let ctx = SessionContext::new(); register_tpch_csv(&ctx, "partsupp").await?; @@ -84,7 +84,7 @@ async fn tpch_q20_correlated() -> Result<()> { // 1. find references to outer scope (ps_partkey, ps_suppkey), if none, bail - not correlated // 2. remove correlated fields from filter // 3. add correlated fields as group by & to projection -#[tokio::test] +// #[tokio::test] async fn tpch_q20_decorrelated() -> Result<()> { let ctx = SessionContext::new(); register_tpch_csv(&ctx, "partsupp").await?; @@ -166,12 +166,12 @@ r#"Sort: #orders.o_orderpriority ASC NULLS LAST Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] Inner Join: #orders.o_orderkey = #lineitem.l_orderkey - TableScan: orders projection=Some([o_orderkey, o_orderpriority]) + TableScan: orders projection=[o_orderkey, o_orderpriority] Projection: #lineitem.l_orderkey Projection: #lineitem.l_orderkey Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] - TableScan: lineitem projection=Some([l_orderkey])"#.to_string(); - assert_eq!(expected, actual); + TableScan: lineitem projection=[l_orderkey]"#.to_string(); + assert_eq!(actual, expected); // assert data let results = execute_to_batches(&ctx, sql).await; @@ -228,7 +228,7 @@ Sort: #orders.o_orderpriority ASC NULLS LAST Ok(()) } -#[tokio::test] +// #[tokio::test] async fn tpch_q17_correlated() -> Result<()> { let ctx = SessionContext::new(); register_tpch_csv(&ctx, "lineitem").await?; @@ -269,7 +269,7 @@ async fn tpch_q17_correlated() -> Result<()> { Ok(()) } -#[tokio::test] +// #[tokio::test] async fn tpch_q17_decorrelated() -> Result<()> { let ctx = SessionContext::new(); register_tpch_csv(&ctx, "lineitem").await?; @@ -311,69 +311,3 @@ async fn tpch_q17_decorrelated() -> Result<()> { Ok(()) } - -#[tokio::test] -async fn scalar_subquery() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select * from (values (1)) where column1 > ( select 0.5 );"; - let results = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+---------+", - "| c1 |", - "+---------+", - "| 0.00005 |", - "+---------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn filter_to_join() -> Result<()> { - let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "customer").await?; - register_tpch_csv(&ctx, "nation").await?; - - /* - Sort: #customer.c_custkey ASC NULLS LAST - Projection: #customer.c_custkey - Filter: #customer.c_nationkey IN ( - Subquery: Projection: #nation.n_nationkey TableScan: nation projection=None - ) - TableScan: customer projection=None - */ - let sql = r#" - select c_custkey from customer - where c_nationkey in (select n_nationkey from nation) - order by c_custkey; - "#; - let results = execute_to_batches(&ctx, sql).await; - /* - Sort: #customer.c_custkey ASC NULLS LAST - Projection: #customer.c_custkey - Semi Join: #customer.c_nationkey = #nation.n_nationkey - TableScan: customer projection=Some([c_custkey, c_nationkey]) - Projection: #nation.n_nationkey - TableScan: nation projection=Some([n_nationkey]) - */ - - let expected = vec![ - "+-----------+", - "| c_custkey |", - "+-----------+", - "| 3 |", - "| 4 |", - "| 5 |", - "| 9 |", - "| 10 |", - "+-----------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) -} diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index 4e47c32eb396..f003eca76a6f 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,7 +1,7 @@ -use std::borrow::Borrow; +use std::sync::Arc; use hashbrown::HashSet; use datafusion_common::Column; -use datafusion_expr::logical_plan::{Filter, JoinType}; +use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; use crate::{OptimizerConfig, OptimizerRule, utils}; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; @@ -24,79 +24,13 @@ impl OptimizerRule for SubqueryDecorrelate { ) -> datafusion_common::Result { match plan { LogicalPlan::Filter(Filter { predicate, input }) => { - // Apply optimizer rule to current input - // let optimized_input = self.optimize(input, optimizer_config)?; - - if plan.inputs().len() != 1 { - return Ok(plan.clone()); - } - let first = if let Some(f) = plan.inputs().get(0) { - (*f).clone() - } else { - return Ok(plan.clone()); - }; - - for input in plan.inputs() { - println!("{}", input.display_indent()); - } - - match predicate { + return match predicate { // TODO: arbitrary expression trees, Expr::InSubQuery - Expr::Exists { subquery,negated } => { - let _text = format!("{:?}", subquery); - match &*subquery.subquery { - LogicalPlan::Projection(_proj) => { - println!("proj"); - } - _ => return Ok(plan.clone()) - } - for input in subquery.subquery.inputs() { - match input { - LogicalPlan::Filter(filter) => { - match &filter.predicate { - Expr::BinaryExpr { left, op, right } => { - let lcol = match &**left { - Expr::Column(col) => col, - _ => return Ok(plan.clone()) - }; - let rcol = match &**right { - Expr::Column(col) => col, - _ => return Ok(plan.clone()) - }; - let cols = vec![lcol, rcol]; - let cols: HashSet<_> = cols.iter().map(|c| &c.name).collect(); - let fields: HashSet<_> = input.schema().fields().iter().map(|f| f.name()).collect(); - - let found: Vec<_> = cols.intersection(&fields).map(|it| (*it).clone()).collect(); - let closed_upon: Vec<_> = cols.difference(&fields).map(|it| (*it).clone()).collect(); - - println!("------{:?} {:?}", found, closed_upon); - - let a = vec![Column::from_qualified_name(closed_upon.get(0).unwrap())]; - let b = vec![Column::from_qualified_name(found.get(0).unwrap())]; - let c = found.get(0).unwrap().clone(); - let d = vec![Expr::Column(c.as_str().into())]; - let e: Vec = vec![]; - let r = LogicalPlanBuilder::from((*filter.input).clone()) - .aggregate(d, e).unwrap() - .project(vec![Expr::Column(c.as_str().into())]).unwrap() - .project(vec![Expr::Column(c.as_str().into())]).unwrap() - .build().unwrap(); - return LogicalPlanBuilder::from(first) - .join(&r, JoinType::Inner, (a.clone(), b.clone()), None).unwrap() - .build(); - }, - _ => return Ok(plan.clone()) - } - }, - _ => return Ok(plan.clone()) - } - } + Expr::Exists { subquery, negated: _ } => { + optimize_exists(plan, subquery, input) }, - _ => return Ok(plan.clone()) + _ => Ok(plan.clone()) } - - return Ok(plan.clone()) }, _ => { // Apply the optimization to all inputs of the plan @@ -110,6 +44,63 @@ impl OptimizerRule for SubqueryDecorrelate { } } +fn optimize_exists( + plan: &LogicalPlan, + subquery: &Subquery, + input: &Arc, +) -> datafusion_common::Result { + let _text = format!("{:?}", subquery); + match &*subquery.subquery { + LogicalPlan::Projection(_proj) => { + println!("proj"); + } + _ => return Ok(plan.clone()) + } + for sub_input in subquery.subquery.inputs() { + match sub_input { + LogicalPlan::Filter(filter) => { + match &filter.predicate { + Expr::BinaryExpr { left, op: _, right } => { + let lcol = match &**left { + Expr::Column(col) => col, + _ => return Ok(plan.clone()) + }; + let rcol = match &**right { + Expr::Column(col) => col, + _ => return Ok(plan.clone()) + }; + let cols = vec![lcol, rcol]; + let cols: HashSet<_> = cols.iter().map(|c| &c.name).collect(); + let fields: HashSet<_> = sub_input.schema().fields().iter().map(|f| f.name()).collect(); + + let found: Vec<_> = cols.intersection(&fields).map(|it| (*it).clone()).collect(); + let closed_upon: Vec<_> = cols.difference(&fields).map(|it| (*it).clone()).collect(); + + println!("------{:?} {:?}", found, closed_upon); + + let a = vec![Column::from_qualified_name(closed_upon.get(0).unwrap())]; + let b = vec![Column::from_qualified_name(found.get(0).unwrap())]; + let c = found.get(0).unwrap().clone(); + let d = vec![Expr::Column(c.as_str().into())]; + let e: Vec = vec![]; + let r = LogicalPlanBuilder::from((*filter.input).clone()) + .aggregate(d, e).unwrap() + .project(vec![Expr::Column(c.as_str().into())]).unwrap() + .project(vec![Expr::Column(c.as_str().into())]).unwrap() + .build().unwrap(); + return LogicalPlanBuilder::from((**input).clone()) + .join(&r, JoinType::Inner, (a.clone(), b.clone()), None).unwrap() + .build(); + }, + _ => return Ok(plan.clone()) + } + }, + _ => return Ok(plan.clone()) + } + } + return Ok(plan.clone()); +} + #[cfg(test)] mod tests { use std::sync::Arc; From 7e025450e8d449c1bf97d17d5ce1addd0b00949c Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 13:41:11 -0600 Subject: [PATCH 30/40] Pass test 4 --- datafusion/core/tests/sql/subqueries.rs | 99 ++++++----- .../optimizer/src/subquery_decorrelate.rs | 154 +++++++++++------- 2 files changed, 140 insertions(+), 113 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index af388dfd8d28..6d965ab816d0 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -133,16 +133,16 @@ async fn tpch_q4_correlated() -> Result<()> { register_tpch_csv(&ctx, "lineitem").await?; /* -#orders.o_orderpriority ASC NULLS LAST - Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count - Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] - Filter: EXISTS ( -- plan - Subquery: Projection: * -- proj - Filter: #lineitem.l_orderkey = #orders.o_orderkey -- filter - TableScan: lineitem projection=None -- filter.input - ) - TableScan: orders projection=None -- plan.inputs - */ + #orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Filter: EXISTS ( -- plan + Subquery: Projection: * -- proj + Filter: #lineitem.l_orderkey = #orders.o_orderkey -- filter + TableScan: lineitem projection=None -- filter.input + ) + TableScan: orders projection=None -- plan.inputs + */ let sql = r#" select o_orderpriority, count(*) as order_count from orders @@ -161,16 +161,15 @@ async fn tpch_q4_correlated() -> Result<()> { .map_err(|e| format!("{:?} at {}", e, "error")) .unwrap(); let actual = format!("{}", plan.display_indent()); - let expected = -r#"Sort: #orders.o_orderpriority ASC NULLS LAST + let expected = r#"Sort: #orders.o_orderpriority ASC NULLS LAST Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] Inner Join: #orders.o_orderkey = #lineitem.l_orderkey TableScan: orders projection=[o_orderkey, o_orderpriority] Projection: #lineitem.l_orderkey - Projection: #lineitem.l_orderkey - Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] - TableScan: lineitem projection=[l_orderkey]"#.to_string(); + Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] + TableScan: lineitem projection=[l_orderkey]"# + .to_string(); assert_eq!(actual, expected); // assert data @@ -195,16 +194,16 @@ async fn tpch_q4_decorrelated() -> Result<()> { register_tpch_csv(&ctx, "lineitem").await?; /* -Sort: #orders.o_orderpriority ASC NULLS LAST - Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count - Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] - Inner Join: #orders.o_orderkey = #lineitem.l_orderkey - TableScan: orders projection=Some([o_orderkey, o_orderpriority]) - Projection: #lineitem.l_orderkey - Projection: #lineitem.l_orderkey - Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] - TableScan: lineitem projection=Some([l_orderkey]) - */ + Sort: #orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Inner Join: #orders.o_orderkey = #lineitem.l_orderkey + TableScan: orders projection=Some([o_orderkey, o_orderpriority]) + Projection: #lineitem.l_orderkey + Projection: #lineitem.l_orderkey + Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] + TableScan: lineitem projection=Some([l_orderkey]) + */ let sql = r#" select o_orderpriority, count(*) as order_count from orders @@ -235,18 +234,18 @@ async fn tpch_q17_correlated() -> Result<()> { register_tpch_csv(&ctx, "part").await?; /* -#SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly - Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]] - Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX") AND #lineitem.l_quantity < ( - Subquery: Projection: Float64(0.2) * #AVG(lineitem.l_quantity) - Aggregate: groupBy=[[]], aggr=[[AVG(#lineitem.l_quantity)]] - Filter: #lineitem.l_partkey = #part.p_partkey - TableScan: lineitem projection=None - ) - Inner Join: #lineitem.l_partkey = #part.p_partkey - TableScan: lineitem projection=None - TableScan: part projection=None - */ + #SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly + Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]] + Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX") AND #lineitem.l_quantity < ( + Subquery: Projection: Float64(0.2) * #AVG(lineitem.l_quantity) + Aggregate: groupBy=[[]], aggr=[[AVG(#lineitem.l_quantity)]] + Filter: #lineitem.l_partkey = #part.p_partkey + TableScan: lineitem projection=None + ) + Inner Join: #lineitem.l_partkey = #part.p_partkey + TableScan: lineitem projection=None + TableScan: part projection=None + */ let sql = r#" select sum(l_extendedprice) / 7.0 as avg_yearly from lineitem, part @@ -276,19 +275,19 @@ async fn tpch_q17_decorrelated() -> Result<()> { register_tpch_csv(&ctx, "part").await?; /* - #SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly - Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]] - Filter: #lineitem.l_quantity < #li.qty - Inner Join: #part.p_partkey = #li.l_partkey - Inner Join: #lineitem.l_partkey = #part.p_partkey - TableScan: lineitem projection=Some([l_partkey, l_quantity, l_extendedprice]) - Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX") - TableScan: part projection=Some([p_partkey, p_brand, p_container]), partial_filters=[#part.p_brand = Utf8("Brand#23"), #part.p_container = Utf8("MED BOX")] - Projection: #li.l_partkey, #li.qty, alias=li - Projection: #lineitem.l_partkey, Float64(0.2) * #AVG(lineitem.l_quantity) AS qty, alias=li - Aggregate: groupBy=[[#lineitem.l_partkey]], aggr=[[AVG(#lineitem.l_quantity)]] - TableScan: lineitem projection=Some([l_partkey, l_quantity, l_extendedprice]) - */ + #SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly + Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]] + Filter: #lineitem.l_quantity < #li.qty + Inner Join: #part.p_partkey = #li.l_partkey + Inner Join: #lineitem.l_partkey = #part.p_partkey + TableScan: lineitem projection=Some([l_partkey, l_quantity, l_extendedprice]) + Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX") + TableScan: part projection=Some([p_partkey, p_brand, p_container]), partial_filters=[#part.p_brand = Utf8("Brand#23"), #part.p_container = Utf8("MED BOX")] + Projection: #li.l_partkey, #li.qty, alias=li + Projection: #lineitem.l_partkey, Float64(0.2) * #AVG(lineitem.l_quantity) AS qty, alias=li + Aggregate: groupBy=[[#lineitem.l_partkey]], aggr=[[AVG(#lineitem.l_quantity)]] + TableScan: lineitem projection=Some([l_partkey, l_quantity, l_extendedprice]) + */ let sql = r#" select sum(l_extendedprice) / 7.0 as avg_yearly from lineitem diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index f003eca76a6f..9086e6cf8c44 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,9 +1,9 @@ -use std::sync::Arc; -use hashbrown::HashSet; +use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::Column; use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; -use crate::{OptimizerConfig, OptimizerRule, utils}; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use hashbrown::HashSet; +use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] @@ -26,12 +26,13 @@ impl OptimizerRule for SubqueryDecorrelate { LogicalPlan::Filter(Filter { predicate, input }) => { return match predicate { // TODO: arbitrary expression trees, Expr::InSubQuery - Expr::Exists { subquery, negated: _ } => { - optimize_exists(plan, subquery, input) - }, - _ => Ok(plan.clone()) - } - }, + Expr::Exists { + subquery, + negated: _, + } => optimize_exists(plan, subquery, input), + _ => Ok(plan.clone()), + }; + } _ => { // Apply the optimization to all inputs of the plan utils::optimize_children(self, plan, optimizer_config) @@ -49,66 +50,92 @@ fn optimize_exists( subquery: &Subquery, input: &Arc, ) -> datafusion_common::Result { - let _text = format!("{:?}", subquery); - match &*subquery.subquery { - LogicalPlan::Projection(_proj) => { - println!("proj"); - } - _ => return Ok(plan.clone()) + // Only operate if there is one input + let sub_inputs = subquery.subquery.inputs(); + if sub_inputs.len() != 1 { + return Ok(plan.clone()); } - for sub_input in subquery.subquery.inputs() { - match sub_input { - LogicalPlan::Filter(filter) => { - match &filter.predicate { - Expr::BinaryExpr { left, op: _, right } => { - let lcol = match &**left { - Expr::Column(col) => col, - _ => return Ok(plan.clone()) - }; - let rcol = match &**right { - Expr::Column(col) => col, - _ => return Ok(plan.clone()) - }; - let cols = vec![lcol, rcol]; - let cols: HashSet<_> = cols.iter().map(|c| &c.name).collect(); - let fields: HashSet<_> = sub_input.schema().fields().iter().map(|f| f.name()).collect(); - - let found: Vec<_> = cols.intersection(&fields).map(|it| (*it).clone()).collect(); - let closed_upon: Vec<_> = cols.difference(&fields).map(|it| (*it).clone()).collect(); - - println!("------{:?} {:?}", found, closed_upon); - - let a = vec![Column::from_qualified_name(closed_upon.get(0).unwrap())]; - let b = vec![Column::from_qualified_name(found.get(0).unwrap())]; - let c = found.get(0).unwrap().clone(); - let d = vec![Expr::Column(c.as_str().into())]; - let e: Vec = vec![]; - let r = LogicalPlanBuilder::from((*filter.input).clone()) - .aggregate(d, e).unwrap() - .project(vec![Expr::Column(c.as_str().into())]).unwrap() - .project(vec![Expr::Column(c.as_str().into())]).unwrap() - .build().unwrap(); - return LogicalPlanBuilder::from((**input).clone()) - .join(&r, JoinType::Inner, (a.clone(), b.clone()), None).unwrap() - .build(); - }, - _ => return Ok(plan.clone()) - } - }, - _ => return Ok(plan.clone()) - } + let sub_input = if let Some(i) = sub_inputs.get(0) { + i + } else { + return Ok(plan.clone()); + }; + + // Only operate on subqueries that are trying to filter on an expression from an outer query + let filter = if let LogicalPlan::Filter(f) = sub_input { + f + } else { + return Ok(plan.clone()); + }; + + // Only operate on a single binary expression (for now) + let (left, _op, right) = + if let Expr::BinaryExpr { left, op, right } = &filter.predicate { + (left, op, right) + } else { + return Ok(plan.clone()); + }; + + // collect list of columns + let lcol = match &**left { + Expr::Column(col) => col, + _ => return Ok(plan.clone()), + }; + let rcol = match &**right { + Expr::Column(col) => col, + _ => return Ok(plan.clone()), + }; + let cols = vec![lcol, rcol]; + let cols: HashSet<_> = cols.iter().map(|c| &c.name).collect(); + let fields: HashSet<_> = sub_input + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect(); + + // Only operate if one column is present and the other closed upon from outside scope + let found: Vec<_> = cols.intersection(&fields).map(|it| (*it).clone()).collect(); + let closed_upon: Vec<_> = cols.difference(&fields).map(|it| (*it).clone()).collect(); + if found.len() != 1 || closed_upon.len() != 1 { + return Ok(plan.clone()); } - return Ok(plan.clone()); + let found = if let Some(it) = found.get(0) { + it + } else { + return Ok(plan.clone()); + }; + let closed_upon = if let Some(it) = closed_upon.get(0) { + it + } else { + return Ok(plan.clone()); + }; + + let c_col = vec![Column::from_qualified_name(closed_upon)]; + let f_col = vec![Column::from_qualified_name(found)]; + let expr = vec![Expr::Column(found.as_str().into())]; + let group_expr = vec![Expr::Column(found.as_str().into())]; + let aggr_expr: Vec = vec![]; + let join_keys = (c_col.clone(), f_col.clone()); + let right = LogicalPlanBuilder::from((*filter.input).clone()) + .aggregate(group_expr, aggr_expr) + .unwrap() + .project(expr) + .unwrap() + .build() + .unwrap(); + return LogicalPlanBuilder::from((**input).clone()) + .join(&right, JoinType::Inner, join_keys, None) + .unwrap() + .build(); } #[cfg(test)] mod tests { - use std::sync::Arc; use super::*; use crate::test::*; - use datafusion_expr::{ - col, in_subquery, logical_plan::LogicalPlanBuilder, - }; + use datafusion_expr::{col, in_subquery, logical_plan::LogicalPlanBuilder}; + use std::sync::Arc; #[test] fn in_subquery_simple() -> datafusion_common::Result<()> { @@ -129,7 +156,9 @@ mod tests { } // TODO: deduplicate with subquery_filter_to_join - fn test_subquery_with_name(name: &str) -> datafusion_common::Result> { + fn test_subquery_with_name( + name: &str, + ) -> datafusion_common::Result> { let table_scan = test_table_scan_with_name(name)?; Ok(Arc::new( LogicalPlanBuilder::from(table_scan) @@ -147,5 +176,4 @@ mod tests { let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); assert_eq!(formatted_plan, expected); } - } From ea3f219afbbb651d9d9f07260eadfe60b5a6b92f Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 13:58:50 -0600 Subject: [PATCH 31/40] Ready for PR? --- datafusion/core/tests/sql/subqueries.rs | 247 ------------------ .../optimizer/src/subquery_decorrelate.rs | 85 ++---- 2 files changed, 22 insertions(+), 310 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 6d965ab816d0..758a8408ec91 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -3,129 +3,6 @@ use crate::sql::execute_to_batches; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; -/// https://github.com/apache/arrow-datafusion/issues/171 -// #[tokio::test] -async fn tpch_q20() -> Result<()> { - let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "supplier").await?; - register_tpch_csv(&ctx, "nation").await?; - register_tpch_csv(&ctx, "partsupp").await?; - register_tpch_csv(&ctx, "part").await?; - register_tpch_csv(&ctx, "lineitem").await?; - - let sql = r#" - select s_name, s_address - from supplier, nation - where s_suppkey in ( - select ps_suppkey from partsupp - where ps_partkey in ( select p_partkey from part where p_name like 'forest%' ) - and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem - where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate >= date '1994-01-01' - and l_shipdate < date '1994-01-01' + interval '1' year - ) - ) - and s_nationkey = n_nationkey and n_name = 'CANADA' - order by s_name; - "#; - let results = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+---------+", - "| c1 |", - "+---------+", - "| 0.00005 |", - "+---------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) -} - -// #[tokio::test] -async fn tpch_q20_correlated() -> Result<()> { - let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "partsupp").await?; - register_tpch_csv(&ctx, "lineitem").await?; - - /* - #partsupp.ps_suppkey ASC NULLS LAST - Projection: #partsupp.ps_suppkey - Filter: #partsupp.ps_availqty > ( - Subquery: Projection: Float64(0.5) * #SUM(lineitem.l_quantity) - Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_quantity)]] - Filter: #lineitem.l_partkey = #partsupp.ps_partkey AND #lineitem.l_suppkey = #partsupp.ps_suppkey - TableScan: lineitem projection=None - ) - TableScan: partsupp projection=None - */ - let sql = r#" - select ps_suppkey from partsupp - where ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem - where l_partkey = ps_partkey and l_suppkey = ps_suppkey - ) order by ps_suppkey; - "#; - let results = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+---------+", - "| suppkey |", - "+---------+", - "| 7311 |", - "+---------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) -} - -// 0. recurse down to most deeply nested subquery -// 1. find references to outer scope (ps_partkey, ps_suppkey), if none, bail - not correlated -// 2. remove correlated fields from filter -// 3. add correlated fields as group by & to projection -// #[tokio::test] -async fn tpch_q20_decorrelated() -> Result<()> { - let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "partsupp").await?; - register_tpch_csv(&ctx, "lineitem").await?; - - /* - #suppkey - Sort: #ps.ps_suppkey ASC NULLS LAST - Projection: #ps.ps_suppkey AS suppkey, #ps.ps_suppkey - Inner Join: #ps.ps_suppkey = #av.l_suppkey, #ps.ps_partkey = #av.l_partkey Filter: #ps.ps_availqty > #av.threshold - SubqueryAlias: ps - TableScan: partsupp projection=Some([ps_partkey, ps_suppkey, ps_availqty]) - Projection: #av.l_partkey, #av.l_suppkey, #av.threshold, alias=av - Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) * #SUM(lineitem.l_quantity) AS threshold, alias=av - Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]] - TableScan: lineitem projection=Some([l_partkey, l_suppkey, l_quantity]) - */ - let sql = r#" - select ps_suppkey as suppkey - from partsupp ps - inner join ( - select l_partkey, l_suppkey, 0.5 * sum(l_quantity) as threshold from lineitem - group by l_partkey, l_suppkey - ) av on av.l_suppkey=ps.ps_suppkey and av.l_partkey=ps.ps_partkey and ps.ps_availqty > av.threshold - order by ps_suppkey; - "#; - let results = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+---------+", - "| suppkey |", - "+---------+", - "| 7311 |", - "+---------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) -} - #[tokio::test] async fn tpch_q4_correlated() -> Result<()> { let ctx = SessionContext::new(); @@ -186,127 +63,3 @@ async fn tpch_q4_correlated() -> Result<()> { Ok(()) } - -#[tokio::test] -async fn tpch_q4_decorrelated() -> Result<()> { - let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "orders").await?; - register_tpch_csv(&ctx, "lineitem").await?; - - /* - Sort: #orders.o_orderpriority ASC NULLS LAST - Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count - Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] - Inner Join: #orders.o_orderkey = #lineitem.l_orderkey - TableScan: orders projection=Some([o_orderkey, o_orderpriority]) - Projection: #lineitem.l_orderkey - Projection: #lineitem.l_orderkey - Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] - TableScan: lineitem projection=Some([l_orderkey]) - */ - let sql = r#" - select o_orderpriority, count(*) as order_count - from orders - inner join ( select l_orderkey from lineitem group by l_orderkey ) on l_orderkey = o_orderkey - group by o_orderpriority - order by o_orderpriority; - "#; - let results = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+-----------------+-------------+", - "| o_orderpriority | order_count |", - "+-----------------+-------------+", - "| 1-URGENT | 1 |", - "| 5-LOW | 1 |", - "+-----------------+-------------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) -} - -// #[tokio::test] -async fn tpch_q17_correlated() -> Result<()> { - let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "lineitem").await?; - register_tpch_csv(&ctx, "part").await?; - - /* - #SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly - Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]] - Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX") AND #lineitem.l_quantity < ( - Subquery: Projection: Float64(0.2) * #AVG(lineitem.l_quantity) - Aggregate: groupBy=[[]], aggr=[[AVG(#lineitem.l_quantity)]] - Filter: #lineitem.l_partkey = #part.p_partkey - TableScan: lineitem projection=None - ) - Inner Join: #lineitem.l_partkey = #part.p_partkey - TableScan: lineitem projection=None - TableScan: part projection=None - */ - let sql = r#" - select sum(l_extendedprice) / 7.0 as avg_yearly - from lineitem, part - where p_partkey = l_partkey and p_brand = 'Brand#23' and p_container = 'MED BOX' - and l_quantity < ( select 0.2 * avg(l_quantity) from lineitem where l_partkey = p_partkey - ); - "#; - let results = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+---------+", - "| suppkey |", - "+---------+", - "| 7311 |", - "+---------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) -} - -// #[tokio::test] -async fn tpch_q17_decorrelated() -> Result<()> { - let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "lineitem").await?; - register_tpch_csv(&ctx, "part").await?; - - /* - #SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly - Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]] - Filter: #lineitem.l_quantity < #li.qty - Inner Join: #part.p_partkey = #li.l_partkey - Inner Join: #lineitem.l_partkey = #part.p_partkey - TableScan: lineitem projection=Some([l_partkey, l_quantity, l_extendedprice]) - Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX") - TableScan: part projection=Some([p_partkey, p_brand, p_container]), partial_filters=[#part.p_brand = Utf8("Brand#23"), #part.p_container = Utf8("MED BOX")] - Projection: #li.l_partkey, #li.qty, alias=li - Projection: #lineitem.l_partkey, Float64(0.2) * #AVG(lineitem.l_quantity) AS qty, alias=li - Aggregate: groupBy=[[#lineitem.l_partkey]], aggr=[[AVG(#lineitem.l_quantity)]] - TableScan: lineitem projection=Some([l_partkey, l_quantity, l_extendedprice]) - */ - let sql = r#" - select sum(l_extendedprice) / 7.0 as avg_yearly - from lineitem - inner join part on p_partkey = l_partkey - inner join ( select l_partkey, 0.2 * avg(l_quantity) as qty from lineitem group by l_partkey - ) li on li.l_partkey = p_partkey - where p_brand = 'Brand#23' and p_container = 'MED BOX' and l_quantity < li.qty; - "#; - let results = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+------------+", - "| avg_yearly |", - "+------------+", - "| |", - "+------------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) -} diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index 9086e6cf8c44..fbde4bdaf622 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -25,11 +25,13 @@ impl OptimizerRule for SubqueryDecorrelate { match plan { LogicalPlan::Filter(Filter { predicate, input }) => { return match predicate { - // TODO: arbitrary expression trees, Expr::InSubQuery - Expr::Exists { - subquery, - negated: _, - } => optimize_exists(plan, subquery, input), + // TODO: arbitrary expressions + Expr::Exists { subquery, negated } => { + if *negated { + return Ok(plan.clone()); + } + optimize_exists(plan, subquery, input) + } _ => Ok(plan.clone()), }; } @@ -45,6 +47,14 @@ impl OptimizerRule for SubqueryDecorrelate { } } +/// Takes a query like: +/// +/// select c.id from customers c where exists (select * from orders o where o.c_id = c.id) +/// +/// and optimizes it into: +/// +/// select c.id from customers c +/// inner join (select o.c_id from orders o group by o.c_id) o on o.c_id = c.c_id fn optimize_exists( plan: &LogicalPlan, subquery: &Subquery, @@ -118,62 +128,11 @@ fn optimize_exists( let aggr_expr: Vec = vec![]; let join_keys = (c_col.clone(), f_col.clone()); let right = LogicalPlanBuilder::from((*filter.input).clone()) - .aggregate(group_expr, aggr_expr) - .unwrap() - .project(expr) - .unwrap() - .build() - .unwrap(); - return LogicalPlanBuilder::from((**input).clone()) - .join(&right, JoinType::Inner, join_keys, None) - .unwrap() - .build(); -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test::*; - use datafusion_expr::{col, in_subquery, logical_plan::LogicalPlanBuilder}; - use std::sync::Arc; - - #[test] - fn in_subquery_simple() -> datafusion_common::Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("c"), test_subquery_with_name("sq")?))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: #test.b [b:UInt32]\ - \n Semi Join: #test.c = #sq.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: #sq.c [c:UInt32]\ - \n TableScan: sq projection=None [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } - - // TODO: deduplicate with subquery_filter_to_join - fn test_subquery_with_name( - name: &str, - ) -> datafusion_common::Result> { - let table_scan = test_table_scan_with_name(name)?; - Ok(Arc::new( - LogicalPlanBuilder::from(table_scan) - .project(vec![col("c")])? - .build()?, - )) - } - - // TODO: deduplicate with subquery_filter_to_join - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let rule = SubqueryDecorrelate::new(); - let optimized_plan = rule - .optimize(plan, &OptimizerConfig::new()) - .expect("failed to optimize plan"); - let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); - assert_eq!(formatted_plan, expected); - } + .aggregate(group_expr, aggr_expr)? + .project(expr)? + .build()?; + let new_plan = LogicalPlanBuilder::from((**input).clone()) + .join(&right, JoinType::Inner, join_keys, None)? + .build()?; + Ok(new_plan) } From 50b35498db4b5d0cdd2058533bab640a03e775ea Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 14:17:41 -0600 Subject: [PATCH 32/40] Only operate on equality expressions --- datafusion/optimizer/src/subquery_decorrelate.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index fbde4bdaf622..757ebac14ffa 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,7 +1,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::Column; use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; -use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Operator}; use hashbrown::HashSet; use std::sync::Arc; @@ -78,13 +78,17 @@ fn optimize_exists( return Ok(plan.clone()); }; - // Only operate on a single binary expression (for now) - let (left, _op, right) = + // Only operate on a single binary equality expression (for now) + let (left, op, right) = if let Expr::BinaryExpr { left, op, right } = &filter.predicate { (left, op, right) } else { return Ok(plan.clone()); }; + match op { + Operator::Eq => {}, + _ => return Ok(plan.clone()) + } // collect list of columns let lcol = match &**left { From f90d95af5d9c3d8ab700a9d8361afff3dd7e0847 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 14:21:30 -0600 Subject: [PATCH 33/40] Lint error --- datafusion/optimizer/src/subquery_decorrelate.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index 757ebac14ffa..ef0516f73575 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -86,8 +86,8 @@ fn optimize_exists( return Ok(plan.clone()); }; match op { - Operator::Eq => {}, - _ => return Ok(plan.clone()) + Operator::Eq => {} + _ => return Ok(plan.clone()), } // collect list of columns From 9377cdf30db3f665d2e642113daa85607a800626 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 16:33:26 -0600 Subject: [PATCH 34/40] Tests still pass because we are losing remaining predicate --- datafusion/core/tests/sql/subqueries.rs | 3 +- datafusion/core/tests/tpch-csv/region.csv | 0 .../optimizer/src/subquery_decorrelate.rs | 121 ++++++++++-------- 3 files changed, 73 insertions(+), 51 deletions(-) create mode 100644 datafusion/core/tests/tpch-csv/region.csv diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 758a8408ec91..9270f35f5de7 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -23,7 +23,8 @@ async fn tpch_q4_correlated() -> Result<()> { let sql = r#" select o_orderpriority, count(*) as order_count from orders - where exists ( select * from lineitem where l_orderkey = o_orderkey ) + where exists ( + select * from lineitem where l_orderkey = o_orderkey and l_commitdate < l_receiptdate) group by o_orderpriority order by o_orderpriority; "#; diff --git a/datafusion/core/tests/tpch-csv/region.csv b/datafusion/core/tests/tpch-csv/region.csv new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index ef0516f73575..b1cae40521d1 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,5 +1,5 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::Column; +use datafusion_common::{Column, DFSchemaRef}; use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Operator}; use hashbrown::HashSet; @@ -78,61 +78,30 @@ fn optimize_exists( return Ok(plan.clone()); }; - // Only operate on a single binary equality expression (for now) - let (left, op, right) = - if let Expr::BinaryExpr { left, op, right } = &filter.predicate { - (left, op, right) - } else { - return Ok(plan.clone()); - }; - match op { - Operator::Eq => {} - _ => return Ok(plan.clone()), - } - - // collect list of columns - let lcol = match &**left { - Expr::Column(col) => col, - _ => return Ok(plan.clone()), - }; - let rcol = match &**right { - Expr::Column(col) => col, - _ => return Ok(plan.clone()), - }; - let cols = vec![lcol, rcol]; - let cols: HashSet<_> = cols.iter().map(|c| &c.name).collect(); - let fields: HashSet<_> = sub_input - .schema() - .fields() - .iter() - .map(|f| f.name()) - .collect(); + // split into filters + let mut filters = vec![]; + utils::split_conjunction(&filter.predicate, &mut filters); - // Only operate if one column is present and the other closed upon from outside scope - let found: Vec<_> = cols.intersection(&fields).map(|it| (*it).clone()).collect(); - let closed_upon: Vec<_> = cols.difference(&fields).map(|it| (*it).clone()).collect(); - if found.len() != 1 || closed_upon.len() != 1 { + // Grab column names to join on + let cols = find_join_exprs(filters, sub_input.schema()); + if cols.is_empty() { return Ok(plan.clone()); } - let found = if let Some(it) = found.get(0) { - it - } else { - return Ok(plan.clone()); - }; - let closed_upon = if let Some(it) = closed_upon.get(0) { - it - } else { - return Ok(plan.clone()); - }; - let c_col = vec![Column::from_qualified_name(closed_upon)]; - let f_col = vec![Column::from_qualified_name(found)]; - let expr = vec![Expr::Column(found.as_str().into())]; - let group_expr = vec![Expr::Column(found.as_str().into())]; + // Only operate if one column is present and the other closed upon from outside scope + let l_col: Vec<_> = cols.iter() + .map(|it| &it.0) + .map(|it| Column::from_qualified_name(it.as_str())) + .collect(); + let r_col: Vec<_> = cols.iter() + .map(|it| &it.1) + .map(|it| Column::from_qualified_name(it.as_str())) + .collect(); + let expr: Vec<_> = r_col.iter().map(|it| Expr::Column(it.clone())).collect(); let aggr_expr: Vec = vec![]; - let join_keys = (c_col.clone(), f_col.clone()); + let join_keys = (l_col.clone(), r_col.clone()); let right = LogicalPlanBuilder::from((*filter.input).clone()) - .aggregate(group_expr, aggr_expr)? + .aggregate(expr.clone(), aggr_expr)? .project(expr)? .build()?; let new_plan = LogicalPlanBuilder::from((**input).clone()) @@ -140,3 +109,55 @@ fn optimize_exists( .build()?; Ok(new_plan) } + +fn find_join_exprs(filters: Vec<&Expr>, schema: &DFSchemaRef) -> Vec<(String, String)> { + // only process equals expressions for joins + let equals: Vec<_> = filters.iter().map(|it| { + match it { + Expr::BinaryExpr { left, op, right } => { + match op { + Operator::Eq => Some((*left.clone(), *right.clone())), + _ => None, + } + } + _ => None + } + }).flatten().collect(); + + // only process column expressions for joins + let cols: Vec<_> = equals.iter().map(|it| { + let l = match &it.0 { + Expr::Column(col) => col, + _ => return None, + }; + let r = match &it.1 { + Expr::Column(col) => col, + _ => return None, + }; + Some((l.name.clone(), r.name.clone())) + }).flatten().collect(); + + // get names of fields TODO: Must fully qualify these! + let fields: HashSet<_> = schema + .fields() + .iter() + .map(|f| f.name()) + .collect(); + + // Ensure closed-upon fields are always on left, and in-scope on the right + let sorted: Vec<_> = cols.iter().map(|it| { + if fields.contains(&it.0) && fields.contains(&it.1) { + return None; // Need one of each + } + if !fields.contains(&it.0) && !fields.contains(&it.1) { + return None; // Need one of each + } + if fields.contains(&it.0) { + Some((it.1.clone(), it.0.clone())) + } else { + Some((it.0.clone(), it.1.clone())) + } + }).flatten().collect(); + + sorted +} \ No newline at end of file From 23b0ffb2f7fab28fd9ed025b8908a25b0abc257a Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 17:26:02 -0600 Subject: [PATCH 35/40] Don't lose remaining expressions --- datafusion/optimizer/Cargo.toml | 1 + .../optimizer/src/subquery_decorrelate.rs | 97 ++++++++++--------- 2 files changed, 50 insertions(+), 48 deletions(-) diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 4bd3868793fe..51b41a1c5291 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -44,4 +44,5 @@ datafusion-common = { path = "../common", version = "9.0.0" } datafusion-expr = { path = "../expr", version = "9.0.0" } datafusion-physical-expr = { path = "../physical-expr", version = "9.0.0" } hashbrown = { version = "0.12", features = ["raw"] } +itertools = "0.10" log = "^0.4" diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index b1cae40521d1..025f0eb4d308 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,9 +1,10 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, DFSchemaRef}; +use datafusion_common::{Column}; use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Operator}; use hashbrown::HashSet; use std::sync::Arc; +use itertools::{Either, Itertools}; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] @@ -82,10 +83,17 @@ fn optimize_exists( let mut filters = vec![]; utils::split_conjunction(&filter.predicate, &mut filters); + // get names of fields TODO: Must fully qualify these! + let fields: HashSet<_> = sub_input.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect(); + // Grab column names to join on - let cols = find_join_exprs(filters, sub_input.schema()); + let (cols, others) = find_join_exprs(filters, &fields); if cols.is_empty() { - return Ok(plan.clone()); + return Ok(plan.clone()); // no joins found } // Only operate if one column is present and the other closed upon from outside scope @@ -110,54 +118,47 @@ fn optimize_exists( Ok(new_plan) } -fn find_join_exprs(filters: Vec<&Expr>, schema: &DFSchemaRef) -> Vec<(String, String)> { - // only process equals expressions for joins - let equals: Vec<_> = filters.iter().map(|it| { - match it { - Expr::BinaryExpr { left, op, right } => { - match op { - Operator::Eq => Some((*left.clone(), *right.clone())), - _ => None, +fn find_join_exprs( + filters: Vec<&Expr>, + fields: &HashSet<&String>, +) -> (Vec<(String, String)>, Vec) { + let (joins, others): (Vec<_>, Vec<_>) = filters.iter() + .partition_map(|filter| { + let (left, op, right) = match filter { + Expr::BinaryExpr { left, op, right } => { + (*left.clone(), op.clone(), *right.clone()) + } + _ => { + return Either::Right((*filter).clone()) } + }; + match op { + Operator::Eq => {} + _ => return Either::Right((*filter).clone()), + } + let left = match left { + Expr::Column(c) => c, + _ => return Either::Right((*filter).clone()), + }; + let right = match right { + Expr::Column(c) => c, + _ => return Either::Right((*filter).clone()), + }; + if fields.contains(&left.name) && fields.contains(&right.name) { + return Either::Right((*filter).clone()); // Need one of each + } + if !fields.contains(&left.name) && !fields.contains(&right.name) { + return Either::Right((*filter).clone()); // Need one of each } - _ => None - } - }).flatten().collect(); - - // only process column expressions for joins - let cols: Vec<_> = equals.iter().map(|it| { - let l = match &it.0 { - Expr::Column(col) => col, - _ => return None, - }; - let r = match &it.1 { - Expr::Column(col) => col, - _ => return None, - }; - Some((l.name.clone(), r.name.clone())) - }).flatten().collect(); - // get names of fields TODO: Must fully qualify these! - let fields: HashSet<_> = schema - .fields() - .iter() - .map(|f| f.name()) - .collect(); + let sorted = if fields.contains(&left.name) { + (right.name.clone(), left.name.clone()) + } else { + (left.name.clone(), right.name.clone()) + }; - // Ensure closed-upon fields are always on left, and in-scope on the right - let sorted: Vec<_> = cols.iter().map(|it| { - if fields.contains(&it.0) && fields.contains(&it.1) { - return None; // Need one of each - } - if !fields.contains(&it.0) && !fields.contains(&it.1) { - return None; // Need one of each - } - if fields.contains(&it.0) { - Some((it.1.clone(), it.0.clone())) - } else { - Some((it.0.clone(), it.1.clone())) - } - }).flatten().collect(); + Either::Left(sorted) + }); - sorted + (joins, others) } \ No newline at end of file From 858b284e149a9f9e4885da56fc46432dd2af0a23 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 29 Jun 2022 17:47:28 -0600 Subject: [PATCH 36/40] Update test to expect remaining filter clause --- datafusion/core/tests/sql/subqueries.rs | 3 +- .../optimizer/src/subquery_decorrelate.rs | 96 ++++++++++--------- 2 files changed, 52 insertions(+), 47 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 9270f35f5de7..9ff1d34fc2e9 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -46,7 +46,8 @@ async fn tpch_q4_correlated() -> Result<()> { TableScan: orders projection=[o_orderkey, o_orderpriority] Projection: #lineitem.l_orderkey Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] - TableScan: lineitem projection=[l_orderkey]"# + Filter: #lineitem.l_commitdate < #lineitem.l_receiptdate + TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate], partial_filters=[#lineitem.l_commitdate < #lineitem.l_receiptdate]"# .to_string(); assert_eq!(actual, expected); diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index 025f0eb4d308..88b30329fdf8 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,10 +1,10 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column}; +use datafusion_common::Column; use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; -use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Operator}; +use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder, Operator}; use hashbrown::HashSet; -use std::sync::Arc; use itertools::{Either, Itertools}; +use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] @@ -25,7 +25,7 @@ impl OptimizerRule for SubqueryDecorrelate { ) -> datafusion_common::Result { match plan { LogicalPlan::Filter(Filter { predicate, input }) => { - return match predicate { + match predicate { // TODO: arbitrary expressions Expr::Exists { subquery, negated } => { if *negated { @@ -34,7 +34,7 @@ impl OptimizerRule for SubqueryDecorrelate { optimize_exists(plan, subquery, input) } _ => Ok(plan.clone()), - }; + } } _ => { // Apply the optimization to all inputs of the plan @@ -84,7 +84,8 @@ fn optimize_exists( utils::split_conjunction(&filter.predicate, &mut filters); // get names of fields TODO: Must fully qualify these! - let fields: HashSet<_> = sub_input.schema() + let fields: HashSet<_> = sub_input + .schema() .fields() .iter() .map(|f| f.name()) @@ -97,18 +98,26 @@ fn optimize_exists( } // Only operate if one column is present and the other closed upon from outside scope - let l_col: Vec<_> = cols.iter() + let l_col: Vec<_> = cols + .iter() .map(|it| &it.0) .map(|it| Column::from_qualified_name(it.as_str())) .collect(); - let r_col: Vec<_> = cols.iter() + let r_col: Vec<_> = cols + .iter() .map(|it| &it.1) .map(|it| Column::from_qualified_name(it.as_str())) .collect(); let expr: Vec<_> = r_col.iter().map(|it| Expr::Column(it.clone())).collect(); let aggr_expr: Vec = vec![]; - let join_keys = (l_col.clone(), r_col.clone()); - let right = LogicalPlanBuilder::from((*filter.input).clone()) + let join_keys = (l_col, r_col); + let right = LogicalPlanBuilder::from((*filter.input).clone()); + let right = if let Some(expr) = combine_filters(&others) { + right.filter(expr)? + } else { + right + }; + let right = right .aggregate(expr.clone(), aggr_expr)? .project(expr)? .build()?; @@ -122,43 +131,38 @@ fn find_join_exprs( filters: Vec<&Expr>, fields: &HashSet<&String>, ) -> (Vec<(String, String)>, Vec) { - let (joins, others): (Vec<_>, Vec<_>) = filters.iter() - .partition_map(|filter| { - let (left, op, right) = match filter { - Expr::BinaryExpr { left, op, right } => { - (*left.clone(), op.clone(), *right.clone()) - } - _ => { - return Either::Right((*filter).clone()) - } - }; - match op { - Operator::Eq => {} - _ => return Either::Right((*filter).clone()), - } - let left = match left { - Expr::Column(c) => c, - _ => return Either::Right((*filter).clone()), - }; - let right = match right { - Expr::Column(c) => c, - _ => return Either::Right((*filter).clone()), - }; - if fields.contains(&left.name) && fields.contains(&right.name) { - return Either::Right((*filter).clone()); // Need one of each - } - if !fields.contains(&left.name) && !fields.contains(&right.name) { - return Either::Right((*filter).clone()); // Need one of each - } + let (joins, others): (Vec<_>, Vec<_>) = filters.iter().partition_map(|filter| { + let (left, op, right) = match filter { + Expr::BinaryExpr { left, op, right } => (*left.clone(), *op, *right.clone()), + _ => return Either::Right((*filter).clone()), + }; + match op { + Operator::Eq => {} + _ => return Either::Right((*filter).clone()), + } + let left = match left { + Expr::Column(c) => c, + _ => return Either::Right((*filter).clone()), + }; + let right = match right { + Expr::Column(c) => c, + _ => return Either::Right((*filter).clone()), + }; + if fields.contains(&left.name) && fields.contains(&right.name) { + return Either::Right((*filter).clone()); // Need one of each + } + if !fields.contains(&left.name) && !fields.contains(&right.name) { + return Either::Right((*filter).clone()); // Need one of each + } - let sorted = if fields.contains(&left.name) { - (right.name.clone(), left.name.clone()) - } else { - (left.name.clone(), right.name.clone()) - }; + let sorted = if fields.contains(&left.name) { + (right.name, left.name) + } else { + (left.name, right.name) + }; - Either::Left(sorted) - }); + Either::Left(sorted) + }); (joins, others) -} \ No newline at end of file +} From 00a661b6bf1d729a5da0aee4dd62d456b498377b Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 30 Jun 2022 13:27:19 -0600 Subject: [PATCH 37/40] Debugging --- datafusion/core/src/execution/context.rs | 4 +-- .../optimizer/src/subquery_decorrelate.rs | 29 +++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 393e8aef34ec..1698c43810f3 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1244,13 +1244,13 @@ impl SessionState { Arc::new(EliminateFilter::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), - Arc::new(ProjectionPushDown::new()), + Arc::new(ProjectionPushDown::new()), // Removes needed fields ]; if config.config_options.get_bool(OPT_FILTER_NULL_JOIN_KEYS) { rules.push(Arc::new(FilterNullJoinKeys::default())); } rules.push(Arc::new(ReduceOuterJoin::new())); - rules.push(Arc::new(FilterPushDown::new())); + rules.push(Arc::new(FilterPushDown::new())); // Fixes the expression rules.push(Arc::new(LimitPushDown::new())); rules.push(Arc::new(SingleDistinctToGroupBy::new())); diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index 88b30329fdf8..10558702c488 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -25,6 +25,13 @@ impl OptimizerRule for SubqueryDecorrelate { ) -> datafusion_common::Result { match plan { LogicalPlan::Filter(Filter { predicate, input }) => { + let fields: HashSet<_> = plan + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect(); + println!("{:?}", fields); match predicate { // TODO: arbitrary expressions Expr::Exists { subquery, negated } => { @@ -48,6 +55,18 @@ impl OptimizerRule for SubqueryDecorrelate { } } +/* +#orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Filter: EXISTS ( -- plan + Subquery: Projection: * -- proj + Filter: #lineitem.l_orderkey = #orders.o_orderkey -- filter + TableScan: lineitem projection=None -- filter.input + ) + TableScan: orders projection=None -- plan.inputs + */ + /// Takes a query like: /// /// select c.id from customers c where exists (select * from orders o where o.c_id = c.id) @@ -90,6 +109,7 @@ fn optimize_exists( .iter() .map(|f| f.name()) .collect(); + println!("{:?}", fields); // Grab column names to join on let (cols, others) = find_join_exprs(filters, &fields); @@ -121,10 +141,15 @@ fn optimize_exists( .aggregate(expr.clone(), aggr_expr)? .project(expr)? .build()?; + println!("Joining:\n{}\nto:\n{}\non:\n{:?}", right.display_indent(), input.display_indent(), join_keys); let new_plan = LogicalPlanBuilder::from((**input).clone()) .join(&right, JoinType::Inner, join_keys, None)? - .build()?; - Ok(new_plan) + .build(); + if let Err(e) = &new_plan { + println!("wtf"); + } + // println!("Optimized:\n{}\n\ninto:\n\n{}", plan.display_indent(), new_plan.display_indent()); + new_plan } fn find_join_exprs( From 1708415aacbcafc8f4a2ae6bc6c052f32d0b5de8 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 30 Jun 2022 13:55:05 -0600 Subject: [PATCH 38/40] Can run query 4 --- .../optimizer/src/subquery_decorrelate.rs | 51 +++++++++++++------ 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index 10558702c488..b7f9d7e8cbe8 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,5 +1,5 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::Column; +use datafusion_common::{Column, DataFusionError}; use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder, Operator}; use hashbrown::HashSet; @@ -25,6 +25,30 @@ impl OptimizerRule for SubqueryDecorrelate { ) -> datafusion_common::Result { match plan { LogicalPlan::Filter(Filter { predicate, input }) => { + let mut filters = vec![]; + utils::split_conjunction(&predicate, &mut filters); + + let (subqueries, others): (Vec<_>, Vec<_>) = filters.iter() + .partition_map(|f| { + match f { + Expr::Exists { subquery, negated } => { + if *negated { // TODO: not exists + Either::Right((*f).clone()) + } else { + Either::Left(subquery.clone()) + } + } + _ => Either::Right((*f).clone()) + } + }); + if subqueries.len() != 1 { + return Ok(plan.clone()); // TODO: >1 subquery + } + let subquery = match subqueries.get(0) { + Some(q) => q, + _ => return Ok(plan.clone()) + }; + let fields: HashSet<_> = plan .schema() .fields() @@ -32,16 +56,8 @@ impl OptimizerRule for SubqueryDecorrelate { .map(|f| f.name()) .collect(); println!("{:?}", fields); - match predicate { - // TODO: arbitrary expressions - Expr::Exists { subquery, negated } => { - if *negated { - return Ok(plan.clone()); - } - optimize_exists(plan, subquery, input) - } - _ => Ok(plan.clone()), - } + + optimize_exists(plan, subquery, input, &others) } _ => { // Apply the optimization to all inputs of the plan @@ -79,6 +95,7 @@ fn optimize_exists( plan: &LogicalPlan, subquery: &Subquery, input: &Arc, + outer_others: &Vec, ) -> datafusion_common::Result { // Only operate if there is one input let sub_inputs = subquery.subquery.inputs(); @@ -143,11 +160,13 @@ fn optimize_exists( .build()?; println!("Joining:\n{}\nto:\n{}\non:\n{:?}", right.display_indent(), input.display_indent(), join_keys); let new_plan = LogicalPlanBuilder::from((**input).clone()) - .join(&right, JoinType::Inner, join_keys, None)? - .build(); - if let Err(e) = &new_plan { - println!("wtf"); - } + .join(&right, JoinType::Inner, join_keys, None)?; + let new_plan = if let Some(expr) = combine_filters(&outer_others) { + new_plan.filter(expr)? + } else { + new_plan + }; + let new_plan = new_plan.build(); // println!("Optimized:\n{}\n\ninto:\n\n{}", plan.display_indent(), new_plan.display_indent()); new_plan } From 60a6e582f6f01bb0569d6996af40cb537ebb2a27 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 30 Jun 2022 13:59:54 -0600 Subject: [PATCH 39/40] Remove debugging code --- datafusion/core/src/execution/context.rs | 4 ++-- datafusion/optimizer/src/subquery_decorrelate.rs | 13 +------------ 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 1698c43810f3..393e8aef34ec 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1244,13 +1244,13 @@ impl SessionState { Arc::new(EliminateFilter::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), - Arc::new(ProjectionPushDown::new()), // Removes needed fields + Arc::new(ProjectionPushDown::new()), ]; if config.config_options.get_bool(OPT_FILTER_NULL_JOIN_KEYS) { rules.push(Arc::new(FilterNullJoinKeys::default())); } rules.push(Arc::new(ReduceOuterJoin::new())); - rules.push(Arc::new(FilterPushDown::new())); // Fixes the expression + rules.push(Arc::new(FilterPushDown::new())); rules.push(Arc::new(LimitPushDown::new())); rules.push(Arc::new(SingleDistinctToGroupBy::new())); diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index b7f9d7e8cbe8..feecc8298259 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -1,5 +1,5 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, DataFusionError}; +use datafusion_common::{Column}; use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder, Operator}; use hashbrown::HashSet; @@ -49,14 +49,6 @@ impl OptimizerRule for SubqueryDecorrelate { _ => return Ok(plan.clone()) }; - let fields: HashSet<_> = plan - .schema() - .fields() - .iter() - .map(|f| f.name()) - .collect(); - println!("{:?}", fields); - optimize_exists(plan, subquery, input, &others) } _ => { @@ -126,7 +118,6 @@ fn optimize_exists( .iter() .map(|f| f.name()) .collect(); - println!("{:?}", fields); // Grab column names to join on let (cols, others) = find_join_exprs(filters, &fields); @@ -158,7 +149,6 @@ fn optimize_exists( .aggregate(expr.clone(), aggr_expr)? .project(expr)? .build()?; - println!("Joining:\n{}\nto:\n{}\non:\n{:?}", right.display_indent(), input.display_indent(), join_keys); let new_plan = LogicalPlanBuilder::from((**input).clone()) .join(&right, JoinType::Inner, join_keys, None)?; let new_plan = if let Some(expr) = combine_filters(&outer_others) { @@ -167,7 +157,6 @@ fn optimize_exists( new_plan }; let new_plan = new_plan.build(); - // println!("Optimized:\n{}\n\ninto:\n\n{}", plan.display_indent(), new_plan.display_indent()); new_plan } From b8c0808392e59337eb7329f2caf6e54dac82da71 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 30 Jun 2022 14:20:05 -0600 Subject: [PATCH 40/40] Clippy --- datafusion/optimizer/src/subquery_decorrelate.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs index feecc8298259..cd19ea0bc7da 100644 --- a/datafusion/optimizer/src/subquery_decorrelate.rs +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -26,7 +26,7 @@ impl OptimizerRule for SubqueryDecorrelate { match plan { LogicalPlan::Filter(Filter { predicate, input }) => { let mut filters = vec![]; - utils::split_conjunction(&predicate, &mut filters); + utils::split_conjunction(predicate, &mut filters); let (subqueries, others): (Vec<_>, Vec<_>) = filters.iter() .partition_map(|f| { @@ -87,7 +87,7 @@ fn optimize_exists( plan: &LogicalPlan, subquery: &Subquery, input: &Arc, - outer_others: &Vec, + outer_others: &[Expr], ) -> datafusion_common::Result { // Only operate if there is one input let sub_inputs = subquery.subquery.inputs(); @@ -151,13 +151,12 @@ fn optimize_exists( .build()?; let new_plan = LogicalPlanBuilder::from((**input).clone()) .join(&right, JoinType::Inner, join_keys, None)?; - let new_plan = if let Some(expr) = combine_filters(&outer_others) { + let new_plan = if let Some(expr) = combine_filters(outer_others) { new_plan.filter(expr)? } else { new_plan }; - let new_plan = new_plan.build(); - new_plan + new_plan.build() } fn find_join_exprs(