Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] {
val value = ctx.freshName("value")
val ve = ExprCode("", isNull, value)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
if (ve.code != "") {
// Add `this` in the comment.
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
} else {
ve
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,33 @@ class CodegenContext {
/** The variable name of the input row in generated code. */
final var INPUT_ROW = "i"

private val curId = new java.util.concurrent.atomic.AtomicInteger()
/**
* The map from a variable name to it's next ID.
*/
private val freshNameIds = new mutable.HashMap[String, Int]
freshNameIds += INPUT_ROW -> 1

/**
* A prefix used to generate fresh name.
*/
var freshNamePrefix = ""

/**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
* Returns a term name that is unique within this instance of a `CodegenContext`.
*/
def freshName(name: String): String = {
if (freshNamePrefix == "") {
s"$name${curId.getAndIncrement}"
def freshName(name: String): String = synchronized {
val fullName = if (freshNamePrefix == "") {
name
} else {
s"${freshNamePrefix}_$name"
}
if (freshNameIds.contains(fullName)) {
val id = freshNameIds(fullName)
freshNameIds(fullName) = id + 1
s"$fullName$id"
} else {
s"${freshNamePrefix}_$name${curId.getAndIncrement}"
freshNameIds += fullName -> 1
fullName
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,26 @@ case class GetArrayStructFields(
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, eval => {
val n = ctx.freshName("n")
val values = ctx.freshName("values")
val j = ctx.freshName("j")
val row = ctx.freshName("row")
s"""
final int n = $eval.numElements();
final Object[] values = new Object[n];
for (int j = 0; j < n; j++) {
if ($eval.isNullAt(j)) {
values[j] = null;
final int $n = $eval.numElements();
final Object[] $values = new Object[$n];
for (int $j = 0; $j < $n; $j++) {
if ($eval.isNullAt($j)) {
$values[$j] = null;
} else {
final InternalRow row = $eval.getStruct(j, $numFields);
if (row.isNullAt($ordinal)) {
values[j] = null;
final InternalRow $row = $eval.getStruct($j, $numFields);
if ($row.isNullAt($ordinal)) {
$values[$j] = null;
} else {
values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
$values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
}
}
}
${ev.value} = new $arrayClass(values);
${ev.value} = new $arrayClass($values);
"""
})
}
Expand Down Expand Up @@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
s"""
final int index = (int) $eval2;
if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) {
final int $index = (int) $eval2;
if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(eval1, dataType, "index")};
${ev.value} = ${ctx.getValue(eval1, dataType, index)};
}
"""
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
s"""
| while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n")}
| ${consume(ctx, columns)}
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| }
""".stripMargin
}
Expand Down Expand Up @@ -236,15 +236,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])

private Object[] references;
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public GeneratedIterator(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
this.references = references;
${ctx.initMutableStates()}
}

${ctx.declareAddedFunctions()}

protected void processNext() throws java.io.IOException {
$code
${code.trim}
}
}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ case class TungstenAggregate(
| $doAgg();
|
| // output the result
| $genResult
| ${genResult.trim}
|
| ${consume(ctx, resultVars)}
| ${consume(ctx, resultVars).trim}
| }
""".stripMargin
}
Expand Down Expand Up @@ -242,9 +242,9 @@ case class TungstenAggregate(
}
s"""
| // do aggregate
| ${aggVals.map(_.code).mkString("\n")}
| ${aggVals.map(_.code).mkString("\n").trim}
| // update aggregation buffer
| ${updates.mkString("")}
| ${updates.mkString("\n").trim}
""".stripMargin
}

Expand Down Expand Up @@ -523,7 +523,7 @@ case class TungstenAggregate(
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
// generate grouping key
${keyCode.code}
${keyCode.code.trim}
UnsafeRow $buffer = null;
if ($checkFallback) {
// try to get the buffer from hash map
Expand All @@ -547,9 +547,9 @@ case class TungstenAggregate(
$incCounter

// evaluate aggregate function
${evals.map(_.code).mkString("\n")}
${evals.map(_.code).mkString("\n").trim}
// update aggregate buffer
${updates.mkString("\n")}
${updates.mkString("\n").trim}
"""
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,14 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
BindReferences.bindReference(condition, child.output))
ctx.currentVars = input
val eval = expr.gen(ctx)
val nullCheck = if (expr.nullable) {
s"!${eval.isNull} &&"
} else {
s""
}
s"""
| ${eval.code}
| if (!${eval.isNull} && ${eval.value}) {
| if ($nullCheck ${eval.value}) {
| ${consume(ctx, ctx.currentVars)}
| }
""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
// These benchmark are skipped in normal build
ignore("benchmark") {
// testWholeStage(200 << 20)
// testStddev(20 << 20)
// testStatFunctions(20 << 20)
// testAggregateWithKey(20 << 20)
// testBytesToBytesMap(1024 * 1024 * 50)
}
Expand Down