From ff116c3da69897358f210a3ea944c8e51dcb7b52 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 27 Jun 2024 16:40:16 +0800 Subject: [PATCH] Support filter for List (#11091) * support basic list cmp Signed-off-by: jayzhan211 * add more ops Signed-off-by: jayzhan211 * add distinct Signed-off-by: jayzhan211 * nested Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/physical-expr-common/src/datum.rs | 180 ++++++++++++++++++ datafusion/physical-expr-common/src/lib.rs | 1 + .../physical-expr/src/expressions/binary.rs | 9 +- .../physical-expr/src/expressions/datum.rs | 58 ------ .../physical-expr/src/expressions/like.rs | 2 +- .../physical-expr/src/expressions/mod.rs | 1 - .../sqllogictest/test_files/array_query.slt | 128 ++++++++++++- 7 files changed, 312 insertions(+), 67 deletions(-) create mode 100644 datafusion/physical-expr-common/src/datum.rs delete mode 100644 datafusion/physical-expr/src/expressions/datum.rs diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs new file mode 100644 index 000000000000..f4ce0eebc081 --- /dev/null +++ b/datafusion/physical-expr-common/src/datum.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// UnLt required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::BooleanArray; +use arrow::array::{make_comparator, ArrayRef, Datum}; +use arrow::buffer::NullBuffer; +use arrow::compute::SortOptions; +use arrow::error::ArrowError; +use datafusion_common::internal_err; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Operator}; +use std::sync::Arc; + +/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` +/// +/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction +pub fn apply( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + match (&lhs, &rhs) { + (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { + Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) + } + (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( + ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ), + (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( + ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), + ), + (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { + let array = f(&left.to_scalar()?, &right.to_scalar()?)?; + let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } + } +} + +/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` +pub fn apply_cmp( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) +} + +/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like +/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type +pub fn apply_cmp_for_nested( + op: Operator, + lhs: &ColumnarValue, + rhs: &ColumnarValue, +) -> Result { + if matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::Gt + | Operator::LtEq + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + ) { + apply(lhs, rhs, |l, r| { + Ok(Arc::new(compare_op_for_nested(op, l, r)?)) + }) + } else { + internal_err!("invalid operator for nested") + } +} + +/// Compare on nested type List, Struct, and so on +fn compare_op_for_nested( + op: Operator, + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + let (l, is_l_scalar) = lhs.get(); + let (r, is_r_scalar) = rhs.get(); + let l_len = l.len(); + let r_len = r.len(); + + if l_len != r_len && !is_l_scalar && !is_r_scalar { + return internal_err!("len mismatch"); + } + + let len = match is_l_scalar { + true => r_len, + false => l_len, + }; + + // fast path, if compare with one null and operator is not 'distinct', then we can return null array directly + if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) + && (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1) + { + return Ok(BooleanArray::new_null(len)); + } + + // TODO: make SortOptions configurable + // we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour + let cmp = make_comparator(l, r, SortOptions::default())?; + + let cmp_with_op = |i, j| match op { + Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(), + Operator::Lt => cmp(i, j).is_lt(), + Operator::Gt => cmp(i, j).is_gt(), + Operator::LtEq => !cmp(i, j).is_gt(), + Operator::GtEq => !cmp(i, j).is_lt(), + Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(), + _ => unreachable!("unexpected operator found"), + }; + + let values = match (is_l_scalar, is_r_scalar) { + (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(), + (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(), + (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(), + (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(), + }; + + // Distinct understand how to compare with NULL + // i.e NULL is distinct from NULL -> false + if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) { + Ok(BooleanArray::new(values, None)) + } else { + // If one of the side is NULL, we returns NULL + // i.e. NULL eq NULL -> NULL + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + Ok(BooleanArray::new(values, nulls)) + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{make_comparator, Array, BooleanArray, ListArray}, + buffer::NullBuffer, + compute::SortOptions, + datatypes::Int32Type, + }; + + #[test] + fn test123() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let a = ListArray::from_iter_primitive::(data); + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let b = ListArray::from_iter_primitive::(data); + let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap(); + let len = a.len().min(b.len()); + let values = (0..len).map(|i| cmp(i, i).is_eq()).collect(); + let nulls = NullBuffer::union(a.nulls(), b.nulls()); + println!("res: {:?}", BooleanArray::new(values, nulls)); + } +} diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 0ddb84141a07..8d50e0b964e5 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -17,6 +17,7 @@ pub mod aggregate; pub mod binary_map; +pub mod datum; pub mod expressions; pub mod physical_expr; pub mod sort_expr; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 98df0cba9f3e..3a8f7ee56ace 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,7 +20,6 @@ mod kernels; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; @@ -40,6 +39,7 @@ use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; +use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested}; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, @@ -265,6 +265,13 @@ impl PhysicalExpr for BinaryExpr { let schema = batch.schema(); let input_schema = schema.as_ref(); + if left_data_type.is_nested() { + if right_data_type != left_data_type { + return internal_err!("type mismatch"); + } + return apply_cmp_for_nested(self.op, &lhs, &rhs); + } + match self.op { Operator::Plus => return apply(&lhs, &rhs, add_wrapping), Operator::Minus => return apply(&lhs, &rhs, sub_wrapping), diff --git a/datafusion/physical-expr/src/expressions/datum.rs b/datafusion/physical-expr/src/expressions/datum.rs deleted file mode 100644 index 2bb79922cfec..000000000000 --- a/datafusion/physical-expr/src/expressions/datum.rs +++ /dev/null @@ -1,58 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::{ArrayRef, Datum}; -use arrow::error::ArrowError; -use arrow_array::BooleanArray; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::ColumnarValue; -use std::sync::Arc; - -/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` -/// -/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction -pub(crate) fn apply( - lhs: &ColumnarValue, - rhs: &ColumnarValue, - f: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - match (&lhs, &rhs) { - (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { - Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) - } - (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( - ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), - ), - (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( - ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), - ), - (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { - let array = f(&left.to_scalar()?, &right.to_scalar()?)?; - let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; - Ok(ColumnarValue::Scalar(scalar)) - } - } -} - -/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` -pub(crate) fn apply_cmp( - lhs: &ColumnarValue, - rhs: &ColumnarValue, - f: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) -} diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index d18651c641fd..e0c02b0a90e9 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -20,11 +20,11 @@ use std::{any::Any, sync::Arc}; use crate::{physical_expr::down_cast_any_ref, PhysicalExpr}; -use crate::expressions::datum::apply_cmp; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Schema}; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::datum::apply_cmp; // Like expression #[derive(Debug, Hash)] diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c98bcc56ad97..608609b81d82 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -21,7 +21,6 @@ mod binary; mod case; mod column; -mod datum; mod in_list; mod is_not_null; mod is_null; diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index 24c99fc849b6..b29b5f5efd98 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -41,17 +41,68 @@ SELECT * FROM data; # Filtering ########### -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I rowsort SELECT * FROM data WHERE column1 = [1,2,3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) -SELECT * FROM data WHERE column1 = column2 - -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I SELECT * FROM data WHERE column1 != [1,2,3]; +---- +[2, 3] [2, 3] 1 -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I SELECT * FROM data WHERE column1 != column2 +---- +[1, 2, 3] [4, 5] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 < [1,2,3,4]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 <= [2, 3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 > [1,2]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 >= [1, 2, 3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +# test with scalar null +query ??I +SELECT * FROM data WHERE column2 = null; +---- + +query ??I +SELECT * FROM data WHERE null = column2; +---- + +query ??I +SELECT * FROM data WHERE column2 is distinct from null; +---- +[2, 3] [2, 3] 1 +[1, 2, 3] [4, 5] 1 + +query ??I +SELECT * FROM data WHERE column2 is not distinct from null; +---- +[1, 2, 3] NULL 1 ########### # Aggregates @@ -158,3 +209,68 @@ SELECT * FROM data ORDER BY column1, column3, column2; statement ok drop table data + + +# test filter column with all nulls +statement ok +create table data (a int) as values (null), (null), (null); + +query I +select * from data where a = null; +---- + +query I +select * from data where a is not distinct from null; +---- +NULL +NULL +NULL + +statement ok +drop table data; + +statement ok +create table data (a int[][], b int) as values ([[1,2,3]], 1), ([[2,3], [4,5]], 2), (null, 3); + +query ?I +select * from data; +---- +[[1, 2, 3]] 1 +[[2, 3], [4, 5]] 2 +NULL 3 + +query ?I +select * from data where a = [[1,2,3]]; +---- +[[1, 2, 3]] 1 + +query ?I +select * from data where a > [[1,2,3]]; +---- +[[2, 3], [4, 5]] 2 + +query ?I +select * from data where a > [[1,2]]; +---- +[[1, 2, 3]] 1 +[[2, 3], [4, 5]] 2 + +query ?I +select * from data where a < [[2, 3]]; +---- +[[1, 2, 3]] 1 + +# compare with null with eq results in null +query ?I +select * from data where a = null; +---- + +query ?I +select * from data where a != null; +---- + +# compare with null with distinct results in true/false +query ?I +select * from data where a is not distinct from null; +---- +NULL 3