diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6e2a5aa4f97c..22c52dacdffd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType import org.apache.spark.util.random.PoissonSampler @@ -79,16 +79,20 @@ case class Filter(condition: Expression, child: SparkPlan) // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { - case IsNotNull(a) if child.output.contains(a) => true + case IsNotNull(a) if child.output.exists(_.semanticEquals(a)) => true case _ => false } // The columns that will filtered out by `IsNotNull` could be considered as not nullable. private val notNullAttributes = notNullPreds.flatMap(_.references) + // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate + // all the variables at the beginning to take advantage of short circuiting. + override def usedInputs: AttributeSet = AttributeSet.empty + override def output: Seq[Attribute] = { child.output.map { a => - if (a.nullable && notNullAttributes.contains(a)) { + if (a.nullable && notNullAttributes.exists(_.semanticEquals(a))) { a.withNullability(false) } else { a @@ -110,39 +114,80 @@ case class Filter(condition: Expression, child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { val numOutput = metricTerm(ctx, "numOutputRows") - // filter out the nulls - val filterOutNull = notNullAttributes.map { a => - val idx = child.output.indexOf(a) - s"if (${input(idx).isNull}) continue;" - }.mkString("\n") + /** + * Generates code for `c`, using `in` for input attributes and `attrs` for nullability. + */ + def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = { + val bound = BindReferences.bindReference(c, attrs) + val evaluated = evaluateRequiredVariables(child.output, in, c.references) - ctx.currentVars = input - val predicates = otherPreds.map { e => - val bound = ExpressionCanonicalizer.execute( - BindReferences.bindReference(e, output)) - val ev = bound.gen(ctx) + // Generate the code for the predicate. + val ev = ExpressionCanonicalizer.execute(bound).gen(ctx) val nullCheck = if (bound.nullable) { s"${ev.isNull} || " } else { s"" } + s""" + |$evaluated |${ev.code} |if (${nullCheck}!${ev.value}) continue; """.stripMargin + } + + ctx.currentVars = input + + // To generate the predicates we will follow this algorithm. + // For each predicate that is not IsNotNull, we will generate them one by one loading attributes + // as necessary. For each of both attributes, if there is a IsNotNull predicate we will generate + // that check *before* the predicate. After all of these predicates, we will generate the + // remaining IsNotNull checks that were not part of other predicates. + // This has the property of not doing redundant IsNotNull checks and taking better advantage of + // short-circuiting, not loading attributes until they are needed. + // This is very perf sensitive. + // TODO: revisit this. We can consider reodering predicates as well. + val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) + val generated = otherPreds.map { c => + val nullChecks = c.references.map { r => + val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} + if (idx != -1 && !generatedIsNotNullChecks(idx)) { + generatedIsNotNullChecks(idx) = true + // Use the child's output. The nullability is what the child produced. + genPredicate(notNullPreds(idx), input, child.output) + } else { + "" + } + }.mkString("\n").trim + + // Here we use *this* operator's output with this output's nullability since we already + // enforced them with the IsNotNull checks above. + s""" + |$nullChecks + |${genPredicate(c, input, output)} + """.stripMargin.trim + }.mkString("\n") + + val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) => + if (!generatedIsNotNullChecks(idx)) { + genPredicate(c, input, child.output) + } else { + "" + } }.mkString("\n") // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => - if (notNullAttributes.contains(child.output(i))) { + if (notNullAttributes.exists(_.semanticEquals(child.output(i)))) { ev.isNull = "false" } ev } + s""" - |$filterOutNull - |$predicates + |$generated + |$nullChecks |$numOutput.add(1); |${consume(ctx, resultVars)} """.stripMargin