Skip to content

Commit

Permalink
single_distinct_to_groupby no longer drops qualifiers (#4050)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Nov 2, 2022
1 parent 396b5aa commit 8d6448e
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 28 deletions.
4 changes: 2 additions & 2 deletions benchmarks/expected-plans/q16.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST
Projection: part.p_brand, part.p_type, part.p_size, COUNT(DISTINCT partsupp.ps_suppkey) AS supplier_cnt
Projection: group_alias_0 AS p_brand, group_alias_1 AS p_type, group_alias_2 AS p_size, COUNT(alias1) AS COUNT(DISTINCT partsupp.ps_suppkey)
Projection: group_alias_0 AS part.p_brand, group_alias_1 AS part.p_type, group_alias_2 AS part.p_size, COUNT(alias1) AS COUNT(DISTINCT partsupp.ps_suppkey)
Aggregate: groupBy=[[group_alias_0, group_alias_1, group_alias_2]], aggr=[[COUNT(alias1)]]
Aggregate: groupBy=[[part.p_brand AS group_alias_0, part.p_type AS group_alias_1, part.p_size AS group_alias_2, partsupp.ps_suppkey AS alias1]], aggr=[[]]
LeftAnti Join: partsupp.ps_suppkey = __sq_1.s_suppkey
Expand All @@ -10,4 +10,4 @@ Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type AS
TableScan: part projection=[p_partkey, p_brand, p_type, p_size]
Projection: supplier.s_suppkey AS s_suppkey, alias=__sq_1
Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%")
TableScan: supplier projection=[s_suppkey, s_comment]
TableScan: supplier projection=[s_suppkey, s_comment]
4 changes: 0 additions & 4 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,7 @@ mod tests {
#[cfg(feature = "ci")]
mod ci {
use super::*;
use datafusion::sql::TableReference;
use datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes};
use std::io::{BufRead, BufReader};
use std::sync::Arc;

async fn serde_round_trip(query: usize) -> Result<()> {
let ctx = SessionContext::default();
Expand Down Expand Up @@ -878,7 +875,6 @@ mod ci {
serde_round_trip(15).await
}

#[ignore] // https://github.com/apache/arrow-datafusion/issues/3820
#[tokio::test]
async fn serde_q16() -> Result<()> {
serde_round_trip(16).await
Expand Down
31 changes: 26 additions & 5 deletions benchmarks/src/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow::array::{
Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array, Int64Array,
StringArray,
};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use std::fs;
use std::ops::{Div, Mul};
Expand Down Expand Up @@ -456,16 +457,36 @@ pub async fn transform_actual_result(
) -> Result<Vec<RecordBatch>> {
// to compare the recorded answers to the answers we got back from running the query,
// we need to round the decimal columns and trim the Utf8 columns
// we also need to rewrite the batches to use a compatible schema
let ctx = SessionContext::new();
let result_schema = result[0].schema();
let fields = result[0]
.schema()
.fields()
.iter()
.map(|f| {
let simple_name = match f.name().find('.') {
Some(i) => f.name()[i + 1..].to_string(),
_ => f.name().to_string(),
};
f.clone().with_name(simple_name)
})
.collect();
let result_schema = SchemaRef::new(Schema::new(fields));
let result = result
.iter()
.map(|b| {
RecordBatch::try_new(result_schema.clone(), b.columns().to_vec())
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>>>()?;
let table = Arc::new(MemTable::try_new(result_schema.clone(), vec![result])?);
let mut df = ctx.read_table(table)?
.select(
result_schema
.fields
.iter()
.map(|field| {
match Field::data_type(field) {
match field.data_type() {
DataType::Decimal128(_, _) => {
// if decimal, then round it to 2 decimal places like the answers
// round() doesn't support the second argument for decimal places to round to
Expand All @@ -481,18 +502,18 @@ pub async fn transform_actual_result(
round,
DataType::Decimal128(15, 2),
))),
Field::name(field).to_string(),
field.name().to_string(),
)
}
DataType::Utf8 => {
// if string, then trim it like the answers got trimmed
Expr::Alias(
Box::new(trim(col(Field::name(field)))),
Field::name(field).to_string(),
field.name().to_string(),
)
}
_ => {
col(Field::name(field))
col(field.name())
}
}
}).collect()
Expand Down
17 changes: 15 additions & 2 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::convert::TryFrom;
use std::sync::Arc;

use crate::error::{DataFusionError, Result, SchemaError};
use crate::{field_not_found, Column};
use crate::{field_not_found, Column, TableReference};

use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
Expand Down Expand Up @@ -203,7 +203,20 @@ impl DFSchema {
// qualifier and name.
(Some(q), Some(field_q)) => q == field_q && field.name() == name,
// field to lookup is qualified but current field is unqualified.
(Some(_), None) => false,
(Some(qq), None) => {
// the original field may now be aliased with a name that matches the
// original qualified name
let table_ref: TableReference = field.name().as_str().into();
match table_ref {
TableReference::Partial { schema, table } => {
schema == qq && table == name
}
TableReference::Full { schema, table, .. } => {
schema == qq && table == name
}
_ => false,
}
}
// field to lookup is unqualified, no need to compare qualifier
(None, Some(_)) | (None, None) => field.name() == name,
})
Expand Down
2 changes: 2 additions & 0 deletions datafusion/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub mod parsers;
mod pyarrow;
pub mod scalar;
pub mod stats;
mod table_reference;
pub mod test_util;

pub use column::Column;
Expand All @@ -35,6 +36,7 @@ pub use error::{field_not_found, DataFusionError, Result, SchemaError};
pub use parsers::parse_interval;
pub use scalar::{ScalarType, ScalarValue};
pub use stats::{ColumnStatistics, Statistics};
pub use table_reference::{ResolvedTableReference, TableReference};

/// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is
/// not possible. In normal usage of DataFusion the downcast should always succeed.
Expand Down
File renamed without changes.
14 changes: 7 additions & 7 deletions datafusion/core/tests/sql/group_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,13 +654,13 @@ async fn group_by_dictionary() {
.expect("ran plan correctly");

let expected = vec![
"+-----+------------------------+",
"| val | COUNT(DISTINCT t.dict) |",
"+-----+------------------------+",
"| 1 | 2 |",
"| 2 | 2 |",
"| 4 | 1 |",
"+-----+------------------------+",
"+-------+------------------------+",
"| t.val | COUNT(DISTINCT t.dict) |",
"+-------+------------------------+",
"| 1 | 2 |",
"| 2 | 2 |",
"| 4 | 1 |",
"+-------+------------------------+",
];
assert_batches_sorted_eq!(expected, &results);
}
Expand Down
12 changes: 7 additions & 5 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,14 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
// - aggr expr
let mut alias_expr: Vec<Expr> = Vec::new();
for (alias, original_field) in group_expr_alias {
alias_expr.push(col(&alias).alias(original_field.name()));
alias_expr.push(col(&alias).alias(original_field.qualified_name()));
}
for (i, expr) in new_aggr_exprs.iter().enumerate() {
alias_expr.push(columnize_expr(
expr.clone()
.alias(schema.clone().fields()[i + group_expr.len()].name()),
expr.clone().alias(
schema.clone().fields()[i + group_expr.len()]
.qualified_name(),
),
&outer_aggr_schema,
));
}
Expand Down Expand Up @@ -362,7 +364,7 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: group_alias_0 AS a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\
let expected = "Projection: group_alias_0 AS test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
Expand Down Expand Up @@ -409,7 +411,7 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: group_alias_0 AS a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
let expected = "Projection: group_alias_0 AS test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1), MAX(alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
\n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
Expand Down
28 changes: 28 additions & 0 deletions datafusion/proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,34 @@ mod roundtrip_tests {
Ok(())
}

#[tokio::test]
async fn roundtrip_single_count_distinct() -> Result<(), DataFusionError> {
let ctx = SessionContext::new();

let schema = Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Decimal128(15, 2), true),
]);

ctx.register_csv(
"t1",
"testdata/test.csv",
CsvReadOptions::default().schema(&schema),
)
.await?;

let query = "SELECT a, COUNT(DISTINCT b) as b_cd FROM t1 GROUP BY a";
let plan = ctx.sql(query).await?.to_logical_plan()?;

println!("{:?}", plan);

let bytes = logical_plan_to_bytes(&plan)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));

Ok(())
}

#[tokio::test]
async fn roundtrip_logical_plan_with_extension() -> Result<(), DataFusionError> {
let ctx = SessionContext::new();
Expand Down
3 changes: 1 addition & 2 deletions datafusion/sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
pub mod parser;
pub mod planner;
mod table_reference;
pub mod utils;

pub use datafusion_common::{ResolvedTableReference, TableReference};
pub use sqlparser;
pub use table_reference::{ResolvedTableReference, TableReference};
2 changes: 1 addition & 1 deletion datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ use std::str::FromStr;
use std::sync::Arc;
use std::{convert::TryInto, vec};

use crate::table_reference::TableReference;
use crate::utils::{make_decimal_type, normalize_ident, resolve_columns};
use datafusion_common::TableReference;
use datafusion_common::{
field_not_found, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
Expand Down

0 comments on commit 8d6448e

Please sign in to comment.