Skip to content

Commit

Permalink
Support filter for List (#11091)
Browse files Browse the repository at this point in the history
* support basic list cmp

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add more ops

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add distinct

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* nested

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add comment

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 authored Jun 27, 2024
1 parent 2d1e850 commit ff116c3
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 67 deletions.
180 changes: 180 additions & 0 deletions datafusion/physical-expr-common/src/datum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// UnLt required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::BooleanArray;
use arrow::array::{make_comparator, ArrayRef, Datum};
use arrow::buffer::NullBuffer;
use arrow::compute::SortOptions;
use arrow::error::ArrowError;
use datafusion_common::internal_err;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Operator};
use std::sync::Arc;

/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs`
///
/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction
pub fn apply(
lhs: &ColumnarValue,
rhs: &ColumnarValue,
f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
) -> Result<ColumnarValue> {
match (&lhs, &rhs) {
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?))
}
(ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
),
(ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?),
),
(ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
Ok(ColumnarValue::Scalar(scalar))
}
}
}

/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
pub fn apply_cmp(
lhs: &ColumnarValue,
rhs: &ColumnarValue,
f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
) -> Result<ColumnarValue> {
apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
}

/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like
/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type
pub fn apply_cmp_for_nested(
op: Operator,
lhs: &ColumnarValue,
rhs: &ColumnarValue,
) -> Result<ColumnarValue> {
if matches!(
op,
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::Gt
| Operator::LtEq
| Operator::GtEq
| Operator::IsDistinctFrom
| Operator::IsNotDistinctFrom
) {
apply(lhs, rhs, |l, r| {
Ok(Arc::new(compare_op_for_nested(op, l, r)?))
})
} else {
internal_err!("invalid operator for nested")
}
}

/// Compare on nested type List, Struct, and so on
fn compare_op_for_nested(
op: Operator,
lhs: &dyn Datum,
rhs: &dyn Datum,
) -> Result<BooleanArray> {
let (l, is_l_scalar) = lhs.get();
let (r, is_r_scalar) = rhs.get();
let l_len = l.len();
let r_len = r.len();

if l_len != r_len && !is_l_scalar && !is_r_scalar {
return internal_err!("len mismatch");
}

let len = match is_l_scalar {
true => r_len,
false => l_len,
};

// fast path, if compare with one null and operator is not 'distinct', then we can return null array directly
if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom)
&& (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1)
{
return Ok(BooleanArray::new_null(len));
}

// TODO: make SortOptions configurable
// we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour
let cmp = make_comparator(l, r, SortOptions::default())?;

let cmp_with_op = |i, j| match op {
Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(),
Operator::Lt => cmp(i, j).is_lt(),
Operator::Gt => cmp(i, j).is_gt(),
Operator::LtEq => !cmp(i, j).is_gt(),
Operator::GtEq => !cmp(i, j).is_lt(),
Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(),
_ => unreachable!("unexpected operator found"),
};

let values = match (is_l_scalar, is_r_scalar) {
(false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(),
(true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(),
(false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(),
(true, true) => std::iter::once(cmp_with_op(0, 0)).collect(),
};

// Distinct understand how to compare with NULL
// i.e NULL is distinct from NULL -> false
if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) {
Ok(BooleanArray::new(values, None))
} else {
// If one of the side is NULL, we returns NULL
// i.e. NULL eq NULL -> NULL
let nulls = NullBuffer::union(l.nulls(), r.nulls());
Ok(BooleanArray::new(values, nulls))
}
}

#[cfg(test)]
mod tests {
use arrow::{
array::{make_comparator, Array, BooleanArray, ListArray},
buffer::NullBuffer,
compute::SortOptions,
datatypes::Int32Type,
};

#[test]
fn test123() {
let data = vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(5)]),
Some(vec![Some(6), Some(7)]),
];
let a = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
let data = vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(5)]),
Some(vec![Some(6), Some(7)]),
];
let b = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap();
let len = a.len().min(b.len());
let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
let nulls = NullBuffer::union(a.nulls(), b.nulls());
println!("res: {:?}", BooleanArray::new(values, nulls));
}
}
1 change: 1 addition & 0 deletions datafusion/physical-expr-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

pub mod aggregate;
pub mod binary_map;
pub mod datum;
pub mod expressions;
pub mod physical_expr;
pub mod sort_expr;
Expand Down
9 changes: 8 additions & 1 deletion datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ mod kernels;
use std::hash::{Hash, Hasher};
use std::{any::Any, sync::Arc};

use crate::expressions::datum::{apply, apply_cmp};
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;
Expand All @@ -40,6 +39,7 @@ use datafusion_expr::interval_arithmetic::{apply_operator, Interval};
use datafusion_expr::sort_properties::ExprProperties;
use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::{ColumnarValue, Operator};
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};

use kernels::{
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
Expand Down Expand Up @@ -265,6 +265,13 @@ impl PhysicalExpr for BinaryExpr {
let schema = batch.schema();
let input_schema = schema.as_ref();

if left_data_type.is_nested() {
if right_data_type != left_data_type {
return internal_err!("type mismatch");
}
return apply_cmp_for_nested(self.op, &lhs, &rhs);
}

match self.op {
Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
Expand Down
58 changes: 0 additions & 58 deletions datafusion/physical-expr/src/expressions/datum.rs

This file was deleted.

2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/expressions/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ use std::{any::Any, sync::Arc};

use crate::{physical_expr::down_cast_any_ref, PhysicalExpr};

use crate::expressions::datum::apply_cmp;
use arrow::record_batch::RecordBatch;
use arrow_schema::{DataType, Schema};
use datafusion_common::{internal_err, Result};
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr_common::datum::apply_cmp;

// Like expression
#[derive(Debug, Hash)]
Expand Down
1 change: 0 additions & 1 deletion datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
mod binary;
mod case;
mod column;
mod datum;
mod in_list;
mod is_not_null;
mod is_null;
Expand Down
Loading

0 comments on commit ff116c3

Please sign in to comment.