diff --git a/datafusion/spark/src/function/map/map_function.rs b/datafusion/spark/src/function/map/map_function.rs new file mode 100644 index 000000000000..25e264df8c01 --- /dev/null +++ b/datafusion/spark/src/function/map/map_function.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, MapArray, StructArray}; +use arrow::buffer::OffsetBuffer; +use arrow::compute::interleave; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +trait KeyValue +where + T: 'static, +{ + fn keys(&self) -> impl Iterator; + fn values(&self) -> impl Iterator; +} + +impl KeyValue for &[T] +where + T: 'static, +{ + fn keys(&self) -> impl Iterator { + self.iter().step_by(2) + } + + fn values(&self) -> impl Iterator { + self.iter().skip(1).step_by(2) + } +} + +fn to_map_array(args: &[ArrayRef]) -> Result { + if !args.len().is_multiple_of(2) { + return exec_err!("map requires an even number of arguments"); + } + let num_entries = args.len() / 2; + let num_rows = args.first().map(|a| a.len()).unwrap_or(0); + if args.iter().any(|a| a.len() != num_rows) { + return exec_err!("map requires all arrays to have the same length"); + } + let key_type = args + .first() + .map(|a| a.data_type()) + .unwrap_or(&DataType::Null); + let value_type = args + .get(1) + .map(|a| a.data_type()) + .unwrap_or(&DataType::Null); + let keys = args.keys().map(|a| a.as_ref()).collect::>(); + let values = args.values().map(|a| a.as_ref()).collect::>(); + if keys.iter().any(|a| a.data_type() != key_type) { + return exec_err!("map requires all key types to be the same"); + } + if values.iter().any(|a| a.data_type() != value_type) { + return exec_err!("map requires all value types to be the same"); + } + // TODO: avoid materializing the indices + let indices = (0..num_rows) + .flat_map(|i| (0..num_entries).map(move |j| (j, i))) + .collect::>(); + let keys = interleave(keys.as_slice(), indices.as_slice())?; + let values = interleave(values.as_slice(), indices.as_slice())?; + let offsets = (0..num_rows + 1) + .map(|i| i as i32 * num_entries as i32) + .collect::>(); + let offsets = OffsetBuffer::new(offsets.into()); + let fields = Fields::from(vec![ + Field::new("key", key_type.clone(), false), + Field::new("value", value_type.clone(), true), + ]); + let entries = StructArray::try_new(fields.clone(), vec![keys, values], None)?; + let field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); + Ok(Arc::new(MapArray::try_new( + field, offsets, entries, None, false, + )?)) +} + +#[derive(Debug, Clone)] +pub struct MapFunction { + signature: Signature, +} + +impl Default for MapFunction { + fn default() -> Self { + Self::new() + } +} + +impl MapFunction { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MapFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types.len().is_multiple_of(2) { + return exec_err!("map requires an even number of arguments"); + } + let key_type = arg_types.first().unwrap_or(&DataType::Null); + let value_type = arg_types.get(1).unwrap_or(&DataType::Null); + // TODO: support type coercion + if arg_types.keys().any(|dt| dt != key_type) { + return exec_err!("map requires all key types to be the same"); + } + if arg_types.values().any(|dt| dt != value_type) { + return exec_err!("map requires all value types to be the same"); + } + Ok(DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + // the key must not be nullable + Field::new("key", key_type.clone(), false), + Field::new("value", value_type.clone(), true), + ])), + false, // the entry is not nullable + )), + false, // the keys are not sorted + )) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + let arrays = ColumnarValue::values_to_arrays(&args)?; + Ok(ColumnarValue::Array(to_map_array(arrays.as_slice())?)) + } +} diff --git a/datafusion/spark/src/function/map/mod.rs b/datafusion/spark/src/function/map/mod.rs index a87df9a2c87a..e2f1ecea386f 100644 --- a/datafusion/spark/src/function/map/mod.rs +++ b/datafusion/spark/src/function/map/mod.rs @@ -15,11 +15,20 @@ // specific language governing permissions and limitations // under the License. +pub mod map_function; + use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +make_udf_function!(map_function::MapFunction, map); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((map, "Creates a map with the given key/value pairs.", args)); +} pub fn functions() -> Vec> { - vec![] + vec![map()] } diff --git a/datafusion/sqllogictest/test_files/spark/map/map.slt b/datafusion/sqllogictest/test_files/spark/map/map.slt new file mode 100644 index 000000000000..c071df5e3c5c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/map/map.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query ? +SELECT map(1.0, '2', 3.0, '4'); +---- +{1.0: 2, 3.0: 4} + +query ? +SELECT map('a', 1, 'b', 2); +---- +{a: 1, b: 2} + +query ? +SELECT map(TRUE, 'yes', FALSE, 'no'); +---- +{true: yes, false: no} + + +statement error DataFusion error: Execution error: map requires an even number of arguments +SELECT map(1, 'a', 2); + +statement error DataFusion error: Execution error: map requires an even number of arguments +SELECT map('key_only'); + + +statement error DataFusion error: Execution error: map requires all value types to be the same +SELECT map('inner', map('a', 1), 'b', 2); + + +query ? +SELECT map('', 'empty', 'non-empty', 'val'); +---- +{: empty, non-empty: val}