Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
<paranamer.version>2.8</paranamer.version>
<maven-antrun.version>1.8</maven-antrun.version>
<commons-crypto.version>1.0.0</commons-crypto.version>
<arrow.version>0.1.0</arrow.version>
<arrow.version>0.1.1-SNAPSHOT</arrow.version>

<test.java.home>${java.home}</test.java.home>
<test.exclude.tags></test.exclude.tags>
Expand Down
21 changes: 18 additions & 3 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,12 +1971,27 @@ def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.spark = SparkSession(cls.sc)

def assertFramesEqual(self, df_with_arrow, df_without):
msg = ("DataFrame from Arrow is not equal" +
("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) +
("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)

def test_arrow_toPandas(self):
schema = StructType().add("key", IntegerType()).add("value", IntegerType())
df = self.spark.createDataFrame([(1, 2), (2, 4), (3, 6), (4, 8)], schema=schema)
schema = StructType([
StructField("str_t", StringType(), True), # Fails in conversion
StructField("int_t", IntegerType(), True), # Fails, without is converted to int64
StructField("long_t", LongType(), True), # Fails if nullable=False
StructField("double_t", DoubleType(), True)])
data = [("a", 1, 10, 2.0),
("b", 2, 20, 4.0),
("c", 3, 30, 6.0)]

df = self.spark.createDataFrame(data, schema=schema)
df = df.select("long_t", "double_t")
pdf = df.toPandas(useArrow=False)
pdf_arrow = df.toPandas(useArrow=True)
self.assertTrue(pdf.equals(pdf_arrow))
self.assertFramesEqual(pdf_arrow, pdf)


if __name__ == "__main__":
Expand Down
140 changes: 113 additions & 27 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.file.ArrowWriter
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
import org.apache.commons.lang3.StringUtils

Expand Down Expand Up @@ -2291,6 +2292,18 @@ class Dataset[T] private[sql](
dt match {
case IntegerType =>
new ArrowType.Int(8 * IntegerType.defaultSize, true)
case LongType =>
new ArrowType.Int(8 * LongType.defaultSize, true)
case StringType =>
ArrowType.List.INSTANCE
case DoubleType =>
new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
case FloatType =>
new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case BooleanType =>
ArrowType.Bool.INSTANCE
case ByteType =>
new ArrowType.Int(8, false)
case _ =>
throw new IllegalArgumentException(s"Unsupported data type")
}
Expand All @@ -2302,8 +2315,16 @@ class Dataset[T] private[sql](
private[sql] def schemaToArrowSchema(schema: StructType): Schema = {
val arrowFields = schema.fields.map {
case StructField(name, dataType, nullable, metadata) =>
// TODO: Consider nested types
new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava)
dataType match {
// TODO: Consider other nested types
case StringType =>
// TODO: Make sure String => List<Utf8>
val itemField =
new Field("item", false, ArrowType.Utf8.INSTANCE, List.empty[Field].asJava)
new Field(name, nullable, dataTypeToArrowType(dataType), List(itemField).asJava)
case _ =>
new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava)
}
}
val arrowSchema = new Schema(arrowFields.toIterable.asJava)
arrowSchema
Expand All @@ -2319,41 +2340,106 @@ class Dataset[T] private[sql](
}

/**
* Infer the validity map from the internal rows.
* @param rows An array of InternalRows
* @param idx Index of current column in the array of InternalRows
* @param field StructField related to the current column
* @param allocator ArrowBuf allocator
* Get an entry from the InternalRow, and then set to ArrowBuf.
* Note: No Null check for the entry.
*/
private def getAndSetToArrow(
row: InternalRow, buf: ArrowBuf, dataType: DataType, ordinal: Int): Unit = {
dataType match {
case NullType =>
case BooleanType =>
buf.writeBoolean(row.getBoolean(ordinal))
case ShortType =>
buf.writeShort(row.getShort(ordinal))
case IntegerType =>
buf.writeInt(row.getInt(ordinal))
case LongType =>
buf.writeLong(row.getLong(ordinal))
case FloatType =>
buf.writeFloat(row.getFloat(ordinal))
case DoubleType =>
buf.writeDouble(row.getDouble(ordinal))
case ByteType =>
buf.writeByte(row.getByte(ordinal))
case _ =>
throw new UnsupportedOperationException(
s"Unsupported data type ${dataType.simpleString}")
}
}

/**
* Convert an array of InternalRow to an ArrowBuf.
*/
private def internalRowToValidityMap(
rows: Array[InternalRow], idx: Int, field: StructField, allocator: RootAllocator): ArrowBuf = {
val buf = allocator.buffer(numBytesOfBitmap(rows.length))
buf
private def internalRowToArrowBuf(
rows: Array[InternalRow],
ordinal: Int,
field: StructField,
allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = {
val numOfRows = rows.length

field.dataType match {
case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType =>
val validity = allocator.buffer(numBytesOfBitmap(numOfRows))
val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
var nullCount = 0
rows.foreach { row =>
if (row.isNullAt(ordinal)) {
nullCount += 1
} else {
getAndSetToArrow(row, buf, field.dataType, ordinal)
}
}

val fieldNode = new ArrowFieldNode(numOfRows, nullCount)

(Array(validity, buf), Array(fieldNode))

case StringType =>
val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows))
val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize)
var bytesCount = 0
bufOffset.writeInt(bytesCount) // Start position
val validityValues = allocator.buffer(numBytesOfBitmap(numOfRows))
val bufValues = allocator.buffer(Int.MaxValue) // TODO: Reduce the size?
var nullCount = 0
rows.foreach { row =>
if (row.isNullAt(ordinal)) {
nullCount += 1
bufOffset.writeInt(bytesCount)
} else {
val bytes = row.getUTF8String(ordinal).getBytes
bytesCount += bytes.length
bufOffset.writeInt(bytesCount)
bufValues.writeBytes(bytes)
}
}

val fieldNodeOffset = if (field.nullable) {
new ArrowFieldNode(numOfRows, nullCount)
} else {
new ArrowFieldNode(numOfRows, 0)
}

val fieldNodeValues = new ArrowFieldNode(bytesCount, 0)

(Array(validityOffset, bufOffset, validityValues, bufValues),
Array(fieldNodeOffset, fieldNodeValues))
}
}

/**
* Transfer an array of InternalRow to an ArrowRecordBatch.
*/
private[sql] def internalRowsToArrowRecordBatch(
rows: Array[InternalRow], allocator: RootAllocator): ArrowRecordBatch = {
val numOfRows = rows.length

val buffers = this.schema.fields.zipWithIndex.flatMap { case (field, idx) =>
val validity = internalRowToValidityMap(rows, idx, field, allocator)
val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
rows.foreach { row => buf.writeInt(row.getInt(idx)) }
Array(validity, buf)
}.toList.asJava
val bufAndField = this.schema.fields.zipWithIndex.map { case (field, ordinal) =>
internalRowToArrowBuf(rows, ordinal, field, allocator)
}

val fieldNodes = this.schema.fields.zipWithIndex.map { case (field, idx) =>
if (field.nullable) {
new ArrowFieldNode(numOfRows, 0)
} else {
new ArrowFieldNode(numOfRows, 0)
}
}.toList.asJava
val buffers = bufAndField.flatMap(_._1).toList.asJava
val fieldNodes = bufAndField.flatMap(_._2).toList.asJava

new ArrowRecordBatch(numOfRows, fieldNodes, buffers)
new ArrowRecordBatch(rows.length, fieldNodes, buffers)
}

/**
Expand Down
Loading