Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.arrow
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.file._
Expand All @@ -28,14 +30,15 @@ import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
* Store Arrow data in a form that can be serialized by Spark and served to a Python process.
*/
private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable {
private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable {

/**
* Convert the ArrowPayload to an ArrowRecordBatch.
Expand All @@ -50,6 +53,17 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se
def asPythonSerializable: Array[Byte] = payload
}

/**
* Iterator interface to iterate over Arrow record batches and return rows
*/
private[sql] trait ArrowRowIterator extends Iterator[InternalRow] {

/**
* Return the schema loaded from the Arrow record batch being iterated over
*/
def schema: StructType
}

private[sql] object ArrowConverters {

/**
Expand Down Expand Up @@ -110,6 +124,66 @@ private[sql] object ArrowConverters {
}
}

/**
* Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator
* and the schema from the first batch of Arrow data read.
*/
private[sql] def fromPayloadIterator(
payloadIter: Iterator[ArrowPayload],
context: TaskContext): ArrowRowIterator = {
val allocator =
ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue)

new ArrowRowIterator {
private var reader: ArrowFileReader = null
private var schemaRead = StructType(Seq.empty)
private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simply put Iterator.empty here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nextBatch() returns the row iterator, so rowIter needs to be initialized here to the first row in the first batch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I thought the first call of hasNext would initialize it.


context.addTaskCompletionListener { _ =>
closeReader()
allocator.close()
}

override def schema: StructType = schemaRead

override def hasNext: Boolean = rowIter.hasNext || {
closeReader()
if (payloadIter.hasNext) {
rowIter = nextBatch()
true
} else {
allocator.close()
false
}
}

override def next(): InternalRow = rowIter.next()

private def closeReader(): Unit = {
if (reader != null) {
reader.close()
reader = null
}
}

private def nextBatch(): Iterator[InternalRow] = {
val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable)
reader = new ArrowFileReader(in, allocator)
reader.loadNextBatch() // throws IOException
val root = reader.getVectorSchemaRoot // throws IOException
schemaRead = ArrowUtils.fromArrowSchema(root.getSchema)

val columns = root.getFieldVectors.asScala.map { vector =>
new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
}.toArray

val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount)
batch.setNumRows(root.getRowCount)
batch.rowIterator().asScala
}
}
}

/**
* Convert a byte array to an ArrowRecordBatch.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import org.apache.arrow.vector.file.json.JsonFileReader
import org.apache.arrow.vector.util.Validator
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -1629,6 +1630,32 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
}
}

test("roundtrip payloads") {
val inputRows = (0 until 9).map { i =>
InternalRow(i)
} :+ InternalRow(null)

val schema = StructType(Seq(StructField("int", IntegerType, nullable = true)))

val ctx = TaskContext.empty()
val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, ctx)
val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx)

assert(schema.equals(outputRowIter.schema))

var count = 0
outputRowIter.zipWithIndex.foreach { case (row, i) =>
if (i != 9) {
assert(row.getInt(0) == i)
} else {
assert(row.isNullAt(0))
}
count += 1
}

assert(count == inputRows.length)
}

/** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */
private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = {
// NOTE: coalesce to single partition because can only load 1 batch in validator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

import org.apache.arrow.vector.NullableIntVector

import org.apache.spark.SparkFunSuite
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -1261,4 +1264,55 @@ class ColumnarBatchSuite extends SparkFunSuite {
s"vectorized reader"))
}
}

test("create columnar batch from Arrow column vectors") {
val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
vector1.allocateNew()
val mutator1 = vector1.getMutator()
val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
vector2.allocateNew()
val mutator2 = vector2.getMutator()

(0 until 10).foreach { i =>
mutator1.setSafe(i, i)
mutator2.setSafe(i + 1, i)
}
mutator1.setNull(10)
mutator1.setValueCount(11)
mutator2.setNull(0)
mutator2.setValueCount(11)

val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2))

val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType)))
val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11)
batch.setNumRows(11)

assert(batch.numCols() == 2)
assert(batch.numRows() == 11)

val rowIter = batch.rowIterator().asScala
rowIter.zipWithIndex.foreach { case (row, i) =>
if (i == 10) {
assert(row.isNullAt(0))
} else {
assert(row.getInt(0) == i)
}
if (i == 0) {
assert(row.isNullAt(1))
} else {
assert(row.getInt(1) == i - 1)
}
}

intercept[java.lang.AssertionError] {
batch.getRow(100)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that is strange. I'll take a look, thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably because the assert is being compiled out.. This should probably not be in the test then.

Copy link
Member

@dongjoon-hyun dongjoon-hyun Aug 31, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, please check the error message here. Please ignore this.

Copy link
Member Author

@BryanCutler BryanCutler Aug 31, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is that if the Java assertion is compiled out, then no error is produced and the test fails.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just made #19098 to remove this check - it's not really testing the functionality added here anyway but maybe another test should be added for checkout index out of bounds errors.

}

batch.close()
allocator.close()
}
}