Skip to content

Commit

Permalink
Add unit tests for the physical plan of get_indexed_field
Browse files Browse the repository at this point in the history
  • Loading branch information
Igosuki committed Oct 29, 2021
1 parent b2012b6 commit 24bac8b
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 20 deletions.
18 changes: 8 additions & 10 deletions datafusion/src/field_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,19 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result<Fiel
match (data_type, key) {
(DataType::List(lt), ScalarValue::Int64(Some(i))) => {
if *i < 0 {
Err(DataFusionError::Plan(
format!("List based indexed access requires a positive int, was {0}", i),
))
Err(DataFusionError::Plan(format!(
"List based indexed access requires a positive int, was {0}",
i
)))
} else {
Ok(Field::new(&i.to_string(), lt.data_type().clone(), false))
}
}
(DataType::List(_), _) => {
Err(DataFusionError::Plan(
"Only ints are valid as an indexed field in a list"
.to_string(),
))
}
(DataType::List(_), _) => Err(DataFusionError::Plan(
"Only ints are valid as an indexed field in a list".to_string(),
)),
_ => Err(DataFusionError::Plan(
"The expression to get an indexed field is only valid for `List` or 'Dictionary'"
"The expression to get an indexed field is only valid for `List` types"
.to_string(),
)),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ pub mod variable;
pub use arrow;
pub use parquet;

pub mod field_util;
pub(crate) mod field_util;
#[cfg(test)]
pub mod test;
pub mod test_util;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ pub enum Expr {
IsNull(Box<Expr>),
/// arithmetic negation of an expression, the operand must be of a signed numeric data type
Negative(Box<Expr>),
/// Returns the field of a [`ListArray`] or ['DictionaryArray'] by name
/// Returns the field of a [`ListArray`] by key
GetIndexedField {
/// the expression to take the field from
expr: Box<Expr>,
Expand Down
137 changes: 129 additions & 8 deletions datafusion/src/physical_plan/expressions/get_indexed_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
// specific language governing permissions and limitations
// under the License.

//! get field of a struct array
//! get field of a `ListArray`
use std::convert::TryInto;
use std::{any::Any, sync::Arc};

use arrow::{
Expand Down Expand Up @@ -80,25 +81,145 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let arg = self.arg.evaluate(batch)?;
match arg {
ColumnarValue::Array(array) => match (array.data_type(), &self.key) {
(DataType::List(_), _) if self.key.is_null() => {
let scalar_null: ScalarValue = array.data_type().try_into()?;
Ok(ColumnarValue::Scalar(scalar_null))
}
(DataType::List(_), ScalarValue::Int64(Some(i))) => {
let as_list_array =
array.as_any().downcast_ref::<ListArray>().unwrap();
let x: Vec<Arc<dyn Array>> = as_list_array
if as_list_array.is_empty() {
let scalar_null: ScalarValue = array.data_type().try_into()?;
return Ok(ColumnarValue::Scalar(scalar_null))
}
let sliced_array: Vec<Arc<dyn Array>> = as_list_array
.iter()
.filter_map(|o| o.map(|list| list.slice(*i as usize, 1)))
.collect();
let vec = x.iter().map(|a| a.as_ref()).collect::<Vec<&dyn Array>>();
let vec = sliced_array.iter().map(|a| a.as_ref()).collect::<Vec<&dyn Array>>();
let iter = concat(vec.as_slice()).unwrap();
Ok(ColumnarValue::Array(iter))
}
(dt, _) => Err(DataFusionError::NotImplemented(format!(
"get indexed field is not implemented for {}",
dt
))),
(dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))),
},
ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented(
"field is not yet implemented for scalar values".to_string(),
"field access is not yet implemented for scalar values".to_string(),
)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use crate::physical_plan::expressions::{col, lit};
use arrow::array::{ListBuilder, StringBuilder};
use arrow::{array::StringArray, datatypes::Field};

fn get_indexed_field_test(
list_of_lists: Vec<Vec<Option<&str>>>,
index: i64,
expected: Vec<Option<&str>>,
) -> Result<()> {
let schema = list_schema("l");
let builder = StringBuilder::new(3);
let mut lb = ListBuilder::new(builder);
for values in list_of_lists {
let builder = lb.values();
for value in values {
match value {
None => builder.append_null(),
Some(v) => builder.append_value(v),
}
.unwrap()
}
lb.append(true).unwrap();
}

let expr = col("l", &schema).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;

let key = ScalarValue::Int64(Some(index));
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
.downcast_ref::<StringArray>()
.expect("failed to downcast to StringArray");
let expected = &StringArray::from(expected);
assert_eq!(expected, result);
Ok(())
}

fn list_schema(col: &str) -> Schema {
Schema::new(vec![Field::new(
col,
DataType::List(Box::new(Field::new("item", DataType::Utf8, true))),
true,
)])
}

#[test]
fn get_indexed_field_list() -> Result<()> {
let list_of_lists = vec![
vec![Some("a"), Some("b"), None],
vec![None, Some("c"), Some("d")],
vec![Some("e"), None, Some("f")],
];
let expected_list = vec![
vec![Some("a"), None, Some("e")],
vec![Some("b"), Some("c"), None],
vec![None, Some("d"), Some("f")],
];

for (i, expected) in expected_list.into_iter().enumerate() {
get_indexed_field_test(list_of_lists.clone(), i as i64, expected)?;
}
Ok(())
}

#[test]
fn get_indexed_field_empty_list() -> Result<()> {
let schema = list_schema("l");
let builder = StringBuilder::new(0);
let mut lb = ListBuilder::new(builder);
let expr = col("l", &schema).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;
let key = ScalarValue::Int64(Some(0));
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
assert!(result.is_empty());
Ok(())
}

fn get_indexed_field_test_failure(
schema: Schema,
expr: Arc<dyn PhysicalExpr>,
key: ScalarValue,
expected: &str,
) -> Result<()> {
let builder = StringBuilder::new(3);
let mut lb = ListBuilder::new(builder);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
let r = expr.evaluate(&batch).map(|_| ());
assert!(r.is_err());
assert_eq!(format!("{}", r.unwrap_err()), expected);
Ok(())
}

#[test]
fn get_indexed_field_invalid_scalar() -> Result<()> {
let schema = list_schema("l");
let expr = lit(ScalarValue::Utf8(Some("a".to_string())));
get_indexed_field_test_failure(schema, expr, ScalarValue::Int64(Some(0)), "This feature is not implemented: field access is not yet implemented for scalar values")
}

#[test]
fn get_indexed_field_invalid_list_index() -> Result<()> {
let schema = list_schema("l");
let expr = col("l", &schema).unwrap();
get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index")
}
}

0 comments on commit 24bac8b

Please sign in to comment.