Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ class CodegenContext {
def declareMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
mutableStates.distinct.map { case (javaType, variableName, _) =>
mutableStates.distinct.map { case (javaType, variableName, _) if variableName != "" =>
s"private $javaType $variableName;"
case _ => ""
}.mkString("\n")
}

Expand Down Expand Up @@ -188,6 +189,14 @@ class CodegenContext {
/** The variable name of the input row in generated code. */
final var INPUT_ROW = "i"

var isRow = true
var enableColumnCodeGen = false
var iteratorInput = ""
var isRowWrite = true
var generateColumnWrite = false
var rowWriteIdx = ""
var columnarBatch = ""

/**
* The map from a variable name to it's next ID.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,64 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
inputTypes: Seq[DataType],
bufferHolder: String,
isTopLevel: Boolean = false): String = {
var colOutVars: Seq[String] = Seq.empty
val rowWriterClass = classOf[UnsafeRowWriter].getName
val rowWriter = ctx.freshName("rowWriter")
ctx.addMutableState(rowWriterClass, rowWriter,
s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
if (!ctx.generateColumnWrite) {
ctx.addMutableState(rowWriterClass, rowWriter,
s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
} else if (isTopLevel) {
val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"

ctx.columnarBatch = ctx.freshName("columnarBatch")
ctx.addMutableState(s"$columnarBatchClz", ctx.columnarBatch, "")

val metadataEmpty = "org.apache.spark.sql.types.Metadata.empty()"
val columnarBatchAllocate = inputs.zip(inputTypes).zipWithIndex.map {
case ((input, dataType), index) =>
val dt = dataType match {
case udt: UserDefinedType[_] => udt.sqlType
case other => other
}
val dtClsName = dt match {
case FloatType => "org.apache.spark.sql.types.DataTypes.FloatType"
case DoubleType => "org.apache.spark.sql.types.DataTypes.DoubleType"
case _ => throw new UnsupportedOperationException()
}
s"""
new org.apache.spark.sql.types.StructField(
"col$index", $dtClsName, ${(input.isNull != "false")}, $metadataEmpty)
""".stripMargin + (if (inputs.length - 1 != index) "," else "")
}

val resetWriter = if (isTopLevel) {
colOutVars = inputs.indices.map(i => ctx.freshName("colOutInstance" + i))
val columnOutAssigns = colOutVars.zipWithIndex.map { case (name, i) =>
ctx.addMutableState(columnVectorClz, name, "")
s"$name = ${ctx.columnarBatch}.column($i);"
}

val batchSchema = ctx.freshName("batchSchema")
val allocateCS = ctx.freshName("allocateColumnarStorage")
ctx.addNewFunction(allocateCS,
s"""
|void $allocateCS() {
|org.apache.spark.sql.types.StructType $batchSchema =
| new org.apache.spark.sql.types.StructType(
| new org.apache.spark.sql.types.StructField[] {
| ${columnarBatchAllocate.mkString("\n")}
|});
|
|${ctx.columnarBatch} = ${columnarBatchClz}.allocate(
| $batchSchema, org.apache.spark.memory.MemoryMode.ON_HEAP);
|registerColumnarBatch(${ctx.columnarBatch});
|${columnOutAssigns.mkString("", "\n", "\n")}
|}
""".stripMargin)
ctx.addMutableState("", "", s"$allocateCS();");
}

val resetWriter = if (ctx.generateColumnWrite) "" else if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
// which means its fixed-size region always in the same position, so we don't need to call
// `reset` to set up its fixed-size region every time.
Expand All @@ -100,14 +152,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
val tmpCursor = ctx.freshName("tmpCursor")

val setNull = dt match {
val setNull = if (ctx.generateColumnWrite) {
s"${colOutVars(index)}.putNull(${ctx.rowWriteIdx});"
} else dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
// Can't call setNullAt() for DecimalType with precision larger than 18.
s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});"
case _ => s"$rowWriter.setNullAt($index);"
}

val writeField = dt match {
val writeField = if (ctx.generateColumnWrite) {
s"""
System.out.println("rowIdx["+${ctx.rowWriteIdx}+"]: v="+${input.value});
${colOutVars(index)}.putFloat(${ctx.rowWriteIdx}, ${input.value});
""".stripMargin
} else dt match {
case t: StructType =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
Expand Down Expand Up @@ -299,19 +358,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val exprTypes = expressions.map(_.dataType)

val numVarLenFields = exprTypes.count {
val numVarLenFields = if (ctx.generateColumnWrite) 0 else exprTypes.count {
case dt if UnsafeRow.isFixedLength(dt) => false
// TODO: consider large decimal and interval type
case _ => true
}

val result = ctx.freshName("result")
ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});")
if (!ctx.generateColumnWrite) {
ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});")
}

val holder = ctx.freshName("holder")
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, holder,
s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});")
if (!ctx.generateColumnWrite) {
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, holder,
s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});")
}

val resetBufferHolder = if (numVarLenFields == 0) {
""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@
import org.apache.spark.TaskContext;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
import org.apache.spark.sql.execution.vectorized.ColumnVector;

/**
* An iterator interface used to pull the output from generated function for multiple operators
* (whole stage codegen).
*/
public abstract class BufferedRowIterator {
protected LinkedList<InternalRow> currentRows = new LinkedList<>();
protected ColumnarBatch columnarBatch;
protected java.util.Iterator<ColumnarBatch.Row> rowIterator;
protected boolean isColumnarBatchAccessed = false;
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);
private long startTimeNs = System.nanoTime();
Expand All @@ -42,11 +47,15 @@ public boolean hasNext() throws IOException {
if (currentRows.isEmpty()) {
processNext();
}
return !currentRows.isEmpty();
if (!isColumnarBatchAccessed) { return !currentRows.isEmpty(); }
if (rowIterator == null) { rowIterator = columnarBatch.rowIterator(); }
return rowIterator.hasNext();
}

public InternalRow next() {
return currentRows.remove();
if (!isColumnarBatchAccessed) { return currentRows.remove(); }
if (rowIterator == null) { rowIterator = columnarBatch.rowIterator(); }
return rowIterator.next().copyUnsafeRow();
}

/**
Expand Down Expand Up @@ -75,9 +84,16 @@ protected void append(InternalRow row) {
* If it returns true, the caller should exit the loop (return from processNext()).
*/
protected boolean shouldStop() {
return !currentRows.isEmpty();
if (!isColumnarBatchAccessed) { return !currentRows.isEmpty(); }
return false; // TODO: currentIdx < allocated # of rows of CS
}

protected boolean isColumnarBatch() { return this.columnarBatch != null; }
protected int numColumns() { return this.columnarBatch.numCols(); }
protected int numRows() { return this.columnarBatch.numRows(); }
protected ColumnVector column(int i) { return this.columnarBatch.column(i); }
protected void registerColumnarBatch(ColumnarBatch columnarBatch) { this.columnarBatch = columnarBatch; }

/**
* Increase the peak execution memory for current task.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,28 @@ public InternalRow copy() {
return row;
}

public UnsafeRow copyUnsafeRow() {
UnsafeRow row = new UnsafeRow(columns.length);
int fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields();
byte[] buffer = new byte[fixedSize + 64]; // 64 is margin
row.pointTo(buffer, buffer.length);
for (int i = 0; i < numFields(); i++) {
if (isNullAt(i)) {
row.setNullAt(i);
} else {
DataType dt = columns[i].dataType();
if (dt instanceof FloatType) {
row.setFloat(i, getFloat(i));
} else if (dt instanceof DoubleType) {
row.setDouble(i, getDouble(i));
} else {
throw new RuntimeException("Not implemented. " + dt);
}
}
}
return row;
}

@Override
public boolean anyNull() {
throw new NotImplementedException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,19 +304,28 @@ private[sql] case class BatchedDataSourceScanExec(
|}""".stripMargin)

ctx.currentVars = null
ctx.isRow = false // always false
ctx.isRowWrite = false // always false
val rowidx = ctx.freshName("rowIdx")
val rowWriteIdx = ctx.freshName("rowWriteIdx")
ctx.rowWriteIdx = rowWriteIdx
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
}
s"""
val isColumnarBatchAccessed = if (ctx.isRowWrite) "" else "isColumnarBatchAccessed = true;"
val source = s"""
|if ($batch == null) {
| $isColumnarBatchAccessed
| $nextBatch();
|}
|int $rowWriteIdx = 0;
|while ($batch != null) {
| int numRows = $batch.numRows();
| while ($idx < numRows) {
| int $rowidx = $idx++;
| ${consume(ctx, columnsBatchInput).trim}
| $rowWriteIdx++;
| ${ctx.columnarBatch}.setNumRows($rowWriteIdx);
| if (shouldStop()) return;
| }
| $batch = null;
Expand All @@ -325,6 +334,9 @@ private[sql] case class BatchedDataSourceScanExec(
|$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
|$scanTimeTotalNs = 0;
""".stripMargin
ctx.isRowWrite = true // always true
ctx.isRow = true // always true
source
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ trait CodegenSupport extends SparkPlan {
val evaluateInputs = evaluateVariables(outputVars)
// generate the code to create a UnsafeRow
ctx.currentVars = outputVars
ctx.generateColumnWrite = !ctx.isRowWrite && parent.isInstanceOf[WholeStageCodegenExec]
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
ctx.generateColumnWrite = false
val code = s"""
|$evaluateInputs
|${ev.code.trim}
Expand Down Expand Up @@ -352,13 +354,36 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(index, Array(iter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
if (!buffer.isColumnarBatch()) {
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()
}
} else {
new ColumnIterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()

override def computeColumn: Unit = {
buffer.processNext
}
override def numColumns: Integer = buffer.numColumns
override def numRows: Integer = buffer.numRows
override def column(i: Integer):
org.apache.spark.sql.execution.vectorized.ColumnVector = buffer.column(i)
override def hasNextRow: Boolean = {
if (numRows == 0) { buffer.processNext }
if (rowIdx < numRows) true else false
}
}
override def next: InternalRow = buffer.next()
}
}
} else {
Expand Down Expand Up @@ -394,9 +419,10 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
} else {
""
}
val append = if (ctx.isRowWrite) s"append(${row.value}$doCopy);" else ""
s"""
|${row.code}
|append(${row.value}$doCopy);
|$append
""".stripMargin.trim
}

Expand All @@ -416,6 +442,17 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
override def simpleString: String = "WholeStageCodegen"
}

abstract class ColumnIterator[T] extends Iterator[T] {
def computeColumn: Unit = { }
def columnarBatch: T = { null.asInstanceOf[T] }
def numColumns: Integer = { 0 }
def numRows: Integer = { 0 }
def column(i: Integer): org.apache.spark.sql.execution.vectorized.ColumnVector = {
null
}
def hasNextRow: Boolean = { false }
var rowIdx: Integer = 0
}

/**
* Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
Expand Down
Loading