From 30f37d0444b361e5f91251f069b87a34c3433432 Mon Sep 17 00:00:00 2001 From: Henry D Date: Tue, 3 Sep 2019 11:39:55 -0700 Subject: [PATCH 1/3] [SQL][JSPARK-28962] Provide index argument to filter lambda functions --- .../expressions/higherOrderFunctions.scala | 22 +++++++++++++++++-- .../HigherOrderFunctionsSuite.scala | 11 ++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ed26bb375de2..e204347129ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -344,6 +344,8 @@ case class MapFilter( Examples: > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); [1,3] + > SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i); + [2, 3] """, since = "2.4.0") case class ArrayFilter( @@ -357,10 +359,23 @@ case class ArrayFilter( override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { val ArrayType(elementType, containsNull) = argument.dataType - copy(function = f(function, (elementType, containsNull) :: Nil)) + function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) + case _ => + copy(function = f(function, (elementType, containsNull) :: Nil)) + } } - @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + @transient lazy val (elementVar, indexVar) = { + val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function + val indexVar = if (tail.nonEmpty) { + Some(tail.head.asInstanceOf[NamedLambdaVariable]) + } else { + None + } + (elementVar, indexVar) + } override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val arr = argumentValue.asInstanceOf[ArrayData] @@ -369,6 +384,9 @@ case class ArrayFilter( var i = 0 while (i < arr.numElements) { elementVar.value.set(arr.get(i, elementVar.dataType)) + if (indexVar.isDefined) { + indexVar.get.value.set(i) + } if (f.eval(inputRow).asInstanceOf[Boolean]) { buffer += elementVar.value.get } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index b83d03025d21..4cdee447fa45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -89,6 +89,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) } + def filter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) + } + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { val MapType(kt, vt, vcn) = expr.dataType TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) @@ -218,9 +223,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isEven: Expression => Expression = x => x % 2 === 0 val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 } checkEvaluation(filter(ai0, isEven), Seq(2)) checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) + checkEvaluation(filter(ai0, indexIsEven), Seq(1, 3)) checkEvaluation(filter(ai1, isEven), Seq.empty) checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3)) checkEvaluation(filter(ain, isEven), null) @@ -234,13 +241,17 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val startsWithA: Expression => Expression = x => x.startsWith("a") checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2")) + checkEvaluation(filter(as0, indexIsEven), Seq("a0", "a2")) checkEvaluation(filter(as1, startsWithA), Seq("a")) + checkEvaluation(filter(as1, indexIsEven), Seq("a", "c")) checkEvaluation(filter(asn, startsWithA), null) val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)), Seq(Seq(1, 3), null, Seq(5))) + checkEvaluation(transform(aai, ix => filter(ix, indexIsEven)), + Seq(Seq(1, 3), null, Seq(4))) } test("ArrayExists") { From 740293528edbae412a8accf1d2e8ddd79a49766d Mon Sep 17 00:00:00 2001 From: Henry D Date: Fri, 13 Sep 2019 12:46:58 -0700 Subject: [PATCH 2/3] fix test --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7d044638db57..6392d292f43f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2199,9 +2199,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ).toDF("s", "i") val ex1 = intercept[AnalysisException] { - df.selectExpr("filter(s, (x, y) -> x + y)") + df.selectExpr("filter(s, (x, y, z) -> x + y)") } - assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) val ex2 = intercept[AnalysisException] { df.selectExpr("filter(i, x -> x)") From 93048670b427bd8c1b54f4be3459738c90806f0b Mon Sep 17 00:00:00 2001 From: Henry D Date: Sun, 29 Sep 2019 21:44:23 -0700 Subject: [PATCH 3/3] add test --- .../expressions/higherOrderFunctions.scala | 13 +++++----- .../spark/sql/DataFrameFunctionsSuite.scala | 24 +++++++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index e204347129ce..5781268148e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -345,9 +345,12 @@ case class MapFilter( > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); [1,3] > SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i); - [2, 3] + [2,3] """, - since = "2.4.0") + since = "2.4.0", + note = """ + The inner function may use the index argument since 3.0.0. + """) case class ArrayFilter( argument: Expression, function: Expression) @@ -369,11 +372,7 @@ case class ArrayFilter( @transient lazy val (elementVar, indexVar) = { val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function - val indexVar = if (tail.nonEmpty) { - Some(tail.head.asInstanceOf[NamedLambdaVariable]) - } else { - None - } + val indexVar = tail.headOption.map(_.asInstanceOf[NamedLambdaVariable]) (elementVar, indexVar) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6392d292f43f..2e59ac273eff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2190,6 +2190,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { testNonPrimitiveType() } + test("filter function - index argument") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testIndexArgument(): Unit = { + checkAnswer(df.selectExpr("filter(s, (x, i) -> i % 2 == 0)"), + Seq( + Row(Seq("c", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testIndexArgument() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testIndexArgument() + } + test("filter function - invalid") { val df = Seq( (Seq("c", "a", "b"), 1),