diff --git a/native/spark-expr/src/array_funcs/array_position.rs b/native/spark-expr/src/array_funcs/array_position.rs new file mode 100644 index 0000000000..868e3307b6 --- /dev/null +++ b/native/spark-expr/src/array_funcs/array_position.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark array_position() function that returns the 1-based position of an element in an array. +/// Returns 0 if the element is not found (Spark behavior differs from DataFusion which returns null). +pub fn spark_array_position(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return exec_err!("array_position function takes exactly two arguments"); + } + + // Convert all arguments to arrays for consistent processing + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let arrays = ColumnarValue::values_to_arrays(args)?; + + let result = array_position_inner(&arrays)?; + + if is_scalar { + let scalar = ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } else { + Ok(ColumnarValue::Array(result)) + } +} + +fn array_position_inner(args: &[ArrayRef]) -> Result { + let array = &args[0]; + let element = &args[1]; + + match array.data_type() { + DataType::List(_) => generic_array_position::(array, element), + DataType::LargeList(_) => generic_array_position::(array, element), + other => exec_err!("array_position does not support type '{other:?}'"), + } +} + +fn generic_array_position( + array: &ArrayRef, + element: &ArrayRef, +) -> Result { + let list_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let mut data = Vec::with_capacity(list_array.len()); + + for row_index in 0..list_array.len() { + if list_array.is_null(row_index) { + // Null array returns null position (same as Spark) + data.push(None); + } else if element.is_null(row_index) { + // Searching for null element returns null in Spark + data.push(None); + } else { + let list_array_row = list_array.value(row_index); + + // Get the search element as a scalar + let element_scalar = ScalarValue::try_from_array(element, row_index)?; + + // Compare element to each item in the list + let mut position: i64 = 0; + for i in 0..list_array_row.len() { + let list_item_scalar = ScalarValue::try_from_array(&list_array_row, i)?; + + // null != anything in Spark array_position + if !list_item_scalar.is_null() && element_scalar == list_item_scalar { + position = (i + 1) as i64; // 1-indexed + break; + } + } + + data.push(Some(position)); + } + } + + Ok(Arc::new(Int64Array::from(data))) +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkArrayPositionFunc { + signature: Signature, +} + +impl Default for SparkArrayPositionFunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayPositionFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArrayPositionFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "spark_array_position" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + spark_array_position(&args.args) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 3ef50a252f..407cd4661b 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -16,11 +16,13 @@ // under the License. mod array_insert; +mod array_position; mod get_array_struct_fields; mod list_extract; mod size; pub use array_insert::ArrayInsert; +pub use array_position::{spark_array_position, SparkArrayPositionFunc}; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; pub use size::{spark_size, SparkSizeFunc}; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 4bfdef7096..be6996cd77 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, - SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, + spark_unscaled_value, EvalMode, SparkArrayPositionFunc, SparkBitwiseCount, SparkContains, + SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -191,6 +191,7 @@ pub fn create_comet_physical_fun_with_eval_mode( fn all_scalar_functions() -> Vec> { vec![ + Arc::new(ScalarUDF::new_from_impl(SparkArrayPositionFunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), Arc::new(ScalarUDF::new_from_impl(SparkContains::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())), diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f627b0c465..742687a18d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -56,6 +56,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArrayJoin] -> CometArrayJoin, classOf[ArrayMax] -> CometArrayMax, classOf[ArrayMin] -> CometArrayMin, + classOf[ArrayPosition] -> CometArrayPosition, classOf[ArrayRemove] -> CometArrayRemove, classOf[ArrayRepeat] -> CometArrayRepeat, classOf[ArraysOverlap] -> CometArraysOverlap, diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index b7ebb9ba7b..e3dc7ba56e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -657,6 +657,36 @@ object CometSize extends CometExpressionSerde[Size] { } +object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase { + + override def convert( + expr: ArrayPosition, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + if (expr.children.forall(_.foldable)) { + withInfo(expr, "all arguments are literals, falling back to Spark") + return None + } + // Check if input types are supported + val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet + for (dt <- inputTypes) { + if (!isTypeSupported(dt)) { + withInfo(expr, s"data type not supported: $dt") + return None + } + } + + val arrayExprProto = exprToProto(expr.left, inputs, binding) + val elementExprProto = exprToProto(expr.right, inputs, binding) + + // Use spark_array_position which returns Int64 and 0 when not found + // (matching Spark's behavior) + val optExpr = + scalarFunctionExprToProto("spark_array_position", arrayExprProto, elementExprProto) + optExprWithInfo(optExpr, expr, expr.left, expr.right) + } +} + trait ArraysBase { def isTypeSupported(dt: DataType): Boolean = { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql new file mode 100644 index 0000000000..13873d6f4c --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql @@ -0,0 +1,206 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_array_position(int_arr array, str_arr array, val int, str_val string) USING parquet + +statement +INSERT INTO test_array_position VALUES + (array(1, 2, 3, 4), array('a', 'b', 'c'), 2, 'b'), + (array(1, 2, NULL, 3), array('a', NULL, 'c'), 3, 'c'), + (array(10, 20, 30), array('x', 'y', 'z'), 99, 'w'), + (array(), array(), 1, 'a'), + (NULL, NULL, 1, 'a'), + (array(1, 1, 1), array('a', 'a', 'a'), 1, 'a'), + (array(5, 6, 7), array('p', 'q', 'r'), NULL, NULL) + +-- literal args fall back to Spark +query spark_answer_only +SELECT array_position(array(1, 2, 3, 4), 3) + +query spark_answer_only +SELECT array_position(array(1, 2, 3, 4), 5) + +query spark_answer_only +SELECT array_position(array('a', 'b', 'c'), 'b') + +query spark_answer_only +SELECT array_position(array(1, 2, NULL, 3), 3) + +query spark_answer_only +SELECT array_position(array(1, 2, 3), cast(NULL as int)) + +query spark_answer_only +SELECT array_position(cast(NULL as array), 1) + +query spark_answer_only +SELECT array_position(array(), 1) + +query spark_answer_only +SELECT array_position(array(1, 2, 1, 3), 1) + +-- column array + column value (includes NULL val row) +query +SELECT array_position(int_arr, val) FROM test_array_position + +-- column array + literal value +query +SELECT array_position(int_arr, 3) FROM test_array_position + +-- literal array + column value +query +SELECT array_position(array(1, 2, 3), val) FROM test_array_position + +-- string column array + column value (includes NULL str_val row) +query +SELECT array_position(str_arr, str_val) FROM test_array_position + +-- string column array + literal value +query +SELECT array_position(str_arr, 'c') FROM test_array_position + +-- expressions in array construction +query +SELECT array_position(array(val, val + 1, val + 2), val) FROM test_array_position + +-- boolean arrays +statement +CREATE TABLE test_ap_bool(arr array, val boolean) USING parquet + +statement +INSERT INTO test_ap_bool VALUES + (array(true, false, true), false), + (array(true, true), false), + (array(false, false), true), + (NULL, true), + (array(true, false), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_bool + +-- tinyint arrays +statement +CREATE TABLE test_ap_byte(arr array, val tinyint) USING parquet + +statement +INSERT INTO test_ap_byte VALUES + (array(cast(1 as tinyint), cast(2 as tinyint), cast(3 as tinyint)), cast(2 as tinyint)), + (array(cast(-128 as tinyint), cast(127 as tinyint)), cast(127 as tinyint)), + (NULL, cast(1 as tinyint)), + (array(cast(1 as tinyint)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_byte + +-- smallint arrays +statement +CREATE TABLE test_ap_short(arr array, val smallint) USING parquet + +statement +INSERT INTO test_ap_short VALUES + (array(cast(100 as smallint), cast(200 as smallint), cast(300 as smallint)), cast(200 as smallint)), + (NULL, cast(1 as smallint)), + (array(cast(1 as smallint)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_short + +-- bigint arrays +statement +CREATE TABLE test_ap_long(arr array, val bigint) USING parquet + +statement +INSERT INTO test_ap_long VALUES + (array(cast(1000000000000 as bigint), cast(2000000000000 as bigint)), cast(2000000000000 as bigint)), + (array(cast(-1 as bigint), cast(0 as bigint), cast(1 as bigint)), cast(0 as bigint)), + (NULL, cast(1 as bigint)), + (array(cast(1 as bigint)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_long + +-- float arrays +statement +CREATE TABLE test_ap_float(arr array, val float) USING parquet + +statement +INSERT INTO test_ap_float VALUES + (array(cast(1.1 as float), cast(2.2 as float), cast(3.3 as float)), cast(2.2 as float)), + (array(cast(0.0 as float), cast(-1.5 as float)), cast(-1.5 as float)), + (NULL, cast(1.0 as float)), + (array(cast(1.0 as float)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_float + +-- double arrays +statement +CREATE TABLE test_ap_double(arr array, val double) USING parquet + +statement +INSERT INTO test_ap_double VALUES + (array(1.1, 2.2, 3.3), 2.2), + (array(0.0, -1.5), -1.5), + (NULL, 1.0), + (array(1.0), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_double + +-- decimal arrays +statement +CREATE TABLE test_ap_decimal(arr array, val decimal(10,2)) USING parquet + +statement +INSERT INTO test_ap_decimal VALUES + (array(cast(1.10 as decimal(10,2)), cast(2.20 as decimal(10,2)), cast(3.30 as decimal(10,2))), cast(2.20 as decimal(10,2))), + (array(cast(0.00 as decimal(10,2))), cast(0.00 as decimal(10,2))), + (NULL, cast(1.00 as decimal(10,2))), + (array(cast(1.00 as decimal(10,2))), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_decimal + +-- date arrays +statement +CREATE TABLE test_ap_date(arr array, val date) USING parquet + +statement +INSERT INTO test_ap_date VALUES + (array(date '2024-01-01', date '2024-06-15', date '2024-12-31'), date '2024-06-15'), + (array(date '2000-01-01'), date '1999-12-31'), + (NULL, date '2024-01-01'), + (array(date '2024-01-01'), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_date + +-- timestamp arrays +statement +CREATE TABLE test_ap_ts(arr array, val timestamp) USING parquet + +statement +INSERT INTO test_ap_ts VALUES + (array(timestamp '2024-01-01 00:00:00', timestamp '2024-06-15 12:30:00'), timestamp '2024-06-15 12:30:00'), + (array(timestamp '2000-01-01 00:00:00'), timestamp '1999-12-31 23:59:59'), + (NULL, timestamp '2024-01-01 00:00:00'), + (array(timestamp '2024-01-01 00:00:00'), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_ts