Skip to content

Commit

Permalink
Support dictionary in InList (apache#3936)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Nov 1, 2022
1 parent b3a4665 commit 9973b03
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 11 deletions.
99 changes: 96 additions & 3 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,101 @@ async fn csv_in_set_test() -> Result<()> {
}

#[tokio::test]
#[ignore]
// https://github.com/apache/arrow-datafusion/issues/3936
async fn in_list_string_dictionaries() -> Result<()> {
// let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")]
let input = vec![Some("foo"), Some("bar"), Some("fazzz")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>();

let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap();

let ctx = SessionContext::new();
ctx.register_batch("test", batch)?;

let sql = "SELECT * FROM test WHERE c1 IN ('Bar')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec!["++", "++"];
assert_batches_eq!(expected, &actual);

let sql = "SELECT * FROM test WHERE c1 IN ('foo')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"];
assert_batches_eq!(expected, &actual);

let sql = "SELECT * FROM test WHERE c1 IN ('bar', 'foo')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----+", "| c1 |", "+-----+", "| foo |", "| bar |", "+-----+",
];
assert_batches_eq!(expected, &actual);

let sql = "SELECT * FROM test WHERE c1 IN ('Bar', 'foo')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"];
assert_batches_eq!(expected, &actual);

let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------+",
"| c1 |",
"+-------+",
"| foo |",
"| fazzz |",
"+-------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn in_list_string_dictionaries_with_null() -> Result<()> {
let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>();

let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap();

let ctx = SessionContext::new();
ctx.register_batch("test", batch)?;

let sql = "SELECT * FROM test WHERE c1 IN ('Bar')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec!["++", "++"];
assert_batches_eq!(expected, &actual);

let sql = "SELECT * FROM test WHERE c1 IN ('foo')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"];
assert_batches_eq!(expected, &actual);

let sql = "SELECT * FROM test WHERE c1 IN ('bar', 'foo')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----+", "| c1 |", "+-----+", "| foo |", "| bar |", "+-----+",
];
assert_batches_eq!(expected, &actual);

let sql = "SELECT * FROM test WHERE c1 IN ('Bar', 'foo')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"];
assert_batches_eq!(expected, &actual);

let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------+",
"| c1 |",
"+-------+",
"| foo |",
"| fazzz |",
"+-------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn in_set_string_dictionaries() -> Result<()> {
let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")]
.into_iter()
Expand All @@ -440,7 +533,7 @@ async fn in_set_string_dictionaries() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("test", batch)?;

let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazz')";
let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------+",
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ hashbrown = { version = "0.12", features = ["raw"] }
itertools = { version = "0.10", features = ["use_std"] }
lazy_static = { version = "^1.4.0" }
md-5 = { version = "^0.10.0", optional = true }
num-traits = { version = "0.2", default-features = false }
ordered-float = "3.0"
paste = "^1.0"
rand = "0.8"
Expand Down
30 changes: 22 additions & 8 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ use crate::physical_expr::down_cast_any_ref;
use crate::utils::expr_list_eq_any_order;
use crate::PhysicalExpr;
use arrow::array::*;
use arrow::compute::take;
use arrow::datatypes::*;
use arrow::downcast_primitive_array;
use arrow::record_batch::RecordBatch;
use arrow::util::bit_iterator::BitIndexIterator;
use arrow::{downcast_dictionary_array, downcast_primitive_array};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use hashbrown::hash_map::RawEntryMut;
Expand All @@ -57,7 +58,7 @@ impl Debug for InListExpr {

/// A type-erased container of array elements
trait Set: Send + Sync {
fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray;
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
}

struct ArrayHashSet {
Expand Down Expand Up @@ -92,13 +93,22 @@ where
for<'a> &'a T: ArrayAccessor,
for<'a> <&'a T as ArrayAccessor>::Item: PartialEq + HashValue,
{
fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray {
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
downcast_dictionary_array! {
v => {
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
return Ok(BooleanArray::from(result.data().clone()))
}
_ => {}
}

let v = v.as_any().downcast_ref::<T>().unwrap();
let in_data = self.array.data();
let in_array = &self.array;
let has_nulls = in_data.null_count() != 0;

ArrayIter::new(v)
Ok(ArrayIter::new(v)
.map(|v| {
v.and_then(|v| {
let hash = v.hash_one(&self.hash_set.state);
Expand All @@ -116,7 +126,7 @@ where
}
})
})
.collect()
.collect())
}
}

Expand Down Expand Up @@ -188,10 +198,12 @@ fn make_set(array: &dyn Array) -> Result<Box<dyn Set>> {
let array = as_generic_binary_array::<i64>(array);
Box::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"),
d => return Err(DataFusionError::NotImplemented(format!("DataType::{} not supported in InList", d)))
})
}

/// Evaluates the list of expressions into an array, flattening any dictionaries
fn evaluate_list(
list: &[Arc<dyn PhysicalExpr>],
batch: &RecordBatch,
Expand All @@ -203,6 +215,8 @@ fn evaluate_list(
ColumnarValue::Array(_) => Err(DataFusionError::Execution(
"InList expression must evaluate to a scalar".to_string(),
)),
// Flatten dictionary values
ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v),
ColumnarValue::Scalar(s) => Ok(s),
})
})
Expand Down Expand Up @@ -286,10 +300,10 @@ impl PhysicalExpr for InListExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let value = self.expr.evaluate(batch)?.into_array(1);
let r = match &self.static_filter {
Some(f) => f.contains(value.as_ref(), self.negated),
Some(f) => f.contains(value.as_ref(), self.negated)?,
None => {
let list = evaluate_list(&self.list, batch)?;
make_set(list.as_ref())?.contains(value.as_ref(), self.negated)
make_set(list.as_ref())?.contains(value.as_ref(), self.negated)?
}
};
Ok(ColumnarValue::Array(Arc::new(r)))
Expand Down Expand Up @@ -947,7 +961,7 @@ mod tests {
let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();

let array = Int64Array::from(vec![1, 2, 3, 4]);
let r = result.contains(&array, false);
let r = result.contains(&array, false).unwrap();
assert_eq!(r, BooleanArray::from(vec![true, true, true, false]));

try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
Expand Down

0 comments on commit 9973b03

Please sign in to comment.