Skip to content
Open
148 changes: 148 additions & 0 deletions native/spark-expr/src/array_funcs/array_position.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use std::any::Any;
use std::sync::Arc;

/// Spark array_position() function that returns the 1-based position of an element in an array.
/// Returns 0 if the element is not found (Spark behavior differs from DataFusion which returns null).
pub fn spark_array_position(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
if args.len() != 2 {
return exec_err!("array_position function takes exactly two arguments");
}

// Convert all arguments to arrays for consistent processing
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let is_scalar = len.is_none();
let arrays = ColumnarValue::values_to_arrays(args)?;

let result = array_position_inner(&arrays)?;

if is_scalar {
let scalar = ScalarValue::try_from_array(&result, 0)?;
Ok(ColumnarValue::Scalar(scalar))
} else {
Ok(ColumnarValue::Array(result))
}
}

fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
let array = &args[0];
let element = &args[1];

match array.data_type() {
DataType::List(_) => generic_array_position::<i32>(array, element),
DataType::LargeList(_) => generic_array_position::<i64>(array, element),
other => exec_err!("array_position does not support type '{other:?}'"),
}
}

fn generic_array_position<O: OffsetSizeTrait>(
array: &ArrayRef,
element: &ArrayRef,
) -> Result<ArrayRef, DataFusionError> {
let list_array = array
.as_any()
.downcast_ref::<GenericListArray<O>>()
.unwrap();

let mut data = Vec::with_capacity(list_array.len());

for row_index in 0..list_array.len() {
if list_array.is_null(row_index) {
// Null array returns null position (same as Spark)
data.push(None);
} else if element.is_null(row_index) {
// Searching for null element returns null in Spark
data.push(None);
} else {
let list_array_row = list_array.value(row_index);

// Get the search element as a scalar
let element_scalar = ScalarValue::try_from_array(element, row_index)?;

// Compare element to each item in the list
let mut position: i64 = 0;
for i in 0..list_array_row.len() {
let list_item_scalar = ScalarValue::try_from_array(&list_array_row, i)?;

// null != anything in Spark array_position
if !list_item_scalar.is_null() && element_scalar == list_item_scalar {
position = (i + 1) as i64; // 1-indexed
break;
}
}

data.push(Some(position));
}
}

Ok(Arc::new(Int64Array::from(data)))
}

#[derive(Debug, Hash, Eq, PartialEq)]
pub struct SparkArrayPositionFunc {
signature: Signature,
}

impl Default for SparkArrayPositionFunc {
fn default() -> Self {
Self::new()
}
}

impl SparkArrayPositionFunc {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for SparkArrayPositionFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"spark_array_position"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> {
Ok(DataType::Int64)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
spark_array_position(&args.args)
}
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/array_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
// under the License.

mod array_insert;
mod array_position;
mod get_array_struct_fields;
mod list_extract;
mod size;

pub use array_insert::ArrayInsert;
pub use array_position::{spark_array_position, SparkArrayPositionFunc};
pub use get_array_struct_fields::GetArrayStructFields;
pub use list_extract::ListExtract;
pub use size::{spark_size, SparkSizeFunc};
5 changes: 3 additions & 2 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan,
spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff,
SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace,
spark_unscaled_value, EvalMode, SparkArrayPositionFunc, SparkBitwiseCount, SparkContains,
SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -191,6 +191,7 @@ pub fn create_comet_physical_fun_with_eval_mode(

fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
vec![
Arc::new(ScalarUDF::new_from_impl(SparkArrayPositionFunc::default())),
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
Arc::new(ScalarUDF::new_from_impl(SparkContains::default())),
Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[ArrayJoin] -> CometArrayJoin,
classOf[ArrayMax] -> CometArrayMax,
classOf[ArrayMin] -> CometArrayMin,
classOf[ArrayPosition] -> CometArrayPosition,
classOf[ArrayRemove] -> CometArrayRemove,
classOf[ArrayRepeat] -> CometArrayRepeat,
classOf[ArraysOverlap] -> CometArraysOverlap,
Expand Down
32 changes: 31 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size}
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -657,6 +657,36 @@ object CometSize extends CometExpressionSerde[Size] {

}

object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase {

override def convert(
expr: ArrayPosition,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
if (expr.children.forall(_.foldable)) {
withInfo(expr, "all arguments are literals, falling back to Spark")
return None
}
// Check if input types are supported
val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet
for (dt <- inputTypes) {
if (!isTypeSupported(dt)) {
withInfo(expr, s"data type not supported: $dt")
return None
}
}

val arrayExprProto = exprToProto(expr.left, inputs, binding)
val elementExprProto = exprToProto(expr.right, inputs, binding)

// Use spark_array_position which returns Int64 and 0 when not found
// (matching Spark's behavior)
val optExpr =
scalarFunctionExprToProto("spark_array_position", arrayExprProto, elementExprProto)
optExprWithInfo(optExpr, expr, expr.left, expr.right)
}
}

trait ArraysBase {

def isTypeSupported(dt: DataType): Boolean = {
Expand Down
Loading
Loading