From 8cfefef16417a0eef4a4069eb4049aeab3bee785 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 25 Nov 2025 21:05:21 +0000 Subject: [PATCH] Add codegen support to array-based higher order functions --- .../expressions/EquivalentExpressions.scala | 4 + .../expressions/codegen/CodeGenerator.scala | 35 ++ .../expressions/higherOrderFunctions.scala | 408 +++++++++++++++++- .../sql/errors/QueryExecutionErrors.scala | 9 + .../HigherOrderFunctionsSuite.scala | 22 + .../HigherOrderFunctionsBenchmark.scala | 73 ++++ 6 files changed, 543 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 78f73f8778b8..43d29ab27e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -146,9 +146,13 @@ class EquivalentExpressions( // There are some special expressions that we should not recurse into all of its children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) // 2. ConditionalExpression: use its children that will always be evaluated. + // 3. HigherOrderFunction: lambda functions operate in the context of local lambdas and can't + // be called outside of that scope, only the arguments can be evaluated ahead of + // time. private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { case _: CodegenFallback => Nil case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) + case h: HigherOrderFunction => h.arguments case other => skipForShortcut(other).children } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 13b1d329f7ec..d2568a5903c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -174,6 +174,41 @@ class CodegenContext extends Logging { */ var currentVars: Seq[ExprCode] = null + /** + * Holding a map of current lambda variables. + */ + var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty + + def withLambdaVars( + namedLambdas: Seq[NamedLambdaVariable], + f: Seq[ExprCode] => ExprCode): ExprCode = { + val lambdaVars = namedLambdas.map { lambda => + val id = lambda.exprId.id + if (currentLambdaVars.get(id).nonEmpty) { + throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(id) + } + val isNull = if (lambda.nullable) { + JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull")) + } else { + FalseLiteral + } + val value = addMutableState(javaType(lambda.dataType), "lambdaValue") + val lambdaVar = ExprCode(isNull, JavaCode.global(value, lambda.dataType)) + currentLambdaVars.put(id, lambdaVar) + lambdaVar + } + + val result = f(lambdaVars) + namedLambdas.map(_.exprId.id).foreach(currentLambdaVars.remove) + result + } + + def getLambdaVar(id: Long): ExprCode = { + currentLambdaVars.getOrElse( + id, + throw QueryExecutionErrors.lambdaVariableNotDefinedError(id)) + } + /** * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a * 2-tuple: java type, variable name. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2a5a38e93706..9222b585914e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -81,8 +82,7 @@ case class NamedLambdaVariable( exprId: ExprId = NamedExpression.newExprId, value: AtomicReference[Any] = new AtomicReference()) extends LeafExpression - with NamedExpression - with CodegenFallback { + with NamedExpression { override def qualifier: Seq[String] = Seq.empty @@ -103,6 +103,10 @@ case class NamedLambdaVariable( override def simpleString(maxFields: Int): String = { s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}" } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.getLambdaVar(exprId.id) + } } /** @@ -114,7 +118,7 @@ case class LambdaFunction( function: Expression, arguments: Seq[NamedExpression], hidden: Boolean = false) - extends Expression with CodegenFallback { + extends Expression { override def children: Seq[Expression] = function +: arguments override def dataType: DataType = function.dataType @@ -132,6 +136,10 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + function.genCode(ctx) + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): LambdaFunction = copy( @@ -239,6 +247,63 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { val canonicalizedChildren = cleaned.children.map(_.canonicalized) withNewChildren(canonicalizedChildren) } + + + protected def assignAtomic( + atomicRef: String, + value: String, + isNull: String = FalseLiteral, + nullable: Boolean = false) = { + if (nullable) { + s""" + if ($isNull) { + $atomicRef.set(null); + } else { + $atomicRef.set($value); + } + """ + } else { + s"$atomicRef.set($value);" + } + } + + protected def assignArrayElement( + ctx: CodegenContext, + arrayName: String, + elementCode: ExprCode, + elementVar: NamedLambdaVariable, + index: String): String = { + val elementType = elementVar.dataType + val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) + val extractElement = CodeGenerator.getValue(arrayName, elementType, index) + val atomicAssign = assignAtomic(elementAtomic, elementCode.value, + elementCode.isNull, elementVar.nullable) + + if (elementVar.nullable) { + s""" + ${elementCode.value} = $extractElement; + ${elementCode.isNull} = $arrayName.isNullAt($index); + $atomicAssign + """ + } else { + s""" + ${elementCode.value} = $extractElement; + $atomicAssign + """ + } + } + + protected def assignIndex( + ctx: CodegenContext, + indexCode: ExprCode, + indexVar: NamedLambdaVariable, + index: String): String = { + val indexAtomic = ctx.addReferenceObj(indexVar.name, indexVar.value) + s""" + ${indexCode.value} = $index; + ${assignAtomic(indexAtomic, indexCode.value)} + """ + } } /** @@ -284,6 +349,29 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr } } + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: String => String): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultCode = f(argumentGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode) + ev.copy(code = code""" + |${argumentGen.code} + |boolean ${ev.isNull} = ${argumentGen.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """) + } else { + ev.copy(code = code""" + |${argumentGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """, isNull = FalseLiteral) + } + } } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { @@ -312,7 +400,7 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { case class ArrayTransform( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) @@ -354,6 +442,49 @@ case class ArrayTransform( result } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, varCodes => { + val elementCode = varCodes.head + val indexCode = varCodes.tail.headOption + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + val i = ctx.freshName("i") + + val initialization = CodeGenerator.createArrayData( + arrayData, dataType.elementType, numElements, s" $prettyName failed.") + + val functionCode = function.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) + val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") + + // Some expressions return internal buffers that we have to copy + val copy = if (CodeGenerator.isPrimitiveType(function.dataType)) { + s"${functionCode.value}" + } else { + s"InternalRow.copyValue(${functionCode.value})" + } + val resultNull = if (function.nullable) Some(functionCode.isNull.toString) else None + val resultAssignment = CodeGenerator.setArrayElement(arrayData, dataType.elementType, + i, copy, isNull = resultNull) + + s""" + |final int $numElements = $arg.numElements(); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | $varAssignments + | ${functionCode.code} + | $resultAssignment + |} + |${ev.value} = $arrayData; + """.stripMargin + }) + }) + } + override def nodeName: String = "transform" override protected def withNewChildrenInternal( @@ -581,7 +712,7 @@ case class MapFilter( case class ArrayFilter( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: DataType = argument.dataType @@ -622,6 +753,72 @@ case class ArrayFilter( new GenericArrayData(buffer) } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, varCodes => { + val elementCode = varCodes.head + val indexCode = varCodes.tail.headOption + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val count = ctx.freshName("count") + val arrayTracker = ctx.freshName("arrayTracker") + val arrayData = ctx.freshName("arrayData") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + + val arrayType = dataType.asInstanceOf[ArrayType] + + val trackerInit = CodeGenerator.createArrayData( + arrayTracker, BooleanType, numElements, s" $prettyName failed.") + val resultInit = CodeGenerator.createArrayData( + arrayData, arrayType.elementType, count, s" $prettyName failed.") + + val functionCode = function.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) + val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") + + val resultAssignment = CodeGenerator.setArrayElement(arrayTracker, BooleanType, + i, functionCode.value, isNull = None) + + val getTrackerValue = CodeGenerator.getValue(arrayTracker, BooleanType, i) + val copy = CodeGenerator.createArrayAssignment(arrayData, arrayType.elementType, arg, + j, i, arrayType.containsNull) + + // This takes a two passes to avoid evaluating the predicate multiple times + // The first pass evaluates each element in the array, tracks how many elements + // returned true, and tracks the result of each element in a boolean array `arrayTracker`. + // The second pass copies elements from the original array to the new array created + // based on the number of elements matching the first pass. + + s""" + |final int $numElements = $arg.numElements(); + |$trackerInit + |int $count = 0; + |for (int $i = 0; $i < $numElements; $i++) { + | $varAssignments + | ${functionCode.code} + | $resultAssignment + | if ((boolean)${functionCode.value}) { + | $count++; + | } + |} + | + |$resultInit + |int $j = 0; + |for (int $i = 0; $i < $numElements; $i++) { + | if ($getTrackerValue) { + | $copy + | $j++; + | } + |} + |${ev.value} = $arrayData; + """.stripMargin + }) + }) + } + override def nodeName: String = "filter" override protected def withNewChildrenInternal( @@ -653,7 +850,7 @@ case class ArrayExists( argument: Expression, function: Expression, followThreeValuedLogic: Boolean) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback with Predicate { + extends ArrayBasedSimpleHigherOrderFunction with Predicate { def this(argument: Expression, function: Expression) = { this( @@ -706,6 +903,50 @@ case class ArrayExists( } } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar), { case Seq(elementCode) => + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val exists = ctx.freshName("exists") + val foundNull = ctx.freshName("foundNull") + val i = ctx.freshName("i") + + val functionCode = function.genCode(ctx) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val threeWayLogic = if (followThreeValuedLogic) TrueLiteral else FalseLiteral + + val nullCheck = if (nullable) { + s""" + if ($threeWayLogic && !$exists && $foundNull) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + + s""" + |final int $numElements = ${arg}.numElements(); + |boolean $exists = false; + |boolean $foundNull = false; + |int $i = 0; + |while ($i < $numElements && !$exists) { + | $elementAssignment + | ${functionCode.code} + | if (${functionCode.isNull}) { + | $foundNull = true; + | } else if (${functionCode.value}) { + | $exists = true; + | } + | $i++; + |} + |$nullCheck + |${ev.value} = $exists; + """.stripMargin + }) + }) + } + override def nodeName: String = "exists" override protected def withNewChildrenInternal( @@ -740,7 +981,7 @@ object ArrayExists { case class ArrayForAll( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback with Predicate { + extends ArrayBasedSimpleHigherOrderFunction with Predicate { override def nullable: Boolean = super.nullable || function.nullable @@ -785,6 +1026,49 @@ case class ArrayForAll( } } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar), { case Seq(elementCode) => + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val forall = ctx.freshName("forall") + val foundNull = ctx.freshName("foundNull") + val i = ctx.freshName("i") + + val functionCode = function.genCode(ctx) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + + val nullCheck = if (nullable) { + s""" + if ($forall && $foundNull) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + + s""" + |final int $numElements = ${arg}.numElements(); + |boolean $forall = true; + |boolean $foundNull = false; + |int $i = 0; + |while ($i < $numElements && $forall) { + | $elementAssignment + | ${functionCode.code} + | if (${functionCode.isNull}) { + | $foundNull = true; + | } else if (!${functionCode.value}) { + | $forall = false; + | } + | $i++; + |} + |$nullCheck + |${ev.value} = $forall; + """.stripMargin + }) + }) + } + override def nodeName: String = "forall" override protected def withNewChildrenInternal( @@ -816,7 +1100,7 @@ case class ArrayAggregate( zero: Expression, merge: Expression, finish: Expression) - extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] { + extends HigherOrderFunction with QuaternaryLike[Expression] { def this(argument: Expression, zero: Expression, merge: Expression) = { this(argument, zero, merge, LambdaFunction.identity) @@ -886,6 +1170,114 @@ case class ArrayAggregate( } } + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: String => String): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultCode = f(argumentGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode) + ev.copy(code = code""" + |${argumentGen.code} + |boolean ${ev.isNull} = ${argumentGen.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """) + } else { + ev.copy(code = code""" + |${argumentGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """, isNull = FalseLiteral) + } + } + + protected def assignVar( + varCode: ExprCode, + atomicVar: String, + value: String, + isNull: String, + nullable: Boolean): String = { + val atomicAssign = assignAtomic(atomicVar, value, isNull, nullable) + if (nullable) { + s""" + ${varCode.value} = $value; + ${varCode.isNull} = $isNull; + $atomicAssign + """ + } else { + s""" + ${varCode.value} = $value; + $atomicAssign + """ + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar), varCodes => { + val Seq(elementCode, accForMergeCode, accForFinishCode) = varCodes + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val i = ctx.freshName("i") + + val zeroCode = zero.genCode(ctx) + val mergeCode = merge.genCode(ctx) + val finishCode = finish.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val mergeAtomic = ctx.addReferenceObj(accForMergeVar.name, + accForMergeVar.value) + val finishAtomic = ctx.addReferenceObj(accForFinishVar.name, + accForFinishVar.value) + + val mergeJavaType = CodeGenerator.javaType(accForMergeVar.dataType) + val finishJavaType = CodeGenerator.javaType(accForFinishVar.dataType) + + // Some expressions return internal buffers that we have to copy + val mergeCopy = if (CodeGenerator.isPrimitiveType(merge.dataType)) { + s"${mergeCode.value}" + } else { + s"($mergeJavaType)InternalRow.copyValue(${mergeCode.value})" + } + + val nullCheck = if (nullable) { + s"${ev.isNull} = ${finishCode.isNull};" + } else { + "" + } + + val initialAssignment = assignVar(accForMergeCode, mergeAtomic, zeroCode.value, + zeroCode.isNull, zero.nullable) + + val mergeAssignment = assignVar(accForMergeCode, mergeAtomic, mergeCopy, + mergeCode.isNull, merge.nullable) + + val finishAssignment = assignVar(accForFinishCode, finishAtomic, accForMergeCode.value, + accForMergeCode.isNull, merge.nullable) + + s""" + |final int $numElements = ${arg}.numElements(); + |${zeroCode.code} + |$initialAssignment + | + |for (int $i = 0; $i < $numElements; $i++) { + | $elementAssignment + | ${mergeCode.code} + | $mergeAssignment + |} + | + |$finishAssignment + |${finishCode.code} + |${ev.value} = ${finishCode.value}; + |$nullCheck + """.stripMargin + }) + }) + } + override def nodeName: String = "aggregate" override def first: Expression = argument diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 989ad8b0dc41..fe88d964da7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -447,6 +447,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE s"failed to match ${toSQLId(funcName)} at `addNewFunction`.") } + def lambdaVariableAlreadyDefinedError(id: Long): Throwable = { + new IllegalArgumentException(s"Lambda variable $id cannot be redefined") + } + + def lambdaVariableNotDefinedError(id: Long): Throwable = { + new IllegalArgumentException( + s"Lambda variable $id is not defined in the current codegen scope") + } + def cannotGenerateCodeForIncomparableTypeError( codeType: String, dataType: DataType): Throwable = { SparkException.internalError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index cc36cd73d6d7..bc608b7afecf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -149,6 +151,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val plusOne: Expression => Expression = x => x + 1 val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i + val plusOneFallback: Expression => Expression = x => CodegenFallbackExpr(x + 1) checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4)) checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5)) @@ -158,6 +161,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transform(transform(ai1, plusIndex), plusOne), Seq(2, null, 6)) checkEvaluation(transform(ain, plusOne), null) + checkEvaluation(transform(ai0, plusOneFallback), Seq(2, 3, 4)) + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) @@ -277,6 +282,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isEven: Expression => Expression = x => x % 2 === 0 val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 } + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(filter(ai0, isEven), Seq(2)) checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) @@ -286,6 +292,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(filter(ain, isEven), null) checkEvaluation(filter(ain, isNullOrOdd), null) + checkEvaluation(filter(ai0, isEvenFallback), Seq(2)) + val as0 = Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) @@ -321,6 +329,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) for (followThreeValuedLogic <- Seq(false, true)) { withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key @@ -337,6 +346,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(exists(ain, isNullOrOdd), null) checkEvaluation(exists(ain, alwaysFalse), null) checkEvaluation(exists(ain, alwaysNull), null) + checkEvaluation(exists(ai0, isEvenFallback), true) } } @@ -383,6 +393,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(forall(ai0, isEven), true) checkEvaluation(forall(ai0, isNullOrOdd), false) @@ -401,6 +412,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(forall(ain, alwaysFalse), null) checkEvaluation(forall(ain, alwaysNull), null) + checkEvaluation(forall(ai0, isEvenFallback), true) + val as0 = Literal.create(Seq("a0", "a1", "a2", "a3"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true)) @@ -886,3 +899,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ))) } } + +case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback { + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + override lazy val resolved = child.resolved + override def eval(input: InternalRow): Any = child.eval(input) + override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpr = + copy(child = newChild) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala new file mode 100644 index 000000000000..960607625559 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Synthetic benchmark for higher order functions. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/Test/runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain " + * Results will be written to "benchmarks/HigherOrderFunctionsBenchmark-results.txt". + * }}} + */ +object HigherOrderFunctionsBenchmark extends SqlBasedBenchmark { + private val N = 100_000_00 + private val M = 10 + + private val df = spark.range(N).select(array(col("id"), col("id"), col("id")).alias("arr")) + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Higher order functions") { + def benchFunction(name: String, col: Column) = { + var benchmark = new Benchmark(name, N, output = output) + benchmark.addCase("codegen", M) { _ => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> + CodegenObjectFactoryMode.CODEGEN_ONLY.toString()) { + df.select(col).noop() + } + } + + benchmark.addCase("interpreted", M) { _ => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> + CodegenObjectFactoryMode.NO_CODEGEN.toString()) { + df.select(col).noop() + } + } + benchmark.run() + } + + benchFunction("transform", transform(col("arr"), x => x + 1)) + benchFunction("filter", filter(col("arr"), x => x > 1)) + benchFunction("forall - fast", forall(col("arr"), x => x < 0)) + benchFunction("forall - slow", forall(col("arr"), x => x >= 0)) + benchFunction("exists - fast", exists(col("arr"), x => x >= 0)) + benchFunction("exists - slow", exists(col("arr"), x => x < 0)) + benchFunction("aggregate", aggregate(col("arr"), lit(0L), (acc, x) => acc + x)) + } + } +}