Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions datafusion/functions-nested/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@

//! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`]

use std::sync::Arc;

use arrow::datatypes::DataType;
use datafusion_common::ExprSchema;
use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::AggregateUDF;
use datafusion_expr::{
planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
sqlparser, Expr, ExprSchemable, GetFieldAccess,
};
use datafusion_functions::core::get_field as get_field_inner;
use datafusion_functions::expr_fn::get_field;
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
use std::sync::Arc;

use crate::map::map_udf;
use crate::{
Expand Down Expand Up @@ -140,7 +143,7 @@ impl ExprPlanner for FieldAccessPlanner {
fn plan_field_access(
&self,
expr: RawFieldAccessExpr,
_schema: &DFSchema,
schema: &DFSchema,
) -> Result<PlannerResult<RawFieldAccessExpr>> {
let RawFieldAccessExpr { expr, field_access } = expr;

Expand Down Expand Up @@ -173,6 +176,17 @@ impl ExprPlanner for FieldAccessPlanner {
null_treatment,
)),
)),
// special case for map access with
Expr::Column(ref c)
if matches!(schema.data_type(c)?, DataType::Map(_, _)) =>
{
Ok(PlannerResult::Planned(Expr::ScalarFunction(
ScalarFunction::new_udf(
get_field_inner(),
vec![expr, *index],
),
)))
}
_ => Ok(PlannerResult::Planned(array_element(expr, *index))),
}
}
Expand Down
111 changes: 68 additions & 43 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
// under the License.

