Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ package object dsl {
def getField(fieldName: String): UnresolvedExtractValue =
UnresolvedExtractValue(expr, Literal(fieldName))

def cast(to: DataType): Expression = Cast(expr, to)
def cast(to: DataType): Expression = {
if (expr.resolved && expr.dataType.sameType(to)) {
expr
} else {
Cast(expr, to)
}
}

def asc: SortOrder = SortOrder(expr, Ascending)
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,48 @@ object CodeGenerator extends Logging {
}
}

/**
* Extracts all the input variables from references and subexpression elimination states
* for a given `expr`. This result will be used to split the generated code of
* expressions into multiple functions.
*/
def getLocalInputVariableValues(
ctx: CodegenContext,
expr: Expression,
subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = {
val argSet = mutable.Set[VariableValue]()
if (ctx.INPUT_ROW != null) {
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
}

// Collects local variables from a given `expr` tree
val collectLocalVariable = (ev: ExprValue) => ev match {
case vv: VariableValue => argSet += vv
case _ =>
}

val stack = mutable.Stack[Expression](expr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: so the stack here is only used for implementing depth-first pre-order traversal of an expression tree in an iterative style instead of recursive style, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, right. I just followed the other similar logics.

while (stack.nonEmpty) {
stack.pop() match {
case e if subExprs.contains(e) =>
val SubExprEliminationState(isNull, value) = subExprs(e)
collectLocalVariable(value)
collectLocalVariable(isNull)

case ref: BoundReference if ctx.currentVars != null &&
ctx.currentVars(ref.ordinal) != null =>
val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal)
collectLocalVariable(value)
collectLocalVariable(isNull)

case e =>
stack.pushAll(e.children)
}
}

argSet.toSet
}

/**
* Returns the name used in accessor and setter for a Java primitive type.
*/
Expand Down Expand Up @@ -1719,6 +1761,15 @@ object CodeGenerator extends Logging {
1 + params.map(paramLengthForExpr).sum
}

def calculateParamLengthFromExprValues(params: Seq[ExprValue]): Int = {
def paramLengthForExpr(input: ExprValue): Int = input.javaType match {
case java.lang.Long.TYPE | java.lang.Double.TYPE => 2
case _ => 1
}
// Initial value is 1 for `this`.
1 + params.map(paramLengthForExpr).sum
}

/**
* In Java, a method descriptor is valid only if it represents method parameters with a total
* length less than a pre-defined constant.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ trait Block extends TreeNode[Block] with JavaCode {
case _ => code.trim
}

def length: Int = toString.length
def length: Int = {
// Returns a code length without comments
CodeFormatter.stripExtraNewLinesAndComments(toString).length
}

def isEmpty: Boolean = toString.isEmpty

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val value = eval.isNull match {
case TrueLiteral => FalseLiteral
case FalseLiteral => TrueLiteral
case v => JavaCode.isNullExpression(s"!$v")
val (value, newCode) = eval.isNull match {
case TrueLiteral => (FalseLiteral, EmptyBlock)
case FalseLiteral => (TrueLiteral, EmptyBlock)
case v =>
val value = ctx.freshName("value")
(JavaCode.variable(value, BooleanType), code"boolean $value = !$v;")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@maropu maropu Aug 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple query to reproduce this issue is as follows (without both changes);

sql("""
  SELECT
    sum(CASE WHEN c0 IS NULL AND c1 IS NOT NULL THEN 1 ELSE 0 END) a,
    sum(CASE WHEN c0 IS NOT NULL AND c1 IS NOT NULL THEN 1 ELSE 0 END) b 
  FROM
    VALUES ((null, null)) t(c0, c1)
""").show

}
ExprCode(code = eval.code, isNull = FalseLiteral, value = value)
ExprCode(code = eval.code + newCode, isNull = FalseLiteral, value = value)
}

override def sql: String = s"(${child.sql} IS NOT NULL)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val CODEGEN_SPLIT_AGGREGATE_FUNC =
buildConf("spark.sql.codegen.aggregate.splitAggregateFunc.enabled")
.internal()
.doc("When true, the code generator would split aggregate code into individual methods " +
"instead of a single big method. This can be used to avoid oversized function that " +
"can miss the opportunity of JIT optimization.")
.booleanConf
.createWithDefault(true)

val MAX_NESTED_VIEW_DEPTH =
buildConf("spark.sql.view.maxNestedViewDepth")
.internal()
Expand Down Expand Up @@ -2310,6 +2319,8 @@ class SQLConf extends Serializable with Logging {
def cartesianProductExecBufferSpillThreshold: Int =
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)

def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC)

def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)

def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION)
Expand Down
Loading