From 26bf37960a1534da8e2119181dce5794e3b48172 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 3 Aug 2018 15:25:08 +0900 Subject: [PATCH] Add `ArrayAggregate`. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 118 +++++++++++++++-- .../HigherOrderFunctionsSuite.scala | 50 ++++++++ .../inputs/higher-order-functions.sql | 12 ++ .../results/higher-order-functions.sql.out | 40 +++++- .../spark/sql/DataFrameFunctionsSuite.scala | 121 ++++++++++++++++++ 6 files changed, 333 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f7517486e541..ac99fd7a304e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -441,6 +441,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), + expression[ArrayAggregate]("aggregate"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index c5c3482afa13..d44c12d6c715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.concurrent.atomic.AtomicReference import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} @@ -74,6 +75,13 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) } +object LambdaFunction { + val identity: LambdaFunction = { + val id = UnresolvedAttribute.quoted("id") + LambdaFunction(id, Seq(id)) + } +} + /** * A higher order function takes one or more (lambda) functions and applies these to some objects. * The function produces a number of variables which can be consumed by some lambda function. @@ -140,6 +148,18 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu @transient lazy val functionForEval: Expression = functionsForEval.head } +object ArrayBasedHigherOrderFunction { + + def elementArgumentType(dt: DataType): (DataType, Boolean) = { + dt match { + case ArrayType(elementType, containsNull) => (elementType, containsNull) + case _ => + val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType + (elementType, containsNull) + } + } +} + /** * Transform elements in an array using the transform function. This is similar to * a `map` in functional programming. @@ -164,17 +184,12 @@ case class ArrayTransform( override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val (elementType, containsNull) = input.dataType match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) - } + val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => - copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) + copy(function = f(function, elem :: (IntegerType, false) :: Nil)) case _ => - copy(function = f(function, (elementType, containsNull) :: Nil)) + copy(function = f(function, elem :: Nil)) } } @@ -210,3 +225,90 @@ case class ArrayTransform( override def prettyName: String = "transform" } + +/** + * Applies a binary operator to a start value and all elements in the array. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr, start, merge, finish) - Applies a binary operator to an initial state and all + elements in the array, and reduces this to a single state. The final state is converted + into the final result by applying a finish function. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x); + 6 + > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10); + 60 + """, + since = "2.4.0") +case class ArrayAggregate( + input: Expression, + zero: Expression, + merge: Expression, + finish: Expression) + extends HigherOrderFunction with CodegenFallback { + + def this(input: Expression, zero: Expression, merge: Expression) = { + this(input, zero, merge, LambdaFunction.identity) + } + + override def inputs: Seq[Expression] = input :: zero :: Nil + + override def functions: Seq[Expression] = merge :: finish :: Nil + + override def nullable: Boolean = input.nullable || finish.nullable + + override def dataType: DataType = finish.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (!ArrayType.acceptsType(input.dataType)) { + TypeCheckResult.TypeCheckFailure( + s"argument 1 requires ${ArrayType.simpleString} type, " + + s"however, '${input.sql}' is of ${input.dataType.catalogString} type.") + } else if (!DataType.equalsStructurally( + zero.dataType, merge.dataType, ignoreNullability = true)) { + TypeCheckResult.TypeCheckFailure( + s"argument 3 requires ${zero.dataType.simpleString} type, " + + s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { + // Be very conservative with nullable. We cannot be sure that the accumulator does not + // evaluate to null. So we always set nullable to true here. + val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val acc = zero.dataType -> true + val newMerge = f(merge, acc :: elem :: Nil) + val newFinish = f(finish, acc :: Nil) + copy(merge = newMerge, finish = newFinish) + } + + @transient lazy val LambdaFunction(_, + Seq(accForMergeVar: NamedLambdaVariable, elementVar: NamedLambdaVariable), _) = merge + @transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish + + override def eval(input: InternalRow): Any = { + val arr = this.input.eval(input).asInstanceOf[ArrayData] + if (arr == null) { + null + } else { + val Seq(mergeForEval, finishForEval) = functionsForEval + accForMergeVar.value.set(zero.eval(input)) + var i = 0 + while (i < arr.numElements()) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + accForMergeVar.value.set(mergeForEval.eval(input)) + i += 1 + } + accForFinishVar.value.set(accForMergeVar.value.get) + finishForEval.eval(input) + } + } + + override def prettyName: String = "aggregate" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index e987ea5b8a4d..cd21892b05e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -54,6 +54,27 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayTransform(expr, createLambda(at.elementType, at.containsNull, IntegerType, false, f)) } + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression, + finish: Expression => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + val zeroType = zero.dataType + ArrayAggregate( + expr, + zero, + createLambda(zeroType, true, at.elementType, at.containsNull, merge), + createLambda(zeroType, true, finish)) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression): Expression = { + aggregate(expr, zero, merge, identity) + } + test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) @@ -94,4 +115,33 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)), Seq("[1, 3, 5]", null, "[4, 6]")) } + + test("ArrayAggregate") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai2 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + checkEvaluation(aggregate(ai0, 0, (acc, elem) => acc + elem, acc => acc * 10), 60) + checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc => acc * 10), 40) + checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10), 0) + checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val as2 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + checkEvaluation(aggregate(as0, "", (acc, elem) => Concat(Seq(acc, elem))), "abc") + checkEvaluation(aggregate(as1, "", (acc, elem) => Concat(Seq(acc, coalesce(elem, "x")))), "axc") + checkEvaluation(aggregate(as2, "", (acc, elem) => Concat(Seq(acc, elem))), "") + checkEvaluation(aggregate(asn, "", (acc, elem) => Concat(Seq(acc, elem))), null) + + val aai = Literal.create(Seq[Seq[Integer]](Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation( + aggregate(aai, 0, + (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), + 15) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 8e928a41f08e..5dfa1749f8a2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -24,3 +24,15 @@ select transform(ys, 0) as v from nested; -- Transform a null array select transform(cast(null as array), x -> x + 1) as v; + +-- Aggregate. +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested; + +-- Aggregate average. +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested; + +-- Aggregate nested arrays +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested; + +-- Aggregate a null array +select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index ca2c3c35333c..47f9a0d940ea 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 12 -- !query 0 @@ -79,3 +79,41 @@ select transform(cast(null as array), x -> x + 1) as v struct> -- !query 7 output NULL + + +-- !query 8 +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested +-- !query 8 schema +struct +-- !query 8 output +131 +15 +5 + + +-- !query 9 +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested +-- !query 9 schema +struct +-- !query 9 output +0.5 +12.0 +64.5 + + +-- !query 10 +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested +-- !query 10 schema +struct> +-- !query 10 output +[1010880,8] +[17] +[4752,20664,1] + + +-- !query 11 +select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v +-- !query 11 schema +struct +-- !query 11 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 923482024b03..d56e7f11b5c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1800,6 +1800,127 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) } + test("aggregate function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("aggregate function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10)"), + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("aggregate function - array for non-primitive type") { + val df = Seq( + (Seq("c", "a", "b"), "a"), + (Seq("b", null, "c", null), "b"), + (Seq.empty, "c"), + (null, "d") + ).toDF("ss", "s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x), acc -> coalesce(acc , ''))"), + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("aggregate function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', x -> x)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', (acc, x) -> x, (acc, x) -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("aggregate(i, 0, (acc, x) -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, 0, (acc, x) -> x)") + } + assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {