diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index ef04e8825811..dddf81965216 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -103,16 +103,17 @@ class EquivalentExpressions { // There are some special expressions that we should not recurse into all of its children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. If: common subexpressions will always be evaluated at the beginning, but the true and + // 2. LambdaFunction: it's children operate in the context of local lambdas and can't be split + // 3. If: common subexpressions will always be evaluated at the beginning, but the true and // false expressions in `If` may not get accessed, according to the predicate // expression. We should only recurse into the predicate expression. - // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain + // 4. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain // condition. We should only recurse into the first condition expression as it // will always get accessed. - // 4. Coalesce: it's also a conditional expression, we should only recurse into the first + // 5. Coalesce: it's also a conditional expression, we should only recurse into the first // children, because others may not get accessed. private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { - case _: CodegenFallback => Nil + case _: CodegenFallback | _: LambdaFunction => Nil case i: If => i.predicate :: Nil case c: CaseWhen => c.children.head :: Nil case c: Coalesce => c.children.head :: Nil @@ -122,7 +123,7 @@ class EquivalentExpressions { // For some special expressions we cannot just recurse into all of its children, but we can // recursively add the common expressions shared between all of its children. private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { - case _: CodegenFallback => Nil + case _: CodegenFallback | _: LambdaFunction => Nil case i: If => Seq(Seq(i.trueValue, i.falseValue)) case c: CaseWhen => // We look at subexpressions in conditions and values of `CaseWhen` separately. It is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index bbcd3b49572d..6aff69a1de19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -22,9 +22,11 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.mutable +import org.apache.spark.sql.catalyst.CatalystTypeConverters.isPrimitive import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ @@ -76,8 +78,7 @@ case class NamedLambdaVariable( exprId: ExprId = NamedExpression.newExprId, value: AtomicReference[Any] = new AtomicReference()) extends LeafExpression - with NamedExpression - with CodegenFallback { + with NamedExpression { override def qualifier: Seq[String] = Seq.empty @@ -98,6 +99,31 @@ case class NamedLambdaVariable( override def simpleString(maxFields: Int): String = { s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}" } + + // We need to include the Expr ID in the Codegen variable name since several tests bypass + // `UnresolvedNamedLambdaVariable.freshVarName` + lazy val variableName = s"${name}_${exprId.id}" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val atomicRef = ctx.addReferenceObj(variableName, value) + val tmpAtomic = ctx.freshName("tmpAtomic") + val boxedType = CodeGenerator.boxedType(dataType) + + if (nullable) { + ev.copy(code = code""" + Object $tmpAtomic = $atomicRef.get(); + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + boolean ${ev.isNull} = $tmpAtomic == null; + if (!${ev.isNull}) { + ${ev.value} = ($boxedType)$tmpAtomic; + } + """) + } else { + ev.copy(code = code""" + ${CodeGenerator.javaType(dataType)} ${ev.value} = ($boxedType)$atomicRef.get(); + """, isNull = FalseLiteral) + } + } } /** @@ -109,7 +135,7 @@ case class LambdaFunction( function: Expression, arguments: Seq[NamedExpression], hidden: Boolean = false) - extends Expression with CodegenFallback { + extends Expression { override def children: Seq[Expression] = function +: arguments override def dataType: DataType = function.dataType @@ -127,6 +153,23 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val functionCode = function.genCode(ctx) + + if (nullable) { + ev.copy(code = code""" + |${functionCode.code} + |boolean ${ev.isNull} = ${functionCode.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; + """.stripMargin) + } else { + ev.copy(code = code""" + |${functionCode.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; + """.stripMargin, isNull = FalseLiteral) + } + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): LambdaFunction = copy( @@ -224,6 +267,21 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { val canonicalizedChildren = cleaned.children.map(_.canonicalized) Canonicalize.execute(withNewChildren(canonicalizedChildren)) } + + protected def assignAtomic(atomicRef: String, value: String, isNull: String = FalseLiteral, + nullable: Boolean = false) = { + if (nullable) { + s""" + if ($isNull) { + $atomicRef.set(null); + } else { + $atomicRef.set($value); + } + """ + } else { + s"$atomicRef.set($value);" + } + } } /** @@ -269,10 +327,49 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr } } + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: String => String): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultCode = f(argumentGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode) + ev.copy(code = code""" + |${argumentGen.code} + |boolean ${ev.isNull} = ${argumentGen.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """) + } else { + ev.copy(code = code""" + |${argumentGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """, isNull = FalseLiteral) + } + } } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { override def argumentType: AbstractDataType = ArrayType + + protected def assignElement(ctx: CodegenContext, arrayName: String, + elementVar: NamedLambdaVariable, index: String): String = { + val elementType = elementVar.dataType + val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value) + val extractElement = CodeGenerator.getValue(arrayName, elementType, index) + + assignAtomic(elementAtomic, extractElement, s"$arrayName.isNullAt($index)", + elementVar.nullable) + } + + protected def assignIndex(ctx: CodegenContext, indexVar: NamedLambdaVariable, + index: String): String = { + val indexAtomic = ctx.addReferenceObj(indexVar.variableName, indexVar.value) + assignAtomic(indexAtomic, index) + } } trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { @@ -297,7 +394,7 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { case class ArrayTransform( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) @@ -338,6 +435,43 @@ case class ArrayTransform( result } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + val i = ctx.freshName("i") + + val initialization = CodeGenerator.createArrayData( + arrayData, dataType.elementType, numElements, s" $prettyName failed.") + + val functionCode = function.genCode(ctx) + + val elementAssignment = assignElement(ctx, arg, elementVar, i) + val indexAssignment = indexVar.map(c => assignIndex(ctx, c, i)) + val varAssignments = (Seq(elementAssignment) ++: indexAssignment).mkString("\n") + + // Some expressions return internal buffers that we have to copy + val copy = if (isPrimitive(function.dataType)) { + s"${functionCode.value}" + } else { + s"InternalRow.copyValue(${functionCode.value})" + } + val resultAssignment = CodeGenerator.setArrayElement(arrayData, dataType.elementType, + i, copy, isNull = Some(functionCode.isNull)) + + s""" + |final int $numElements = ${arg}.numElements(); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | $varAssignments + | ${functionCode.code} + | $resultAssignment + |} + |${ev.value} = $arrayData; + """.stripMargin + }) + } + override def prettyName: String = "transform" override protected def withNewChildrenInternal( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 3fe367be5545..ab039ae6e02b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -251,6 +251,15 @@ object QueryExecutionErrors { new IllegalArgumentException(s"$funcName is not matched at addNewFunction") } + def lambdaVariableAlreadyDefinedError(name: String): Throwable = { + new IllegalArgumentException(s"Lambda variable $name cannot be redefined") + } + + def lambdaVariableNotDefinedError(name: String): Throwable = { + new IllegalArgumentException( + s"Lambda variable $name is not defined in the current codegen scope") + } + def cannotGenerateCodeForUncomparableTypeError( codeType: String, dataType: DataType): Throwable = { new IllegalArgumentException(