diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ed26bb375de2..5781268148e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -344,8 +344,13 @@ case class MapFilter( Examples: > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); [1,3] + > SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i); + [2,3] """, - since = "2.4.0") + since = "2.4.0", + note = """ + The inner function may use the index argument since 3.0.0. + """) case class ArrayFilter( argument: Expression, function: Expression) @@ -357,10 +362,19 @@ case class ArrayFilter( override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { val ArrayType(elementType, containsNull) = argument.dataType - copy(function = f(function, (elementType, containsNull) :: Nil)) + function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) + case _ => + copy(function = f(function, (elementType, containsNull) :: Nil)) + } } - @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + @transient lazy val (elementVar, indexVar) = { + val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function + val indexVar = tail.headOption.map(_.asInstanceOf[NamedLambdaVariable]) + (elementVar, indexVar) + } override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val arr = argumentValue.asInstanceOf[ArrayData] @@ -369,6 +383,9 @@ case class ArrayFilter( var i = 0 while (i < arr.numElements) { elementVar.value.set(arr.get(i, elementVar.dataType)) + if (indexVar.isDefined) { + indexVar.get.value.set(i) + } if (f.eval(inputRow).asInstanceOf[Boolean]) { buffer += elementVar.value.get } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index b83d03025d21..4cdee447fa45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -89,6 +89,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) } + def filter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) + } + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { val MapType(kt, vt, vcn) = expr.dataType TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) @@ -218,9 +223,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isEven: Expression => Expression = x => x % 2 === 0 val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 } checkEvaluation(filter(ai0, isEven), Seq(2)) checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) + checkEvaluation(filter(ai0, indexIsEven), Seq(1, 3)) checkEvaluation(filter(ai1, isEven), Seq.empty) checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3)) checkEvaluation(filter(ain, isEven), null) @@ -234,13 +241,17 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val startsWithA: Expression => Expression = x => x.startsWith("a") checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2")) + checkEvaluation(filter(as0, indexIsEven), Seq("a0", "a2")) checkEvaluation(filter(as1, startsWithA), Seq("a")) + checkEvaluation(filter(as1, indexIsEven), Seq("a", "c")) checkEvaluation(filter(asn, startsWithA), null) val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)), Seq(Seq(1, 3), null, Seq(5))) + checkEvaluation(transform(aai, ix => filter(ix, indexIsEven)), + Seq(Seq(1, 3), null, Seq(4))) } test("ArrayExists") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7d044638db57..2e59ac273eff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2190,6 +2190,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { testNonPrimitiveType() } + test("filter function - index argument") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testIndexArgument(): Unit = { + checkAnswer(df.selectExpr("filter(s, (x, i) -> i % 2 == 0)"), + Seq( + Row(Seq("c", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testIndexArgument() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testIndexArgument() + } + test("filter function - invalid") { val df = Seq( (Seq("c", "a", "b"), 1), @@ -2199,9 +2223,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ).toDF("s", "i") val ex1 = intercept[AnalysisException] { - df.selectExpr("filter(s, (x, y) -> x + y)") + df.selectExpr("filter(s, (x, y, z) -> x + y)") } - assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) val ex2 = intercept[AnalysisException] { df.selectExpr("filter(i, x -> x)")