diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 53ac3560bc3b..8ab4e8ff4062 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -242,6 +242,24 @@ trait PredicateHelper extends AliasHelper with Logging { None } } + + // If one expression and its children are null intolerant, it is null intolerant. + protected def isNullIntolerant(expr: Expression): Boolean = expr match { + case e: NullIntolerant => e.children.forall(isNullIntolerant) + case _ => false + } + + protected def outputWithNullability( + output: Seq[Attribute], + nonNullAttrExprIds: Seq[ExprId]): Seq[Attribute] = { + output.map { a => + if (a.nullable && nonNullAttrExprIds.contains(a.exprId)) { + a.withNullability(false) + } else { + a + } + } + } } @ExpressionDescription( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 52d0450afb18..cdad9de00620 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,7 +53,8 @@ case class HashAggregateExec( resultExpressions: Seq[NamedExpression], child: SparkPlan) extends BaseAggregateExec - with BlockingOperatorWithCodegen { + with BlockingOperatorWithCodegen + with GeneratePredicateHelper { require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) @@ -131,10 +132,8 @@ case class HashAggregateExec( override def usedInputs: AttributeSet = inputSet override def supportCodegen: Boolean = { - // ImperativeAggregate and filter predicate are not supported right now - // TODO: SPARK-30027 Support codegen for filter exprs in HashAggregateExec - !(aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) || - aggregateExpressions.exists(_.filter.isDefined)) + // ImperativeAggregate are not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -254,7 +253,7 @@ case class HashAggregateExec( aggNames: Seq[String], aggBufferUpdatingExprs: Seq[Seq[Expression]], aggCodeBlocks: Seq[Block], - subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { + subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = { val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil } if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) { // `SimpleExprValue`s cannot be used as an input variable for split functions, so @@ -293,7 +292,7 @@ case class HashAggregateExec( val inputVariables = args.map(_.variableName).mkString(", ") s"$doAggFuncName($inputVariables);" } - Some(splitCodes.mkString("\n").trim) + Some(splitCodes) } else { val errMsg = "Failed to split aggregate code into small functions because the parameter " + "length of at least one split function went over the JVM limit: " + @@ -308,6 +307,39 @@ case class HashAggregateExec( } } + private def generateEvalCodeForAggFuncs( + ctx: CodegenContext, + input: Seq[ExprCode], + inputAttrs: Seq[Attribute], + boundUpdateExprs: Seq[Seq[Expression]], + aggNames: Seq[String], + aggCodeBlocks: Seq[Block], + subExprs: SubExprCodes): String = { + val aggCodes = if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { + val maybeSplitCodes = splitAggregateExpressions( + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) + + maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code)) + } else { + aggCodeBlocks.map(_.code) + } + + aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map { + case (aggCode, (Partial | Complete, Some(condition))) => + // Note: wrap in "do { } while(false);", so the generated checks can jump out + // with "continue;" + s""" + |do { + | ${generatePredicateCode(ctx, condition, inputAttrs, input)} + | $aggCode + |} while(false); + """.stripMargin + case (aggCode, _) => + aggCode + }.mkString("\n") + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -354,24 +386,14 @@ case class HashAggregateExec( """.stripMargin } - val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc && - aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { - val maybeSplitCode = splitAggregateExpressions( - ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - - maybeSplitCode.getOrElse { - aggCodeBlocks.fold(EmptyBlock)(_ + _).code - } - } else { - aggCodeBlocks.fold(EmptyBlock)(_ + _).code - } - + val codeToEvalAggFuncs = generateEvalCodeForAggFuncs( + ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) s""" |// do aggregate |// common sub-expressions |$effectiveCodes |// evaluate aggregate functions and update aggregation buffers - |$codeToEvalAggFunc + |$codeToEvalAggFuncs """.stripMargin } @@ -908,7 +930,7 @@ case class HashAggregateExec( } } - val inputAttr = aggregateBufferAttributes ++ inputAttributes + val inputAttrs = aggregateBufferAttributes ++ inputAttributes // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while // generating input columns, we use `currentVars`. @@ -930,7 +952,7 @@ case class HashAggregateExec( val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => - bindReferences(updateExprsForOneFunc, inputAttr) + bindReferences(updateExprsForOneFunc, inputAttrs) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") @@ -961,23 +983,13 @@ case class HashAggregateExec( """.stripMargin } - val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc && - aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { - val maybeSplitCode = splitAggregateExpressions( - ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - - maybeSplitCode.getOrElse { - aggCodeBlocks.fold(EmptyBlock)(_ + _).code - } - } else { - aggCodeBlocks.fold(EmptyBlock)(_ + _).code - } - + val codeToEvalAggFuncs = generateEvalCodeForAggFuncs( + ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) s""" |// common sub-expressions |$effectiveCodes |// evaluate aggregate functions and update aggregation buffers - |$codeToEvalAggFunc + |$codeToEvalAggFuncs """.stripMargin } @@ -986,7 +998,7 @@ case class HashAggregateExec( if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => - bindReferences(updateExprsForOneFunc, inputAttr) + bindReferences(updateExprsForOneFunc, inputAttrs) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") @@ -1016,18 +1028,8 @@ case class HashAggregateExec( """.stripMargin } - - val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc && - aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { - val maybeSplitCode = splitAggregateExpressions( - ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - - maybeSplitCode.getOrElse { - aggCodeBlocks.fold(EmptyBlock)(_ + _).code - } - } else { - aggCodeBlocks.fold(EmptyBlock)(_ + _).code - } + val codeToEvalAggFuncs = generateEvalCodeForAggFuncs( + ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) // If vectorized fast hash map is on, we first generate code to update row // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. @@ -1037,7 +1039,7 @@ case class HashAggregateExec( | // common sub-expressions | $effectiveCodes | // evaluate aggregate functions and update aggregation buffers - | $codeToEvalAggFunc + | $codeToEvalAggFuncs |} else { | $updateRowInRegularHashMap |} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d74d0bf733c2..abd336006848 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -109,59 +109,39 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } } -/** Physical plan for Filter. */ -case class FilterExec(condition: Expression, child: SparkPlan) - extends UnaryExecNode with CodegenSupport with PredicateHelper { - - // Split out all the IsNotNulls from condition. - private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { - case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet) - case _ => false - } - - // If one expression and its children are null intolerant, it is null intolerant. - private def isNullIntolerant(expr: Expression): Boolean = expr match { - case e: NullIntolerant => e.children.forall(isNullIntolerant) - case _ => false - } - - // The columns that will filtered out by `IsNotNull` could be considered as not nullable. - private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) - - // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate - // all the variables at the beginning to take advantage of short circuiting. - override def usedInputs: AttributeSet = AttributeSet.empty - - override def output: Seq[Attribute] = { - child.output.map { a => - if (a.nullable && notNullAttributes.contains(a.exprId)) { - a.withNullability(false) - } else { - a - } +trait GeneratePredicateHelper extends PredicateHelper { + self: CodegenSupport => + + protected def generatePredicateCode( + ctx: CodegenContext, + condition: Expression, + inputAttrs: Seq[Attribute], + inputExprCode: Seq[ExprCode]): String = { + val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { + case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(AttributeSet(inputAttrs)) + case _ => false } - } - - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].inputRDDs() - } - - protected override def doProduce(ctx: CodegenContext): String = { - child.asInstanceOf[CodegenSupport].produce(ctx, this) - } - - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val numOutput = metricTerm(ctx, "numOutputRows") - + val nonNullAttrExprIds = notNullPreds.flatMap(_.references).distinct.map(_.exprId) + val outputAttrs = outputWithNullability(inputAttrs, nonNullAttrExprIds) + generatePredicateCode( + ctx, inputAttrs, inputExprCode, outputAttrs, notNullPreds, otherPreds, + nonNullAttrExprIds) + } + + protected def generatePredicateCode( + ctx: CodegenContext, + inputAttrs: Seq[Attribute], + inputExprCode: Seq[ExprCode], + outputAttrs: Seq[Attribute], + notNullPreds: Seq[Expression], + otherPreds: Seq[Expression], + nonNullAttrExprIds: Seq[ExprId]): String = { /** * Generates code for `c`, using `in` for input attributes and `attrs` for nullability. */ def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = { val bound = BindReferences.bindReference(c, attrs) - val evaluated = evaluateRequiredVariables(child.output, in, c.references) + val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references) // Generate the code for the predicate. val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx) @@ -195,10 +175,10 @@ case class FilterExec(condition: Expression, child: SparkPlan) if (idx != -1 && !generatedIsNotNullChecks(idx)) { generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. - genPredicate(notNullPreds(idx), input, child.output) - } else if (notNullAttributes.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) { + genPredicate(notNullPreds(idx), inputExprCode, inputAttrs) + } else if (nonNullAttrExprIds.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) { extraIsNotNullAttrs += r - genPredicate(IsNotNull(r), input, child.output) + genPredicate(IsNotNull(r), inputExprCode, inputAttrs) } else { "" } @@ -208,18 +188,61 @@ case class FilterExec(condition: Expression, child: SparkPlan) // enforced them with the IsNotNull checks above. s""" |$nullChecks - |${genPredicate(c, input, output)} + |${genPredicate(c, inputExprCode, outputAttrs)} """.stripMargin.trim }.mkString("\n") val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) => if (!generatedIsNotNullChecks(idx)) { - genPredicate(c, input, child.output) + genPredicate(c, inputExprCode, inputAttrs) } else { "" } }.mkString("\n") + s""" + |$generated + |$nullChecks + """.stripMargin + } +} + +/** Physical plan for Filter. */ +case class FilterExec(condition: Expression, child: SparkPlan) + extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper { + + // Split out all the IsNotNulls from condition. + private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { + case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet) + case _ => false + } + + // The columns that will filtered out by `IsNotNull` could be considered as not nullable. + private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) + + // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate + // all the variables at the beginning to take advantage of short circuiting. + override def usedInputs: AttributeSet = AttributeSet.empty + + override def output: Seq[Attribute] = outputWithNullability(child.output, notNullAttributes) + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + + val predicateCode = generatePredicateCode( + ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes) + // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => @@ -232,8 +255,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;" s""" |do { - | $generated - | $nullChecks + | $predicateCode | $numOutput.add(1); | ${consume(ctx, resultVars)} |} while(false); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index e4193d845f2e..c1ccb654ee08 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -1,4 +1,7 @@ --- Test filter clause for aggregate expression. +-- Test filter clause for aggregate expression with codegen on and off. +--CONFIG_DIM1 spark.sql.codegen.wholeStage=true +--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY +--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN --CONFIG_DIM1 spark.sql.optimizeNullAwareAntiJoin=true --CONFIG_DIM1 spark.sql.optimizeNullAwareAntiJoin=false diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 886b98e538d2..a4c92382750e 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -878,7 +878,7 @@ struct == Physical Plan == * HashAggregate (5) +- Exchange (4) - +- HashAggregate (3) + +- * HashAggregate (3) +- * ColumnarToRow (2) +- Scan parquet default.explain_temp1 (1) @@ -892,7 +892,7 @@ ReadSchema: struct (2) ColumnarToRow [codegen id : 1] Input [2]: [key#x, val#x] -(3) HashAggregate +(3) HashAggregate [codegen id : 1] Input [2]: [key#x, val#x] Keys: [] Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]