Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,9 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S
}

private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {

/**
* Marks the end of a stream written with [[serializeStream()]].
*/
private[this] val EOF: Int = -1

/**
* Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
* length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
* The end of the stream is denoted by a record with the special length `EOF` (-1).
*/
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
Expand Down Expand Up @@ -92,7 +85,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst

override def close(): Unit = {
writeBuffer = null
dOut.writeInt(EOF)
dOut.close()
}
}
Expand All @@ -104,12 +96,20 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow()
private[this] var rowTuple: (Int, UnsafeRow) = (0, row)
private[this] val EOF: Int = -1

override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
new Iterator[(Int, UnsafeRow)] {
private[this] var rowSize: Int = dIn.readInt()
if (rowSize == EOF) dIn.close()

private[this] def readSize(): Int = try {
dIn.readInt()
} catch {
case e: EOFException =>
dIn.close()
EOF
}

private[this] var rowSize: Int = readSize()
override def hasNext: Boolean = rowSize != EOF

override def next(): (Int, UnsafeRow) = {
Expand All @@ -118,7 +118,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
}
ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
rowSize = dIn.readInt() // read the next row's size
rowSize = readSize()
if (rowSize == EOF) { // We are returning the last row in this stream
dIn.close()
val _rowTuple = rowTuple
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.execution

import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream}

import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.Utils
Expand All @@ -41,7 +42,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
}
}

class UnsafeRowSerializerSuite extends SparkFunSuite {
class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {

private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
val converter = unsafeRowConverter(schema)
Expand Down Expand Up @@ -87,11 +88,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
}

test("close empty input stream") {
val baos = new ByteArrayOutputStream()
val dout = new DataOutputStream(baos)
dout.writeInt(-1) // EOF
dout.flush()
val input = new ClosableByteArrayInputStream(baos.toByteArray)
val input = new ClosableByteArrayInputStream(Array.empty)
val serializer = new UnsafeRowSerializer(numFields = 2).newInstance()
val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator
assert(!deserializerIter.hasNext)
Expand Down Expand Up @@ -143,4 +140,16 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
}
}
}

test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") {
val conf = new SparkConf()
.set("spark.shuffle.manager", "tungsten-sort")
sc = new SparkContext("local", "test", conf)
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)))
.asInstanceOf[RDD[Product2[Int, InternalRow]]]
val shuffled = new ShuffledRowRDD(rowsRDD, new UnsafeRowSerializer(2), 2)
shuffled.count()
}
}