Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
package org.apache.spark.sql.catalyst.catalog

import java.net.URI
import java.util.Locale

import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.Shell

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, Predicate}

object ExternalCatalogUtils {
// This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't
Expand Down Expand Up @@ -148,7 +147,7 @@ object ExternalCatalogUtils {
}

val boundPredicate =
InterpretedPredicate.create(predicates.reduce(And).transform {
Predicate.createInterpreted(predicates.reduce(And).transform {
case att: AttributeReference =>
val index = partitionSchema.indexWhere(_.name == att.name)
BoundReference(index, partitionSchema(index).dataType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ object MutableProjection
}

/**
* Returns an MutableProjection for given sequence of bound Expressions.
* Returns a MutableProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): MutableProjection = {
createObject(exprs)
}

/**
* Returns an MutableProjection for given sequence of Expressions, which will be bound to
* Returns a MutableProjection for given sequence of Expressions, which will be bound to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not touch this file in this PR~

Copy link
Member Author

@maropu maropu Nov 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You meant a separate PR for this typo?

* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._

/**
* Interface for generated predicate
*/
abstract class Predicate {
def eval(r: InternalRow): Boolean

/**
* Initializes internal states given the current partition index.
* This is used by nondeterministic expressions to set initial states.
* The default implementation does nothing.
*/
def initialize(partitionIndex: Int): Unit = {}
}

/**
* Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]].
*/
object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] {

protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)

protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
BindReferences.bindReference(in, inputSchema)

protected def create(predicate: Expression): Predicate = {
protected def create(predicate: Expression): BasePredicate = {
val ctx = newCodeGenContext()
val eval = predicate.genCode(ctx)

Expand All @@ -53,7 +39,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
return new SpecificPredicate(references);
}

class SpecificPredicate extends ${classOf[Predicate].getName} {
class SpecificPredicate extends ${classOf[BasePredicate].getName} {
private final Object[] references;
${ctx.declareMutableStates()}

Expand All @@ -79,6 +65,6 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")

val (clazz, _) = CodeGenerator.compile(code)
clazz.generate(ctx.references.toArray).asInstanceOf[Predicate]
clazz.generate(ctx.references.toArray).asInstanceOf[BasePredicate]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,28 @@ import scala.collection.immutable.TreeSet

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


object InterpretedPredicate {
def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate =
create(BindReferences.bindReference(expression, inputSchema))
/**
* A base class for generated/interpreted predicate
*/
abstract class BasePredicate {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems reasonable because we renamed Predicate => BasePredicate before.

def eval(r: InternalRow): Boolean

def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression)
/**
* Initializes internal states given the current partition index.
* This is used by nondeterministic expressions to set initial states.
* The default implementation does nothing.
*/
def initialize(partitionIndex: Int): Unit = {}
}

case class InterpretedPredicate(expression: Expression) extends BasePredicate {
Expand All @@ -56,6 +64,35 @@ trait Predicate extends Expression {
override def dataType: DataType = BooleanType
}

/**
* The factory object for `BasePredicate`.
*/
object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePredicate] {

override protected def createCodeGeneratedObject(in: Expression): BasePredicate = {
GeneratePredicate.generate(in)
}

override protected def createInterpretedObject(in: Expression): BasePredicate = {
InterpretedPredicate(in)
}

def createInterpreted(e: Expression): InterpretedPredicate = InterpretedPredicate(e)

/**
* Returns a BasePredicate for an Expression, which will be bound to `inputSchema`.
*/
def create(e: Expression, inputSchema: Seq[Attribute]): BasePredicate = {
createObject(bindReference(e, inputSchema))
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add a method def createInterpreted... for places that want to use interpreted predicates.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea. ok.

/**
* Returns a BasePredicate for a given bound Expression.
*/
def create(e: Expression): BasePredicate = {
createObject(e)
}
}

trait PredicateHelper {
protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {

case Filter(condition, LocalRelation(output, data, isStreaming))
if !hasUnevaluableExpr(condition) =>
val predicate = InterpretedPredicate.create(condition, output)
val predicate = Predicate.create(condition, output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to optimize local relation so perf doesn't matter too much. The change should be fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

predicate.initialize(0)
LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("Interpreted Predicate should initialize nondeterministic expressions") {
val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0)))
val interpreted = Predicate.create(LessThan(Rand(7), Literal(1.0)))
interpreted.initialize(0)
assert(interpreted.eval(new UnsafeRow()))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ case class FileSourceScanExec(
// call the file index for the files matching all filters except dynamic partition filters
val predicate = dynamicPartitionFilters.reduce(And)
val partitionColumns = relation.partitionSchema
val boundPredicate = newPredicate(predicate.transform {
val boundPredicate = Predicate.create(predicate.transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext

import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.InternalCompilerException
Expand All @@ -33,7 +32,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -471,28 +470,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
MutableProjection.create(expressions, inputSchema)
}

private def genInterpretedPredicate(
expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = {
val str = expression.toString
val logMessage = if (str.length > 256) {
str.substring(0, 256 - 3) + "..."
} else {
str
}
logWarning(s"Codegen disabled for this expression:\n $logMessage")
InterpretedPredicate.create(expression, inputSchema)
Copy link
Member

@dongjoon-hyun dongjoon-hyun Nov 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't follow the context in the previous PR. Is this genInterpretedPredicate function unused?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, CodeGeneratorWithInterpretedFallback#createInterpretedObject does the same thing with genInterpretedPredicate :

}

protected def newPredicate(
expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = {
try {
GeneratePredicate.generate(expression, inputSchema)
} catch {
case _ @ (_: InternalCompilerException | _: CompileException) if codeGenFallBack =>
genInterpretedPredicate(expression, inputSchema)
}
}

protected def newOrdering(
order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = {
GenerateOrdering.generate(order, inputSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{LongType, StructType}
Expand Down Expand Up @@ -227,7 +226,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
val predicate = newPredicate(condition, child.output)
val predicate = Predicate.create(condition, child.output)
predicate.initialize(0)
iter.filter { row =>
val r = predicate.eval(row)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ case class InMemoryTableScanExec(
val buffers = relation.cacheBuilder.cachedColumnBuffers

buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
val partitionFilter = newPredicate(
val partitionFilter = Predicate.create(
partitionFilters.reduceOption(And).getOrElse(Literal(true)),
schema)
partitionFilter.initialize(index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ abstract class PartitioningAwareFileIndex(
if (partitionPruningPredicates.nonEmpty) {
val predicate = partitionPruningPredicates.reduce(expressions.And)

val boundPredicate = InterpretedPredicate.create(predicate.transform {
val boundPredicate = Predicate.createInterpreted(predicate.transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.collection.{BitSet, CompactBuffer}

case class BroadcastNestedLoopJoinExec(
Expand Down Expand Up @@ -84,7 +82,7 @@ case class BroadcastNestedLoopJoinExec(

@transient private lazy val boundCondition = {
if (condition.isDefined) {
newPredicate(condition.get, streamed.output ++ broadcast.output).eval _
Predicate.create(condition.get, streamed.output ++ broadcast.output).eval _
} else {
(r: InternalRow) => true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Predicate, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils, ExternalAppendOnlyUnsafeRowArray, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.CompletionIterator
Expand Down Expand Up @@ -93,7 +92,7 @@ case class CartesianProductExec(
pair.mapPartitionsWithIndexInternal { (index, iter) =>
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
val filtered = if (condition.isDefined) {
val boundCondition = newPredicate(condition.get, left.output ++ right.output)
val boundCondition = Predicate.create(condition.get, left.output ++ right.output)
boundCondition.initialize(index)
val joined = new JoinedRow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ trait HashJoin {
UnsafeProjection.create(streamedKeys)

@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _
Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _
} else {
(r: InternalRow) => true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ case class SortMergeJoinExec(
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
newPredicate(cond, left.output ++ right.output).eval _
Predicate.create(cond, left.output ++ right.output).eval _
}.getOrElse {
(r: InternalRow) => true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -233,8 +233,9 @@ case class StreamingSymmetricHashJoinExec(
val joinedRow = new JoinedRow


val inputSchema = left.output ++ right.output
val postJoinFilter =
newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _
Predicate.create(condition.bothSides.getOrElse(Literal(true)), inputSchema).eval _
val leftSideJoiner = new OneSideHashJoiner(
LeftSide, left.output, leftKeys, leftInputIter,
condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left)
Expand Down Expand Up @@ -417,7 +418,7 @@ case class StreamingSymmetricHashJoinExec(

// Filter the joined rows based on the given condition.
val preJoinFilter =
newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _
Predicate.create(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _

private val joinStateManager = new SymmetricHashJoinStateManager(
joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value,
Expand All @@ -428,16 +429,16 @@ case class StreamingSymmetricHashJoinExec(
case Some(JoinStateKeyWatermarkPredicate(expr)) =>
// inputSchema can be empty as expr should only have BoundReferences and does not require
// the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]].
newPredicate(expr, Seq.empty).eval _
Predicate.create(expr, Seq.empty).eval _
case _ =>
newPredicate(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
}

private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match {
case Some(JoinStateValueWatermarkPredicate(expr)) =>
newPredicate(expr, inputAttributes).eval _
Predicate.create(expr, inputAttributes).eval _
case _ =>
newPredicate(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
}

private[this] var updatedStateRowsCount = 0
Expand All @@ -457,7 +458,7 @@ case class StreamingSymmetricHashJoinExec(
val nonLateRows =
WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match {
case Some(watermarkExpr) =>
val predicate = newPredicate(watermarkExpr, inputAttributes)
val predicate = Predicate.create(watermarkExpr, inputAttributes)
inputIter.filter { row => !predicate.eval(row) }
case None =>
inputIter
Expand Down
Loading