Skip to content

Commit

Permalink
Enforce ambiguity check whilst normalizing columns (#5509)
Browse files Browse the repository at this point in the history
* Enforce ambiguity check whilst normalizing columns

* Update datafusion/expr/src/expr_rewriter.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* Fix clippy

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
Jefffrey and alamb authored Mar 12, 2023
1 parent 5423ba0 commit aeb593f
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 282 deletions.
192 changes: 192 additions & 0 deletions datafusion/common/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ impl Column {
/// In this case, both `t1.id` and `t2.id` will match unqualified column `id`. To express this possibility, use
/// `using_columns`. Each entry in this array is a set of columns that are bound together via a `USING` clause. So
/// in this example this would be `[{t1.id, t2.id}]`.
#[deprecated(
since = "20.0.0",
note = "use normalize_with_schemas_and_ambiguity_check instead"
)]
pub fn normalize_with_schemas(
self,
schemas: &[&Arc<DFSchema>],
Expand Down Expand Up @@ -154,6 +158,105 @@ impl Column {
.collect(),
}))
}

/// Qualify column if not done yet.
///
/// If this column already has a [relation](Self::relation), it will be returned as is and the given parameters are
/// ignored. Otherwise this will search through the given schemas to find the column.
///
/// Will check for ambiguity at each level of `schemas`.
///
/// A schema matches if there is a single column that -- when unqualified -- matches this column. There is an
/// exception for `USING` statements, see below.
///
/// # Using columns
/// Take the following SQL statement:
///
/// ```sql
/// SELECT id FROM t1 JOIN t2 USING(id)
/// ```
///
/// In this case, both `t1.id` and `t2.id` will match unqualified column `id`. To express this possibility, use
/// `using_columns`. Each entry in this array is a set of columns that are bound together via a `USING` clause. So
/// in this example this would be `[{t1.id, t2.id}]`.
///
/// Regarding ambiguity check, `schemas` is structured to allow levels of schemas to be passed in.
/// For example:
///
/// ```text
/// schemas = &[
/// &[schema1, schema2], // first level
/// &[schema3, schema4], // second level
/// ]
/// ```
///
/// Will search for a matching field in all schemas in the first level. If a matching field according to above
/// mentioned conditions is not found, then will check the next level. If found more than one matching column across
/// all schemas in a level, that isn't a USING column, will return an error due to ambiguous column.
///
/// If checked all levels and couldn't find field, will return field not found error.
pub fn normalize_with_schemas_and_ambiguity_check(
self,
schemas: &[&[&DFSchema]],
using_columns: &[HashSet<Column>],
) -> Result<Self> {
if self.relation.is_some() {
return Ok(self);
}

for schema_level in schemas {
let fields = schema_level
.iter()
.flat_map(|s| s.fields_with_unqualified_name(&self.name))
.collect::<Vec<_>>();
match fields.len() {
0 => continue,
1 => return Ok(fields[0].qualified_column()),
_ => {
// More than 1 fields in this schema have their names set to self.name.
//
// This should only happen when a JOIN query with USING constraint references
// join columns using unqualified column name. For example:
//
// ```sql
// SELECT id FROM t1 JOIN t2 USING(id)
// ```
//
// In this case, both `t1.id` and `t2.id` will match unqualified column `id`.
// We will use the relation from the first matched field to normalize self.

// Compare matched fields with one USING JOIN clause at a time
for using_col in using_columns {
let all_matched = fields
.iter()
.all(|f| using_col.contains(&f.qualified_column()));
// All matched fields belong to the same using column set, in orther words
// the same join clause. We simply pick the qualifer from the first match.
if all_matched {
return Ok(fields[0].qualified_column());
}
}

// If not due to USING columns then due to ambiguous column name
return Err(DataFusionError::SchemaError(
SchemaError::AmbiguousReference {
qualifier: None,
name: self.name,
},
));
}
}
}

Err(DataFusionError::SchemaError(SchemaError::FieldNotFound {
field: self,
valid_fields: schemas
.iter()
.flat_map(|s| s.iter())
.flat_map(|s| s.fields().iter().map(|f| f.qualified_column()))
.collect(),
}))
}
}