use arrow::array::{
make_array, Array, Capacities, MutableArrayData, Scalar, StringArray,
make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData,
Scalar,
};
use arrow::compute::SortOptions;
use arrow::datatypes::DataType;
use arrow_buffer::NullBuffer;
use datafusion_common::cast::{as_map_array, as_struct_array};
use datafusion_common::{
exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result,
Expand Down Expand Up @@ -104,26 +107,17 @@ impl ScalarUDFImpl for GetFieldFunc {

let name = match field_name {
Expr::Literal(name) => name,
_ => {
return exec_err!(
"get_field function requires the argument field_name to be a string"
);
}
other => &ScalarValue::Utf8(Some(other.schema_name().to_string())),
};

Ok(format!("{base}[{name}]"))
}

fn schema_name(&self, args: &[Expr]) -> Result<String> {
let [base, field_name] = take_function_args(self.name(), args)?;

let name = match field_name {
Expr::Literal(name) => name,
_ => {
return exec_err!(
"get_field function requires the argument field_name to be a string"
);
}
other => &ScalarValue::Utf8(Some(other.schema_name().to_string())),
};

Ok(format!("{}[{}]", base.schema_name(), name))
Expand Down Expand Up @@ -184,7 +178,6 @@ impl ScalarUDFImpl for GetFieldFunc {
let arrays =
ColumnarValue::values_to_arrays(&[base.clone(), field_name.clone()])?;
let array = Arc::clone(&arrays[0]);

let name = match field_name {
ColumnarValue::Scalar(name) => name,
_ => {
Expand All @@ -194,38 +187,70 @@ impl ScalarUDFImpl for GetFieldFunc {
}
};

fn process_map_array(
array: Arc<dyn Array>,
key_array: Arc<dyn Array>,
) -> Result<ColumnarValue> {
let map_array = as_map_array(array.as_ref())?;
let keys = if key_array.data_type().is_nested() {
let comparator = make_comparator(
map_array.keys().as_ref(),
key_array.as_ref(),
SortOptions::default(),
)?;
let len = map_array.keys().len().min(key_array.len());
let values = (0..len).map(|i| comparator(i, i).is_eq()).collect();
let nulls =
NullBuffer::union(map_array.keys().nulls(), key_array.nulls());
BooleanArray::new(values, nulls)
} else {
let be_compared = Scalar::new(key_array);
arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())?
};

let original_data = map_array.entries().column(1).to_data();
let capacity = Capacities::Array(original_data.len());
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true, capacity);

for entry in 0..map_array.len() {
let start = map_array.value_offsets()[entry] as usize;
let end = map_array.value_offsets()[entry + 1] as usize;

let maybe_matched = keys
.slice(start, end - start)
.iter()
.enumerate()
.find(|(_, t)| t.unwrap());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this panic is never called, but then I see you just moved the code


if maybe_matched.is_none() {
mutable.extend_nulls(1);
continue;
}
let (match_offset, _) = maybe_matched.unwrap();
mutable.extend(0, start + match_offset, start + match_offset + 1);
}

let data = mutable.freeze();
let data = make_array(data);
Ok(ColumnarValue::Array(data))
}

match (array.data_type(), name) {
(DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
let map_array = as_map_array(array.as_ref())?;
let key_scalar: Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>> = Scalar::new(StringArray::from(vec![k.clone()]));
let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?;

// note that this array has more entries than the expected output/input size
// because map_array is flattened
let original_data = map_array.entries().column(1).to_data();
let capacity = Capacities::Array(original_data.len());
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true,
capacity);

for entry in 0..map_array.len(){
let start = map_array.value_offsets()[entry] as usize;
let end = map_array.value_offsets()[entry + 1] as usize;

let maybe_matched =
keys.slice(start, end-start).
iter().enumerate().
find(|(_, t)| t.unwrap());
if maybe_matched.is_none() {
mutable.extend_nulls(1);
continue
}
let (match_offset,_) = maybe_matched.unwrap();
mutable.extend(0, start + match_offset, start + match_offset + 1);
(DataType::Map(_, _), ScalarValue::List(arr)) => {
let key_array: Arc<dyn Array> = Arc::new((**arr).clone());
process_map_array(array, key_array)
}
(DataType::Map(_, _), ScalarValue::Struct(arr)) => {
process_map_array(array, Arc::clone(arr) as Arc<dyn Array>)
}
(DataType::Map(_, _), other) => {
let data_type = other.data_type();
if data_type.is_nested() {
exec_err!("unsupported type {:?} for map access", data_type)
} else {
process_map_array(array, other.to_array()?)
}
let data = mutable.freeze();
let data = make_array(data);
Ok(ColumnarValue::Array(data))
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = as_struct_array(&array)?;
Expand Down
62 changes: 62 additions & 0 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,43 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7)
[NULL] [NULL] [[1, NULL, 3]]
[NULL] [NULL] [NULL]

query ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also please add some error tests for out of bounds access. For example, column[0], column[-1] and column[1000] to test the boundary condidtions?

Also, it seems like the code in this PR supports queries like

select column1[column2]

where column2 is an array of integer as well

Can you also add a test for that syntax as well?

select column1[1] from map_array_table_1;
----
[1, NULL, 3]
NULL
NULL
NULL

query ?
select column1[-1000 + 1001] from map_array_table_1;
----
[1, NULL, 3]
NULL
NULL
NULL

# test for negative scenario
query ?
SELECT column1[-1] FROM map_array_table_1;
----
NULL
NULL
NULL
NULL

query ?
SELECT column1[1000] FROM map_array_table_1;
----
NULL
NULL
NULL
NULL


query error DataFusion error: Arrow error: Invalid argument error
SELECT column1[NULL] FROM map_array_table_1;

query ???
select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1;
----
Expand Down Expand Up @@ -722,3 +759,28 @@ drop table map_array_table_1;

statement ok
drop table map_array_table_2;


statement ok
create table tt as values(MAP{[1,2,3]:1}, MAP {{'a':1, 'b':2}:2}, MAP{true: 3});

# accessing using an array
query I
select column1[make_array(1, 2, 3)] from tt;
----
1

# accessing using a struct
query I
select column2[{a:1, b: 2}] from tt;
----
2

# accessing using Bool
query I
select column3[true] from tt;
----
3

statement ok
drop table tt;
Loading