Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@

package org.apache.spark.sql.execution.python

import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream}
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, OutputStream}
import java.net.Socket

import scala.collection.JavaConverters._

import org.apache.arrow.vector._
import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.types.pojo._

import org.apache.spark.{SparkEnv, SparkFiles, TaskContext}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRDD, SpecialLengths}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator}
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ArrowUtils}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -192,22 +193,6 @@ class VectorizedPythonRunner(
this.interrupt()
}

private def writeValue(
vector: FieldVector,
rowId: Int,
row: InternalRow,
columnIndex: Int): Unit = vector match {
// todo: since we know the batch size. can we pre-allocate memory for VectorSchemaRoot and
// call set instead of setSafe here?
// todo: null handling
case v: NullableIntVector =>
v.getMutator.setSafe(rowId, row.getInt(columnIndex))
case v: NullableBigIntVector =>
v.getMutator.setSafe(rowId, row.getLong(columnIndex))
case v: NullableVarBinaryVector =>
v.getMutator.set(rowId, row.getUTF8String(columnIndex).getBytes)
}

override def run(): Unit = Utils.logUncaughtExceptions {
try {
TaskContext.setTaskContext(context)
Expand Down Expand Up @@ -266,32 +251,9 @@ class VectorizedPythonRunner(
}
dataOut.flush()

val root = VectorSchemaRoot.create(schema, ArrowColumnVector.allocator)
// TODO: does ArrowStreamWriter buffer data?
// TODO: who decides the dictionary?
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()

val fields = schema.getFields.toArray
while (inputRows.hasNext) {
var rowId = 0
root.getFieldVectors.asScala.foreach(_.allocateNew())
while (inputRows.hasNext && rowId < batchSize) {
val row = inputRows.next()
var columnIndex = 0
while (columnIndex < fields.length) {
val vector = root.getFieldVectors.get(columnIndex)
writeValue(vector, rowId, row, columnIndex)
columnIndex += 1
}
rowId += 1
}
root.getFieldVectors.asScala.foreach(_.getMutator.setValueCount(rowId))
root.setRowCount(rowId)
writer.writeBatch()
}
writer.end()
root.close()
val arrowWriter = GenerateArrowWriter.generate(schema)
arrowWriter.initialize(batchSize)
arrowWriter.writeAll(inputRows, dataOut)

dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
Expand All @@ -314,3 +276,131 @@ class VectorizedPythonRunner(
}

}

abstract class ArrowWriter {
def initialize(batchSize: Int): Unit
def writeAll(inputRows: Iterator[InternalRow], out: OutputStream): Unit
}

object GenerateArrowWriter extends CodeGenerator[Schema, ArrowWriter] with Logging {

protected def canonicalize(in: Schema): Schema = in
protected def bind(in: Schema, inputSchema: Seq[Attribute]): Schema = in

private def fieldTypeToVectorClass(arrowType: ArrowType, nullable: Boolean): String =
(arrowType, nullable) match {
case (ArrowType.Bool.INSTANCE, true) =>
classOf[NullableBitVector].getName
case (intType: ArrowType.Int, true) if intType.getBitWidth() == 8 * 4 =>
classOf[NullableIntVector].getName
case (intType: ArrowType.Int, true) if intType.getBitWidth() == 8 * 8 =>
classOf[NullableBigIntVector].getName
case (ArrowType.Utf8.INSTANCE, true) =>
classOf[NullableVarCharVector].getName
}

protected def create(schema: Schema): ArrowWriter = {
val ctx = newCodeGenContext()
val schemaName = ctx.addReferenceObj("schema", schema)

val (vectorNames, vectorTypes) = schema.getFields().asScala.map { field =>
val name = ctx.freshName(s"${field.getName}Vector")
val fieldType = field.getFieldType
(name, (fieldType.getType, fieldType.isNullable))
}.unzip

val initFieldVectors = vectorNames.zip(vectorTypes).zipWithIndex.map {
case ((name, (arrowType, nullable)), idx) =>
val cls = fieldTypeToVectorClass(arrowType, nullable)
s"$cls $name = ($cls) root.getFieldVectors().get($idx);"
}.mkString("\n")

val allocateMemories = vectorNames.map { name =>
s"$name.allocateNew();"
}.mkString("\n")

val writeValues = vectorNames.zip(vectorTypes.map(_._1)).zipWithIndex.map {
case ((name, ArrowType.Bool.INSTANCE), idx) =>
s"$name.getMutator().setSafe(rowId, row.getBoolean($idx) ? 1 : 0);"
case ((name, intType: ArrowType.Int), idx) if intType.getBitWidth() == 8 * 4 =>
s"$name.getMutator().setSafe(rowId, row.getInt($idx));"
case ((name, intType: ArrowType.Int), idx) if intType.getBitWidth() == 8 * 8 =>
s"$name.getMutator().setSafe(rowId, row.getLong($idx));"
case ((name, ArrowType.Utf8.INSTANCE), idx) =>
s"$name.getMutator().set(rowId, row.getUTF8String($idx).getBytes());"
}.mkString("\n")

val setValueCounts = vectorNames.map { name =>
s"$name.getMutator().setValueCount(rowId);"
}.mkString("\n")

val codeBody = s"""
import java.io.IOException;
import java.io.OutputStream;
import scala.collection.Iterator;
import ${classOf[VectorSchemaRoot].getName};
import ${classOf[ArrowStreamWriter].getName};
import ${classOf[ArrowWriter].getName};
import ${classOf[ArrowColumnVector].getName};

public SpecificArrowWriter generate(Object[] references) {
return new SpecificArrowWriter(references);
}

class SpecificArrowWriter extends ArrowWriter {

private Object[] references;
${ctx.declareMutableStates()}

private int batchSize;

public SpecificArrowWriter(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
}

${ctx.declareAddedFunctions()}

public void initialize(int batchSize) {
this.batchSize = batchSize;
}

public void writeAll(Iterator<InternalRow> inputRows, OutputStream out) throws IOException {
VectorSchemaRoot root = VectorSchemaRoot.create($schemaName, ArrowColumnVector.allocator);
$initFieldVectors

ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out);
writer.start();

try {
while (inputRows.hasNext()) {
int rowId = 0;
$allocateMemories
while (inputRows.hasNext() && rowId < batchSize) {
InternalRow row = (InternalRow) inputRows.next();
$writeValues
rowId += 1;
}
$setValueCounts
root.setRowCount(rowId);
writer.writeBatch();
}
} finally {
writer.end();
out.flush();
root.close();
}
}

${ctx.initNestedClasses()}
${ctx.declareNestedClasses()}
}
"""

val code = CodeFormatter.stripOverlappingComments(
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
logDebug(s"Generated ArrowWriter:\n${CodeFormatter.format(code)}")

CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[ArrowWriter]
}
}