impl From<&str> for Column {
Expand Down Expand Up @@ -189,3 +292,92 @@ impl fmt::Display for Column {
write!(f, "{}", self.flat_name())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::DFField;
use arrow::datatypes::DataType;
use std::collections::HashMap;

fn create_schema(names: &[(Option<&str>, &str)]) -> Result<DFSchema> {
let fields = names
.iter()
.map(|(qualifier, name)| {
DFField::new(qualifier.to_owned(), name, DataType::Boolean, true)
})
.collect::<Vec<_>>();
DFSchema::new_with_metadata(fields, HashMap::new())
}

#[test]
fn test_normalize_with_schemas_and_ambiguity_check() -> Result<()> {
let schema1 = create_schema(&[(Some("t1"), "a"), (Some("t1"), "b")])?;
let schema2 = create_schema(&[(Some("t2"), "c"), (Some("t2"), "d")])?;
let schema3 = create_schema(&[
(Some("t3"), "a"),
(Some("t3"), "b"),
(Some("t3"), "c"),
(Some("t3"), "d"),
(Some("t3"), "e"),
])?;

// already normalized
let col = Column::new(Some("t1"), "a");
let col = col.normalize_with_schemas_and_ambiguity_check(&[], &[])?;
assert_eq!(col, Column::new(Some("t1"), "a"));

// should find in first level (schema1)
let col = Column::from_name("a");
let col = col.normalize_with_schemas_and_ambiguity_check(
&[&[&schema1, &schema2], &[&schema3]],
&[],
)?;
assert_eq!(col, Column::new(Some("t1"), "a"));

// should find in second level (schema3)
let col = Column::from_name("e");
let col = col.normalize_with_schemas_and_ambiguity_check(
&[&[&schema1, &schema2], &[&schema3]],
&[],
)?;
assert_eq!(col, Column::new(Some("t3"), "e"));

// using column in first level (pick schema1)
let mut using_columns = HashSet::new();
using_columns.insert(Column::new(Some("t1"), "a"));
using_columns.insert(Column::new(Some("t3"), "a"));
let col = Column::from_name("a");
let col = col.normalize_with_schemas_and_ambiguity_check(
&[&[&schema1, &schema3], &[&schema2]],
&[using_columns],
)?;
assert_eq!(col, Column::new(Some("t1"), "a"));

// not found in any level
let col = Column::from_name("z");
let err = col
.normalize_with_schemas_and_ambiguity_check(
&[&[&schema1, &schema2], &[&schema3]],
&[],
)
.expect_err("should've failed to find field");
let expected = "Schema error: No field named 'z'. \
Valid fields are 't1'.'a', 't1'.'b', 't2'.'c', \
't2'.'d', 't3'.'a', 't3'.'b', 't3'.'c', 't3'.'d', 't3'.'e'.";
assert_eq!(err.to_string(), expected);

// ambiguous column reference
let col = Column::from_name("a");
let err = col
.normalize_with_schemas_and_ambiguity_check(
&[&[&schema1, &schema3], &[&schema2]],
&[],
)
.expect_err("should've found ambiguous field");
let expected = "Schema error: Ambiguous reference to unqualified field 'a'";
assert_eq!(err.to_string(), expected);

Ok(())
}
}
5 changes: 2 additions & 3 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1393,8 +1393,7 @@ mod tests {
let join = left
.join_on(right, JoinType::Inner, [col("c1").eq(col("c1"))])
.expect_err("join didn't fail check");
let expected =
"Error during planning: reference 'c1' is ambiguous, could be a.c1,b.c1;";
let expected = "Schema error: Ambiguous reference to unqualified field 'c1'";
assert_eq!(join.to_string(), expected);

Ok(())
Expand Down Expand Up @@ -1861,7 +1860,7 @@ mod tests {
)]));

