diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs b/datafusion/core/src/logical_plan/expr_rewriter.rs index 20e9e598cf02..d6cf4c08ba94 100644 --- a/datafusion/core/src/logical_plan/expr_rewriter.rs +++ b/datafusion/core/src/logical_plan/expr_rewriter.rs @@ -228,7 +228,7 @@ impl ExprRewritable for Expr { } Expr::GetIndexedField { expr, key } => Expr::GetIndexedField { expr: rewrite_boxed(expr, rewriter)?, - key, + key: rewrite_boxed(key, rewriter)?, }, }; diff --git a/datafusion/core/src/logical_plan/expr_visitor.rs b/datafusion/core/src/logical_plan/expr_visitor.rs index 9d6697fc762e..9a771fcfa35b 100644 --- a/datafusion/core/src/logical_plan/expr_visitor.rs +++ b/datafusion/core/src/logical_plan/expr_visitor.rs @@ -101,8 +101,11 @@ impl ExprVisitable for Expr { | Expr::Negative(expr) | Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } - | Expr::Sort { expr, .. } - | Expr::GetIndexedField { expr, .. } => expr.accept(visitor), + | Expr::Sort { expr, .. } => expr.accept(visitor), + Expr::GetIndexedField { expr, key } => { + let visitor = expr.accept(visitor)?; + key.accept(visitor) + } Expr::Column(_) | Expr::OuterColumn(_, _) | Expr::ScalarVariable(_, _) diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 31c2070a6aba..1748f497d540 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -312,8 +312,10 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { | Expr::Alias(expr, ..) | Expr::Not(expr) | Expr::Negative(expr) - | Expr::Sort { expr, .. } - | Expr::GetIndexedField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), + | Expr::Sort { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), + Expr::GetIndexedField { expr, key } => { + Ok(vec![expr.as_ref().to_owned(), key.as_ref().to_owned()]) + } Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. } | Expr::TableUDF { args, .. } @@ -547,9 +549,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { "QualifiedWildcard expressions are not valid in a logical query plan" .to_owned(), )), - Expr::GetIndexedField { expr: _, key } => Ok(Expr::GetIndexedField { + Expr::GetIndexedField { .. } => Ok(Expr::GetIndexedField { expr: Box::new(expressions[0].clone()), - key: key.clone(), + key: Box::new(expressions[1].clone()), }), } } diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 37aa7c079dc4..9263d7b3c546 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -156,6 +156,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { } Expr::GetIndexedField { expr, key } => { let expr = create_physical_name(expr, false)?; + let key = create_physical_name(key, false)?; Ok(format!("{}[{}]", expr, key)) } Expr::ScalarFunction { fun, args, .. } => { @@ -1093,7 +1094,7 @@ pub fn create_physical_expr( )?), Expr::GetIndexedField { expr, key } => Ok(Arc::new(GetIndexedFieldExpr::new( create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, - key.clone(), + create_physical_expr(key, input_dfschema, input_schema, execution_props)?, ))), Expr::ScalarFunction { fun, args } => { diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 310cb6c2b90e..b6ee6c8eccca 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -133,25 +133,7 @@ impl SqlToRelContext { } } -fn plan_key(key: SQLExpr) -> Result { - let scalar = match key { - SQLExpr::Value(Value::Number(s, _)) => { - ScalarValue::Int64(Some(s.parse().unwrap())) - } - SQLExpr::Value(Value::SingleQuotedString(s)) => ScalarValue::Utf8(Some(s)), - SQLExpr::Identifier(ident) => ScalarValue::Utf8(Some(ident.value)), - _ => { - return Err(DataFusionError::SQL(ParserError(format!( - "Unsuported index key expression: {:?}", - key - )))) - } - }; - - Ok(scalar) -} - -fn plan_indexed(expr: Expr, mut keys: Vec) -> Result { +fn plan_indexed(expr: Expr, mut keys: Vec) -> Result { let key = keys.pop().ok_or_else(|| { DataFusionError::SQL(ParserError( "Internal error: Missing index key expression".to_string(), @@ -166,7 +148,7 @@ fn plan_indexed(expr: Expr, mut keys: Vec) -> Result { Ok(Expr::GetIndexedField { expr: Box::new(expr), - key: plan_key(key)?, + key: Box::new(key), }) } @@ -1704,26 +1686,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - SQLExpr::MapAccess { ref column, keys } => { - if let SQLExpr::Identifier(ref id) = column.as_ref() { - plan_indexed(col(&id.value), keys) - } else { - Err(DataFusionError::NotImplemented(format!( - "map access requires an identifier, found column {} instead", - column - ))) - } - } - SQLExpr::ArrayIndex { obj, indexs } => { - if let SQLExpr::Identifier(ref id) = obj.as_ref() { - plan_indexed(col(&id.value), indexs) - } else { - Err(DataFusionError::NotImplemented(format!( - "array index access requires an identifier, found column {} instead", - obj - ))) - } + let expr = self.sql_expr_to_logical_expr(*obj, schema)?; + + plan_indexed(expr, indexs.into_iter() + .map(|e| self.sql_expr_to_logical_expr(e, schema)) + .collect::>>()?) } SQLExpr::CompoundIdentifier(ids) => { @@ -1754,7 +1722,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Access to a field of a column which is a structure, example: SELECT my_struct.key Ok(Expr::GetIndexedField { expr: Box::new(Expr::Column(field.qualified_column())), - key: ScalarValue::Utf8(Some(name)), + key: Box::new(Expr::Literal(ScalarValue::Utf8(Some(name)))), }) } else { // table.column identifier @@ -2104,7 +2072,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::DotExpr { expr, field } => { Ok(Expr::GetIndexedField { expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema)?), - key: ScalarValue::Utf8(Some(field.value)), + key: Box::new(Expr::Literal(ScalarValue::Utf8(Some(field.value)))), }) } diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs index b35ff3d16d64..2614664552e7 100644 --- a/datafusion/core/src/sql/utils.rs +++ b/datafusion/core/src/sql/utils.rs @@ -391,7 +391,7 @@ where Expr::QualifiedWildcard { .. } => Ok(expr.clone()), Expr::GetIndexedField { expr, key } => Ok(Expr::GetIndexedField { expr: Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?), - key: key.clone(), + key: Box::new(clone_with_replacement(key.as_ref(), replacement_fn)?), }), }, } diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 0526b6c32da5..62298ef66e6f 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -596,12 +596,92 @@ async fn query_nested_get_indexed_field() -> Result<()> { "+----------+", ]; assert_batches_eq!(expected, &actual); + + // nested with scalar values let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+", ]; assert_batches_eq!(expected, &actual); + + // nested with dynamic expr in key + assert_batches_eq!(expected, &actual); + let sql = "SELECT some_list[1 - 1][1 - 1] as i0 FROM ints LIMIT 3"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn query_get_indexed_array_dynamic_key() -> Result<()> { + let ctx = SessionContext::new(); + + let list_dt = Box::new(Field::new("item", DataType::Int64, true)); + let schema = Arc::new(Schema::new(vec![ + Field::new("arr", DataType::List(list_dt), false), + Field::new("key", DataType::Int64, false), + ])); + + let array_ints_builder = PrimitiveBuilder::::new(3); + let mut arr_builder = ListBuilder::new(array_ints_builder); + let mut key_builder = PrimitiveBuilder::::new(3); + + for (int_vec, key) in vec![ + (vec![0, 1, 2, 3], 1), + (vec![4, 5, 6, 7], 2), + (vec![8, 9, 10, 11], 3), + ] { + for n in int_vec { + arr_builder.values().append_value(n)?; + } + + key_builder.append_value(key)?; + arr_builder.append(true)?; + } + + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arr_builder.finish()), + Arc::new(key_builder.finish()), + ], + )?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + let table_a = Arc::new(table); + + ctx.register_table("array_and_keys", table_a)?; + + let sql = "SELECT arr[key], key FROM array_and_keys"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----------------------------------------+-----+", + "| array_and_keys.arr[array_and_keys.key] | key |", + "+----------------------------------------+-----+", + "| 0 | 1 |", + "| 4 | 2 |", + "| 8 | 3 |", + "+----------------------------------------+-----+", + ]; + assert_batches_eq!(expected, &actual); + + // All dynamic + let sql = "SELECT r.value[r.key] FROM (SELECT array[1,2,3] as value, 1 as key UNION ALL SELECT array[4,5,6] as value, 2 as key) as r"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----------------+", + "| r.value[r.key] |", + "+----------------+", + "| 1 |", + "| 5 |", + "+----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) } @@ -634,7 +714,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { ctx.register_table("structs", table_a)?; // Original column is micros, convert to millis and check timestamp - let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3"; + let sql = "SELECT some_struct['bar'] as l0 FROM structs LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------+", @@ -661,7 +741,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { ]; assert_batches_eq!(expected, &actual); - let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3"; + let sql = "SELECT some_struct['bar'][0] as i0 FROM structs LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 8 |", "+----+", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 519ef7109bbb..5e624d005ed7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -114,7 +114,7 @@ pub enum Expr { /// the expression to take the field from expr: Box, /// The name of the field to take - key: ScalarValue, + key: Box, }, /// Whether an expression is between a given range. Between { diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 26a5cf2034a0..e4ba11d19d5d 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -17,10 +17,14 @@ //! get field of a `ListArray` -use crate::{field_util::get_indexed_field as get_data_type_field, PhysicalExpr}; -use arrow::array::Array; +use crate::expressions::Literal; +use crate::PhysicalExpr; +use arrow::array::{ + Array, Int64Array, StringArray, UInt16Array, UInt32Array, UInt64Array, +}; use arrow::array::{ListArray, StructArray}; use arrow::compute::concat; +use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -37,12 +41,12 @@ use std::{any::Any, sync::Arc}; #[derive(Debug)] pub struct GetIndexedFieldExpr { arg: Arc, - key: ScalarValue, + key: Arc, } impl GetIndexedFieldExpr { /// Create new get field expression - pub fn new(arg: Arc, key: ScalarValue) -> Self { + pub fn new(arg: Arc, key: Arc) -> Self { Self { arg, key } } @@ -50,6 +54,41 @@ impl GetIndexedFieldExpr { pub fn arg(&self) -> &Arc { &self.arg } + + fn get_data_type_field(&self, input_schema: &Schema) -> Result { + let data_type = self.arg.data_type(input_schema)?; + match data_type { + DataType::Struct(fields) => { + if let Some(key_lit) = self.key.as_any().downcast_ref::() { + match key_lit.value() { + ScalarValue::Utf8(Some(v)) => { + let field = fields.iter().find(|f| f.name() == v); + match field { + None => return Err(DataFusionError::Execution(format!( + "Field {} not found in struct", + v + ))), + Some(f) => return Ok(f.clone()), + } + }, + _ => {}, + } + } + + Err(DataFusionError::Execution(format!( + "Only utf8 strings are valid as an indexed field in a struct, actual: {:?}", + self.key + ))) + }, + DataType::List(lt) => { + Ok(Field::new("unknown", lt.data_type().clone(), false)) + }, + _ => Err(DataFusionError::Plan( + "The expression to get an indexed field is only valid for `List` and `Struct` types" + .to_string(), + )), + } + } } impl std::fmt::Display for GetIndexedFieldExpr { @@ -64,20 +103,21 @@ impl PhysicalExpr for GetIndexedFieldExpr { } fn data_type(&self, input_schema: &Schema) -> Result { - let data_type = self.arg.data_type(input_schema)?; - get_data_type_field(&data_type, &self.key).map(|f| f.data_type().clone()) + self.get_data_type_field(input_schema) + .map(|f| f.data_type().clone()) } fn nullable(&self, input_schema: &Schema) -> Result { - let data_type = self.arg.data_type(input_schema)?; - get_data_type_field(&data_type, &self.key).map(|f| f.is_nullable()) + self.get_data_type_field(input_schema) + .map(|f| f.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result { - let arg = self.arg.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => match (array.data_type(), &self.key) { - (DataType::List(_) | DataType::Struct(_), _) if self.key.is_null() => { + let left = self.arg.evaluate(batch)?; + let right = self.key.evaluate(batch)?; + match (left, right) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(key)) => match (array.data_type(), &key) { + (DataType::List(_) | DataType::Struct(_), _) if key.is_null() => { let scalar_null: ScalarValue = array.data_type().try_into()?; Ok(ColumnarValue::Scalar(scalar_null)) } @@ -98,14 +138,63 @@ impl PhysicalExpr for GetIndexedFieldExpr { } (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { let as_struct_array = array.as_any().downcast_ref::().unwrap(); - match as_struct_array.column_by_name(k) { + match as_struct_array.column_by_name(&k) { None => Err(DataFusionError::Execution(format!("get indexed field {} not found in struct", k))), Some(col) => Ok(ColumnarValue::Array(col.clone())) } } (dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))), }, - ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( + (ColumnarValue::Array(array), ColumnarValue::Array(wrapper)) => match (array.data_type(), wrapper.data_type()) { + (DataType::List(_), _) if wrapper.is_null(0) => { + let scalar_null: ScalarValue = array.data_type().try_into()?; + Ok(ColumnarValue::Scalar(scalar_null)) + }, + (DataType::List(_), DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64) => { + let as_list_array = + array.as_any().downcast_ref::().unwrap(); + + if as_list_array.is_empty() { + let scalar_null: ScalarValue = array.data_type().try_into()?; + return Ok(ColumnarValue::Scalar(scalar_null)) + } + + let key = match wrapper.data_type() { + DataType::Int16 => wrapper.as_any().downcast_ref::().unwrap().value(0) as usize, + DataType::Int32 => wrapper.as_any().downcast_ref::().unwrap().value(0) as usize, + DataType::Int64 => wrapper.as_any().downcast_ref::().unwrap().value(0) as usize, + DataType::UInt16 => wrapper.as_any().downcast_ref::().unwrap().value(0) as usize, + DataType::UInt32 => wrapper.as_any().downcast_ref::().unwrap().value(0) as usize, + DataType::UInt64 => wrapper.as_any().downcast_ref::().unwrap().value(0) as usize, + _ => unreachable!(), + }; + let key = key - 1; + + let sliced_array: Vec> = as_list_array + .iter() + .filter_map(|o| o.map(|list| list.slice(key, 1))) + .collect(); + let vec = sliced_array.iter().map(|a| a.as_ref()).collect::>(); + let iter = concat(vec.as_slice()).unwrap(); + Ok(ColumnarValue::Array(iter)) + }, + (DataType::Struct(_), DataType::Utf8) => { + let key = match wrapper.data_type() { + DataType::Utf8 => wrapper.as_any().downcast_ref::().unwrap().value(0), + _ => unreachable!(), + }; + + let as_struct_array = array.as_any().downcast_ref::().unwrap(); + match as_struct_array.column_by_name(&key) { + None => Err(DataFusionError::Execution(format!("get indexed field {} not found in struct", key))), + Some(col) => Ok(ColumnarValue::Array(col.clone())) + } + } + (DataType::List(_), key) => Err(DataFusionError::NotImplemented(format!("list field access is only possible with integers indexes. Tried with {} index", key))), + (DataType::Struct(_), key) => Err(DataFusionError::NotImplemented(format!("struct field access is only possible with utf8 literals indexes. Tried with {} index", key))), + (ldt, rdt) => Err(DataFusionError::Internal(format!("field access is only possible with struct/list. Tried to access {} with {} index", ldt, rdt))), + }, + (ColumnarValue::Scalar(_), _) => Err(DataFusionError::NotImplemented( "field access is not yet implemented for scalar values".to_string(), )), } @@ -150,7 +239,7 @@ mod tests { let list_col = build_utf8_lists(list_of_lists); let expr = col("l", &schema).unwrap(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_col)])?; - let key = ScalarValue::Int64(Some(index)); + let key = lit(ScalarValue::Int64(Some(index))); let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result @@ -196,7 +285,7 @@ mod tests { let mut lb = ListBuilder::new(builder); let expr = col("l", &schema).unwrap(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; - let key = ScalarValue::Int64(Some(0)); + let key = lit(ScalarValue::Int64(Some(0))); let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); assert!(result.is_empty()); @@ -206,7 +295,7 @@ mod tests { fn get_indexed_field_test_failure( schema: Schema, expr: Arc, - key: ScalarValue, + key: Arc, expected: &str, ) -> Result<()> { let builder = StringBuilder::new(3); @@ -223,14 +312,14 @@ mod tests { fn get_indexed_field_invalid_scalar() -> Result<()> { let schema = list_schema("l"); let expr = lit(ScalarValue::Utf8(Some("a".to_string()))); - get_indexed_field_test_failure(schema, expr, ScalarValue::Int64(Some(0)), "This feature is not implemented: field access is not yet implemented for scalar values") + get_indexed_field_test_failure(schema, expr, lit(ScalarValue::Int64(Some(0))), "This feature is not implemented: field access is not yet implemented for scalar values") } #[test] fn get_indexed_field_invalid_list_index() -> Result<()> { let schema = list_schema("l"); let expr = col("l", &schema).unwrap(); - get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index") + get_indexed_field_test_failure(schema, expr, lit(ScalarValue::Int8(Some(0))), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index") } fn build_struct( @@ -291,7 +380,7 @@ mod tests { let struct_col_expr = col("s", &schema).unwrap(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_col)])?; - let int_field_key = ScalarValue::Utf8(Some("foo".to_string())); + let int_field_key = lit(ScalarValue::Utf8(Some("foo".to_string()))); let get_field_expr = Arc::new(GetIndexedFieldExpr::new( struct_col_expr.clone(), int_field_key, @@ -306,7 +395,7 @@ mod tests { let expected = &Int64Array::from(expected_ints); assert_eq!(expected, result); - let list_field_key = ScalarValue::Utf8(Some("bar".to_string())); + let list_field_key = lit(ScalarValue::Utf8(Some("bar".to_string()))); let get_list_expr = Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key)); let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows()); @@ -321,7 +410,7 @@ mod tests { for (i, expected) in expected_strings.into_iter().enumerate() { let get_nested_str_expr = Arc::new(GetIndexedFieldExpr::new( get_list_expr.clone(), - ScalarValue::Int64(Some(i as i64)), + lit(ScalarValue::Int64(Some(i as i64))), )); let result = get_nested_str_expr .evaluate(&batch)? diff --git a/datafusion/physical-expr/src/field_util.rs b/datafusion/physical-expr/src/field_util.rs index 2c9411e875d4..389f9ef83e97 100644 --- a/datafusion/physical-expr/src/field_util.rs +++ b/datafusion/physical-expr/src/field_util.rs @@ -20,15 +20,16 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::Expr; /// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`] /// # Error /// Errors if /// * the `data_type` is not a Struct or, /// * there is no field key is not of the required index type -pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { - match (data_type, key) { - (DataType::List(lt), ScalarValue::Int64(Some(i))) => { +pub fn get_indexed_field(data_type: &DataType, key: &Box) -> Result { + match (data_type, &**key) { + (DataType::List(lt), Expr::Literal(ScalarValue::Int64(Some(i)))) => { if *i < 0 { Err(DataFusionError::Plan(format!( "List based indexed access requires a positive int, was {0}", @@ -38,7 +39,11 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { + // Allow any kind of dynamic expressions for key + (DataType::List(lt),_) => { + Ok(Field::new("unknown", lt.data_type().clone(), false)) + } + (DataType::Struct(fields), Expr::Literal(ScalarValue::Utf8(Some(s)))) => { if s.is_empty() { Err(DataFusionError::Plan( "Struct based indexed access requires a non empty string".to_string(), @@ -54,14 +59,12 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result Err(DataFusionError::Plan( - "Only utf8 strings are valid as an indexed field in a struct".to_string(), - )), - (DataType::List(_), _) => Err(DataFusionError::Plan( - "Only ints are valid as an indexed field in a list".to_string(), - )), + (DataType::Struct(_), key) => Err(DataFusionError::Plan(format!( + "Only utf8 strings are valid as an indexed field in a struct, actual: {}", + key + ))), _ => Err(DataFusionError::Plan( - "The expression to get an indexed field is only valid for `List` types" + "The expression to get an indexed field is only valid for `List` and `Struct` types" .to_string(), )), }