From 8f53ed24e39fb9d43594cb335e1fe084d9791a0e Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 16 Jun 2017 18:15:22 -0700 Subject: [PATCH] Prototype of codegen for ArrowWriter. --- .../python/VectorizedPythonRunner.scala | 180 +++++++++++++----- 1 file changed, 135 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/VectorizedPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/VectorizedPythonRunner.scala index d462d916bc079..094cb44e568de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/VectorizedPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/VectorizedPythonRunner.scala @@ -17,20 +17,21 @@ package org.apache.spark.sql.execution.python -import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream} +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, OutputStream} import java.net.Socket import scala.collection.JavaConverters._ import org.apache.arrow.vector._ import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter} -import org.apache.arrow.vector.types.pojo.Schema +import org.apache.arrow.vector.types.pojo._ import org.apache.spark.{SparkEnv, SparkFiles, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRDD, SpecialLengths} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator} import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ArrowUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -192,22 +193,6 @@ class VectorizedPythonRunner( this.interrupt() } - private def writeValue( - vector: FieldVector, - rowId: Int, - row: InternalRow, - columnIndex: Int): Unit = vector match { - // todo: since we know the batch size. can we pre-allocate memory for VectorSchemaRoot and - // call set instead of setSafe here? - // todo: null handling - case v: NullableIntVector => - v.getMutator.setSafe(rowId, row.getInt(columnIndex)) - case v: NullableBigIntVector => - v.getMutator.setSafe(rowId, row.getLong(columnIndex)) - case v: NullableVarBinaryVector => - v.getMutator.set(rowId, row.getUTF8String(columnIndex).getBytes) - } - override def run(): Unit = Utils.logUncaughtExceptions { try { TaskContext.setTaskContext(context) @@ -266,32 +251,9 @@ class VectorizedPythonRunner( } dataOut.flush() - val root = VectorSchemaRoot.create(schema, ArrowColumnVector.allocator) - // TODO: does ArrowStreamWriter buffer data? - // TODO: who decides the dictionary? - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() - - val fields = schema.getFields.toArray - while (inputRows.hasNext) { - var rowId = 0 - root.getFieldVectors.asScala.foreach(_.allocateNew()) - while (inputRows.hasNext && rowId < batchSize) { - val row = inputRows.next() - var columnIndex = 0 - while (columnIndex < fields.length) { - val vector = root.getFieldVectors.get(columnIndex) - writeValue(vector, rowId, row, columnIndex) - columnIndex += 1 - } - rowId += 1 - } - root.getFieldVectors.asScala.foreach(_.getMutator.setValueCount(rowId)) - root.setRowCount(rowId) - writer.writeBatch() - } - writer.end() - root.close() + val arrowWriter = GenerateArrowWriter.generate(schema) + arrowWriter.initialize(batchSize) + arrowWriter.writeAll(inputRows, dataOut) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -314,3 +276,131 @@ class VectorizedPythonRunner( } } + +abstract class ArrowWriter { + def initialize(batchSize: Int): Unit + def writeAll(inputRows: Iterator[InternalRow], out: OutputStream): Unit +} + +object GenerateArrowWriter extends CodeGenerator[Schema, ArrowWriter] with Logging { + + protected def canonicalize(in: Schema): Schema = in + protected def bind(in: Schema, inputSchema: Seq[Attribute]): Schema = in + + private def fieldTypeToVectorClass(arrowType: ArrowType, nullable: Boolean): String = + (arrowType, nullable) match { + case (ArrowType.Bool.INSTANCE, true) => + classOf[NullableBitVector].getName + case (intType: ArrowType.Int, true) if intType.getBitWidth() == 8 * 4 => + classOf[NullableIntVector].getName + case (intType: ArrowType.Int, true) if intType.getBitWidth() == 8 * 8 => + classOf[NullableBigIntVector].getName + case (ArrowType.Utf8.INSTANCE, true) => + classOf[NullableVarCharVector].getName + } + + protected def create(schema: Schema): ArrowWriter = { + val ctx = newCodeGenContext() + val schemaName = ctx.addReferenceObj("schema", schema) + + val (vectorNames, vectorTypes) = schema.getFields().asScala.map { field => + val name = ctx.freshName(s"${field.getName}Vector") + val fieldType = field.getFieldType + (name, (fieldType.getType, fieldType.isNullable)) + }.unzip + + val initFieldVectors = vectorNames.zip(vectorTypes).zipWithIndex.map { + case ((name, (arrowType, nullable)), idx) => + val cls = fieldTypeToVectorClass(arrowType, nullable) + s"$cls $name = ($cls) root.getFieldVectors().get($idx);" + }.mkString("\n") + + val allocateMemories = vectorNames.map { name => + s"$name.allocateNew();" + }.mkString("\n") + + val writeValues = vectorNames.zip(vectorTypes.map(_._1)).zipWithIndex.map { + case ((name, ArrowType.Bool.INSTANCE), idx) => + s"$name.getMutator().setSafe(rowId, row.getBoolean($idx) ? 1 : 0);" + case ((name, intType: ArrowType.Int), idx) if intType.getBitWidth() == 8 * 4 => + s"$name.getMutator().setSafe(rowId, row.getInt($idx));" + case ((name, intType: ArrowType.Int), idx) if intType.getBitWidth() == 8 * 8 => + s"$name.getMutator().setSafe(rowId, row.getLong($idx));" + case ((name, ArrowType.Utf8.INSTANCE), idx) => + s"$name.getMutator().set(rowId, row.getUTF8String($idx).getBytes());" + }.mkString("\n") + + val setValueCounts = vectorNames.map { name => + s"$name.getMutator().setValueCount(rowId);" + }.mkString("\n") + + val codeBody = s""" + import java.io.IOException; + import java.io.OutputStream; + import scala.collection.Iterator; + import ${classOf[VectorSchemaRoot].getName}; + import ${classOf[ArrowStreamWriter].getName}; + import ${classOf[ArrowWriter].getName}; + import ${classOf[ArrowColumnVector].getName}; + + public SpecificArrowWriter generate(Object[] references) { + return new SpecificArrowWriter(references); + } + + class SpecificArrowWriter extends ArrowWriter { + + private Object[] references; + ${ctx.declareMutableStates()} + + private int batchSize; + + public SpecificArrowWriter(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} + } + + ${ctx.declareAddedFunctions()} + + public void initialize(int batchSize) { + this.batchSize = batchSize; + } + + public void writeAll(Iterator inputRows, OutputStream out) throws IOException { + VectorSchemaRoot root = VectorSchemaRoot.create($schemaName, ArrowColumnVector.allocator); + $initFieldVectors + + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out); + writer.start(); + + try { + while (inputRows.hasNext()) { + int rowId = 0; + $allocateMemories + while (inputRows.hasNext() && rowId < batchSize) { + InternalRow row = (InternalRow) inputRows.next(); + $writeValues + rowId += 1; + } + $setValueCounts + root.setRowCount(rowId); + writer.writeBatch(); + } + } finally { + writer.end(); + out.flush(); + root.close(); + } + } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} + } + """ + + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + logDebug(s"Generated ArrowWriter:\n${CodeFormatter.format(code)}") + + CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[ArrowWriter] + } +}