diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs index c34233972c0e..9d5facebc0c1 100644 --- a/datafusion/src/field_util.rs +++ b/datafusion/src/field_util.rs @@ -31,21 +31,19 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { if *i < 0 { - Err(DataFusionError::Plan( - format!("List based indexed access requires a positive int, was {0}", i), - )) + Err(DataFusionError::Plan(format!( + "List based indexed access requires a positive int, was {0}", + i + ))) } else { Ok(Field::new(&i.to_string(), lt.data_type().clone(), false)) } } - (DataType::List(_), _) => { - Err(DataFusionError::Plan( - "Only ints are valid as an indexed field in a list" - .to_string(), - )) - } + (DataType::List(_), _) => Err(DataFusionError::Plan( + "Only ints are valid as an indexed field in a list".to_string(), + )), _ => Err(DataFusionError::Plan( - "The expression to get an indexed field is only valid for `List` or 'Dictionary'" + "The expression to get an indexed field is only valid for `List` types" .to_string(), )), } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index b91aef9077f3..fa140cee38f9 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -231,7 +231,7 @@ pub mod variable; pub use arrow; pub use parquet; -pub mod field_util; +pub(crate) mod field_util; #[cfg(test)] pub mod test; pub mod test_util; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index fe91ba0c3ec8..499a8c720dba 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -246,7 +246,7 @@ pub enum Expr { IsNull(Box), /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), - /// Returns the field of a [`ListArray`] or ['DictionaryArray'] by name + /// Returns the field of a [`ListArray`] by key GetIndexedField { /// the expression to take the field from expr: Box, diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index 25ec41cfbd79..8a9191e9c346 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! get field of a struct array +//! get field of a `ListArray` +use std::convert::TryInto; use std::{any::Any, sync::Arc}; use arrow::{ @@ -80,25 +81,145 @@ impl PhysicalExpr for GetIndexedFieldExpr { let arg = self.arg.evaluate(batch)?; match arg { ColumnarValue::Array(array) => match (array.data_type(), &self.key) { + (DataType::List(_), _) if self.key.is_null() => { + let scalar_null: ScalarValue = array.data_type().try_into()?; + Ok(ColumnarValue::Scalar(scalar_null)) + } (DataType::List(_), ScalarValue::Int64(Some(i))) => { let as_list_array = array.as_any().downcast_ref::().unwrap(); - let x: Vec> = as_list_array + if as_list_array.is_empty() { + let scalar_null: ScalarValue = array.data_type().try_into()?; + return Ok(ColumnarValue::Scalar(scalar_null)) + } + let sliced_array: Vec> = as_list_array .iter() .filter_map(|o| o.map(|list| list.slice(*i as usize, 1))) .collect(); - let vec = x.iter().map(|a| a.as_ref()).collect::>(); + let vec = sliced_array.iter().map(|a| a.as_ref()).collect::>(); let iter = concat(vec.as_slice()).unwrap(); Ok(ColumnarValue::Array(iter)) } - (dt, _) => Err(DataFusionError::NotImplemented(format!( - "get indexed field is not implemented for {}", - dt - ))), + (dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))), }, ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( - "field is not yet implemented for scalar values".to_string(), + "field access is not yet implemented for scalar values".to_string(), )), } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_plan::expressions::{col, lit}; + use arrow::array::{ListBuilder, StringBuilder}; + use arrow::{array::StringArray, datatypes::Field}; + + fn get_indexed_field_test( + list_of_lists: Vec>>, + index: i64, + expected: Vec>, + ) -> Result<()> { + let schema = list_schema("l"); + let builder = StringBuilder::new(3); + let mut lb = ListBuilder::new(builder); + for values in list_of_lists { + let builder = lb.values(); + for value in values { + match value { + None => builder.append_null(), + Some(v) => builder.append_value(v), + } + .unwrap() + } + lb.append(true).unwrap(); + } + + let expr = col("l", &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; + + let key = ScalarValue::Int64(Some(index)); + let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to StringArray"); + let expected = &StringArray::from(expected); + assert_eq!(expected, result); + Ok(()) + } + + fn list_schema(col: &str) -> Schema { + Schema::new(vec![Field::new( + col, + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))), + true, + )]) + } + + #[test] + fn get_indexed_field_list() -> Result<()> { + let list_of_lists = vec![ + vec![Some("a"), Some("b"), None], + vec![None, Some("c"), Some("d")], + vec![Some("e"), None, Some("f")], + ]; + let expected_list = vec![ + vec![Some("a"), None, Some("e")], + vec![Some("b"), Some("c"), None], + vec![None, Some("d"), Some("f")], + ]; + + for (i, expected) in expected_list.into_iter().enumerate() { + get_indexed_field_test(list_of_lists.clone(), i as i64, expected)?; + } + Ok(()) + } + + #[test] + fn get_indexed_field_empty_list() -> Result<()> { + let schema = list_schema("l"); + let builder = StringBuilder::new(0); + let mut lb = ListBuilder::new(builder); + let expr = col("l", &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; + let key = ScalarValue::Int64(Some(0)); + let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + assert!(result.is_empty()); + Ok(()) + } + + fn get_indexed_field_test_failure( + schema: Schema, + expr: Arc, + key: ScalarValue, + expected: &str, + ) -> Result<()> { + let builder = StringBuilder::new(3); + let mut lb = ListBuilder::new(builder); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; + let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); + let r = expr.evaluate(&batch).map(|_| ()); + assert!(r.is_err()); + assert_eq!(format!("{}", r.unwrap_err()), expected); + Ok(()) + } + + #[test] + fn get_indexed_field_invalid_scalar() -> Result<()> { + let schema = list_schema("l"); + let expr = lit(ScalarValue::Utf8(Some("a".to_string()))); + get_indexed_field_test_failure(schema, expr, ScalarValue::Int64(Some(0)), "This feature is not implemented: field access is not yet implemented for scalar values") + } + + #[test] + fn get_indexed_field_invalid_list_index() -> Result<()> { + let schema = list_schema("l"); + let expr = col("l", &schema).unwrap(); + get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index") + } +}