let data = RecordBatch::try_new(
schema.clone(),
schema,
vec![
Arc::new(arrow::array::StringArray::from(vec![
Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"),
Expand Down
101 changes: 87 additions & 14 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion::prelude::JoinType;
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::{avg, col, count, lit, sum, Expr, ExprSchemable};
use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};

#[tokio::test]
async fn describe() -> Result<()> {
Expand Down Expand Up @@ -201,6 +201,7 @@ async fn sort_on_distinct_columns() -> Result<()> {
assert_batches_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn sort_on_distinct_unprojected_columns() -> Result<()> {
let schema = Schema::new(vec![
Expand All @@ -214,26 +215,98 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> {
Arc::new(Int32Array::from_slice([1, 10, 10, 100])),
Arc::new(Int32Array::from_slice([2, 3, 4, 5])),
],
)
.unwrap();
)?;

// Cannot sort on a column after distinct that would add a new column
let ctx = SessionContext::new();
ctx.register_batch("t", batch).unwrap();
ctx.register_batch("t", batch)?;
let err = ctx
.table("t")
.await
.unwrap()
.select(vec![col("a")])
.unwrap()
.distinct()
.unwrap()
.await?
.select(vec![col("a")])?
.distinct()?
.sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))])
.unwrap_err();
assert_eq!(err.to_string(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions b must appear in select list");
Ok(())
}

#[tokio::test]
async fn sort_on_ambiguous_column() -> Result<()> {
let err = create_test_table("t1")
.await?
.join(
create_test_table("t2").await?,
JoinType::Inner,
&["a"],
&["a"],
None,
)?
.sort(vec![col("b").sort(true, true)])
.unwrap_err();

let expected = "Schema error: Ambiguous reference to unqualified field 'b'";
assert_eq!(err.to_string(), expected);
Ok(())
}

#[tokio::test]
async fn group_by_ambiguous_column() -> Result<()> {
let err = create_test_table("t1")
.await?
.join(
create_test_table("t2").await?,
JoinType::Inner,
&["a"],
&["a"],
None,
)?
.aggregate(vec![col("b")], vec![max(col("a"))])
.unwrap_err();

let expected = "Schema error: Ambiguous reference to unqualified field 'b'";
assert_eq!(err.to_string(), expected);
Ok(())
}

#[tokio::test]
async fn filter_on_ambiguous_column() -> Result<()> {
let err = create_test_table("t1")
.await?
.join(
create_test_table("t2").await?,
JoinType::Inner,
&["a"],
&["a"],
None,
)?
.filter(col("b").eq(lit(1)))
.unwrap_err();

let expected = "Schema error: Ambiguous reference to unqualified field 'b'";
assert_eq!(err.to_string(), expected);
Ok(())
}

#[tokio::test]
async fn select_ambiguous_column() -> Result<()> {
let err = create_test_table("t1")
.await?
.join(
create_test_table("t2").await?,
JoinType::Inner,
&["a"],
&["a"],
None,
)?
.select(vec![col("b")])
.unwrap_err();

let expected = "Schema error: Ambiguous reference to unqualified field 'b'";
assert_eq!(err.to_string(), expected);
Ok(())
}

#[tokio::test]
async fn filter_with_alias_overwrite() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down Expand Up @@ -316,7 +389,7 @@ async fn test_grouping_sets() -> Result<()> {
vec![col("a"), col("b")],
]));

let df = create_test_table()
let df = create_test_table("test")
.await?
.aggregate(vec![grouping_set_expr], vec![count(col("a"))])?
.sort(vec![
Expand Down Expand Up @@ -746,7 +819,7 @@ async fn unnest_aggregate_columns() -> Result<()> {
Ok(())
}

async fn create_test_table() -> Result<DataFrame> {
async fn create_test_table(name: &str) -> Result<DataFrame> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false),
Expand All @@ -768,9 +841,9 @@ async fn create_test_table() -> Result<DataFrame> {

let ctx = SessionContext::new();

ctx.register_batch("test", batch)?;
ctx.register_batch(name, batch)?;

ctx.table("test").await
ctx.table(name).await
}

async fn aggregates_table(ctx: &SessionContext) -> Result<DataFrame> {
Expand Down
Loading

0 comments on commit aeb593f

Please sign in to comment.