Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.language.existentials

import com.google.common.cache.{CacheBuilder, CacheLoader}
Expand Down Expand Up @@ -265,6 +266,45 @@ class CodeGenContext {
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)

def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))

/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM
*
* @param row the variable name of row that is used by expressions
*/
def splitExpressions(row: String, expressions: Seq[String]): String = {
val blocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (code <- expressions) {
// We can't know how many byte code will be generated, so use the number of bytes as limit
if (blockBuilder.length > 64 * 1000) {
blocks.append(blockBuilder.toString())
blockBuilder.clear()
}
blockBuilder.append(code)
}
blocks.append(blockBuilder.toString())

if (blocks.length == 1) {
// inline execution if only one block
blocks.head
} else {
val apply = freshName("apply")
val functions = blocks.zipWithIndex.map { case (body, i) =>
val name = s"${apply}_$i"
val code = s"""
|private void $name(InternalRow $row) {
| $body
|}
""".stripMargin
addNewFunction(name, code)
name
}

functions.map(name => s"$name($row);").mkString("\n")
}
}
}

/**
Expand All @@ -289,15 +329,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def declareMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
}.mkString
}.mkString("\n")
}

protected def initMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map(_._3).mkString
ctx.mutableStates.map(_._3).mkString("\n")
}

protected def declareAddedFunctions(ctx: CodeGenContext): String = {
ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString
ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}

/**
Expand Down Expand Up @@ -328,6 +368,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(getClass.getClassLoader)
// Cannot be under package codegen, or fail with java.lang.InstantiationException
evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass")
evaluator.setDefaultImports(Array(
classOf[PlatformDependent].getName,
classOf[InternalRow].getName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
val projectionCode = expressions.zipWithIndex.map {
val projectionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
Expand All @@ -65,49 +65,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
"""
}
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (projection <- projectionCode) {
if (blockBuilder.length > 16 * 1000) {
projectionBlocks.append(blockBuilder.toString())
blockBuilder.clear()
}
blockBuilder.append(projection)
}
projectionBlocks.append(blockBuilder.toString())

val (projectionFuns, projectionCalls) = {
// inline execution if codesize limit was not broken
if (projectionBlocks.length == 1) {
("", projectionBlocks.head)
} else {
(
projectionBlocks.zipWithIndex.map { case (body, i) =>
s"""
|private void apply$i(InternalRow i) {
| $body
|}
""".stripMargin
}.mkString,
projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
)
}
}
val allProjections = ctx.splitExpressions("i", projectionCodes)

val code = s"""
public Object generate($exprType[] expr) {
return new SpecificProjection(expr);
return new SpecificMutableProjection(expr);
}

class SpecificProjection extends ${classOf[BaseMutableProjection].getName} {
class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} {

private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}

public SpecificProjection($exprType[] expr) {
public SpecificMutableProjection($exprType[] expr) {
expressions = expr;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
Expand All @@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
return (InternalRow) mutableRow;
}

$projectionFuns

public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
$projectionCalls

$allProjections
return mutableRow;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.types._
Expand All @@ -43,6 +41,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
// These expressions could be splitted into multiple functions
ctx.addMutableState("Object[]", values, s"this.$values = null;")

val rowClass = classOf[GenericInternalRow].getName

val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
Expand All @@ -53,12 +54,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
$values[$i] = ${converter.primitive};
}
"""
}.mkString("\n")

}
val allFields = ctx.splitExpressions(tmp, fieldWriters)
val code = s"""
final InternalRow $tmp = $input;
final Object[] $values = new Object[${schema.length}];
$fieldWriters
this.$values = new Object[${schema.length}];
$allFields
final InternalRow $output = new $rowClass($values);
"""

Expand Down Expand Up @@ -128,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]

protected def create(expressions: Seq[Expression]): Projection = {
val ctx = newCodeGenContext()
val projectionCode = expressions.zipWithIndex.map {
val expressionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
Expand All @@ -143,36 +144,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (projection <- projectionCode) {
if (blockBuilder.length > 16 * 1000) {
projectionBlocks.append(blockBuilder.toString())
blockBuilder.clear()
}
blockBuilder.append(projection)
}
projectionBlocks.append(blockBuilder.toString())

val (projectionFuns, projectionCalls) = {
// inline it if we have only one block
if (projectionBlocks.length == 1) {
("", projectionBlocks.head)
} else {
(
projectionBlocks.zipWithIndex.map { case (body, i) =>
s"""
|private void apply$i(InternalRow i) {
| $body
|}
""".stripMargin
}.mkString,
projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
)
}
}

val allExpressions = ctx.splitExpressions("i", expressionCodes)
val code = s"""
public Object generate($exprType[] expr) {
return new SpecificSafeProjection(expr);
Expand All @@ -183,19 +155,17 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}

public SpecificSafeProjection($exprType[] expr) {
expressions = expr;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
}

$projectionFuns

public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
$projectionCalls

$allExpressions
return mutableRow;
}
}
Expand Down
Loading