From 644f86fdc2d791f3106c753ae83c97d90e217e5b Mon Sep 17 00:00:00 2001 From: Yanxin Xiang Date: Sun, 16 Feb 2025 22:37:07 -0800 Subject: [PATCH 1/3] Map access supports constant-resolvable expressions --- datafusion/functions-nested/src/planner.rs | 19 ++++- datafusion/functions/src/core/getfield.rs | 90 ++++++++++++---------- datafusion/sqllogictest/test_files/map.slt | 18 +++++ 3 files changed, 82 insertions(+), 45 deletions(-) diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index d55176a42c9a..b42f6e55901d 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -17,8 +17,8 @@ //! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`] -use std::sync::Arc; - +use arrow::datatypes::DataType; +use datafusion_common::ExprSchema; use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction}; use datafusion_expr::AggregateUDF; @@ -26,8 +26,10 @@ use datafusion_expr::{ planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, sqlparser, Expr, ExprSchemable, GetFieldAccess, }; +use datafusion_functions::core::get_field as get_field_inner; use datafusion_functions::expr_fn::get_field; use datafusion_functions_aggregate::nth_value::nth_value_udaf; +use std::sync::Arc; use crate::map::map_udf; use crate::{ @@ -140,7 +142,7 @@ impl ExprPlanner for FieldAccessPlanner { fn plan_field_access( &self, expr: RawFieldAccessExpr, - _schema: &DFSchema, + schema: &DFSchema, ) -> Result> { let RawFieldAccessExpr { expr, field_access } = expr; @@ -173,6 +175,17 @@ impl ExprPlanner for FieldAccessPlanner { null_treatment, )), )), + // special case for map access with + Expr::Column(ref c) + if matches!(schema.data_type(c)?, DataType::Map(_, _)) => + { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + get_field_inner(), + vec![expr, *index], + ), + ))) + } _ => Ok(PlannerResult::Planned(array_element(expr, *index))), } } diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index d971001dbf78..66a325f21eaf 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::Int64Array; use arrow::array::{ make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, }; @@ -104,11 +105,7 @@ impl ScalarUDFImpl for GetFieldFunc { let name = match field_name { Expr::Literal(name) => name, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } + other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; Ok(format!("{base}[{name}]")) @@ -116,14 +113,9 @@ impl ScalarUDFImpl for GetFieldFunc { fn schema_name(&self, args: &[Expr]) -> Result { let [base, field_name] = take_function_args(self.name(), args)?; - let name = match field_name { Expr::Literal(name) => name, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } + other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; Ok(format!("{}[{}]", base.schema_name(), name)) @@ -184,7 +176,6 @@ impl ScalarUDFImpl for GetFieldFunc { let arrays = ColumnarValue::values_to_arrays(&[base.clone(), field_name.clone()])?; let array = Arc::clone(&arrays[0]); - let name = match field_name { ColumnarValue::Scalar(name) => name, _ => { @@ -194,39 +185,54 @@ impl ScalarUDFImpl for GetFieldFunc { } }; + fn process_map_array( + array: Arc, + key_scalar: Scalar, + ) -> Result + where + K: Array + 'static, + { + let map_array = as_map_array(array.as_ref())?; + let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + + let original_data = map_array.entries().column(1).to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + for entry in 0..map_array.len() { + let start = map_array.value_offsets()[entry] as usize; + let end = map_array.value_offsets()[entry + 1] as usize; + + let maybe_matched = keys + .slice(start, end - start) + .iter() + .enumerate() + .find(|(_, t)| t.unwrap()); + + if maybe_matched.is_none() { + mutable.extend_nulls(1); + continue; + } + let (match_offset, _) = maybe_matched.unwrap(); + mutable.extend(0, start + match_offset, start + match_offset + 1); + } + + let data = mutable.freeze(); + let data = make_array(data); + Ok(ColumnarValue::Array(data)) + } + match (array.data_type(), name) { (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let map_array = as_map_array(array.as_ref())?; - let key_scalar: Scalar>> = Scalar::new(StringArray::from(vec![k.clone()])); - let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; - - // note that this array has more entries than the expected output/input size - // because map_array is flattened - let original_data = map_array.entries().column(1).to_data(); - let capacity = Capacities::Array(original_data.len()); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], true, - capacity); - - for entry in 0..map_array.len(){ - let start = map_array.value_offsets()[entry] as usize; - let end = map_array.value_offsets()[entry + 1] as usize; - - let maybe_matched = - keys.slice(start, end-start). - iter().enumerate(). - find(|(_, t)| t.unwrap()); - if maybe_matched.is_none() { - mutable.extend_nulls(1); - continue - } - let (match_offset,_) = maybe_matched.unwrap(); - mutable.extend(0, start + match_offset, start + match_offset + 1); - } - let data = mutable.freeze(); - let data = make_array(data); - Ok(ColumnarValue::Array(data)) + let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); + process_map_array::(array, key_scalar) + } + (DataType::Map(_, _), ScalarValue::Int64(Some(k))) => { + let key_scalar = Scalar::new(Int64Array::from(vec![*k])); + process_map_array::(array, key_scalar) } + (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { let as_struct_array = as_struct_array(&array)?; match as_struct_array.column_by_name(k) { diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 71296b6f6474..996d3f78adac 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -592,6 +592,24 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) [NULL] [NULL] [[1, NULL, 3]] [NULL] [NULL] [NULL] +query ? +select column1[1] from map_array_table_1; +---- +[1, NULL, 3] +NULL +NULL +NULL + +query ? +select column1[-1000 + 1001] from map_array_table_1; +---- +[1, NULL, 3] +NULL +NULL +NULL + + + query ??? select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1; ---- From 1265bba5fecf1cec7ffbc241fcc21033aff1ee62 Mon Sep 17 00:00:00 2001 From: Yanxin Xiang Date: Mon, 17 Feb 2025 13:47:06 -0800 Subject: [PATCH 2/3] adding tests fix clippy fix clippy fix clippy --- datafusion/functions-nested/src/planner.rs | 3 +- datafusion/functions/src/core/getfield.rs | 49 ++++++++++++++++++---- datafusion/sqllogictest/test_files/map.slt | 44 +++++++++++++++++++ 3 files changed, 86 insertions(+), 10 deletions(-) diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index b42f6e55901d..369eaecb1905 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -20,7 +20,8 @@ use arrow::datatypes::DataType; use datafusion_common::ExprSchema; use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; -use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::AggregateUDF; use datafusion_expr::{ planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 66a325f21eaf..47014c7ce6fb 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Int64Array; use arrow::array::{ - make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, + make_array, make_comparator, Array, BooleanArray, Capacities, Datum, + MutableArrayData, Scalar, StringArray, StructArray, }; +use arrow::array::{Int64Array, ListArray}; +use arrow::compute::SortOptions; use arrow::datatypes::DataType; +use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, @@ -187,13 +190,27 @@ impl ScalarUDFImpl for GetFieldFunc { fn process_map_array( array: Arc, - key_scalar: Scalar, + key_array: Arc, ) -> Result where K: Array + 'static, { let map_array = as_map_array(array.as_ref())?; - let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + let keys = if key_array.data_type().is_nested() { + let comparator = make_comparator( + map_array.keys().as_ref(), + key_array.as_ref(), + SortOptions::default(), + )?; + let len = map_array.keys().len().min(key_array.len()); + let values = (0..len).map(|i| comparator(i, i).is_eq()).collect(); + let nulls = + NullBuffer::union(map_array.keys().nulls(), key_array.nulls()); + BooleanArray::new(values, nulls) + } else { + let be_compared = Scalar::new(key_array); + arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())? + }; let original_data = map_array.entries().column(1).to_data(); let capacity = Capacities::Array(original_data.len()); @@ -225,14 +242,28 @@ impl ScalarUDFImpl for GetFieldFunc { match (array.data_type(), name) { (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); - process_map_array::(array, key_scalar) + let key_array: Arc = Arc::new(StringArray::from(vec![k.clone()])); + process_map_array::(array, key_array) } (DataType::Map(_, _), ScalarValue::Int64(Some(k))) => { - let key_scalar = Scalar::new(Int64Array::from(vec![*k])); - process_map_array::(array, key_scalar) + let key_array: Arc = Arc::new(Int64Array::from(vec![*k])); + process_map_array::(array, key_array) + } + (DataType::Map(_, _), ScalarValue::List(arr)) => { + let key_array: Arc = Arc::new((**arr).clone()); + process_map_array::(array, key_array) + } + (DataType::Map(_, _), ScalarValue::Struct(arr)) => { + process_map_array::(array, Arc::new(arr.clone() as Arc)) + } + (DataType::Map(_, _), other) => { + let data_type = other.data_type(); + if data_type.is_nested() { + return exec_err!("unsupported type {:?} for map access", data_type); + } else { + process_map_array::>(array, other.to_array()?) + } } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { let as_struct_array = as_struct_array(&array)?; match as_struct_array.column_by_name(k) { diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 996d3f78adac..42a4ba621801 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -608,7 +608,26 @@ NULL NULL NULL +# test for negative scenario +query ? +SELECT column1[-1] FROM map_array_table_1; +---- +NULL +NULL +NULL +NULL +query ? +SELECT column1[1000] FROM map_array_table_1; +---- +NULL +NULL +NULL +NULL + + +query error DataFusion error: Arrow error: Invalid argument error +SELECT column1[NULL] FROM map_array_table_1; query ??? select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1; @@ -740,3 +759,28 @@ drop table map_array_table_1; statement ok drop table map_array_table_2; + + +statement ok +create table tt as values(MAP{[1,2,3]:1}, MAP {{'a':1, 'b':2}:2}, MAP{true: 3}); + +# accessing using an array +query I +select column1[make_array(1, 2, 3)] from tt; +---- +1 + +# accessing using a struct +query I +select column2[{a:1, b: 2}] from tt; +---- +2 + +# accessing using Bool +query I +select column3[true] from tt; +---- +3 + +statement ok +drop table tt; From 2d64c9101cc37bfe4e438c26c2e1c73d81953f0a Mon Sep 17 00:00:00 2001 From: Yanxin Xiang Date: Mon, 17 Feb 2025 14:37:41 -0800 Subject: [PATCH 3/3] fix clippy --- datafusion/functions/src/core/getfield.rs | 28 +++++++---------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 47014c7ce6fb..7c196d0ba69e 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -16,10 +16,9 @@ // under the License. use arrow::array::{ - make_array, make_comparator, Array, BooleanArray, Capacities, Datum, - MutableArrayData, Scalar, StringArray, StructArray, + make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData, + Scalar, }; -use arrow::array::{Int64Array, ListArray}; use arrow::compute::SortOptions; use arrow::datatypes::DataType; use arrow_buffer::NullBuffer; @@ -188,13 +187,10 @@ impl ScalarUDFImpl for GetFieldFunc { } }; - fn process_map_array( + fn process_map_array( array: Arc, key_array: Arc, - ) -> Result - where - K: Array + 'static, - { + ) -> Result { let map_array = as_map_array(array.as_ref())?; let keys = if key_array.data_type().is_nested() { let comparator = make_comparator( @@ -241,27 +237,19 @@ impl ScalarUDFImpl for GetFieldFunc { } match (array.data_type(), name) { - (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let key_array: Arc = Arc::new(StringArray::from(vec![k.clone()])); - process_map_array::(array, key_array) - } - (DataType::Map(_, _), ScalarValue::Int64(Some(k))) => { - let key_array: Arc = Arc::new(Int64Array::from(vec![*k])); - process_map_array::(array, key_array) - } (DataType::Map(_, _), ScalarValue::List(arr)) => { let key_array: Arc = Arc::new((**arr).clone()); - process_map_array::(array, key_array) + process_map_array(array, key_array) } (DataType::Map(_, _), ScalarValue::Struct(arr)) => { - process_map_array::(array, Arc::new(arr.clone() as Arc)) + process_map_array(array, Arc::clone(arr) as Arc) } (DataType::Map(_, _), other) => { let data_type = other.data_type(); if data_type.is_nested() { - return exec_err!("unsupported type {:?} for map access", data_type); + exec_err!("unsupported type {:?} for map access", data_type) } else { - process_map_array::>(array, other.to_array()?) + process_map_array(array, other.to_array()?) } } (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {