-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Map access supports constant-resolvable expressions #14712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,12 @@ | |
// under the License. | ||
|
||
use arrow::array::{ | ||
make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, | ||
make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData, | ||
Scalar, | ||
}; | ||
use arrow::compute::SortOptions; | ||
use arrow::datatypes::DataType; | ||
use arrow_buffer::NullBuffer; | ||
use datafusion_common::cast::{as_map_array, as_struct_array}; | ||
use datafusion_common::{ | ||
exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, | ||
|
@@ -104,26 +107,17 @@ impl ScalarUDFImpl for GetFieldFunc { | |
|
||
let name = match field_name { | ||
Expr::Literal(name) => name, | ||
_ => { | ||
return exec_err!( | ||
"get_field function requires the argument field_name to be a string" | ||
); | ||
} | ||
other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), | ||
}; | ||
|
||
Ok(format!("{base}[{name}]")) | ||
} | ||
|
||
fn schema_name(&self, args: &[Expr]) -> Result<String> { | ||
let [base, field_name] = take_function_args(self.name(), args)?; | ||
|
||
let name = match field_name { | ||
Expr::Literal(name) => name, | ||
_ => { | ||
return exec_err!( | ||
"get_field function requires the argument field_name to be a string" | ||
); | ||
} | ||
other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), | ||
}; | ||
|
||
Ok(format!("{}[{}]", base.schema_name(), name)) | ||
|
@@ -184,7 +178,6 @@ impl ScalarUDFImpl for GetFieldFunc { | |
let arrays = | ||
ColumnarValue::values_to_arrays(&[base.clone(), field_name.clone()])?; | ||
let array = Arc::clone(&arrays[0]); | ||
|
||
let name = match field_name { | ||
ColumnarValue::Scalar(name) => name, | ||
_ => { | ||
|
@@ -194,38 +187,70 @@ impl ScalarUDFImpl for GetFieldFunc { | |
} | ||
}; | ||
|
||
fn process_map_array( | ||
array: Arc<dyn Array>, | ||
key_array: Arc<dyn Array>, | ||
) -> Result<ColumnarValue> { | ||
let map_array = as_map_array(array.as_ref())?; | ||
let keys = if key_array.data_type().is_nested() { | ||
let comparator = make_comparator( | ||
map_array.keys().as_ref(), | ||
key_array.as_ref(), | ||
SortOptions::default(), | ||
)?; | ||
let len = map_array.keys().len().min(key_array.len()); | ||
let values = (0..len).map(|i| comparator(i, i).is_eq()).collect(); | ||
let nulls = | ||
NullBuffer::union(map_array.keys().nulls(), key_array.nulls()); | ||
BooleanArray::new(values, nulls) | ||
} else { | ||
let be_compared = Scalar::new(key_array); | ||
arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())? | ||
}; | ||
|
||
let original_data = map_array.entries().column(1).to_data(); | ||
let capacity = Capacities::Array(original_data.len()); | ||
let mut mutable = | ||
MutableArrayData::with_capacities(vec![&original_data], true, capacity); | ||
|
||
for entry in 0..map_array.len() { | ||
let start = map_array.value_offsets()[entry] as usize; | ||
let end = map_array.value_offsets()[entry + 1] as usize; | ||
|
||
let maybe_matched = keys | ||
.slice(start, end - start) | ||
.iter() | ||
.enumerate() | ||
.find(|(_, t)| t.unwrap()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why this panic is never called, but then I see you just moved the code |
||
|
||
if maybe_matched.is_none() { | ||
mutable.extend_nulls(1); | ||
continue; | ||
} | ||
let (match_offset, _) = maybe_matched.unwrap(); | ||
mutable.extend(0, start + match_offset, start + match_offset + 1); | ||
} | ||
|
||
let data = mutable.freeze(); | ||
let data = make_array(data); | ||
Ok(ColumnarValue::Array(data)) | ||
} | ||
|
||
match (array.data_type(), name) { | ||
(DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { | ||
let map_array = as_map_array(array.as_ref())?; | ||
let key_scalar: Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>> = Scalar::new(StringArray::from(vec![k.clone()])); | ||
let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; | ||
|
||
// note that this array has more entries than the expected output/input size | ||
// because map_array is flattened | ||
let original_data = map_array.entries().column(1).to_data(); | ||
let capacity = Capacities::Array(original_data.len()); | ||
let mut mutable = | ||
MutableArrayData::with_capacities(vec![&original_data], true, | ||
capacity); | ||
|
||
for entry in 0..map_array.len(){ | ||
let start = map_array.value_offsets()[entry] as usize; | ||
let end = map_array.value_offsets()[entry + 1] as usize; | ||
|
||
let maybe_matched = | ||
keys.slice(start, end-start). | ||
iter().enumerate(). | ||
find(|(_, t)| t.unwrap()); | ||
if maybe_matched.is_none() { | ||
mutable.extend_nulls(1); | ||
continue | ||
} | ||
let (match_offset,_) = maybe_matched.unwrap(); | ||
mutable.extend(0, start + match_offset, start + match_offset + 1); | ||
(DataType::Map(_, _), ScalarValue::List(arr)) => { | ||
let key_array: Arc<dyn Array> = Arc::new((**arr).clone()); | ||
process_map_array(array, key_array) | ||
} | ||
(DataType::Map(_, _), ScalarValue::Struct(arr)) => { | ||
process_map_array(array, Arc::clone(arr) as Arc<dyn Array>) | ||
} | ||
(DataType::Map(_, _), other) => { | ||
let data_type = other.data_type(); | ||
if data_type.is_nested() { | ||
exec_err!("unsupported type {:?} for map access", data_type) | ||
} else { | ||
process_map_array(array, other.to_array()?) | ||
} | ||
let data = mutable.freeze(); | ||
let data = make_array(data); | ||
Ok(ColumnarValue::Array(data)) | ||
} | ||
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { | ||
let as_struct_array = as_struct_array(&array)?; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -592,6 +592,43 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) | |
[NULL] [NULL] [[1, NULL, 3]] | ||
[NULL] [NULL] [NULL] | ||
|
||
query ? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you also please add some error tests for out of bounds access. For example, Also, it seems like the code in this PR supports queries like select column1[column2] where column2 is an array of integer as well Can you also add a test for that syntax as well? |
||
select column1[1] from map_array_table_1; | ||
---- | ||
[1, NULL, 3] | ||
NULL | ||
NULL | ||
NULL | ||
|
||
query ? | ||
select column1[-1000 + 1001] from map_array_table_1; | ||
---- | ||
[1, NULL, 3] | ||
NULL | ||
NULL | ||
NULL | ||
|
||
# test for negative scenario | ||
query ? | ||
SELECT column1[-1] FROM map_array_table_1; | ||
---- | ||
NULL | ||
NULL | ||
NULL | ||
NULL | ||
|
||
query ? | ||
SELECT column1[1000] FROM map_array_table_1; | ||
---- | ||
NULL | ||
NULL | ||
NULL | ||
NULL | ||
|
||
|
||
query error DataFusion error: Arrow error: Invalid argument error | ||
SELECT column1[NULL] FROM map_array_table_1; | ||
|
||
query ??? | ||
select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1; | ||
---- | ||
|
@@ -722,3 +759,28 @@ drop table map_array_table_1; | |
|
||
statement ok | ||
drop table map_array_table_2; | ||
|
||
|
||
statement ok | ||
create table tt as values(MAP{[1,2,3]:1}, MAP {{'a':1, 'b':2}:2}, MAP{true: 3}); | ||
|
||
# accessing using an array | ||
query I | ||
select column1[make_array(1, 2, 3)] from tt; | ||
---- | ||
1 | ||
|
||
# accessing using a struct | ||
query I | ||
select column2[{a:1, b: 2}] from tt; | ||
---- | ||
2 | ||
|
||
# accessing using Bool | ||
query I | ||
select column3[true] from tt; | ||
---- | ||
3 | ||
|
||
statement ok | ||
drop table tt; |
Uh oh!
There was an error while loading. Please reload this page.