From f681d524f8f6986d2e05851814d67d2a3a858f0e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 14 Dec 2016 13:42:02 -0800 Subject: [PATCH 01/56] Inital attempt to integrate Arrow for use in dataframe.toPandas. Conversion has basic data types and is working for small datasets with longs, doubles. Using Arrow 0.1.1-SNAPSHOT dependency. --- pom.xml | 20 ++ python/pyspark/serializers.py | 18 ++ python/pyspark/sql/dataframe.py | 20 +- python/pyspark/sql/tests.py | 40 ++++ sql/core/pom.xml | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 211 +++++++++++++++++- .../spark/sql/DatasetToArrowSuite.scala | 172 ++++++++++++++ 7 files changed, 480 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala diff --git a/pom.xml b/pom.xml index c1174593c192..17f6e5492871 100644 --- a/pom.xml +++ b/pom.xml @@ -184,6 +184,7 @@ 2.6 1.8 1.0.0 + 0.1.1-SNAPSHOT ${java.home} @@ -1871,6 +1872,25 @@ paranamer ${paranamer.version} + + org.apache.arrow + arrow-vector + ${arrow.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + + org.slf4j + log4j-over-slf4j + + + diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ea5e00e9eeef..c291786c8452 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -182,6 +182,24 @@ def loads(self, obj): raise NotImplementedError +class ArrowSerializer(FramedSerializer): + + """ + Serializes an Arrow stream. + """ + + def dumps(self, obj): + raise NotImplementedError + + def loads(self, obj): + from pyarrow.ipc import ArrowFileReader + reader = ArrowFileReader(obj) + return reader.get_record_batch(0) + + def __repr__(self): + return "ArrowSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 70efeaf0160c..91ae27cab933 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -27,7 +27,7 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -36,6 +36,7 @@ from pyspark.sql.streaming import DataStreamWriter from pyspark.sql.types import * + __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -390,6 +391,15 @@ def collect(self): port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + @ignore_unicode_prefix + @since(2.0) + def collectAsArrow(self): + """Returns all the records as an ArrowRecordBatch + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(port, ArrowSerializer()))[0] + @ignore_unicode_prefix @since(2.0) def toLocalIterator(self): @@ -1597,7 +1607,7 @@ def toDF(self, *cols): return DataFrame(jdf, self.sql_ctx) @since(1.3) - def toPandas(self): + def toPandas(self, useArrow=False): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. @@ -1611,7 +1621,11 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + if useArrow: + return self.collectAsArrow().to_pandas() + else: + return pd.DataFrame.from_records(self.collect(), columns=self.columns) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9058443285ac..69984ff0e9a7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -56,6 +56,15 @@ from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException +_have_arrow = False +try: + import pyarrow + _have_arrow = True +except: + # No Arrow, but that's okay, we'll skip those tests + pass + + class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2338,6 +2347,37 @@ def range_frame_match(): importlib.reload(window) +@unittest.skipIf(not _have_arrow, "Arrow not installed") +class ArrowTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + def assertFramesEqual(self, df_with_arrow, df_without): + msg = ("DataFrame from Arrow is not equal" + + ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + + ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) + self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + + def test_arrow_toPandas(self): + schema = StructType([ + StructField("str_t", StringType(), True), # Fails in conversion + StructField("int_t", IntegerType(), True), # Fails, without is converted to int64 + StructField("long_t", LongType(), True), # Fails if nullable=False + StructField("double_t", DoubleType(), True)]) + data = [("a", 1, 10, 2.0), + ("b", 2, 20, 4.0), + ("c", 3, 30, 6.0)] + + df = self.spark.createDataFrame(data, schema=schema) + df = df.select("long_t", "double_t") + pdf = df.toPandas(useArrow=False) + pdf_arrow = df.toPandas(useArrow=True) + self.assertFramesEqual(pdf_arrow, pdf) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 69d797b47915..cebecf7dff83 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,6 +103,10 @@ jackson-databind ${fasterxml.jackson.version} + + org.apache.arrow + arrow-vector + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 38a24cc8ed8c..a980291ad5e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql -import java.io.CharArrayWriter +import java.io.{ByteArrayOutputStream, CharArrayWriter} +import java.nio.channels.Channels import java.sql.{Date, Timestamp} import java.util.TimeZone @@ -26,6 +27,12 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal +import io.netty.buffer.ArrowBuf +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.file.ArrowWriter +import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} @@ -56,6 +63,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils + private[sql] object Dataset { def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) @@ -2363,7 +2371,185 @@ class Dataset[T] private[sql]( } /** - * Return an iterator that contains all rows in this Dataset. + * Transform Spark DataType to Arrow ArrowType. + */ + private[sql] def dataTypeToArrowType(dt: DataType): ArrowType = { + dt match { + case IntegerType => + new ArrowType.Int(8 * IntegerType.defaultSize, true) + case LongType => + new ArrowType.Int(8 * LongType.defaultSize, true) + case StringType => + ArrowType.List.INSTANCE + case DoubleType => + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case FloatType => + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case BooleanType => + ArrowType.Bool.INSTANCE + case ByteType => + new ArrowType.Int(8, false) + case _ => + throw new IllegalArgumentException(s"Unsupported data type") + } + } + + /** + * Transform Spark StructType to Arrow Schema. + */ + private[sql] def schemaToArrowSchema(schema: StructType): Schema = { + val arrowFields = schema.fields.map { + case StructField(name, dataType, nullable, metadata) => + dataType match { + // TODO: Consider other nested types + case StringType => + // TODO: Make sure String => List + val itemField = + new Field("item", false, ArrowType.Utf8.INSTANCE, List.empty[Field].asJava) + new Field(name, nullable, dataTypeToArrowType(dataType), List(itemField).asJava) + case _ => + new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) + } + } + val arrowSchema = new Schema(arrowFields.toIterable.asJava) + arrowSchema + } + + /** + * Compute the number of bytes needed to build validity map. According to + * [Arrow Layout](https://github.com/apache/arrow/blob/master/format/Layout.md#null-bitmaps), + * the length of the validity bitmap should be multiples of 64 bytes. + */ + private def numBytesOfBitmap(numOfRows: Int): Int = { + Math.ceil(numOfRows / 64.0).toInt * 8 + } + + /** + * Get an entry from the InternalRow, and then set to ArrowBuf. + * Note: No Null check for the entry. + */ + private def getAndSetToArrow( + row: InternalRow, buf: ArrowBuf, dataType: DataType, ordinal: Int): Unit = { + dataType match { + case NullType => + case BooleanType => + buf.writeBoolean(row.getBoolean(ordinal)) + case ShortType => + buf.writeShort(row.getShort(ordinal)) + case IntegerType => + buf.writeInt(row.getInt(ordinal)) + case LongType => + buf.writeLong(row.getLong(ordinal)) + case FloatType => + buf.writeFloat(row.getFloat(ordinal)) + case DoubleType => + buf.writeDouble(row.getDouble(ordinal)) + case ByteType => + buf.writeByte(row.getByte(ordinal)) + case _ => + throw new UnsupportedOperationException( + s"Unsupported data type ${dataType.simpleString}") + } + } + + /** + * Convert an array of InternalRow to an ArrowBuf. + */ + private def internalRowToArrowBuf( + rows: Array[InternalRow], + ordinal: Int, + field: StructField, + allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = { + val numOfRows = rows.length + + field.dataType match { + case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => + val validity = allocator.buffer(numBytesOfBitmap(numOfRows)) + val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) + var nullCount = 0 + rows.foreach { row => + if (row.isNullAt(ordinal)) { + nullCount += 1 + } else { + getAndSetToArrow(row, buf, field.dataType, ordinal) + } + } + + val fieldNode = new ArrowFieldNode(numOfRows, nullCount) + + (Array(validity, buf), Array(fieldNode)) + + case StringType => + val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows)) + val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize) + var bytesCount = 0 + bufOffset.writeInt(bytesCount) // Start position + val validityValues = allocator.buffer(numBytesOfBitmap(numOfRows)) + val bufValues = allocator.buffer(Int.MaxValue) // TODO: Reduce the size? + var nullCount = 0 + rows.foreach { row => + if (row.isNullAt(ordinal)) { + nullCount += 1 + bufOffset.writeInt(bytesCount) + } else { + val bytes = row.getUTF8String(ordinal).getBytes + bytesCount += bytes.length + bufOffset.writeInt(bytesCount) + bufValues.writeBytes(bytes) + } + } + + val fieldNodeOffset = if (field.nullable) { + new ArrowFieldNode(numOfRows, nullCount) + } else { + new ArrowFieldNode(numOfRows, 0) + } + + val fieldNodeValues = new ArrowFieldNode(bytesCount, 0) + + (Array(validityOffset, bufOffset, validityValues, bufValues), + Array(fieldNodeOffset, fieldNodeValues)) + } + } + + /** + * Transfer an array of InternalRow to an ArrowRecordBatch. + */ + private[sql] def internalRowsToArrowRecordBatch( + rows: Array[InternalRow], allocator: RootAllocator): ArrowRecordBatch = { + val bufAndField = this.schema.fields.zipWithIndex.map { case (field, ordinal) => + internalRowToArrowBuf(rows, ordinal, field, allocator) + } + + val buffers = bufAndField.flatMap(_._1).toList.asJava + val fieldNodes = bufAndField.flatMap(_._2).toList.asJava + + new ArrowRecordBatch(rows.length, fieldNodes, buffers) + } + + /** + * Collect a Dataset to an ArrowRecordBatch. + * + * @group action + * @since 2.2.0 + */ + @DeveloperApi + def collectAsArrow(): ArrowRecordBatch = { + val allocator = new RootAllocator(Long.MaxValue) + withNewExecutionId { + try { + val collectedRows = queryExecution.executedPlan.executeCollect() + val recordBatch = internalRowsToArrowRecordBatch(collectedRows, allocator) + recordBatch + } catch { + case e: Exception => + throw e + } + } + } + + /** + * Return an iterator that contains all of [[Row]]s in this Dataset. * * The iterator will consume as much memory as the largest partition in this Dataset. * @@ -2747,6 +2933,27 @@ class Dataset[T] private[sql]( } } + /** + * Collect a Dataset as an ArrowRecordBatch, and serve the ArrowRecordBatch to PySpark. + */ + private[sql] def collectAsArrowToPython(): Int = { + val recordBatch = collectAsArrow() + val arrowSchema = schemaToArrowSchema(this.schema) + val out = new ByteArrayOutputStream() + try { + val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) + writer.writeRecordBatch(recordBatch) + writer.close() + } catch { + case e: Exception => + throw e + } + + withNewExecutionId { + PythonRDD.serveIterator(Iterator(out.toByteArray), "serve-Arrow") + } + } + private[sql] def toPythonIterator(): Int = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala new file mode 100644 index 000000000000..8aec3699c9dd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io._ +import java.net.{InetAddress, Socket} +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.channels.FileChannel + +import scala.util.Random + +import io.netty.buffer.ArrowBuf +import org.apache.arrow.flatbuf.Precision +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.file.ArrowReader +import org.apache.arrow.vector.types.pojo.{ArrowType, Field} + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + + +case class ArrowTestClass(col1: Int, col2: Double, col3: String) + +class DatasetToArrowSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + final val numElements = 4 + @transient var data: Seq[ArrowTestClass] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + data = Seq.fill(numElements)(ArrowTestClass( + Random.nextInt, Random.nextDouble, Random.nextString(Random.nextInt(100)))) + } + + test("Collect as arrow to python") { + val dataset = data.toDS() + + val port = dataset.collectAsArrowToPython() + + val receiver: RecordBatchReceiver = new RecordBatchReceiver + val (buffer, numBytesRead) = receiver.connectAndRead(port) + val channel = receiver.makeFile(buffer) + val reader = new ArrowReader(channel, receiver.allocator) + + val footer = reader.readFooter() + val schema = footer.getSchema + + val numCols = schema.getFields.size() + assert(numCols === dataset.schema.fields.length) + for (i <- 0 until schema.getFields.size()) { + val arrowField = schema.getFields.get(i) + val sparkField = dataset.schema.fields(i) + assert(arrowField.getName === sparkField.name) + assert(arrowField.isNullable === sparkField.nullable) + assert(DatasetToArrowSuite.compareSchemaTypes(arrowField, sparkField)) + } + + val blockMetadata = footer.getRecordBatches + assert(blockMetadata.size() === 1) + + val recordBatch = reader.readRecordBatch(blockMetadata.get(0)) + val nodes = recordBatch.getNodes + assert(nodes.size() === numCols + 1) // +1 for Type String, which has two nodes. + + val firstNode = nodes.get(0) + assert(firstNode.getLength === numElements) + assert(firstNode.getNullCount === 0) + + val buffers = recordBatch.getBuffers + assert(buffers.size() === (numCols + 1) * 2) // +1 for Type String + + assert(receiver.getIntArray(buffers.get(1)) === data.map(_.col1)) + assert(receiver.getDoubleArray(buffers.get(3)) === data.map(_.col2)) + assert(receiver.getStringArray(buffers.get(5), buffers.get(7)) === + data.map(d => UTF8String.fromString(d.col3)).toArray) + } +} + +object DatasetToArrowSuite { + def compareSchemaTypes(arrowField: Field, sparkField: StructField): Boolean = { + val arrowType = arrowField.getType + val sparkType = sparkField.dataType + (arrowType, sparkType) match { + case (_: ArrowType.Int, _: IntegerType) => true + case (_: ArrowType.FloatingPoint, _: DoubleType) => + arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.DOUBLE + case (_: ArrowType.FloatingPoint, _: FloatType) => + arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.SINGLE + case (_: ArrowType.List, _: StringType) => + val subField = arrowField.getChildren + (subField.size() == 1) && subField.get(0).getType.isInstanceOf[ArrowType.Utf8] + case (_: ArrowType.Bool, _: BooleanType) => true + case _ => false + } + } +} + +class RecordBatchReceiver { + + val allocator = new RootAllocator(Long.MaxValue) + + def getIntArray(buf: ArrowBuf): Array[Int] = { + val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer() + val resultArray = Array.ofDim[Int](buffer.remaining()) + buffer.get(resultArray) + resultArray + } + + def getDoubleArray(buf: ArrowBuf): Array[Double] = { + val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer() + val resultArray = Array.ofDim[Double](buffer.remaining()) + buffer.get(resultArray) + resultArray + } + + def getStringArray(bufOffsets: ArrowBuf, bufValues: ArrowBuf): Array[UTF8String] = { + val offsets = getIntArray(bufOffsets) + val lens = offsets.zip(offsets.drop(1)) + .map { case (prevOffset, offset) => offset - prevOffset } + + val values = array(bufValues) + val strings = offsets.zip(lens).map { case (offset, len) => + UTF8String.fromBytes(values, offset, len) + } + strings + } + + private def array(buf: ArrowBuf): Array[Byte] = { + val bytes = Array.ofDim[Byte](buf.readableBytes()) + buf.readBytes(bytes) + bytes + } + + def connectAndRead(port: Int): (Array[Byte], Int) = { + val clientSocket = new Socket(InetAddress.getByName("localhost"), port) + val clientDataIns = new DataInputStream(clientSocket.getInputStream) + val messageLength = clientDataIns.readInt() + val buffer = Array.ofDim[Byte](messageLength) + clientDataIns.readFully(buffer, 0, messageLength) + (buffer, messageLength) + } + + def makeFile(buffer: Array[Byte]): FileChannel = { + val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName).getPath + val arrowFile = new File(tempDir, "arrow-bytes") + val arrowOus = new FileOutputStream(arrowFile.getPath) + arrowOus.write(buffer) + arrowOus.close() + + val arrowIns = new FileInputStream(arrowFile.getPath) + arrowIns.getChannel + } +} From afd57398451d58aa37e92e2b5842e263c1e0705e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 12 Dec 2016 16:50:51 -0500 Subject: [PATCH 02/56] Test suite prototyping for collectAsArrow Changed scope of arrow-tools dependency to test commented out lines to Integration.compareXX that are private to arrow closes #10 --- pom.xml | 20 +++++++ sql/core/pom.xml | 5 ++ .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../resources/test-data/arrowNullInts.json | 31 ++++++++++ .../org/apache/spark/sql/ArrowSuite.scala | 58 +++++++++++++++++++ 5 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/arrowNullInts.json create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala diff --git a/pom.xml b/pom.xml index 17f6e5492871..da6202785021 100644 --- a/pom.xml +++ b/pom.xml @@ -1891,6 +1891,26 @@ + + org.apache.arrow + arrow-tools + ${arrow.version} + test + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + + org.slf4j + log4j-over-slf4j + + + diff --git a/sql/core/pom.xml b/sql/core/pom.xml index cebecf7dff83..77e16c76c9dc 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -107,6 +107,11 @@ org.apache.arrow arrow-vector + + org.apache.arrow + arrow-tools + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a980291ad5e9..6a1382c83b0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2534,8 +2534,8 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @DeveloperApi - def collectAsArrow(): ArrowRecordBatch = { - val allocator = new RootAllocator(Long.MaxValue) + def collectAsArrow( + allocator: RootAllocator = new RootAllocator(Long.MaxValue)): ArrowRecordBatch = { withNewExecutionId { try { val collectedRows = queryExecution.executedPlan.executeCollect() diff --git a/sql/core/src/test/resources/test-data/arrowNullInts.json b/sql/core/src/test/resources/test-data/arrowNullInts.json new file mode 100644 index 000000000000..31b272af7d12 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrowNullInts.json @@ -0,0 +1,31 @@ +{ + "schema": { + "fields": [ + { + "name": "a", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + "batches": [ + { + "count": 4, + "columns": [ + { + "name": "a", + "count": 4, + "VALIDITY": [1, 1, 1, 0], + "DATA": [1, 2, 3, 0] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala new file mode 100644 index 000000000000..123b0f56fb47 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import java.io.File + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.tools.Integration +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.file.json.JsonFileReader + +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + import testImplicits._ + private val nullIntsFile = "test-data/arrowNullInts.json" + + private def testFile(fileName: String): String = { + // TODO: Copied from CSVSuite, find a better way to read test files + Thread.currentThread().getContextClassLoader.getResource(fileName).toString.substring(5) + } + + test("convert int column with null to arrow") { + val df = nullInts + val jsonFilePath = testFile(nullIntsFile) + + val allocator = new RootAllocator(Integer.MAX_VALUE) + val jsonReader = new JsonFileReader(new File(jsonFilePath), allocator) + + val arrowSchema = df.schemaToArrowSchema(df.schema) + val jsonSchema = jsonReader.start() + // TODO - requires changing to public API in arrow, will be addressed in ARROW-411 + //Integration.compareSchemas(arrowSchema, jsonSchema) + + val arrowRecordBatch = df.collectAsArrow(allocator) + val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator) + val vectorLoader = new VectorLoader(arrowRoot) + vectorLoader.load(arrowRecordBatch) + val jsonRoot = jsonReader.read() + + // TODO - requires changing to public API in arrow, will be addressed in ARROW-411 + //Integration.compare(arrowRoot, jsonRoot) + } +} From a4b958e6b149c0734f9c70de2defd8807a3e4972 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 5 Jan 2017 17:21:19 -0500 Subject: [PATCH 03/56] Test compiling against the newest arrow; Fix validity map; Add benchmark script Remove arrow-tools dependency changed zipWithIndex to while loop modified benchmark to work with Python2 timeit closes #13 --- benchmark.py | 41 +++++++++++++++++++ bin/pyspark | 2 +- pom.xml | 20 --------- sql/core/pom.xml | 5 --- .../scala/org/apache/spark/sql/Dataset.scala | 40 ++++++++++++++++-- .../org/apache/spark/sql/ArrowSuite.scala | 11 ++--- 6 files changed, 82 insertions(+), 37 deletions(-) create mode 100644 benchmark.py diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 000000000000..f6e7c0ae8b2b --- /dev/null +++ b/benchmark.py @@ -0,0 +1,41 @@ +import pyspark +import timeit +import random +from pyspark.sql import SparkSession + +numPartition = 8 + +def time(df, repeat, number): + print("toPandas with arrow") + print(timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number)) + + print("toPandas without arrow") + print(timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number)) + +def long(): + return random.randint(0, 10000) + +def double(): + return random.random() + +def genDataLocal(spark, size, columns): + data = [list([fn() for fn in columns]) for x in range(0, size)] + df = spark.createDataFrame(data) + return df + +def genData(spark, size, columns): + rdd = spark.sparkContext\ + .parallelize(range(0, size), numPartition)\ + .map(lambda _: [fn() for fn in columns]) + df = spark.createDataFrame(rdd) + return df + +if __name__ == "__main__": + spark = SparkSession.builder.appName("ArrowBenchmark").getOrCreate() + df = genData(spark, 1000 * 1000, [long, double]) + df.cache() + df.count() + + time(df, 10, 1) + + df.unpersist() diff --git a/bin/pyspark b/bin/pyspark index 98387c2ec5b8..8eeea7716cc9 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$1" + exec "$PYSPARK_DRIVER_PYTHON" -m "$@" exit fi diff --git a/pom.xml b/pom.xml index da6202785021..17f6e5492871 100644 --- a/pom.xml +++ b/pom.xml @@ -1891,26 +1891,6 @@ - - org.apache.arrow - arrow-tools - ${arrow.version} - test - - - com.fasterxml.jackson.core - jackson-annotations - - - com.fasterxml.jackson.core - jackson-databind - - - org.slf4j - log4j-over-slf4j - - - diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 77e16c76c9dc..cebecf7dff83 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -107,11 +107,6 @@ org.apache.arrow arrow-vector - - org.apache.arrow - arrow-tools - test - org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6a1382c83b0d..04de001bf5ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -29,6 +29,7 @@ import scala.util.control.NonFatal import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.BitVector import org.apache.arrow.vector.file.ArrowWriter import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} import org.apache.arrow.vector.types.FloatingPointPrecision @@ -63,7 +64,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils - private[sql] object Dataset { def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) @@ -2424,6 +2424,29 @@ class Dataset[T] private[sql]( Math.ceil(numOfRows / 64.0).toInt * 8 } + private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = { + dataType match { + case NullType => + case BooleanType => + buf.writeBoolean(false) + case ShortType => + buf.writeShort(0) + case IntegerType => + buf.writeInt(0) + case LongType => + buf.writeLong(0L) + case FloatType => + buf.writeFloat(0f) + case DoubleType => + buf.writeDouble(0d) + case ByteType => + buf.writeByte(0) + case _ => + throw new UnsupportedOperationException( + s"Unsupported data type ${dataType.simpleString}") + } + } + /** * Get an entry from the InternalRow, and then set to ArrowBuf. * Note: No Null check for the entry. @@ -2464,20 +2487,29 @@ class Dataset[T] private[sql]( field.dataType match { case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => - val validity = allocator.buffer(numBytesOfBitmap(numOfRows)) + val validityVector = new BitVector("validity", allocator) + val validityMutator = validityVector.getMutator + validityVector.allocateNew(numOfRows) + validityMutator.setValueCount(numOfRows) val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) var nullCount = 0 - rows.foreach { row => + var index = 0 + while (index < rows.length) { + val row = rows(index) if (row.isNullAt(ordinal)) { nullCount += 1 + validityMutator.set(index, 0) + fillArrow(buf, field.dataType) } else { + validityMutator.set(index, 1) getAndSetToArrow(row, buf, field.dataType, ordinal) } + index += 1 } val fieldNode = new ArrowFieldNode(numOfRows, nullCount) - (Array(validity, buf), Array(fieldNode)) + (Array(validityVector.getBuffer, buf), Array(fieldNode)) case StringType => val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 123b0f56fb47..9b1786c83f16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql import java.io.File import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.tools.Integration -import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.{BitVector, VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader +import org.apache.arrow.vector.util.Validator import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { - import testImplicits._ private val nullIntsFile = "test-data/arrowNullInts.json" private def testFile(fileName: String): String = { @@ -43,8 +42,7 @@ class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val arrowSchema = df.schemaToArrowSchema(df.schema) val jsonSchema = jsonReader.start() - // TODO - requires changing to public API in arrow, will be addressed in ARROW-411 - //Integration.compareSchemas(arrowSchema, jsonSchema) + Validator.compareSchemas(arrowSchema, jsonSchema) val arrowRecordBatch = df.collectAsArrow(allocator) val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator) @@ -52,7 +50,6 @@ class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { vectorLoader.load(arrowRecordBatch) val jsonRoot = jsonReader.read() - // TODO - requires changing to public API in arrow, will be addressed in ARROW-411 - //Integration.compare(arrowRoot, jsonRoot) + Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) } } From be508a587e25d51aaa755bd6c6e74795b5287645 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 11 Jan 2017 23:38:15 -0500 Subject: [PATCH 04/56] Fix conversion for String type; refactor related functions to Arrow.scala changed tests to use existing SQLTestData and removed unused files closes #14 --- .../scala/org/apache/spark/sql/Arrow.scala | 228 ++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 200 +-------------- .../resources/test-data/arrowNullInts.json | 1 + .../resources/test-data/arrowNullStrings.json | 34 +++ .../org/apache/spark/sql/ArrowSuite.scala | 24 +- 5 files changed, 282 insertions(+), 205 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala create mode 100644 sql/core/src/test/resources/test-data/arrowNullStrings.json diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala new file mode 100644 index 000000000000..31b90259de0e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -0,0 +1,228 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions + +import io.netty.buffer.ArrowBuf +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.BitVector +import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +object Arrow { + + /** + * Compute the number of bytes needed to build validity map. According to + * [Arrow Layout](https://github.com/apache/arrow/blob/master/format/Layout.md#null-bitmaps), + * the length of the validity bitmap should be multiples of 64 bytes. + */ + private def numBytesOfBitmap(numOfRows: Int): Int = { + Math.ceil(numOfRows / 64.0).toInt * 8 + } + + private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = { + dataType match { + case NullType => + case BooleanType => + buf.writeBoolean(false) + case ShortType => + buf.writeShort(0) + case IntegerType => + buf.writeInt(0) + case LongType => + buf.writeLong(0L) + case FloatType => + buf.writeFloat(0f) + case DoubleType => + buf.writeDouble(0d) + case ByteType => + buf.writeByte(0) + case _ => + throw new UnsupportedOperationException( + s"Unsupported data type ${dataType.simpleString}") + } + } + + /** + * Get an entry from the InternalRow, and then set to ArrowBuf. + * Note: No Null check for the entry. + */ + private def getAndSetToArrow( + row: InternalRow, + buf: ArrowBuf, + dataType: DataType, + ordinal: Int): Unit = { + dataType match { + case NullType => + case BooleanType => + buf.writeBoolean(row.getBoolean(ordinal)) + case ShortType => + buf.writeShort(row.getShort(ordinal)) + case IntegerType => + buf.writeInt(row.getInt(ordinal)) + case LongType => + buf.writeLong(row.getLong(ordinal)) + case FloatType => + buf.writeFloat(row.getFloat(ordinal)) + case DoubleType => + buf.writeDouble(row.getDouble(ordinal)) + case ByteType => + buf.writeByte(row.getByte(ordinal)) + case _ => + throw new UnsupportedOperationException( + s"Unsupported data type ${dataType.simpleString}") + } + } + + /** + * Transfer an array of InternalRow to an ArrowRecordBatch. + */ + def internalRowsToArrowRecordBatch( + rows: Array[InternalRow], + schema: StructType, + allocator: RootAllocator): ArrowRecordBatch = { + val bufAndField = schema.fields.zipWithIndex.map { case (field, ordinal) => + internalRowToArrowBuf(rows, ordinal, field, allocator) + } + + val buffers = bufAndField.flatMap(_._1).toList.asJava + val fieldNodes = bufAndField.flatMap(_._2).toList.asJava + + new ArrowRecordBatch(rows.length, fieldNodes, buffers) + } + + /** + * Convert an array of InternalRow to an ArrowBuf. + */ + def internalRowToArrowBuf( + rows: Array[InternalRow], + ordinal: Int, + field: StructField, + allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = { + val numOfRows = rows.length + + field.dataType match { + case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => + val validityVector = new BitVector("validity", allocator) + val validityMutator = validityVector.getMutator + validityVector.allocateNew(numOfRows) + validityMutator.setValueCount(numOfRows) + + val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) + var nullCount = 0 + var index = 0 + while (index < rows.length) { + val row = rows(index) + if (row.isNullAt(ordinal)) { + nullCount += 1 + validityMutator.set(index, 0) + fillArrow(buf, field.dataType) + } else { + validityMutator.set(index, 1) + getAndSetToArrow(row, buf, field.dataType, ordinal) + } + index += 1 + } + + val fieldNode = new ArrowFieldNode(numOfRows, nullCount) + + (Array(validityVector.getBuffer, buf), Array(fieldNode)) + + case StringType => + val validityVector = new BitVector("validity", allocator) + val validityMutator = validityVector.getMutator() + validityVector.allocateNew(numOfRows) + validityMutator.setValueCount(numOfRows) + + val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize) + var bytesCount = 0 + bufOffset.writeInt(bytesCount) + val bufValues = allocator.buffer(1024) + var nullCount = 0 + rows.zipWithIndex.foreach { case (row, index) => + if (row.isNullAt(ordinal)) { + nullCount += 1 + validityMutator.set(index, 0) + bufOffset.writeInt(bytesCount) + } else { + validityMutator.set(index, 1) + val bytes = row.getUTF8String(ordinal).getBytes + bytesCount += bytes.length + bufOffset.writeInt(bytesCount) + bufValues.writeBytes(bytes) + } + } + + val fieldNode = new ArrowFieldNode(numOfRows, nullCount) + + (Array(validityVector.getBuffer, bufOffset, bufValues), + Array(fieldNode)) + } + } + + private[sql] def schemaToArrowSchema(schema: StructType): Schema = { + val arrowFields = schema.fields.map(sparkFieldToArrowField(_)) + new Schema(arrowFields.toList.asJava) + } + + private[sql] def sparkFieldToArrowField(sparkField: StructField): Field = { + val name = sparkField.name + val dataType = sparkField.dataType + val nullable = sparkField.nullable + + dataType match { + case StructType(fields) => + val childrenFields = fields.map(sparkFieldToArrowField(_)).toList.asJava + new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields) + case _ => + new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) + } + } + + /** + * Transform Spark DataType to Arrow ArrowType. + */ + private[sql] def dataTypeToArrowType(dt: DataType): ArrowType = { + dt match { + case IntegerType => + new ArrowType.Int(8 * IntegerType.defaultSize, true) + case LongType => + new ArrowType.Int(8 * LongType.defaultSize, true) + case StringType => + ArrowType.Utf8.INSTANCE + case DoubleType => + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case FloatType => + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case BooleanType => + ArrowType.Bool.INSTANCE + case ByteType => + new ArrowType.Int(8, false) + case StructType(_) => + ArrowType.Struct.INSTANCE + case _ => + throw new IllegalArgumentException(s"Unsupported data type") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 04de001bf5ee..6fbdb155847d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -27,13 +27,9 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.BitVector import org.apache.arrow.vector.file.ArrowWriter -import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} +import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} @@ -2370,195 +2366,6 @@ class Dataset[T] private[sql]( java.util.Arrays.asList(values : _*) } - /** - * Transform Spark DataType to Arrow ArrowType. - */ - private[sql] def dataTypeToArrowType(dt: DataType): ArrowType = { - dt match { - case IntegerType => - new ArrowType.Int(8 * IntegerType.defaultSize, true) - case LongType => - new ArrowType.Int(8 * LongType.defaultSize, true) - case StringType => - ArrowType.List.INSTANCE - case DoubleType => - new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case FloatType => - new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case BooleanType => - ArrowType.Bool.INSTANCE - case ByteType => - new ArrowType.Int(8, false) - case _ => - throw new IllegalArgumentException(s"Unsupported data type") - } - } - - /** - * Transform Spark StructType to Arrow Schema. - */ - private[sql] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map { - case StructField(name, dataType, nullable, metadata) => - dataType match { - // TODO: Consider other nested types - case StringType => - // TODO: Make sure String => List - val itemField = - new Field("item", false, ArrowType.Utf8.INSTANCE, List.empty[Field].asJava) - new Field(name, nullable, dataTypeToArrowType(dataType), List(itemField).asJava) - case _ => - new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) - } - } - val arrowSchema = new Schema(arrowFields.toIterable.asJava) - arrowSchema - } - - /** - * Compute the number of bytes needed to build validity map. According to - * [Arrow Layout](https://github.com/apache/arrow/blob/master/format/Layout.md#null-bitmaps), - * the length of the validity bitmap should be multiples of 64 bytes. - */ - private def numBytesOfBitmap(numOfRows: Int): Int = { - Math.ceil(numOfRows / 64.0).toInt * 8 - } - - private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = { - dataType match { - case NullType => - case BooleanType => - buf.writeBoolean(false) - case ShortType => - buf.writeShort(0) - case IntegerType => - buf.writeInt(0) - case LongType => - buf.writeLong(0L) - case FloatType => - buf.writeFloat(0f) - case DoubleType => - buf.writeDouble(0d) - case ByteType => - buf.writeByte(0) - case _ => - throw new UnsupportedOperationException( - s"Unsupported data type ${dataType.simpleString}") - } - } - - /** - * Get an entry from the InternalRow, and then set to ArrowBuf. - * Note: No Null check for the entry. - */ - private def getAndSetToArrow( - row: InternalRow, buf: ArrowBuf, dataType: DataType, ordinal: Int): Unit = { - dataType match { - case NullType => - case BooleanType => - buf.writeBoolean(row.getBoolean(ordinal)) - case ShortType => - buf.writeShort(row.getShort(ordinal)) - case IntegerType => - buf.writeInt(row.getInt(ordinal)) - case LongType => - buf.writeLong(row.getLong(ordinal)) - case FloatType => - buf.writeFloat(row.getFloat(ordinal)) - case DoubleType => - buf.writeDouble(row.getDouble(ordinal)) - case ByteType => - buf.writeByte(row.getByte(ordinal)) - case _ => - throw new UnsupportedOperationException( - s"Unsupported data type ${dataType.simpleString}") - } - } - - /** - * Convert an array of InternalRow to an ArrowBuf. - */ - private def internalRowToArrowBuf( - rows: Array[InternalRow], - ordinal: Int, - field: StructField, - allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = { - val numOfRows = rows.length - - field.dataType match { - case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => - val validityVector = new BitVector("validity", allocator) - val validityMutator = validityVector.getMutator - validityVector.allocateNew(numOfRows) - validityMutator.setValueCount(numOfRows) - val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) - var nullCount = 0 - var index = 0 - while (index < rows.length) { - val row = rows(index) - if (row.isNullAt(ordinal)) { - nullCount += 1 - validityMutator.set(index, 0) - fillArrow(buf, field.dataType) - } else { - validityMutator.set(index, 1) - getAndSetToArrow(row, buf, field.dataType, ordinal) - } - index += 1 - } - - val fieldNode = new ArrowFieldNode(numOfRows, nullCount) - - (Array(validityVector.getBuffer, buf), Array(fieldNode)) - - case StringType => - val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows)) - val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize) - var bytesCount = 0 - bufOffset.writeInt(bytesCount) // Start position - val validityValues = allocator.buffer(numBytesOfBitmap(numOfRows)) - val bufValues = allocator.buffer(Int.MaxValue) // TODO: Reduce the size? - var nullCount = 0 - rows.foreach { row => - if (row.isNullAt(ordinal)) { - nullCount += 1 - bufOffset.writeInt(bytesCount) - } else { - val bytes = row.getUTF8String(ordinal).getBytes - bytesCount += bytes.length - bufOffset.writeInt(bytesCount) - bufValues.writeBytes(bytes) - } - } - - val fieldNodeOffset = if (field.nullable) { - new ArrowFieldNode(numOfRows, nullCount) - } else { - new ArrowFieldNode(numOfRows, 0) - } - - val fieldNodeValues = new ArrowFieldNode(bytesCount, 0) - - (Array(validityOffset, bufOffset, validityValues, bufValues), - Array(fieldNodeOffset, fieldNodeValues)) - } - } - - /** - * Transfer an array of InternalRow to an ArrowRecordBatch. - */ - private[sql] def internalRowsToArrowRecordBatch( - rows: Array[InternalRow], allocator: RootAllocator): ArrowRecordBatch = { - val bufAndField = this.schema.fields.zipWithIndex.map { case (field, ordinal) => - internalRowToArrowBuf(rows, ordinal, field, allocator) - } - - val buffers = bufAndField.flatMap(_._1).toList.asJava - val fieldNodes = bufAndField.flatMap(_._2).toList.asJava - - new ArrowRecordBatch(rows.length, fieldNodes, buffers) - } - /** * Collect a Dataset to an ArrowRecordBatch. * @@ -2571,7 +2378,8 @@ class Dataset[T] private[sql]( withNewExecutionId { try { val collectedRows = queryExecution.executedPlan.executeCollect() - val recordBatch = internalRowsToArrowRecordBatch(collectedRows, allocator) + val recordBatch = Arrow.internalRowsToArrowRecordBatch( + collectedRows, this.schema, allocator) recordBatch } catch { case e: Exception => @@ -2970,7 +2778,7 @@ class Dataset[T] private[sql]( */ private[sql] def collectAsArrowToPython(): Int = { val recordBatch = collectAsArrow() - val arrowSchema = schemaToArrowSchema(this.schema) + val arrowSchema = Arrow.schemaToArrowSchema(this.schema) val out = new ByteArrayOutputStream() try { val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) diff --git a/sql/core/src/test/resources/test-data/arrowNullInts.json b/sql/core/src/test/resources/test-data/arrowNullInts.json index 31b272af7d12..1a2447abdc0b 100644 --- a/sql/core/src/test/resources/test-data/arrowNullInts.json +++ b/sql/core/src/test/resources/test-data/arrowNullInts.json @@ -15,6 +15,7 @@ } ] }, + "batches": [ { "count": 4, diff --git a/sql/core/src/test/resources/test-data/arrowNullStrings.json b/sql/core/src/test/resources/test-data/arrowNullStrings.json new file mode 100644 index 000000000000..c93e1e757bc5 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrowNullStrings.json @@ -0,0 +1,34 @@ +{ + "schema": { + "fields": [ + { + "name": "s", + "type": {"name": "utf8"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "OFFSET", "typeBitWidth": 32}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + } + ] + }, + + "batches": [ + { + "count": 3, + "columns": [ + { + "name": "s", + "count": 3, + "VALIDITY": [1, 1, 0], + "OFFSET": [0, 3, 6, 6], + "DATA": ["abc", "ABC", ""] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 9b1786c83f16..036a367cb0db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -19,28 +19,34 @@ package org.apache.spark.sql import java.io.File import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.{BitVector, VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.test.SharedSQLContext -class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { - private val nullIntsFile = "test-data/arrowNullInts.json" +class ArrowSuite extends SharedSQLContext { private def testFile(fileName: String): String = { - // TODO: Copied from CSVSuite, find a better way to read test files - Thread.currentThread().getContextClassLoader.getResource(fileName).toString.substring(5) + Thread.currentThread().getContextClassLoader.getResource(fileName).getFile } test("convert int column with null to arrow") { - val df = nullInts - val jsonFilePath = testFile(nullIntsFile) + testCollect(nullInts, "test-data/arrowNullInts.json") + } + + test("convert string column with null to arrow") { + val nullStringsColOnly = nullStrings.select(nullStrings.columns(1)) + testCollect(nullStringsColOnly, "test-data/arrowNullStrings.json") + } + + private def testCollect(df: DataFrame, arrowFile: String) { + val jsonFilePath = testFile(arrowFile) val allocator = new RootAllocator(Integer.MAX_VALUE) val jsonReader = new JsonFileReader(new File(jsonFilePath), allocator) - val arrowSchema = df.schemaToArrowSchema(df.schema) + val arrowSchema = Arrow.schemaToArrowSchema(df.schema) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) From 5dbad2241f318ce4926d2dc7446dbd81e092d8a3 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 12 Jan 2017 14:40:32 -0800 Subject: [PATCH 05/56] Moved test data files to a sub-dir for arrow, merged dataType matching and cleanup closes #15 --- .../scala/org/apache/spark/sql/Arrow.scala | 137 ++++++-------- .../null-ints.json} | 0 .../null-strings.json} | 0 .../org/apache/spark/sql/ArrowSuite.scala | 5 +- .../spark/sql/DatasetToArrowSuite.scala | 172 ------------------ 5 files changed, 61 insertions(+), 253 deletions(-) rename sql/core/src/test/resources/test-data/{arrowNullInts.json => arrow/null-ints.json} (100%) rename sql/core/src/test/resources/test-data/{arrowNullStrings.json => arrow/null-strings.json} (100%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index 31b90259de0e..1b68a9d0429c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -32,66 +32,70 @@ import org.apache.spark.sql.types._ object Arrow { - /** - * Compute the number of bytes needed to build validity map. According to - * [Arrow Layout](https://github.com/apache/arrow/blob/master/format/Layout.md#null-bitmaps), - * the length of the validity bitmap should be multiples of 64 bytes. - */ - private def numBytesOfBitmap(numOfRows: Int): Int = { - Math.ceil(numOfRows / 64.0).toInt * 8 - } + private case class TypeFuncs(getType: () => ArrowType, + fill: ArrowBuf => Unit, + write: (InternalRow, Int, ArrowBuf) => Unit) - private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = { - dataType match { - case NullType => - case BooleanType => - buf.writeBoolean(false) - case ShortType => - buf.writeShort(0) - case IntegerType => - buf.writeInt(0) - case LongType => - buf.writeLong(0L) - case FloatType => - buf.writeFloat(0f) - case DoubleType => - buf.writeDouble(0d) - case ByteType => - buf.writeByte(0) - case _ => - throw new UnsupportedOperationException( - s"Unsupported data type ${dataType.simpleString}") - } - } + private def getTypeFuncs(dataType: DataType): TypeFuncs = { + val err = s"Unsupported data type ${dataType.simpleString}" - /** - * Get an entry from the InternalRow, and then set to ArrowBuf. - * Note: No Null check for the entry. - */ - private def getAndSetToArrow( - row: InternalRow, - buf: ArrowBuf, - dataType: DataType, - ordinal: Int): Unit = { dataType match { case NullType => + TypeFuncs( + () => ArrowType.Null.INSTANCE, + (buf: ArrowBuf) => (), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => ()) case BooleanType => - buf.writeBoolean(row.getBoolean(ordinal)) + TypeFuncs( + () => ArrowType.Bool.INSTANCE, + (buf: ArrowBuf) => buf.writeBoolean(false), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => + buf.writeBoolean(row.getBoolean(ordinal))) case ShortType => - buf.writeShort(row.getShort(ordinal)) + TypeFuncs( + () => new ArrowType.Int(4 * ShortType.defaultSize, true), // TODO - check on this + (buf: ArrowBuf) => buf.writeShort(0), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeShort(row.getShort(ordinal))) case IntegerType => - buf.writeInt(row.getInt(ordinal)) + TypeFuncs( + () => new ArrowType.Int(8 * IntegerType.defaultSize, true), + (buf: ArrowBuf) => buf.writeInt(0), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeInt(row.getInt(ordinal))) case LongType => - buf.writeLong(row.getLong(ordinal)) + TypeFuncs( + () => new ArrowType.Int(8 * LongType.defaultSize, true), + (buf: ArrowBuf) => buf.writeLong(0L), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeLong(row.getLong(ordinal))) case FloatType => - buf.writeFloat(row.getFloat(ordinal)) + TypeFuncs( + () => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), + (buf: ArrowBuf) => buf.writeFloat(0f), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeFloat(row.getFloat(ordinal))) case DoubleType => - buf.writeDouble(row.getDouble(ordinal)) + TypeFuncs( + () => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), + (buf: ArrowBuf) => buf.writeDouble(0d), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => + buf.writeDouble(row.getDouble(ordinal))) case ByteType => - buf.writeByte(row.getByte(ordinal)) + TypeFuncs( + () => new ArrowType.Int(8, false), + (buf: ArrowBuf) => buf.writeByte(0), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeByte(row.getByte(ordinal))) + case StringType => + TypeFuncs( + () => ArrowType.Utf8.INSTANCE, + (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => + throw new UnsupportedOperationException(err)) + case StructType(_) => + TypeFuncs( + () => ArrowType.Struct.INSTANCE, + (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => + throw new UnsupportedOperationException(err)) case _ => - throw new UnsupportedOperationException( - s"Unsupported data type ${dataType.simpleString}") + throw new IllegalArgumentException(err) } } @@ -130,6 +134,7 @@ object Arrow { validityMutator.setValueCount(numOfRows) val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) + val typeFunc = getTypeFuncs(field.dataType) var nullCount = 0 var index = 0 while (index < rows.length) { @@ -137,10 +142,10 @@ object Arrow { if (row.isNullAt(ordinal)) { nullCount += 1 validityMutator.set(index, 0) - fillArrow(buf, field.dataType) + typeFunc.fill(buf) } else { validityMutator.set(index, 1) - getAndSetToArrow(row, buf, field.dataType, ordinal) + typeFunc.write(row, ordinal, buf) } index += 1 } @@ -182,7 +187,7 @@ object Arrow { } private[sql] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map(sparkFieldToArrowField(_)) + val arrowFields = schema.fields.map(sparkFieldToArrowField) new Schema(arrowFields.toList.asJava) } @@ -193,36 +198,10 @@ object Arrow { dataType match { case StructType(fields) => - val childrenFields = fields.map(sparkFieldToArrowField(_)).toList.asJava + val childrenFields = fields.map(sparkFieldToArrowField).toList.asJava new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields) case _ => - new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) - } - } - - /** - * Transform Spark DataType to Arrow ArrowType. - */ - private[sql] def dataTypeToArrowType(dt: DataType): ArrowType = { - dt match { - case IntegerType => - new ArrowType.Int(8 * IntegerType.defaultSize, true) - case LongType => - new ArrowType.Int(8 * LongType.defaultSize, true) - case StringType => - ArrowType.Utf8.INSTANCE - case DoubleType => - new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case FloatType => - new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case BooleanType => - ArrowType.Bool.INSTANCE - case ByteType => - new ArrowType.Int(8, false) - case StructType(_) => - ArrowType.Struct.INSTANCE - case _ => - throw new IllegalArgumentException(s"Unsupported data type") + new Field(name, nullable, getTypeFuncs(dataType).getType(), List.empty[Field].asJava) } } } diff --git a/sql/core/src/test/resources/test-data/arrowNullInts.json b/sql/core/src/test/resources/test-data/arrow/null-ints.json similarity index 100% rename from sql/core/src/test/resources/test-data/arrowNullInts.json rename to sql/core/src/test/resources/test-data/arrow/null-ints.json diff --git a/sql/core/src/test/resources/test-data/arrowNullStrings.json b/sql/core/src/test/resources/test-data/arrow/null-strings.json similarity index 100% rename from sql/core/src/test/resources/test-data/arrowNullStrings.json rename to sql/core/src/test/resources/test-data/arrow/null-strings.json diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 036a367cb0db..ff65f13151ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -32,14 +32,15 @@ class ArrowSuite extends SharedSQLContext { } test("convert int column with null to arrow") { - testCollect(nullInts, "test-data/arrowNullInts.json") + testCollect(nullInts, "test-data/arrow/null-ints.json") } test("convert string column with null to arrow") { val nullStringsColOnly = nullStrings.select(nullStrings.columns(1)) - testCollect(nullStringsColOnly, "test-data/arrowNullStrings.json") + testCollect(nullStringsColOnly, "test-data/arrow/null-strings.json") } + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def testCollect(df: DataFrame, arrowFile: String) { val jsonFilePath = testFile(arrowFile) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala deleted file mode 100644 index 8aec3699c9dd..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.io._ -import java.net.{InetAddress, Socket} -import java.nio.{ByteBuffer, ByteOrder} -import java.nio.channels.FileChannel - -import scala.util.Random - -import io.netty.buffer.ArrowBuf -import org.apache.arrow.flatbuf.Precision -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.file.ArrowReader -import org.apache.arrow.vector.types.pojo.{ArrowType, Field} - -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils - - -case class ArrowTestClass(col1: Int, col2: Double, col3: String) - -class DatasetToArrowSuite extends QueryTest with SharedSQLContext { - - import testImplicits._ - - final val numElements = 4 - @transient var data: Seq[ArrowTestClass] = _ - - override def beforeAll(): Unit = { - super.beforeAll() - data = Seq.fill(numElements)(ArrowTestClass( - Random.nextInt, Random.nextDouble, Random.nextString(Random.nextInt(100)))) - } - - test("Collect as arrow to python") { - val dataset = data.toDS() - - val port = dataset.collectAsArrowToPython() - - val receiver: RecordBatchReceiver = new RecordBatchReceiver - val (buffer, numBytesRead) = receiver.connectAndRead(port) - val channel = receiver.makeFile(buffer) - val reader = new ArrowReader(channel, receiver.allocator) - - val footer = reader.readFooter() - val schema = footer.getSchema - - val numCols = schema.getFields.size() - assert(numCols === dataset.schema.fields.length) - for (i <- 0 until schema.getFields.size()) { - val arrowField = schema.getFields.get(i) - val sparkField = dataset.schema.fields(i) - assert(arrowField.getName === sparkField.name) - assert(arrowField.isNullable === sparkField.nullable) - assert(DatasetToArrowSuite.compareSchemaTypes(arrowField, sparkField)) - } - - val blockMetadata = footer.getRecordBatches - assert(blockMetadata.size() === 1) - - val recordBatch = reader.readRecordBatch(blockMetadata.get(0)) - val nodes = recordBatch.getNodes - assert(nodes.size() === numCols + 1) // +1 for Type String, which has two nodes. - - val firstNode = nodes.get(0) - assert(firstNode.getLength === numElements) - assert(firstNode.getNullCount === 0) - - val buffers = recordBatch.getBuffers - assert(buffers.size() === (numCols + 1) * 2) // +1 for Type String - - assert(receiver.getIntArray(buffers.get(1)) === data.map(_.col1)) - assert(receiver.getDoubleArray(buffers.get(3)) === data.map(_.col2)) - assert(receiver.getStringArray(buffers.get(5), buffers.get(7)) === - data.map(d => UTF8String.fromString(d.col3)).toArray) - } -} - -object DatasetToArrowSuite { - def compareSchemaTypes(arrowField: Field, sparkField: StructField): Boolean = { - val arrowType = arrowField.getType - val sparkType = sparkField.dataType - (arrowType, sparkType) match { - case (_: ArrowType.Int, _: IntegerType) => true - case (_: ArrowType.FloatingPoint, _: DoubleType) => - arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.DOUBLE - case (_: ArrowType.FloatingPoint, _: FloatType) => - arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.SINGLE - case (_: ArrowType.List, _: StringType) => - val subField = arrowField.getChildren - (subField.size() == 1) && subField.get(0).getType.isInstanceOf[ArrowType.Utf8] - case (_: ArrowType.Bool, _: BooleanType) => true - case _ => false - } - } -} - -class RecordBatchReceiver { - - val allocator = new RootAllocator(Long.MaxValue) - - def getIntArray(buf: ArrowBuf): Array[Int] = { - val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer() - val resultArray = Array.ofDim[Int](buffer.remaining()) - buffer.get(resultArray) - resultArray - } - - def getDoubleArray(buf: ArrowBuf): Array[Double] = { - val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer() - val resultArray = Array.ofDim[Double](buffer.remaining()) - buffer.get(resultArray) - resultArray - } - - def getStringArray(bufOffsets: ArrowBuf, bufValues: ArrowBuf): Array[UTF8String] = { - val offsets = getIntArray(bufOffsets) - val lens = offsets.zip(offsets.drop(1)) - .map { case (prevOffset, offset) => offset - prevOffset } - - val values = array(bufValues) - val strings = offsets.zip(lens).map { case (offset, len) => - UTF8String.fromBytes(values, offset, len) - } - strings - } - - private def array(buf: ArrowBuf): Array[Byte] = { - val bytes = Array.ofDim[Byte](buf.readableBytes()) - buf.readBytes(bytes) - bytes - } - - def connectAndRead(port: Int): (Array[Byte], Int) = { - val clientSocket = new Socket(InetAddress.getByName("localhost"), port) - val clientDataIns = new DataInputStream(clientSocket.getInputStream) - val messageLength = clientDataIns.readInt() - val buffer = Array.ofDim[Byte](messageLength) - clientDataIns.readFully(buffer, 0, messageLength) - (buffer, messageLength) - } - - def makeFile(buffer: Array[Byte]): FileChannel = { - val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName).getPath - val arrowFile = new File(tempDir, "arrow-bytes") - val arrowOus = new FileOutputStream(arrowFile.getPath) - arrowOus.write(buffer) - arrowOus.close() - - val arrowIns = new FileInputStream(arrowFile.getPath) - arrowIns.getChannel - } -} From 5837b38e5f7a3d8e31c6f06e64c1c7139d40a46a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 13 Jan 2017 16:16:05 -0800 Subject: [PATCH 06/56] added some python unit tests added more conversion tests short type should have a bit-width of 16 closes #17 --- python/pyspark/sql/tests.py | 42 ++-- .../scala/org/apache/spark/sql/Arrow.scala | 4 +- .../arrow/decimalData-BigDecimal.json | 50 +++++ .../doubleData-double_precision-nullable.json | 68 ++++++ .../floatData-single_precision-nullable.json | 68 ++++++ .../test-data/arrow/indexData-ints.json | 32 +++ .../arrow/intData-32bit_ints-nullable.json | 68 ++++++ .../test-data/arrow/largeAndSmall-ints.json | 50 +++++ .../arrow/longData-64bit_ints-nullable.json | 68 ++++++ .../test-data/arrow/lowercase-strings.json | 52 +++++ .../arrow/mixedData-standard-nullable.json | 212 ++++++++++++++++++ .../test-data/arrow/null-ints-mixed.json | 50 +++++ .../test-data/arrow/salary-doubles.json | 50 +++++ .../arrow/shortData-16bit_ints-nullable.json | 68 ++++++ .../test-data/arrow/testData2-ints.json | 50 +++++ .../test-data/arrow/uppercase-strings.json | 52 +++++ .../org/apache/spark/sql/ArrowSuite.scala | 151 ++++++++++++- 17 files changed, 1117 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/arrow/decimalData-BigDecimal.json create mode 100644 sql/core/src/test/resources/test-data/arrow/doubleData-double_precision-nullable.json create mode 100644 sql/core/src/test/resources/test-data/arrow/floatData-single_precision-nullable.json create mode 100644 sql/core/src/test/resources/test-data/arrow/indexData-ints.json create mode 100644 sql/core/src/test/resources/test-data/arrow/intData-32bit_ints-nullable.json create mode 100644 sql/core/src/test/resources/test-data/arrow/largeAndSmall-ints.json create mode 100644 sql/core/src/test/resources/test-data/arrow/longData-64bit_ints-nullable.json create mode 100644 sql/core/src/test/resources/test-data/arrow/lowercase-strings.json create mode 100644 sql/core/src/test/resources/test-data/arrow/mixedData-standard-nullable.json create mode 100644 sql/core/src/test/resources/test-data/arrow/null-ints-mixed.json create mode 100644 sql/core/src/test/resources/test-data/arrow/salary-doubles.json create mode 100644 sql/core/src/test/resources/test-data/arrow/shortData-16bit_ints-nullable.json create mode 100644 sql/core/src/test/resources/test-data/arrow/testData2-ints.json create mode 100644 sql/core/src/test/resources/test-data/arrow/uppercase-strings.json diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 69984ff0e9a7..575dc3f20078 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2354,6 +2354,15 @@ class ArrowTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.spark = SparkSession(cls.sc) + cls.schema = StructType([ + StructField("str_t", StringType(), True), + StructField("int_t", IntegerType(), True), + StructField("long_t", LongType(), True), + StructField("float_t", FloatType(), True), + StructField("double_t", DoubleType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0), + ("b", 2, 20, 0.4, 4.0), + ("c", 3, 30, 0.8, 6.0)] def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -2361,20 +2370,27 @@ def assertFramesEqual(self, df_with_arrow, df_without): ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - def test_arrow_toPandas(self): - schema = StructType([ - StructField("str_t", StringType(), True), # Fails in conversion - StructField("int_t", IntegerType(), True), # Fails, without is converted to int64 - StructField("long_t", LongType(), True), # Fails if nullable=False - StructField("double_t", DoubleType(), True)]) - data = [("a", 1, 10, 2.0), - ("b", 2, 20, 4.0), - ("c", 3, 30, 6.0)] + def test_null_conversion(self): + df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + + self.data) + pdf = df_null.toPandas(useArrow=True) + null_counts = pdf.isnull().sum().tolist() + self.assertTrue(all([c == 1 for c in null_counts])) + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + # NOTE - toPandas(useArrow=False) will infer standard data types + df_sel = df.select("str_t", "long_t", "double_t") + pdf = df_sel.toPandas(useArrow=False) + pdf_arrow = df_sel.toPandas(useArrow=True) + self.assertFramesEqual(pdf_arrow, pdf) - df = self.spark.createDataFrame(data, schema=schema) - df = df.select("long_t", "double_t") - pdf = df.toPandas(useArrow=False) - pdf_arrow = df.toPandas(useArrow=True) + def test_pandas_round_trip(self): + import pandas as pd + data_dict = {name: [self.data[i][j] for i in range(len(self.data))] + for j, name in enumerate(self.schema.names)} + pdf = pd.DataFrame(data=data_dict) + pdf_arrow = self.spark.createDataFrame(pdf).toPandas(useArrow=True) self.assertFramesEqual(pdf_arrow, pdf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index 1b68a9d0429c..30e79c2e3c6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -53,7 +53,7 @@ object Arrow { buf.writeBoolean(row.getBoolean(ordinal))) case ShortType => TypeFuncs( - () => new ArrowType.Int(4 * ShortType.defaultSize, true), // TODO - check on this + () => new ArrowType.Int(8 * ShortType.defaultSize, true), (buf: ArrowBuf) => buf.writeShort(0), (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeShort(row.getShort(ordinal))) case IntegerType => @@ -127,7 +127,7 @@ object Arrow { val numOfRows = rows.length field.dataType match { - case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => + case ShortType | IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => val validityVector = new BitVector("validity", allocator) val validityMutator = validityVector.getMutator validityVector.allocateNew(numOfRows) diff --git a/sql/core/src/test/resources/test-data/arrow/decimalData-BigDecimal.json b/sql/core/src/test/resources/test-data/arrow/decimalData-BigDecimal.json new file mode 100644 index 000000000000..8449acaab23d --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/decimalData-BigDecimal.json @@ -0,0 +1,50 @@ +{ + "schema": { + "fields": [ + { + "name": "a", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "a", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 1, 2, 2, 3, 3] + }, + { + "name": "b", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 1, 2, 1, 2] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/doubleData-double_precision-nullable.json b/sql/core/src/test/resources/test-data/arrow/doubleData-double_precision-nullable.json new file mode 100644 index 000000000000..d29b9ed6a2c9 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/doubleData-double_precision-nullable.json @@ -0,0 +1,68 @@ +{ + "schema": { + "fields": [ + { + "name": "i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "a_d", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_d", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "i", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 3, 4, 5, 6] + }, + { + "name": "a_d", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0] + }, + { + "name": "b_d", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1.1, 0, 0, 2.2, 0, 3.3] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/floatData-single_precision-nullable.json b/sql/core/src/test/resources/test-data/arrow/floatData-single_precision-nullable.json new file mode 100644 index 000000000000..9d686d1367a6 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/floatData-single_precision-nullable.json @@ -0,0 +1,68 @@ +{ + "schema": { + "fields": [ + { + "name": "i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "a_f", + "type": {"name": "floatingpoint", "precision": "SINGLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_f", + "type": {"name": "floatingpoint", "precision": "SINGLE"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "i", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 3, 4, 5, 6] + }, + { + "name": "a_f", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0] + }, + { + "name": "b_f", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1.1, 0, 0, 2.2, 0, 3.3] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/indexData-ints.json b/sql/core/src/test/resources/test-data/arrow/indexData-ints.json new file mode 100644 index 000000000000..e96945d8b7ac --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/indexData-ints.json @@ -0,0 +1,32 @@ +{ + "schema": { + "fields": [ + { + "name": "i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "i", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 3, 4, 5, 6] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/intData-32bit_ints-nullable.json b/sql/core/src/test/resources/test-data/arrow/intData-32bit_ints-nullable.json new file mode 100644 index 000000000000..049b30cf4a3c --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/intData-32bit_ints-nullable.json @@ -0,0 +1,68 @@ +{ + "schema": { + "fields": [ + { + "name": "i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "a_i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "i", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 3, 4, 5, 6] + }, + { + "name": "a_i", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, -1, 2, -2, 2147483647, -2147483648] + }, + { + "name": "b_i", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1, -1, 2, -2, 2147483647, -2147483648] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/largeAndSmall-ints.json b/sql/core/src/test/resources/test-data/arrow/largeAndSmall-ints.json new file mode 100644 index 000000000000..e2f15e865626 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/largeAndSmall-ints.json @@ -0,0 +1,50 @@ +{ + "schema": { + "fields": [ + { + "name": "a", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "a", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [2147483644, 1, 2147483645, 2, 2147483646, 3] + }, + { + "name": "b", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 1, 2, 1, 2] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/longData-64bit_ints-nullable.json b/sql/core/src/test/resources/test-data/arrow/longData-64bit_ints-nullable.json new file mode 100644 index 000000000000..a6bd5f002b05 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/longData-64bit_ints-nullable.json @@ -0,0 +1,68 @@ +{ + "schema": { + "fields": [ + { + "name": "i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "a_l", + "type": {"name": "int", "isSigned": true, "bitWidth": 64}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_l", + "type": {"name": "int", "isSigned": true, "bitWidth": 64}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "i", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 3, 4, 5, 6] + }, + { + "name": "a_l", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, -1, 2, -2, 9223372036854775807, -9223372036854775808] + }, + { + "name": "b_l", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1, -1, 2, -2, 9223372036854775807, -9223372036854775808] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/lowercase-strings.json b/sql/core/src/test/resources/test-data/arrow/lowercase-strings.json new file mode 100644 index 000000000000..356c431a671e --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/lowercase-strings.json @@ -0,0 +1,52 @@ +{ + "schema": { + "fields": [ + { + "name": "n", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "l", + "type": {"name": "utf8"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "OFFSET", "typeBitWidth": 32}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + } + ] + }, + + "batches": [ + { + "count": 4, + "columns": [ + { + "name": "n", + "count": 4, + "VALIDITY": [1, 1, 1, 1], + "DATA": [1, 2, 3, 4] + }, + { + "name": "l", + "count": 4, + "VALIDITY": [1, 1, 1, 1], + "OFFSET": [0, 1, 2, 3, 4], + "DATA": ["a", "b", "c", "d"] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/mixedData-standard-nullable.json b/sql/core/src/test/resources/test-data/arrow/mixedData-standard-nullable.json new file mode 100644 index 000000000000..2d7921001eb7 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/mixedData-standard-nullable.json @@ -0,0 +1,212 @@ +{ + "schema": { + "fields": [ + { + "name": "i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "a_s", + "type": {"name": "int", "isSigned": true, "bitWidth": 16}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_s", + "type": {"name": "int", "isSigned": true, "bitWidth": 16}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "a_i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "a_l", + "type": {"name": "int", "isSigned": true, "bitWidth": 64}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_l", + "type": {"name": "int", "isSigned": true, "bitWidth": 64}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "a_f", + "type": {"name": "floatingpoint", "precision": "SINGLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_f", + "type": {"name": "floatingpoint", "precision": "SINGLE"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "a_d", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_d", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "i", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 3, 4, 5, 6] + }, + { + "name": "a_s", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, -1, 2, -2, 32767, -32768] + }, + { + "name": "b_s", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1, -1, 2, -2, 32767, -32768] + }, + { + "name": "a_i", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, -1, 2, -2, 2147483647, -2147483648] + }, + { + "name": "b_i", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1, -1, 2, -2, 2147483647, -2147483648] + }, + { + "name": "a_l", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, -1, 2, -2, 9223372036854775807, -9223372036854775808] + }, + { + "name": "b_l", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1, -1, 2, -2, 9223372036854775807, -9223372036854775808] + }, + { + "name": "a_f", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0] + }, + { + "name": "b_f", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1.1, 0, 0, 2.2, 0, 3.3] + }, + { + "name": "a_d", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0] + }, + { + "name": "b_d", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1.1, 0, 0, 2.2, 0, 3.3] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/null-ints-mixed.json b/sql/core/src/test/resources/test-data/arrow/null-ints-mixed.json new file mode 100644 index 000000000000..a82ba623f539 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/null-ints-mixed.json @@ -0,0 +1,50 @@ +{ + "schema": { + "fields": [ + { + "name": "a", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 2, + "columns": [ + { + "name": "a", + "count": 2, + "VALIDITY": [1, 1], + "DATA": [1, 2] + }, + { + "name": "b", + "count": 2, + "VALIDITY": [0, 1], + "DATA": [0, 2] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/salary-doubles.json b/sql/core/src/test/resources/test-data/arrow/salary-doubles.json new file mode 100644 index 000000000000..2cc42182a56c --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/salary-doubles.json @@ -0,0 +1,50 @@ +{ + "schema": { + "fields": [ + { + "name": "personId", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "salary", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 2, + "columns": [ + { + "name": "personId", + "count": 2, + "VALIDITY": [1, 1], + "DATA": [0, 1] + }, + { + "name": "salary", + "count": 2, + "VALIDITY": [1, 1], + "DATA": [2000.0, 1000.0] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/shortData-16bit_ints-nullable.json b/sql/core/src/test/resources/test-data/arrow/shortData-16bit_ints-nullable.json new file mode 100644 index 000000000000..ca04de5b0ea3 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/shortData-16bit_ints-nullable.json @@ -0,0 +1,68 @@ +{ + "schema": { + "fields": [ + { + "name": "i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "a_s", + "type": {"name": "int", "isSigned": true, "bitWidth": 16}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b_s", + "type": {"name": "int", "isSigned": true, "bitWidth": 16}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "i", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 3, 4, 5, 6] + }, + { + "name": "a_s", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, -1, 2, -2, 32767, -32768] + }, + { + "name": "b_s", + "count": 6, + "VALIDITY": [1, 0, 0, 1, 0, 1], + "DATA": [1, -1, 2, -2, 32767, -32768] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/testData2-ints.json b/sql/core/src/test/resources/test-data/arrow/testData2-ints.json new file mode 100644 index 000000000000..6edc2a030287 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/testData2-ints.json @@ -0,0 +1,50 @@ +{ + "schema": { + "fields": [ + { + "name": "a", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "a", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 1, 2, 2, 3, 3] + }, + { + "name": "b", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 1, 2, 1, 2] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/uppercase-strings.json b/sql/core/src/test/resources/test-data/arrow/uppercase-strings.json new file mode 100644 index 000000000000..b6016022e314 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/uppercase-strings.json @@ -0,0 +1,52 @@ +{ + "schema": { + "fields": [ + { + "name": "N", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "L", + "type": {"name": "utf8"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "OFFSET", "typeBitWidth": 32}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "N", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "DATA": [1, 2, 3, 4, 5, 6] + }, + { + "name": "L", + "count": 6, + "VALIDITY": [1, 1, 1, 1, 1, 1], + "OFFSET": [0, 1, 2, 3, 4, 5, 6], + "DATA": ["A", "B", "C", "D", "E", "F"] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index ff65f13151ad..7b5231824b2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -25,23 +25,116 @@ import org.apache.arrow.vector.util.Validator import org.apache.spark.sql.test.SharedSQLContext + +// NOTE - nullable type can be declared as Option[*] or java.lang.* +private[sql] case class ShortData(i: Int, a_s: Short, b_s: Option[Short]) +private[sql] case class IntData(i: Int, a_i: Int, b_i: Option[Int]) +private[sql] case class LongData(i: Int, a_l: Long, b_l: java.lang.Long) +private[sql] case class FloatData(i: Int, a_f: Float, b_f: Option[Float]) +private[sql] case class DoubleData(i: Int, a_d: Double, b_d: Option[Double]) + + class ArrowSuite extends SharedSQLContext { + import testImplicits._ private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).getFile } + test("collect to arrow record batch") { + val arrowRecordBatch = indexData.collectAsArrow() + assert(arrowRecordBatch.getLength > 0) + assert(arrowRecordBatch.getNodes.size() > 0) + arrowRecordBatch.close() + } + + test("standard type conversion") { + collectAndValidate(indexData, "test-data/arrow/indexData-ints.json") + collectAndValidate(largeAndSmallInts, "test-data/arrow/largeAndSmall-ints.json") + collectAndValidate(salary, "test-data/arrow/salary-doubles.json") + } + + test("standard type nullable conversion") { + collectAndValidate(shortData, "test-data/arrow/shortData-16bit_ints-nullable.json") + collectAndValidate(intData, "test-data/arrow/intData-32bit_ints-nullable.json") + collectAndValidate(longData, "test-data/arrow/longData-64bit_ints-nullable.json") + collectAndValidate(floatData, "test-data/arrow/floatData-single_precision-nullable.json") + collectAndValidate(doubleData, "test-data/arrow/doubleData-double_precision-nullable.json") + } + + test("mixed standard type nullable conversion") { + val mixedData = shortData.join(intData, "i").join(longData, "i").join(floatData, "i") + .join(doubleData, "i").sort("i") + collectAndValidate(mixedData, "test-data/arrow/mixedData-standard-nullable.json") + } + + test("partitioned DataFrame") { + collectAndValidate(testData2, "test-data/arrow/testData2-ints.json") + } + + test("string type conversion") { + collectAndValidate(upperCaseData, "test-data/arrow/uppercase-strings.json") + collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json") + } + + test("time and date conversion") { } + + test("nested type conversion") { } + + test("array type conversion") { } + + test("mapped type conversion") { } + + test("other type conversion") { + // half-precision + // byte type, or binary + // allNulls + } + + test("floating-point NaN") { } + + // Arrow currently supports single or double precision + ignore("arbitrary precision floating point") { + collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json") + } + + test("other null conversion") { } + test("convert int column with null to arrow") { - testCollect(nullInts, "test-data/arrow/null-ints.json") + collectAndValidate(nullInts, "test-data/arrow/null-ints.json") + collectAndValidate(testData3, "test-data/arrow/null-ints-mixed.json") } test("convert string column with null to arrow") { val nullStringsColOnly = nullStrings.select(nullStrings.columns(1)) - testCollect(nullStringsColOnly, "test-data/arrow/null-strings.json") + collectAndValidate(nullStringsColOnly, "test-data/arrow/null-strings.json") + } + + test("empty frame collect") { + val emptyBatch = spark.emptyDataFrame.collectAsArrow() + assert(emptyBatch.getLength == 0) + } + + test("negative tests") { + + // Missing test file + intercept[NullPointerException] { + collectAndValidate(indexData, "test-data/arrow/missing-file") + } + + // Different schema + intercept[IllegalArgumentException] { + collectAndValidate(shortData, "test-data/arrow/intData-32bit_ints-nullable.json") + } + + // Different values + intercept[IllegalArgumentException] { + collectAndValidate(indexData.sort($"i".desc), "test-data/arrow/indexData-ints.json") + } } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def testCollect(df: DataFrame, arrowFile: String) { + private def collectAndValidate(df: DataFrame, arrowFile: String) { val jsonFilePath = testFile(arrowFile) val allocator = new RootAllocator(Integer.MAX_VALUE) @@ -59,4 +152,56 @@ class ArrowSuite extends SharedSQLContext { Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) } + + protected lazy val indexData = Seq(1, 2, 3, 4, 5, 6).toDF("i") + + protected lazy val shortData: DataFrame = { + spark.sparkContext.parallelize( + ShortData(1, 1, Some(1)) :: + ShortData(2, -1, None) :: + ShortData(3, 2, None) :: + ShortData(4, -2, Some(-2)) :: + ShortData(5, 32767, None) :: + ShortData(6, -32768, Some(-32768)) :: Nil).toDF() + } + + protected lazy val intData: DataFrame = { + spark.sparkContext.parallelize( + IntData(1, 1, Some(1)) :: + IntData(2, -1, None) :: + IntData(3, 2, None) :: + IntData(4, -2, Some(-2)) :: + IntData(5, 2147483647, None) :: + IntData(6, -2147483648, Some(-2147483648)) :: Nil).toDF() + } + + protected lazy val longData: DataFrame = { + spark.sparkContext.parallelize( + LongData(1, 1L, 1L) :: + LongData(2, -1L, null) :: + LongData(3, 2L, null) :: + LongData(4, -2, -2L) :: + LongData(5, 9223372036854775807L, null) :: + LongData(6, -9223372036854775808L, -9223372036854775808L) :: Nil).toDF() + } + + protected lazy val floatData: DataFrame = { + spark.sparkContext.parallelize( + FloatData(1, 1.0f, Some(1.1f)) :: + FloatData(2, 2.0f, None) :: + FloatData(3, 0.01f, None) :: + FloatData(4, 200.0f, Some(2.2f)) :: + FloatData(5, 0.0001f, None) :: + FloatData(6, 20000.0f, Some(3.3f)) :: Nil).toDF() + } + + protected lazy val doubleData: DataFrame = { + spark.sparkContext.parallelize( + DoubleData(1, 1.0, Some(1.1)) :: + DoubleData(2, 2.0, None) :: + DoubleData(3, 0.01, None) :: + DoubleData(4, 200.0, Some(2.2)) :: + DoubleData(5, 0.0001, None) :: + DoubleData(6, 20000.0, Some(3.3)) :: Nil).toDF() + } } From bdba357eef31b3225ea6a565e7841d3459822e97 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 17 Jan 2017 16:47:07 -0500 Subject: [PATCH 07/56] Implement Arrow column writers Move column writers to Arrow.scala Add support for more types; Switch to arrow NullableVector closes #16 --- .../scala/org/apache/spark/sql/Arrow.scala | 310 ++++++++++-------- 1 file changed, 180 insertions(+), 130 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index 30e79c2e3c6a..7100a8f03515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -21,8 +21,9 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import io.netty.buffer.ArrowBuf -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.BitVector +import org.apache.arrow.memory.{BaseAllocator, RootAllocator} +import org.apache.arrow.vector._ +import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} import org.apache.arrow.vector.types.FloatingPointPrecision import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} @@ -32,70 +33,17 @@ import org.apache.spark.sql.types._ object Arrow { - private case class TypeFuncs(getType: () => ArrowType, - fill: ArrowBuf => Unit, - write: (InternalRow, Int, ArrowBuf) => Unit) - - private def getTypeFuncs(dataType: DataType): TypeFuncs = { - val err = s"Unsupported data type ${dataType.simpleString}" - + private def sparkTypeToArrowType(dataType: DataType): ArrowType = { dataType match { - case NullType => - TypeFuncs( - () => ArrowType.Null.INSTANCE, - (buf: ArrowBuf) => (), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => ()) - case BooleanType => - TypeFuncs( - () => ArrowType.Bool.INSTANCE, - (buf: ArrowBuf) => buf.writeBoolean(false), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => - buf.writeBoolean(row.getBoolean(ordinal))) - case ShortType => - TypeFuncs( - () => new ArrowType.Int(8 * ShortType.defaultSize, true), - (buf: ArrowBuf) => buf.writeShort(0), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeShort(row.getShort(ordinal))) - case IntegerType => - TypeFuncs( - () => new ArrowType.Int(8 * IntegerType.defaultSize, true), - (buf: ArrowBuf) => buf.writeInt(0), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeInt(row.getInt(ordinal))) - case LongType => - TypeFuncs( - () => new ArrowType.Int(8 * LongType.defaultSize, true), - (buf: ArrowBuf) => buf.writeLong(0L), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeLong(row.getLong(ordinal))) - case FloatType => - TypeFuncs( - () => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), - (buf: ArrowBuf) => buf.writeFloat(0f), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeFloat(row.getFloat(ordinal))) - case DoubleType => - TypeFuncs( - () => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), - (buf: ArrowBuf) => buf.writeDouble(0d), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => - buf.writeDouble(row.getDouble(ordinal))) - case ByteType => - TypeFuncs( - () => new ArrowType.Int(8, false), - (buf: ArrowBuf) => buf.writeByte(0), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeByte(row.getByte(ordinal))) - case StringType => - TypeFuncs( - () => ArrowType.Utf8.INSTANCE, - (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => - throw new UnsupportedOperationException(err)) - case StructType(_) => - TypeFuncs( - () => ArrowType.Struct.INSTANCE, - (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => - throw new UnsupportedOperationException(err)) - case _ => - throw new IllegalArgumentException(err) + case BooleanType => ArrowType.Bool.INSTANCE + case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) + case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) + case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case ByteType => new ArrowType.Int(8, false) + case StringType => ArrowType.Utf8.INSTANCE + case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") } } @@ -110,8 +58,8 @@ object Arrow { internalRowToArrowBuf(rows, ordinal, field, allocator) } - val buffers = bufAndField.flatMap(_._1).toList.asJava - val fieldNodes = bufAndField.flatMap(_._2).toList.asJava + val fieldNodes = bufAndField.flatMap(_._1).toList.asJava + val buffers = bufAndField.flatMap(_._2).toList.asJava new ArrowRecordBatch(rows.length, fieldNodes, buffers) } @@ -123,67 +71,24 @@ object Arrow { rows: Array[InternalRow], ordinal: Int, field: StructField, - allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = { + allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = { val numOfRows = rows.length + val columnWriter = ColumnWriter(allocator, field.dataType) + columnWriter.init(numOfRows) + var index = 0 - field.dataType match { - case ShortType | IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => - val validityVector = new BitVector("validity", allocator) - val validityMutator = validityVector.getMutator - validityVector.allocateNew(numOfRows) - validityMutator.setValueCount(numOfRows) - - val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) - val typeFunc = getTypeFuncs(field.dataType) - var nullCount = 0 - var index = 0 - while (index < rows.length) { - val row = rows(index) - if (row.isNullAt(ordinal)) { - nullCount += 1 - validityMutator.set(index, 0) - typeFunc.fill(buf) - } else { - validityMutator.set(index, 1) - typeFunc.write(row, ordinal, buf) - } - index += 1 - } - - val fieldNode = new ArrowFieldNode(numOfRows, nullCount) - - (Array(validityVector.getBuffer, buf), Array(fieldNode)) - - case StringType => - val validityVector = new BitVector("validity", allocator) - val validityMutator = validityVector.getMutator() - validityVector.allocateNew(numOfRows) - validityMutator.setValueCount(numOfRows) - - val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize) - var bytesCount = 0 - bufOffset.writeInt(bytesCount) - val bufValues = allocator.buffer(1024) - var nullCount = 0 - rows.zipWithIndex.foreach { case (row, index) => - if (row.isNullAt(ordinal)) { - nullCount += 1 - validityMutator.set(index, 0) - bufOffset.writeInt(bytesCount) - } else { - validityMutator.set(index, 1) - val bytes = row.getUTF8String(ordinal).getBytes - bytesCount += bytes.length - bufOffset.writeInt(bytesCount) - bufValues.writeBytes(bytes) - } - } - - val fieldNode = new ArrowFieldNode(numOfRows, nullCount) - - (Array(validityVector.getBuffer, bufOffset, bufValues), - Array(fieldNode)) + while(index < numOfRows) { + val row = rows(index) + if (row.isNullAt(ordinal)) { + columnWriter.writeNull() + } else { + columnWriter.write(row, ordinal) + } + index += 1 } + + val (arrowFieldNodes, arrowBufs) = columnWriter.finish() + (arrowFieldNodes.toArray, arrowBufs.toArray) } private[sql] def schemaToArrowSchema(schema: StructType): Schema = { @@ -195,13 +100,158 @@ object Arrow { val name = sparkField.name val dataType = sparkField.dataType val nullable = sparkField.nullable + new Field(name, nullable, sparkTypeToArrowType(dataType), List.empty[Field].asJava) + } +} +object ColumnWriter { + def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = { dataType match { - case StructType(fields) => - val childrenFields = fields.map(sparkFieldToArrowField).toList.asJava - new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields) - case _ => - new Field(name, nullable, getTypeFuncs(dataType).getType(), List.empty[Field].asJava) + case BooleanType => new BooleanColumnWriter(allocator) + case ShortType => new ShortColumnWriter(allocator) + case IntegerType => new IntegerColumnWriter(allocator) + case LongType => new LongColumnWriter(allocator) + case FloatType => new FloatColumnWriter(allocator) + case DoubleType => new DoubleColumnWriter(allocator) + case ByteType => new ByteColumnWriter(allocator) + case StringType => new UTF8StringColumnWriter(allocator) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") } } } + +private[sql] trait ColumnWriter { + def init(initialSize: Int): Unit + def writeNull(): Unit + def write(row: InternalRow, ordinal: Int): Unit + def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) +} + +/** + * Base class for flat arrow column writer, i.e., column without children. + */ +private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator) + extends ColumnWriter { + protected val valueVector: BaseDataValueVector + protected val valueMutator: BaseMutator + + protected var count = 0 + protected var nullCount = 0 + + protected def setNull(): Unit + protected def setValue(row: InternalRow, ordinal: Int): Unit + protected def valueBuffers(): Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag + + override def init(initialSize: Int): Unit = { + valueVector.allocateNew() + } + + override def writeNull(): Unit = { + setNull() + nullCount += 1 + count += 1 + } + + override def write(row: InternalRow, ordinal: Int): Unit = { + setValue(row, ordinal) + count += 1 + } + + override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = { + valueMutator.setValueCount(count) + val fieldNode = new ArrowFieldNode(count, nullCount) + (List(fieldNode), valueBuffers) + } +} + +private[sql] class BooleanColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + private def bool2int(b: Boolean): Int = if (b) 1 else 0 + + override protected val valueVector: NullableBitVector + = new NullableBitVector("BooleanValue", allocator) + override protected val valueMutator: NullableBitVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal))) +} + +private[sql] class ShortColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableSmallIntVector + = new NullableSmallIntVector("ShortValue", allocator) + override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getShort(ordinal)) +} + +private[sql] class IntegerColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableIntVector + = new NullableIntVector("IntValue", allocator) + override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getInt(ordinal)) +} + +private[sql] class LongColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableBigIntVector + = new NullableBigIntVector("LongValue", allocator) + override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getLong(ordinal)) +} + +private[sql] class FloatColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableFloat4Vector + = new NullableFloat4Vector("FloatValue", allocator) + override protected val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getFloat(ordinal)) +} + +private[sql] class DoubleColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableFloat8Vector + = new NullableFloat8Vector("DoubleValue", allocator) + override protected val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getDouble(ordinal)) +} + +private[sql] class ByteColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableUInt1Vector + = new NullableUInt1Vector("ByteValue", allocator) + override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getByte(ordinal)) +} + +private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableVarBinaryVector + = new NullableVarBinaryVector("UTF8StringValue", allocator) + override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit = { + val bytes = row.getUTF8String(ordinal).getBytes + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +} From d20437f37253f565a4f2647a7a5768a525678db1 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 19 Jan 2017 17:52:26 -0800 Subject: [PATCH 08/56] added bool type converstion test added test for byte data byte type should be signed closes #18 --- .../scala/org/apache/spark/sql/Arrow.scala | 2 +- .../resources/test-data/arrow/boolData.json | 32 +++++++++++++++ .../resources/test-data/arrow/byteData.json | 32 +++++++++++++++ .../org/apache/spark/sql/ArrowSuite.scala | 41 +++++++++++++++---- 4 files changed, 97 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/arrow/boolData.json create mode 100644 sql/core/src/test/resources/test-data/arrow/byteData.json diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index 7100a8f03515..d58a25fd05e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -41,7 +41,7 @@ object Arrow { case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case ByteType => new ArrowType.Int(8, false) + case ByteType => new ArrowType.Int(8, true) case StringType => ArrowType.Utf8.INSTANCE case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") } diff --git a/sql/core/src/test/resources/test-data/arrow/boolData.json b/sql/core/src/test/resources/test-data/arrow/boolData.json new file mode 100644 index 000000000000..f402e5118cef --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/boolData.json @@ -0,0 +1,32 @@ +{ + "schema": { + "fields": [ + { + "name": "a_bool", + "type": {"name": "bool"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + } + ] + }, + + "batches": [ + { + "count": 4, + "columns": [ + { + "name": "a_bool", + "count": 4, + "VALIDITY": [1, 1, 1, 1], + "DATA": [true, true, false, true] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/byteData.json b/sql/core/src/test/resources/test-data/arrow/byteData.json new file mode 100644 index 000000000000..d0a6ceb818f7 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/byteData.json @@ -0,0 +1,32 @@ +{ + "schema": { + "fields": [ + { + "name": "a_byte", + "type": {"name": "int", "isSigned": true, "bitWidth": 8}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + } + ] + }, + + "batches": [ + { + "count": 4, + "columns": [ + { + "name": "a_byte", + "count": 4, + "VALIDITY": [1, 1, 1, 1], + "DATA": [1, -1, 64, 127] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 7b5231824b2a..f51a74084a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql import java.io.File +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat +import java.util.Locale import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} @@ -62,9 +65,19 @@ class ArrowSuite extends SharedSQLContext { collectAndValidate(doubleData, "test-data/arrow/doubleData-double_precision-nullable.json") } + test("boolean type conversion") { + val boolData = Seq(true, true, false, true).toDF("a_bool") + collectAndValidate(boolData, "test-data/arrow/boolData.json") + } + + test("byte type conversion") { + val byteData = Seq(1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") + collectAndValidate(byteData, "test-data/arrow/byteData.json") + } + test("mixed standard type nullable conversion") { - val mixedData = shortData.join(intData, "i").join(longData, "i").join(floatData, "i") - .join(doubleData, "i").sort("i") + val mixedData = Seq(shortData, intData, longData, floatData, doubleData) + .reduce((a, b) => a.join(b, "i")).sort("i") collectAndValidate(mixedData, "test-data/arrow/mixedData-standard-nullable.json") } @@ -77,7 +90,16 @@ class ArrowSuite extends SharedSQLContext { collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json") } - test("time and date conversion") { } + ignore("time and date conversion") { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val d2 = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10").getTime) + val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) + .toDF("a_date", "b_string", "c_timestamp") + collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json") + } test("nested type conversion") { } @@ -93,11 +115,6 @@ class ArrowSuite extends SharedSQLContext { test("floating-point NaN") { } - // Arrow currently supports single or double precision - ignore("arbitrary precision floating point") { - collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json") - } - test("other null conversion") { } test("convert int column with null to arrow") { @@ -115,7 +132,13 @@ class ArrowSuite extends SharedSQLContext { assert(emptyBatch.getLength == 0) } - test("negative tests") { + test("unsupported types") { + intercept[UnsupportedOperationException] { + collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json") + } + } + + test("test Arrow Validator") { // Missing test file intercept[NullPointerException] { From 2e81a93735d06d6fdbecd17747d85dcbe79d23bf Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 24 Jan 2017 11:35:00 -0800 Subject: [PATCH 09/56] changed scope of some functions and minor cleanup --- .../scala/org/apache/spark/sql/Arrow.scala | 79 +++++++++---------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index d58a25fd05e2..bfd7e0c36599 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -33,7 +33,10 @@ import org.apache.spark.sql.types._ object Arrow { - private def sparkTypeToArrowType(dataType: DataType): ArrowType = { + /** + * Map a Spark Dataset type to ArrowType. + */ + private[sql] def sparkTypeToArrowType(dataType: DataType): ArrowType = { dataType match { case BooleanType => ArrowType.Bool.INSTANCE case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) @@ -50,24 +53,22 @@ object Arrow { /** * Transfer an array of InternalRow to an ArrowRecordBatch. */ - def internalRowsToArrowRecordBatch( + private[sql] def internalRowsToArrowRecordBatch( rows: Array[InternalRow], schema: StructType, allocator: RootAllocator): ArrowRecordBatch = { - val bufAndField = schema.fields.zipWithIndex.map { case (field, ordinal) => + val (fieldNodes, buffers) = schema.fields.zipWithIndex.map { case (field, ordinal) => internalRowToArrowBuf(rows, ordinal, field, allocator) - } + }.unzip - val fieldNodes = bufAndField.flatMap(_._1).toList.asJava - val buffers = bufAndField.flatMap(_._2).toList.asJava - - new ArrowRecordBatch(rows.length, fieldNodes, buffers) + new ArrowRecordBatch(rows.length, + fieldNodes.flatten.toList.asJava, buffers.flatten.toList.asJava) } /** - * Convert an array of InternalRow to an ArrowBuf. + * Write a Field from array of InternalRow to an ArrowBuf. */ - def internalRowToArrowBuf( + private def internalRowToArrowBuf( rows: Array[InternalRow], ordinal: Int, field: StructField, @@ -91,32 +92,14 @@ object Arrow { (arrowFieldNodes.toArray, arrowBufs.toArray) } + /** + * Convert a Spark Dataset schema to Arrow schema. + */ private[sql] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map(sparkFieldToArrowField) - new Schema(arrowFields.toList.asJava) - } - - private[sql] def sparkFieldToArrowField(sparkField: StructField): Field = { - val name = sparkField.name - val dataType = sparkField.dataType - val nullable = sparkField.nullable - new Field(name, nullable, sparkTypeToArrowType(dataType), List.empty[Field].asJava) - } -} - -object ColumnWriter { - def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = { - dataType match { - case BooleanType => new BooleanColumnWriter(allocator) - case ShortType => new ShortColumnWriter(allocator) - case IntegerType => new IntegerColumnWriter(allocator) - case LongType => new LongColumnWriter(allocator) - case FloatType => new FloatColumnWriter(allocator) - case DoubleType => new DoubleColumnWriter(allocator) - case ByteType => new ByteColumnWriter(allocator) - case StringType => new UTF8StringColumnWriter(allocator) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") + val arrowFields = schema.fields.map { f => + new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) } + new Schema(arrowFields.toList.asJava) } } @@ -132,15 +115,14 @@ private[sql] trait ColumnWriter { */ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator) extends ColumnWriter { - protected val valueVector: BaseDataValueVector - protected val valueMutator: BaseMutator - - protected var count = 0 - protected var nullCount = 0 + protected def valueVector: BaseDataValueVector + protected def valueMutator: BaseMutator protected def setNull(): Unit protected def setValue(row: InternalRow, ordinal: Int): Unit - protected def valueBuffers(): Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag + + protected var count = 0 + protected var nullCount = 0 override def init(initialSize: Int): Unit = { valueVector.allocateNew() @@ -160,6 +142,7 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = { valueMutator.setValueCount(count) val fieldNode = new ArrowFieldNode(count, nullCount) + val valueBuffers: Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag (List(fieldNode), valueBuffers) } } @@ -255,3 +238,19 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) valueMutator.setSafe(count, bytes, 0, bytes.length) } } + +private[sql] object ColumnWriter { + def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = { + dataType match { + case BooleanType => new BooleanColumnWriter(allocator) + case ShortType => new ShortColumnWriter(allocator) + case IntegerType => new IntegerColumnWriter(allocator) + case LongType => new LongColumnWriter(allocator) + case FloatType => new FloatColumnWriter(allocator) + case DoubleType => new DoubleColumnWriter(allocator) + case ByteType => new ByteColumnWriter(allocator) + case StringType => new UTF8StringColumnWriter(allocator) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") + } + } +} From 1ce4f2d30de03833b56e24d06c81d79322aca6ef Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 23 Jan 2017 18:06:16 -0500 Subject: [PATCH 10/56] Add support for date/timestamp/binary; Add more numbers to benchmark.py; Fix memory leaking bug closes #19 --- benchmark.py | 41 ++++++++++-- .../scala/org/apache/spark/sql/Arrow.scala | 66 +++++++++++++++++-- .../scala/org/apache/spark/sql/Dataset.scala | 2 + .../test-data/arrow/timestampData.json | 32 +++++++++ .../org/apache/spark/sql/ArrowSuite.scala | 29 +++++--- 5 files changed, 150 insertions(+), 20 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/arrow/timestampData.json diff --git a/benchmark.py b/benchmark.py index f6e7c0ae8b2b..b23f34833666 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,16 +1,45 @@ import pyspark import timeit import random +import sys from pyspark.sql import SparkSession +import numpy as np +import pandas as pd numPartition = 8 -def time(df, repeat, number): +def scala_object(jpkg, obj): + return jpkg.__getattr__(obj + "$").__getattr__("MODULE$") + +def time(spark, df, repeat, number): + print("collect as internal rows") + time = timeit.repeat(lambda: df._jdf.queryExecution().executedPlan().executeCollect(), repeat=repeat, number=number) + time_df = pd.Series(time) + print(time_df.describe()) + + print("internal rows to arrow record batch") + arrow = scala_object(spark._jvm.org.apache.spark.sql, "Arrow") + root_allocator = spark._jvm.org.apache.arrow.memory.RootAllocator(sys.maxsize) + internal_rows = df._jdf.queryExecution().executedPlan().executeCollect() + jschema = df._jdf.schema() + def internalRowsToArrowRecordBatch(): + rb = arrow.internalRowsToArrowRecordBatch(internal_rows, jschema, root_allocator) + rb.close() + + time = timeit.repeat(internalRowsToArrowRecordBatch, repeat=repeat, number=number) + root_allocator.close() + time_df = pd.Series(time) + print(time_df.describe()) + print("toPandas with arrow") - print(timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number)) + time = timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number) + time_df = pd.Series(time) + print(time_df.describe()) print("toPandas without arrow") - print(timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number)) + time = timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number) + time_df = pd.Series(time) + print(time_df.describe()) def long(): return random.randint(0, 10000) @@ -32,10 +61,10 @@ def genData(spark, size, columns): if __name__ == "__main__": spark = SparkSession.builder.appName("ArrowBenchmark").getOrCreate() - df = genData(spark, 1000 * 1000, [long, double]) + df = genData(spark, 1000 * 1000, [double]) df.cache() df.count() + df.collect() - time(df, 10, 1) - + time(spark, df, 50, 1) df.unpersist() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index bfd7e0c36599..41ee20ab34a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -25,7 +25,7 @@ import org.apache.arrow.memory.{BaseAllocator, RootAllocator} import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.spark.sql.catalyst.InternalRow @@ -46,6 +46,9 @@ object Arrow { case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) case ByteType => new ArrowType.Int(8, true) case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case DateType => ArrowType.Date.INSTANCE + case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND) case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") } } @@ -57,12 +60,17 @@ object Arrow { rows: Array[InternalRow], schema: StructType, allocator: RootAllocator): ArrowRecordBatch = { - val (fieldNodes, buffers) = schema.fields.zipWithIndex.map { case (field, ordinal) => + val fieldAndBuf = schema.fields.zipWithIndex.map { case (field, ordinal) => internalRowToArrowBuf(rows, ordinal, field, allocator) }.unzip + val fieldNodes = fieldAndBuf._1.flatten + val buffers = fieldAndBuf._2.flatten - new ArrowRecordBatch(rows.length, - fieldNodes.flatten.toList.asJava, buffers.flatten.toList.asJava) + val recordBatch = new ArrowRecordBatch(rows.length, + fieldNodes.toList.asJava, buffers.toList.asJava) + + buffers.foreach(_.release()) + recordBatch } /** @@ -107,6 +115,11 @@ private[sql] trait ColumnWriter { def init(initialSize: Int): Unit def writeNull(): Unit def write(row: InternalRow, ordinal: Int): Unit + + /** + * Clear the column writer and return the ArrowFieldNode and ArrowBuf. + * This should be called only once after all the data is written. + */ def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) } @@ -142,7 +155,7 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = { valueMutator.setValueCount(count) val fieldNode = new ArrowFieldNode(count, nullCount) - val valueBuffers: Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag + val valueBuffers: Seq[ArrowBuf] = valueVector.getBuffers(true) (List(fieldNode), valueBuffers) } } @@ -239,6 +252,44 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) } } +private[sql] class BinaryColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableVarBinaryVector + = new NullableVarBinaryVector("UTF8StringValue", allocator) + override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit = { + val bytes = row.getBinary(ordinal) + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +} + +private[sql] class DateColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableDateVector + = new NullableDateVector("DateValue", allocator) + override protected val valueMutator: NullableDateVector#Mutator = valueVector.getMutator + + override protected def setNull(): Unit = valueMutator.setNull(count) + override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000) + } +} + +private[sql] class TimeStampColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableTimeStampVector + = new NullableTimeStampVector("TimeStampValue", allocator) + override protected val valueMutator: NullableTimeStampVector#Mutator = valueVector.getMutator + + override protected def setNull(): Unit = valueMutator.setNull(count) + + override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + valueMutator.setSafe(count, row.getLong(ordinal) / 1000) + } +} + private[sql] object ColumnWriter { def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = { dataType match { @@ -250,7 +301,10 @@ private[sql] object ColumnWriter { case DoubleType => new DoubleColumnWriter(allocator) case ByteType => new ByteColumnWriter(allocator) case StringType => new UTF8StringColumnWriter(allocator) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") + case BinaryType => new BinaryColumnWriter(allocator) + case DateType => new DateColumnWriter(allocator) + case TimestampType => new TimeStampColumnWriter(allocator) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6fbdb155847d..a53890ca25e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2787,6 +2787,8 @@ class Dataset[T] private[sql]( } catch { case e: Exception => throw e + } finally { + recordBatch.close() } withNewExecutionId { diff --git a/sql/core/src/test/resources/test-data/arrow/timestampData.json b/sql/core/src/test/resources/test-data/arrow/timestampData.json new file mode 100644 index 000000000000..174c62e4a12d --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/timestampData.json @@ -0,0 +1,32 @@ +{ + "schema": { + "fields": [ + { + "name": "a_timestamp", + "type": {"name": "timestamp", "unit": "MILLISECOND"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 64} + ] + } + } + ] + }, + + "batches": [ + { + "count": 2, + "columns": [ + { + "name": "a_timestamp", + "count": 2, + "VALIDITY": [1, 1], + "DATA": [1365383415567, 1365426610789] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index f51a74084a10..13b38c8c8568 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.File import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.Locale +import java.util.{Locale, TimeZone} import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} @@ -90,17 +90,30 @@ class ArrowSuite extends SharedSQLContext { collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json") } - ignore("time and date conversion") { - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15").getTime) - val d2 = new Date(sdf.parse("2015-04-08 13:10:15").getTime) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10").getTime) + ignore("date conversion") { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime) + val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789").getTime) val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) - .toDF("a_date", "b_string", "c_timestamp") + .toDF("a_date", "b_string", "c_timestamp") collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json") } + test("timestamp conversion") { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + val dateTimeData = Seq((ts1), (ts2)).toDF("a_timestamp") + collectAndValidate(dateTimeData, "test-data/arrow/timestampData.json") + } + + // Arrow json reader doesn't support binary data + ignore("binary type conversion") { + collectAndValidate(binaryData, "test-data/arrow/binaryData.json") + } + test("nested type conversion") { } test("array type conversion") { } From ed1f0fabaec1b367b593af1a000bf40876e7816c Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 24 Jan 2017 14:20:13 -0800 Subject: [PATCH 11/56] Cleanup of changes before updating the PR for review remove unwanted changes removed benchmark.py from repository, will attach to PR instead --- benchmark.py | 70 ------------------- bin/pyspark | 2 +- python/pyspark/sql/dataframe.py | 1 - .../scala/org/apache/spark/sql/Dataset.scala | 2 +- 4 files changed, 2 insertions(+), 73 deletions(-) delete mode 100644 benchmark.py diff --git a/benchmark.py b/benchmark.py deleted file mode 100644 index b23f34833666..000000000000 --- a/benchmark.py +++ /dev/null @@ -1,70 +0,0 @@ -import pyspark -import timeit -import random -import sys -from pyspark.sql import SparkSession -import numpy as np -import pandas as pd - -numPartition = 8 - -def scala_object(jpkg, obj): - return jpkg.__getattr__(obj + "$").__getattr__("MODULE$") - -def time(spark, df, repeat, number): - print("collect as internal rows") - time = timeit.repeat(lambda: df._jdf.queryExecution().executedPlan().executeCollect(), repeat=repeat, number=number) - time_df = pd.Series(time) - print(time_df.describe()) - - print("internal rows to arrow record batch") - arrow = scala_object(spark._jvm.org.apache.spark.sql, "Arrow") - root_allocator = spark._jvm.org.apache.arrow.memory.RootAllocator(sys.maxsize) - internal_rows = df._jdf.queryExecution().executedPlan().executeCollect() - jschema = df._jdf.schema() - def internalRowsToArrowRecordBatch(): - rb = arrow.internalRowsToArrowRecordBatch(internal_rows, jschema, root_allocator) - rb.close() - - time = timeit.repeat(internalRowsToArrowRecordBatch, repeat=repeat, number=number) - root_allocator.close() - time_df = pd.Series(time) - print(time_df.describe()) - - print("toPandas with arrow") - time = timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number) - time_df = pd.Series(time) - print(time_df.describe()) - - print("toPandas without arrow") - time = timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number) - time_df = pd.Series(time) - print(time_df.describe()) - -def long(): - return random.randint(0, 10000) - -def double(): - return random.random() - -def genDataLocal(spark, size, columns): - data = [list([fn() for fn in columns]) for x in range(0, size)] - df = spark.createDataFrame(data) - return df - -def genData(spark, size, columns): - rdd = spark.sparkContext\ - .parallelize(range(0, size), numPartition)\ - .map(lambda _: [fn() for fn in columns]) - df = spark.createDataFrame(rdd) - return df - -if __name__ == "__main__": - spark = SparkSession.builder.appName("ArrowBenchmark").getOrCreate() - df = genData(spark, 1000 * 1000, [double]) - df.cache() - df.count() - df.collect() - - time(spark, df, 50, 1) - df.unpersist() diff --git a/bin/pyspark b/bin/pyspark index 8eeea7716cc9..98387c2ec5b8 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$@" + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 91ae27cab933..f9ad9cd6dfd1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -36,7 +36,6 @@ from pyspark.sql.streaming import DataStreamWriter from pyspark.sql.types import * - __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a53890ca25e4..8d77fb17a164 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2389,7 +2389,7 @@ class Dataset[T] private[sql]( } /** - * Return an iterator that contains all of [[Row]]s in this Dataset. + * Return an iterator that contains all rows in this Dataset. * * The iterator will consume as much memory as the largest partition in this Dataset. * From 202650ea6a7fb503bb375c531e1976bf480f4ed6 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 25 Jan 2017 14:02:59 -0800 Subject: [PATCH 12/56] Changed RootAllocator param to Option in collectAsArrow added more tests and cleanup closes #20 --- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../test-data/arrow/allNulls-ints.json | 32 +++++++++ .../arrow/nanData-floating_point.json | 68 +++++++++++++++++++ .../test-data/arrow/timestampData.json | 4 +- .../org/apache/spark/sql/ArrowSuite.scala | 54 +++++++-------- 5 files changed, 131 insertions(+), 31 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/arrow/allNulls-ints.json create mode 100644 sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8d77fb17a164..ce3f3f1f2e52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2373,8 +2373,8 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @DeveloperApi - def collectAsArrow( - allocator: RootAllocator = new RootAllocator(Long.MaxValue)): ArrowRecordBatch = { + def collectAsArrow(rootAllocator: Option[RootAllocator] = None): ArrowRecordBatch = { + val allocator = rootAllocator.getOrElse(new RootAllocator(Long.MaxValue)) withNewExecutionId { try { val collectedRows = queryExecution.executedPlan.executeCollect() diff --git a/sql/core/src/test/resources/test-data/arrow/allNulls-ints.json b/sql/core/src/test/resources/test-data/arrow/allNulls-ints.json new file mode 100644 index 000000000000..e12f546e461c --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/allNulls-ints.json @@ -0,0 +1,32 @@ +{ + "schema": { + "fields": [ + { + "name": "a", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 4, + "columns": [ + { + "name": "a", + "count": 4, + "VALIDITY": [0, 0, 0, 0], + "DATA": [0, 0, 0, 0] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json b/sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json new file mode 100644 index 000000000000..4a8407d45f37 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json @@ -0,0 +1,68 @@ +{ + "schema": { + "fields": [ + { + "name": "i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "NaN_f", + "type": {"name": "floatingpoint", "precision": "SINGLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "NaN_d", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 2, + "columns": [ + { + "name": "i", + "count": 2, + "VALIDITY": [1, 1], + "DATA": [1, 2] + }, + { + "name": "NaN_f", + "count": 2, + "VALIDITY": [1, 1], + "DATA": [1.2, "NaN"] + }, + { + "name": "NaN_d", + "count": 2, + "VALIDITY": [1, 1], + "DATA": ["NaN", 1.23] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/timestampData.json b/sql/core/src/test/resources/test-data/arrow/timestampData.json index 174c62e4a12d..6fe59975954d 100644 --- a/sql/core/src/test/resources/test-data/arrow/timestampData.json +++ b/sql/core/src/test/resources/test-data/arrow/timestampData.json @@ -2,7 +2,7 @@ "schema": { "fields": [ { - "name": "a_timestamp", + "name": "c_timestamp", "type": {"name": "timestamp", "unit": "MILLISECOND"}, "nullable": true, "children": [], @@ -21,7 +21,7 @@ "count": 2, "columns": [ { - "name": "a_timestamp", + "name": "c_timestamp", "count": 2, "VALIDITY": [1, 1], "DATA": [1365383415567, 1365426610789] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 13b38c8c8568..c784b3eefb74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -27,6 +27,7 @@ import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.CalendarInterval // NOTE - nullable type can be declared as Option[*] or java.lang.* @@ -88,25 +89,16 @@ class ArrowSuite extends SharedSQLContext { test("string type conversion") { collectAndValidate(upperCaseData, "test-data/arrow/uppercase-strings.json") collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json") + val nullStringsColOnly = nullStrings.select(nullStrings.columns(1)) + collectAndValidate(nullStringsColOnly, "test-data/arrow/null-strings.json") } ignore("date conversion") { - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime) - val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789").getTime) - val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) - .toDF("a_date", "b_string", "c_timestamp") collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json") } test("timestamp conversion") { - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - val dateTimeData = Seq((ts1), (ts2)).toDF("a_timestamp") - collectAndValidate(dateTimeData, "test-data/arrow/timestampData.json") + collectAndValidate(dateTimeData.select($"c_timestamp"), "test-data/arrow/timestampData.json") } // Arrow json reader doesn't support binary data @@ -120,24 +112,15 @@ class ArrowSuite extends SharedSQLContext { test("mapped type conversion") { } - test("other type conversion") { - // half-precision - // byte type, or binary - // allNulls + test("floating-point NaN") { + val nanData = Seq((1, 1.2F, Double.NaN), (2, Float.NaN, 1.23)).toDF("i", "NaN_f", "NaN_d") + collectAndValidate(nanData, "test-data/arrow/nanData-floating_point.json") } - test("floating-point NaN") { } - - test("other null conversion") { } - test("convert int column with null to arrow") { collectAndValidate(nullInts, "test-data/arrow/null-ints.json") collectAndValidate(testData3, "test-data/arrow/null-ints-mixed.json") - } - - test("convert string column with null to arrow") { - val nullStringsColOnly = nullStrings.select(nullStrings.columns(1)) - collectAndValidate(nullStringsColOnly, "test-data/arrow/null-strings.json") + collectAndValidate(allNulls, "test-data/arrow/allNulls-ints.json") } test("empty frame collect") { @@ -146,7 +129,14 @@ class ArrowSuite extends SharedSQLContext { } test("unsupported types") { - intercept[UnsupportedOperationException] { + def runUnsupported(block: => Unit): Unit = { + val msg = intercept[UnsupportedOperationException] { + block + } + assert(msg.getMessage.contains("Unsupported data type")) + } + + runUnsupported { collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json") } } @@ -180,7 +170,7 @@ class ArrowSuite extends SharedSQLContext { val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) - val arrowRecordBatch = df.collectAsArrow(allocator) + val arrowRecordBatch = df.collectAsArrow(Some(allocator)) val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator) val vectorLoader = new VectorLoader(arrowRoot) vectorLoader.load(arrowRecordBatch) @@ -240,4 +230,14 @@ class ArrowSuite extends SharedSQLContext { DoubleData(5, 0.0001, None) :: DoubleData(6, 20000.0, Some(3.3)) :: Nil).toDF() } + + protected lazy val dateTimeData: DataFrame = { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) + .toDF("a_date", "b_string", "c_timestamp") + } } From fbe3b7ce06c1322306f2e3f3db6e77ec60620189 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 26 Jan 2017 16:25:40 -0800 Subject: [PATCH 13/56] renamed to ArrowConverters defined ArrowPayload and encapsulated Arrow classes in ArrowConverters addressed some minor comments in code review closes #21 --- .../{Arrow.scala => ArrowConverters.scala} | 58 +++++++++++++++++-- .../scala/org/apache/spark/sql/Dataset.scala | 31 +++------- ...Suite.scala => ArrowConvertersSuite.scala} | 32 +++++----- 3 files changed, 77 insertions(+), 44 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{Arrow.scala => ArrowConverters.scala} (87%) rename sql/core/src/test/scala/org/apache/spark/sql/{ArrowSuite.scala => ArrowConvertersSuite.scala} (90%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala similarity index 87% rename from sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala rename to sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 41ee20ab34a8..8a7379536ab5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql +import java.io.ByteArrayOutputStream +import java.nio.channels.Channels + import scala.collection.JavaConverters._ -import scala.language.implicitConversions import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.{BaseAllocator, RootAllocator} import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator +import org.apache.arrow.vector.file.ArrowWriter import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} @@ -31,7 +34,33 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -object Arrow { +/** + * Intermediate data structure returned from Arrow conversions + */ +private[sql] abstract class ArrowPayload extends Iterator[ArrowRecordBatch] + +/** + * Class that wraps an Arrow RootAllocator used in conversion + */ +private[sql] class ArrowConverters { + private val _allocator = new RootAllocator(Long.MaxValue) + + private[sql] def allocator: RootAllocator = _allocator + + private class ArrowStaticPayload(batches: ArrowRecordBatch*) extends ArrowPayload { + private val iter = batches.iterator + + override def next(): ArrowRecordBatch = iter.next() + override def hasNext: Boolean = iter.hasNext + } + + def internalRowsToPayload(rows: Array[InternalRow], schema: StructType): ArrowPayload = { + val batch = ArrowConverters.internalRowsToArrowRecordBatch(rows, schema, allocator) + new ArrowStaticPayload(batch) + } +} + +private[sql] object ArrowConverters { /** * Map a Spark Dataset type to ArrowType. @@ -49,7 +78,7 @@ object Arrow { case BinaryType => ArrowType.Binary.INSTANCE case DateType => ArrowType.Date.INSTANCE case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } @@ -109,6 +138,25 @@ object Arrow { } new Schema(arrowFields.toList.asJava) } + + /** + * Write an ArrowPayload to a byte array + */ + private[sql] def payloadToByteArray(payload: ArrowPayload, schema: StructType): Array[Byte] = { + val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val out = new ByteArrayOutputStream() + val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) + try { + payload.foreach(writer.writeRecordBatch) + } catch { + case e: Exception => + throw e + } finally { + writer.close() + payload.foreach(_.close()) + } + out.toByteArray + } } private[sql] trait ColumnWriter { @@ -255,7 +303,7 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) private[sql] class BinaryColumnWriter(allocator: BaseAllocator) extends PrimitiveColumnWriter(allocator) { override protected val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("UTF8StringValue", allocator) + = new NullableVarBinaryVector("BinaryValue", allocator) override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -273,6 +321,7 @@ private[sql] class DateColumnWriter(allocator: BaseAllocator) override protected def setNull(): Unit = valueMutator.setNull(count) override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + // TODO: comment on diff btw value representations of date/timestamp valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000) } } @@ -286,6 +335,7 @@ private[sql] class TimeStampColumnWriter(allocator: BaseAllocator) override protected def setNull(): Unit = valueMutator.setNull(count) override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + // TODO: use microsecond timestamp when ARROW-477 is resolved valueMutator.setSafe(count, row.getLong(ordinal) / 1000) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ce3f3f1f2e52..37ba4e6feaca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import java.io.{ByteArrayOutputStream, CharArrayWriter} -import java.nio.channels.Channels +import java.io.CharArrayWriter import java.sql.{Date, Timestamp} import java.util.TimeZone @@ -27,9 +26,6 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.file.ArrowWriter -import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} @@ -2373,14 +2369,12 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @DeveloperApi - def collectAsArrow(rootAllocator: Option[RootAllocator] = None): ArrowRecordBatch = { - val allocator = rootAllocator.getOrElse(new RootAllocator(Long.MaxValue)) + def collectAsArrow(converter: Option[ArrowConverters] = None): ArrowPayload = { + val cnvtr = converter.getOrElse(new ArrowConverters) withNewExecutionId { try { val collectedRows = queryExecution.executedPlan.executeCollect() - val recordBatch = Arrow.internalRowsToArrowRecordBatch( - collectedRows, this.schema, allocator) - recordBatch + cnvtr.internalRowsToPayload(collectedRows, this.schema) } catch { case e: Exception => throw e @@ -2777,22 +2771,11 @@ class Dataset[T] private[sql]( * Collect a Dataset as an ArrowRecordBatch, and serve the ArrowRecordBatch to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - val recordBatch = collectAsArrow() - val arrowSchema = Arrow.schemaToArrowSchema(this.schema) - val out = new ByteArrayOutputStream() - try { - val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) - writer.writeRecordBatch(recordBatch) - writer.close() - } catch { - case e: Exception => - throw e - } finally { - recordBatch.close() - } + val payload = collectAsArrow() + val payloadBytes = ArrowConverters.payloadToByteArray(payload, this.schema) withNewExecutionId { - PythonRDD.serveIterator(Iterator(out.toByteArray), "serve-Arrow") + PythonRDD.serveIterator(Iterator(payloadBytes), "serve-Arrow") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala similarity index 90% rename from sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index c784b3eefb74..d4a6b6672e07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.sql import java.io.File import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{Locale, TimeZone} +import java.util.Locale -import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.types.CalendarInterval // NOTE - nullable type can be declared as Option[*] or java.lang.* @@ -38,7 +36,7 @@ private[sql] case class FloatData(i: Int, a_f: Float, b_f: Option[Float]) private[sql] case class DoubleData(i: Int, a_d: Double, b_d: Option[Double]) -class ArrowSuite extends SharedSQLContext { +class ArrowConvertersSuite extends SharedSQLContext { import testImplicits._ private def testFile(fileName: String): String = { @@ -46,10 +44,11 @@ class ArrowSuite extends SharedSQLContext { } test("collect to arrow record batch") { - val arrowRecordBatch = indexData.collectAsArrow() - assert(arrowRecordBatch.getLength > 0) - assert(arrowRecordBatch.getNodes.size() > 0) - arrowRecordBatch.close() + val arrowPayload = indexData.collectAsArrow() + assert(arrowPayload.nonEmpty) + arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getLength > 0)) + arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getNodes.size() > 0)) + arrowPayload.foreach(arrowRecordBatch => arrowRecordBatch.close()) } test("standard type conversion") { @@ -124,8 +123,9 @@ class ArrowSuite extends SharedSQLContext { } test("empty frame collect") { - val emptyBatch = spark.emptyDataFrame.collectAsArrow() - assert(emptyBatch.getLength == 0) + val arrowPayload = spark.emptyDataFrame.collectAsArrow() + assert(arrowPayload.nonEmpty) + arrowPayload.foreach(emptyBatch => assert(emptyBatch.getLength == 0)) } test("unsupported types") { @@ -163,17 +163,17 @@ class ArrowSuite extends SharedSQLContext { private def collectAndValidate(df: DataFrame, arrowFile: String) { val jsonFilePath = testFile(arrowFile) - val allocator = new RootAllocator(Integer.MAX_VALUE) - val jsonReader = new JsonFileReader(new File(jsonFilePath), allocator) + val converter = new ArrowConverters + val jsonReader = new JsonFileReader(new File(jsonFilePath), converter.allocator) - val arrowSchema = Arrow.schemaToArrowSchema(df.schema) + val arrowSchema = ArrowConverters.schemaToArrowSchema(df.schema) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) - val arrowRecordBatch = df.collectAsArrow(Some(allocator)) - val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator) + val arrowPayload = df.collectAsArrow(Some(converter)) + val arrowRoot = new VectorSchemaRoot(arrowSchema, converter.allocator) val vectorLoader = new VectorLoader(arrowRoot) - vectorLoader.load(arrowRecordBatch) + arrowPayload.foreach(vectorLoader.load) val jsonRoot = jsonReader.read() Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) From f44e6d74d728e430878586e0eec99c5ac6017e4e Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Mon, 30 Jan 2017 17:36:23 -0500 Subject: [PATCH 14/56] Adjust to cleaned up pyarrow FileReader API, support multiple record batches in a stream closes #22 --- python/pyspark/serializers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index c291786c8452..14af494c4cdd 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -192,9 +192,9 @@ def dumps(self, obj): raise NotImplementedError def loads(self, obj): - from pyarrow.ipc import ArrowFileReader - reader = ArrowFileReader(obj) - return reader.get_record_batch(0) + from pyarrow import FileReader, BufferReader + reader = FileReader(BufferReader(obj)) + return reader.read_all() def __repr__(self): return "ArrowSerializer" From e0bf11b3b7e9669974090c8bfde5fd8a2a0629c0 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 3 Feb 2017 14:53:58 -0800 Subject: [PATCH 15/56] changed conversion to use Iterator[InternalRow] instead of Array arrow conversion done at partition by executors some cleanup of APIs, made tests complete for non-complex data types closes #23 --- python/pyspark/sql/dataframe.py | 15 +- .../apache/spark/sql/ArrowConverters.scala | 309 ++++++++++-------- .../scala/org/apache/spark/sql/Dataset.scala | 41 +-- ...a2-ints.json => testData2-ints-part1.json} | 14 +- .../test-data/arrow/testData2-ints-part2.json | 50 +++ .../spark/sql/ArrowConvertersSuite.scala | 83 +++-- 6 files changed, 326 insertions(+), 186 deletions(-) rename sql/core/src/test/resources/test-data/arrow/{testData2-ints.json => testData2-ints-part1.json} (79%) create mode 100644 sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f9ad9cd6dfd1..030717fc78b7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -393,11 +393,11 @@ def collect(self): @ignore_unicode_prefix @since(2.0) def collectAsArrow(self): - """Returns all the records as an ArrowRecordBatch + """Returns all records as list of deserialized ArrowPayloads """ with SCCallSiteSync(self._sc) as css: port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer()))[0] + return list(_load_from_socket(port, ArrowSerializer())) @ignore_unicode_prefix @since(2.0) @@ -1611,6 +1611,9 @@ def toPandas(self, useArrow=False): This is only available if Pandas is installed and available. + :param useArrow: Make use of Apache Arrow for conversion, pyarrow must be installed + on the calling Python process. + .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. @@ -1619,11 +1622,13 @@ def toPandas(self, useArrow=False): 0 2 Alice 1 5 Bob """ - import pandas as pd - if useArrow: - return self.collectAsArrow().to_pandas() + from pyarrow.table import concat_tables + tables = self.collectAsArrow() + table = concat_tables(tables) + return table.to_pandas() else: + import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) ########################################################################################## diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 8a7379536ab5..47a2d966b0c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream -import java.nio.channels.Channels +import java.nio.ByteBuffer +import java.nio.channels.{SeekableByteChannel, Channels} import scala.collection.JavaConverters._ @@ -26,7 +27,7 @@ import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.{BaseAllocator, RootAllocator} import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator -import org.apache.arrow.vector.file.ArrowWriter +import org.apache.arrow.vector.file.{ArrowReader, ArrowWriter} import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} @@ -34,11 +35,65 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ + +/** + * ArrowReader requires a seekable byte channel. + * NOTE - this is taken from test org.apache.vector.file, see about moving to public util pkg + */ +private[sql] class ByteArrayReadableSeekableByteChannel(var byteArray: Array[Byte]) + extends SeekableByteChannel { + var _position: Long = 0L + + override def isOpen: Boolean = { + byteArray != null + } + + override def close(): Unit = { + byteArray = null + } + + override def read(dst: ByteBuffer): Int = { + val remainingBuf = byteArray.length - _position + val length = Math.min(dst.remaining(), remainingBuf).toInt + dst.put(byteArray, _position.toInt, length) + _position += length + length.toInt + } + + override def position(): Long = _position + + override def position(newPosition: Long): SeekableByteChannel = { + _position = newPosition.toLong + this + } + + override def size: Long = { + byteArray.length.toLong + } + + override def write(src: ByteBuffer): Int = { + throw new UnsupportedOperationException("Read Only") + } + + override def truncate(size: Long): SeekableByteChannel = { + throw new UnsupportedOperationException("Read Only") + } +} + /** * Intermediate data structure returned from Arrow conversions */ private[sql] abstract class ArrowPayload extends Iterator[ArrowRecordBatch] +/** + * Build a payload from existing ArrowRecordBatches + */ +private[sql] class ArrowStaticPayload(batches: ArrowRecordBatch*) extends ArrowPayload { + private val iter = batches.iterator + override def next(): ArrowRecordBatch = iter.next() + override def hasNext: Boolean = iter.hasNext +} + /** * Class that wraps an Arrow RootAllocator used in conversion */ @@ -47,16 +102,24 @@ private[sql] class ArrowConverters { private[sql] def allocator: RootAllocator = _allocator - private class ArrowStaticPayload(batches: ArrowRecordBatch*) extends ArrowPayload { - private val iter = batches.iterator - - override def next(): ArrowRecordBatch = iter.next() - override def hasNext: Boolean = iter.hasNext + def interalRowIterToPayload(rowIter: Iterator[InternalRow], schema: StructType): ArrowPayload = { + val batch = ArrowConverters.internalRowIterToArrowBatch(rowIter, schema, allocator) + new ArrowStaticPayload(batch) } - def internalRowsToPayload(rows: Array[InternalRow], schema: StructType): ArrowPayload = { - val batch = ArrowConverters.internalRowsToArrowRecordBatch(rows, schema, allocator) - new ArrowStaticPayload(batch) + def readPayloadByteArrays(payloadByteArrays: Array[Array[Byte]]): ArrowPayload = { + val batches = scala.collection.mutable.ArrayBuffer.empty[ArrowRecordBatch] + var i = 0 + while (i < payloadByteArrays.length) { + val payloadBytes = payloadByteArrays(i) + val in = new ByteArrayReadableSeekableByteChannel(payloadBytes) + val reader = new ArrowReader(in, _allocator) + val footer = reader.readFooter() + val batchBlocks = footer.getRecordBatches.asScala.toArray + batchBlocks.foreach(block => batches += reader.readRecordBatch(block)) + i += 1 + } + new ArrowStaticPayload(batches: _*) } } @@ -83,52 +146,43 @@ private[sql] object ArrowConverters { } /** - * Transfer an array of InternalRow to an ArrowRecordBatch. + * Iterate over InternalRows and write to an ArrowRecordBatch. */ - private[sql] def internalRowsToArrowRecordBatch( - rows: Array[InternalRow], + private def internalRowIterToArrowBatch( + rowIter: Iterator[InternalRow], schema: StructType, allocator: RootAllocator): ArrowRecordBatch = { - val fieldAndBuf = schema.fields.zipWithIndex.map { case (field, ordinal) => - internalRowToArrowBuf(rows, ordinal, field, allocator) + + val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => + ColumnWriter(ordinal, allocator, field.dataType) + .init() + } + + val writerLength = columnWriters.length + while (rowIter.hasNext) { + val row = rowIter.next() + var i = 0 + while (i < writerLength) { + columnWriters(i).write(row) + i += 1 + } + } + + val fieldAndBuf = columnWriters.map { writer => + writer.finish() }.unzip - val fieldNodes = fieldAndBuf._1.flatten + val fieldNodes = fieldAndBuf._1 val buffers = fieldAndBuf._2.flatten - val recordBatch = new ArrowRecordBatch(rows.length, + val rowLength = if(fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 + + val recordBatch = new ArrowRecordBatch(rowLength, fieldNodes.toList.asJava, buffers.toList.asJava) buffers.foreach(_.release()) recordBatch } - /** - * Write a Field from array of InternalRow to an ArrowBuf. - */ - private def internalRowToArrowBuf( - rows: Array[InternalRow], - ordinal: Int, - field: StructField, - allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = { - val numOfRows = rows.length - val columnWriter = ColumnWriter(allocator, field.dataType) - columnWriter.init(numOfRows) - var index = 0 - - while(index < numOfRows) { - val row = rows(index) - if (row.isNullAt(ordinal)) { - columnWriter.writeNull() - } else { - columnWriter.write(row, ordinal) - } - index += 1 - } - - val (arrowFieldNodes, arrowBufs) = columnWriter.finish() - (arrowFieldNodes.toArray, arrowBufs.toArray) - } - /** * Convert a Spark Dataset schema to Arrow schema. */ @@ -160,138 +214,139 @@ private[sql] object ArrowConverters { } private[sql] trait ColumnWriter { - def init(initialSize: Int): Unit - def writeNull(): Unit - def write(row: InternalRow, ordinal: Int): Unit + def init(): this.type + def write(row: InternalRow): Unit /** * Clear the column writer and return the ArrowFieldNode and ArrowBuf. * This should be called only once after all the data is written. */ - def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) + def finish(): (ArrowFieldNode, Array[ArrowBuf]) } /** * Base class for flat arrow column writer, i.e., column without children. */ -private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator) +private[sql] abstract class PrimitiveColumnWriter( + val ordinal: Int, + val allocator: BaseAllocator) extends ColumnWriter { - protected def valueVector: BaseDataValueVector - protected def valueMutator: BaseMutator + def valueVector: BaseDataValueVector + def valueMutator: BaseMutator - protected def setNull(): Unit - protected def setValue(row: InternalRow, ordinal: Int): Unit + def setNull(): Unit + def setValue(row: InternalRow, ordinal: Int): Unit protected var count = 0 protected var nullCount = 0 - override def init(initialSize: Int): Unit = { + override def init(): this.type = { valueVector.allocateNew() + this } - override def writeNull(): Unit = { - setNull() - nullCount += 1 - count += 1 - } - - override def write(row: InternalRow, ordinal: Int): Unit = { - setValue(row, ordinal) + override def write(row: InternalRow): Unit = { + if (row.isNullAt(ordinal)) { + setNull() + nullCount += 1 + } else { + setValue(row, ordinal) + } count += 1 } - override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = { + override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { valueMutator.setValueCount(count) val fieldNode = new ArrowFieldNode(count, nullCount) - val valueBuffers: Seq[ArrowBuf] = valueVector.getBuffers(true) - (List(fieldNode), valueBuffers) + val valueBuffers = valueVector.getBuffers(true) + (fieldNode, valueBuffers) } } -private[sql] class BooleanColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { +private[sql] class BooleanColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { private def bool2int(b: Boolean): Int = if (b) 1 else 0 - override protected val valueVector: NullableBitVector + override val valueVector: NullableBitVector = new NullableBitVector("BooleanValue", allocator) - override protected val valueMutator: NullableBitVector#Mutator = valueVector.getMutator + override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal))) } -private[sql] class ShortColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableSmallIntVector +private[sql] class ShortColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableSmallIntVector = new NullableSmallIntVector("ShortValue", allocator) - override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator + override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getShort(ordinal)) } -private[sql] class IntegerColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableIntVector +private[sql] class IntegerColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableIntVector = new NullableIntVector("IntValue", allocator) - override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator + override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getInt(ordinal)) } -private[sql] class LongColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableBigIntVector +private[sql] class LongColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableBigIntVector = new NullableBigIntVector("LongValue", allocator) - override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator + override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getLong(ordinal)) } -private[sql] class FloatColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableFloat4Vector +private[sql] class FloatColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableFloat4Vector = new NullableFloat4Vector("FloatValue", allocator) - override protected val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator + override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getFloat(ordinal)) } -private[sql] class DoubleColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableFloat8Vector +private[sql] class DoubleColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableFloat8Vector = new NullableFloat8Vector("DoubleValue", allocator) - override protected val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator + override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getDouble(ordinal)) } -private[sql] class ByteColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableUInt1Vector +private[sql] class ByteColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableUInt1Vector = new NullableUInt1Vector("ByteValue", allocator) - override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator + override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getByte(ordinal)) } -private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableVarBinaryVector +private[sql] class UTF8StringColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableVarBinaryVector = new NullableVarBinaryVector("UTF8StringValue", allocator) - override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = { @@ -300,11 +355,11 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) } } -private[sql] class BinaryColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableVarBinaryVector +private[sql] class BinaryColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableVarBinaryVector = new NullableVarBinaryVector("BinaryValue", allocator) - override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = { @@ -313,47 +368,45 @@ private[sql] class BinaryColumnWriter(allocator: BaseAllocator) } } -private[sql] class DateColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableDateVector +private[sql] class DateColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableDateVector = new NullableDateVector("DateValue", allocator) - override protected val valueMutator: NullableDateVector#Mutator = valueVector.getMutator + override val valueMutator: NullableDateVector#Mutator = valueVector.getMutator - override protected def setNull(): Unit = valueMutator.setNull(count) - override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit = { // TODO: comment on diff btw value representations of date/timestamp valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000) } } -private[sql] class TimeStampColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableTimeStampVector - = new NullableTimeStampVector("TimeStampValue", allocator) - override protected val valueMutator: NullableTimeStampVector#Mutator = valueVector.getMutator +private[sql] class TimeStampColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableTimeStampMicroVector + = new NullableTimeStampMicroVector("TimeStampValue", allocator) + override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator - override protected def setNull(): Unit = valueMutator.setNull(count) - - override protected def setValue(row: InternalRow, ordinal: Int): Unit = { - // TODO: use microsecond timestamp when ARROW-477 is resolved + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit = { valueMutator.setSafe(count, row.getLong(ordinal) / 1000) } } private[sql] object ColumnWriter { - def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = { + def apply(ordinal: Int, allocator: BaseAllocator, dataType: DataType): ColumnWriter = { dataType match { - case BooleanType => new BooleanColumnWriter(allocator) - case ShortType => new ShortColumnWriter(allocator) - case IntegerType => new IntegerColumnWriter(allocator) - case LongType => new LongColumnWriter(allocator) - case FloatType => new FloatColumnWriter(allocator) - case DoubleType => new DoubleColumnWriter(allocator) - case ByteType => new ByteColumnWriter(allocator) - case StringType => new UTF8StringColumnWriter(allocator) - case BinaryType => new BinaryColumnWriter(allocator) - case DateType => new DateColumnWriter(allocator) - case TimestampType => new TimeStampColumnWriter(allocator) + case BooleanType => new BooleanColumnWriter(ordinal, allocator) + case ShortType => new ShortColumnWriter(ordinal, allocator) + case IntegerType => new IntegerColumnWriter(ordinal, allocator) + case LongType => new LongColumnWriter(ordinal, allocator) + case FloatType => new FloatColumnWriter(ordinal, allocator) + case DoubleType => new DoubleColumnWriter(ordinal, allocator) + case ByteType => new ByteColumnWriter(ordinal, allocator) + case StringType => new UTF8StringColumnWriter(ordinal, allocator) + case BinaryType => new BinaryColumnWriter(ordinal, allocator) + case DateType => new DateColumnWriter(ordinal, allocator) + case TimestampType => new TimeStampColumnWriter(ordinal, allocator) case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 37ba4e6feaca..afceec6f8f74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2362,26 +2362,6 @@ class Dataset[T] private[sql]( java.util.Arrays.asList(values : _*) } - /** - * Collect a Dataset to an ArrowRecordBatch. - * - * @group action - * @since 2.2.0 - */ - @DeveloperApi - def collectAsArrow(converter: Option[ArrowConverters] = None): ArrowPayload = { - val cnvtr = converter.getOrElse(new ArrowConverters) - withNewExecutionId { - try { - val collectedRows = queryExecution.executedPlan.executeCollect() - cnvtr.internalRowsToPayload(collectedRows, this.schema) - } catch { - case e: Exception => - throw e - } - } - } - /** * Return an iterator that contains all rows in this Dataset. * @@ -2768,14 +2748,13 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as an ArrowRecordBatch, and serve the ArrowRecordBatch to PySpark. + * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - val payload = collectAsArrow() - val payloadBytes = ArrowConverters.payloadToByteArray(payload, this.schema) - + val payloadRdd = toArrowPayloadBytes() + val payloadByteArrays = payloadRdd.collect() withNewExecutionId { - PythonRDD.serveIterator(Iterator(payloadBytes), "serve-Arrow") + PythonRDD.serveIterator(payloadByteArrays.iterator, "serve-Arrow") } } @@ -2860,4 +2839,16 @@ class Dataset[T] private[sql]( Dataset(sparkSession, logicalPlan) } } + + /** Convert to an RDD of ArrowPayload byte arrays */ + private[sql] def toArrowPayloadBytes(): RDD[Array[Byte]] = { + val schema_captured = this.schema + queryExecution.toRdd.mapPartitionsInternal { iter => + val converter = new ArrowConverters + val payload = converter.interalRowIterToPayload(iter, schema_captured) + val payloadBytes = ArrowConverters.payloadToByteArray(payload, schema_captured) + payload.foreach(_.close()) + Iterator(payloadBytes) + } + } } diff --git a/sql/core/src/test/resources/test-data/arrow/testData2-ints.json b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json similarity index 79% rename from sql/core/src/test/resources/test-data/arrow/testData2-ints.json rename to sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json index 6edc2a030287..bf6f0a38a332 100644 --- a/sql/core/src/test/resources/test-data/arrow/testData2-ints.json +++ b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json @@ -30,19 +30,19 @@ "batches": [ { - "count": 6, + "count": 3, "columns": [ { "name": "a", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 1, 2, 2, 3, 3] + "count": 3, + "VALIDITY": [1, 1, 1], + "DATA": [1, 1, 2] }, { "name": "b", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 1, 2, 1, 2] + "count": 3, + "VALIDITY": [1, 1, 1], + "DATA": [1, 2, 1] } ] } diff --git a/sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json new file mode 100644 index 000000000000..5261d51ff218 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json @@ -0,0 +1,50 @@ +{ + "schema": { + "fields": [ + { + "name": "a", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 3, + "columns": [ + { + "name": "a", + "count": 3, + "VALIDITY": [1, 1, 1], + "DATA": [2, 3, 3] + }, + { + "name": "b", + "count": 3, + "VALIDITY": [1, 1, 1], + "DATA": [2, 1, 2] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index d4a6b6672e07..e0497b855e03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -24,8 +24,10 @@ import java.util.Locale import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator +import org.apache.spark.SparkException import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType // NOTE - nullable type can be declared as Option[*] or java.lang.* @@ -39,16 +41,26 @@ private[sql] case class DoubleData(i: Int, a_d: Double, b_d: Option[Double]) class ArrowConvertersSuite extends SharedSQLContext { import testImplicits._ + private def collectAsArrow(df: DataFrame, + converter: Option[ArrowConverters] = None): ArrowPayload = { + val cnvtr = converter.getOrElse(new ArrowConverters) + val payloadByteArrays = df.toArrowPayloadBytes().collect() + cnvtr.readPayloadByteArrays(payloadByteArrays) + } + private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).getFile } test("collect to arrow record batch") { - val arrowPayload = indexData.collectAsArrow() + val arrowPayload = collectAsArrow(indexData) assert(arrowPayload.nonEmpty) - arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getLength > 0)) - arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getNodes.size() > 0)) - arrowPayload.foreach(arrowRecordBatch => arrowRecordBatch.close()) + val arrowBatches = arrowPayload.toArray + assert(arrowBatches.length == indexData.rdd.getNumPartitions) + val rowCount = arrowBatches.map(batch => batch.getLength).sum + assert(rowCount === indexData.count()) + arrowBatches.foreach(batch => assert(batch.getNodes.size() > 0)) + arrowBatches.foreach(batch => batch.close()) } test("standard type conversion") { @@ -82,7 +94,16 @@ class ArrowConvertersSuite extends SharedSQLContext { } test("partitioned DataFrame") { - collectAndValidate(testData2, "test-data/arrow/testData2-ints.json") + val converter = new ArrowConverters + val schema = testData2.schema + val arrowPayload = collectAsArrow(testData2, Some(converter)) + val arrowBatches = arrowPayload.toArray + // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + assert(arrowBatches.length === 2) + val pl1 = new ArrowStaticPayload(arrowBatches(0)) + val pl2 = new ArrowStaticPayload(arrowBatches(1)) + validateConversion(schema, pl1,"test-data/arrow/testData2-ints-part1.json", Some(converter)) + validateConversion(schema, pl2,"test-data/arrow/testData2-ints-part2.json", Some(converter)) } test("string type conversion") { @@ -105,11 +126,14 @@ class ArrowConvertersSuite extends SharedSQLContext { collectAndValidate(binaryData, "test-data/arrow/binaryData.json") } - test("nested type conversion") { } + // Type not yet supported + ignore("nested type conversion") { } - test("array type conversion") { } + // Type not yet supported + ignore("array type conversion") { } - test("mapped type conversion") { } + // Type not yet supported + ignore("mapped type conversion") { } test("floating-point NaN") { val nanData = Seq((1, 1.2F, Double.NaN), (2, Float.NaN, 1.23)).toDF("i", "NaN_f", "NaN_d") @@ -123,22 +147,32 @@ class ArrowConvertersSuite extends SharedSQLContext { } test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.collectAsArrow() - assert(arrowPayload.nonEmpty) - arrowPayload.foreach(emptyBatch => assert(emptyBatch.getLength == 0)) + val arrowPayload = collectAsArrow(spark.emptyDataFrame) + assert(arrowPayload.isEmpty) + } + + test("empty partition collect") { + val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") + val arrowPayload = collectAsArrow(emptyPart) + val arrowBatches = arrowPayload.toArray + assert(arrowBatches.length === 2) + assert(arrowBatches.count(_.getLength == 0) === 1) + assert(arrowBatches.count(_.getLength == 1) === 1) } test("unsupported types") { def runUnsupported(block: => Unit): Unit = { - val msg = intercept[UnsupportedOperationException] { + val msg = intercept[SparkException] { block } assert(msg.getMessage.contains("Unsupported data type")) + assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { - collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json") - } + runUnsupported { collectAsArrow(decimalData) } + runUnsupported { collectAsArrow(arrayData.toDF()) } + runUnsupported { collectAsArrow(mapData.toDF()) } + runUnsupported { collectAsArrow(complexData) } } test("test Arrow Validator") { @@ -160,22 +194,29 @@ class ArrowConvertersSuite extends SharedSQLContext { } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, arrowFile: String) { - val jsonFilePath = testFile(arrowFile) - + private def collectAndValidate(df: DataFrame, arrowFile: String): Unit = { val converter = new ArrowConverters + // NOTE: coalesce to single partition because can only load 1 batch in validator + val arrowPayload = collectAsArrow(df.coalesce(1), Some(converter)) + validateConversion(df.schema, arrowPayload, arrowFile, Some(converter)) + } + + private def validateConversion(sparkSchema: StructType, + arrowPayload: ArrowPayload, + arrowFile: String, + converterOpt: Option[ArrowConverters] = None): Unit = { + val converter = converterOpt.getOrElse(new ArrowConverters) + val jsonFilePath = testFile(arrowFile) val jsonReader = new JsonFileReader(new File(jsonFilePath), converter.allocator) - val arrowSchema = ArrowConverters.schemaToArrowSchema(df.schema) + val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) - val arrowPayload = df.collectAsArrow(Some(converter)) val arrowRoot = new VectorSchemaRoot(arrowSchema, converter.allocator) val vectorLoader = new VectorLoader(arrowRoot) arrowPayload.foreach(vectorLoader.load) val jsonRoot = jsonReader.read() - Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) } From 3090a3eeb0f8efb68187ee870c0fb3ea46b9f46e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 21 Feb 2017 17:44:57 -0800 Subject: [PATCH 16/56] Changed tests to use generated JSON data instead of files --- .../test-data/arrow/allNulls-ints.json | 32 -- .../resources/test-data/arrow/boolData.json | 32 -- .../resources/test-data/arrow/byteData.json | 32 -- .../arrow/decimalData-BigDecimal.json | 50 -- .../doubleData-double_precision-nullable.json | 68 --- .../floatData-single_precision-nullable.json | 68 --- .../test-data/arrow/indexData-ints.json | 32 -- .../arrow/intData-32bit_ints-nullable.json | 68 --- .../test-data/arrow/largeAndSmall-ints.json | 50 -- .../arrow/longData-64bit_ints-nullable.json | 68 --- .../test-data/arrow/lowercase-strings.json | 52 -- .../arrow/mixedData-standard-nullable.json | 212 ------- .../arrow/nanData-floating_point.json | 68 --- .../test-data/arrow/null-ints-mixed.json | 50 -- .../resources/test-data/arrow/null-ints.json | 32 -- .../test-data/arrow/null-strings.json | 34 -- .../test-data/arrow/salary-doubles.json | 50 -- .../arrow/shortData-16bit_ints-nullable.json | 68 --- .../test-data/arrow/testData2-ints-part1.json | 50 -- .../test-data/arrow/testData2-ints-part2.json | 50 -- .../test-data/arrow/timestampData.json | 32 -- .../test-data/arrow/uppercase-strings.json | 52 -- .../spark/sql/ArrowConvertersSuite.scala | 537 +++++++++++++----- 23 files changed, 410 insertions(+), 1377 deletions(-) delete mode 100644 sql/core/src/test/resources/test-data/arrow/allNulls-ints.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/boolData.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/byteData.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/decimalData-BigDecimal.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/doubleData-double_precision-nullable.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/floatData-single_precision-nullable.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/indexData-ints.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/intData-32bit_ints-nullable.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/largeAndSmall-ints.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/longData-64bit_ints-nullable.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/lowercase-strings.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/mixedData-standard-nullable.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/null-ints-mixed.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/null-ints.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/null-strings.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/salary-doubles.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/shortData-16bit_ints-nullable.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/timestampData.json delete mode 100644 sql/core/src/test/resources/test-data/arrow/uppercase-strings.json diff --git a/sql/core/src/test/resources/test-data/arrow/allNulls-ints.json b/sql/core/src/test/resources/test-data/arrow/allNulls-ints.json deleted file mode 100644 index e12f546e461c..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/allNulls-ints.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "a", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 4, - "columns": [ - { - "name": "a", - "count": 4, - "VALIDITY": [0, 0, 0, 0], - "DATA": [0, 0, 0, 0] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/boolData.json b/sql/core/src/test/resources/test-data/arrow/boolData.json deleted file mode 100644 index f402e5118cef..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/boolData.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "a_bool", - "type": {"name": "bool"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - } - ] - }, - - "batches": [ - { - "count": 4, - "columns": [ - { - "name": "a_bool", - "count": 4, - "VALIDITY": [1, 1, 1, 1], - "DATA": [true, true, false, true] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/byteData.json b/sql/core/src/test/resources/test-data/arrow/byteData.json deleted file mode 100644 index d0a6ceb818f7..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/byteData.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "a_byte", - "type": {"name": "int", "isSigned": true, "bitWidth": 8}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - } - ] - }, - - "batches": [ - { - "count": 4, - "columns": [ - { - "name": "a_byte", - "count": 4, - "VALIDITY": [1, 1, 1, 1], - "DATA": [1, -1, 64, 127] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/decimalData-BigDecimal.json b/sql/core/src/test/resources/test-data/arrow/decimalData-BigDecimal.json deleted file mode 100644 index 8449acaab23d..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/decimalData-BigDecimal.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "a", - "type": {"name": "floatingpoint", "precision": "DOUBLE"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b", - "type": {"name": "floatingpoint", "precision": "DOUBLE"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "a", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 1, 2, 2, 3, 3] - }, - { - "name": "b", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 1, 2, 1, 2] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/doubleData-double_precision-nullable.json b/sql/core/src/test/resources/test-data/arrow/doubleData-double_precision-nullable.json deleted file mode 100644 index d29b9ed6a2c9..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/doubleData-double_precision-nullable.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - }, - { - "name": "a_d", - "type": {"name": "floatingpoint", "precision": "DOUBLE"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_d", - "type": {"name": "floatingpoint", "precision": "DOUBLE"}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "i", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 3, 4, 5, 6] - }, - { - "name": "a_d", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0] - }, - { - "name": "b_d", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1.1, 0, 0, 2.2, 0, 3.3] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/floatData-single_precision-nullable.json b/sql/core/src/test/resources/test-data/arrow/floatData-single_precision-nullable.json deleted file mode 100644 index 9d686d1367a6..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/floatData-single_precision-nullable.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - }, - { - "name": "a_f", - "type": {"name": "floatingpoint", "precision": "SINGLE"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_f", - "type": {"name": "floatingpoint", "precision": "SINGLE"}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "i", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 3, 4, 5, 6] - }, - { - "name": "a_f", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0] - }, - { - "name": "b_f", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1.1, 0, 0, 2.2, 0, 3.3] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/indexData-ints.json b/sql/core/src/test/resources/test-data/arrow/indexData-ints.json deleted file mode 100644 index e96945d8b7ac..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/indexData-ints.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "i", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 3, 4, 5, 6] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/intData-32bit_ints-nullable.json b/sql/core/src/test/resources/test-data/arrow/intData-32bit_ints-nullable.json deleted file mode 100644 index 049b30cf4a3c..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/intData-32bit_ints-nullable.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - }, - { - "name": "a_i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "i", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 3, 4, 5, 6] - }, - { - "name": "a_i", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, -1, 2, -2, 2147483647, -2147483648] - }, - { - "name": "b_i", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1, -1, 2, -2, 2147483647, -2147483648] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/largeAndSmall-ints.json b/sql/core/src/test/resources/test-data/arrow/largeAndSmall-ints.json deleted file mode 100644 index e2f15e865626..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/largeAndSmall-ints.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "a", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "a", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [2147483644, 1, 2147483645, 2, 2147483646, 3] - }, - { - "name": "b", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 1, 2, 1, 2] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/longData-64bit_ints-nullable.json b/sql/core/src/test/resources/test-data/arrow/longData-64bit_ints-nullable.json deleted file mode 100644 index a6bd5f002b05..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/longData-64bit_ints-nullable.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - }, - { - "name": "a_l", - "type": {"name": "int", "isSigned": true, "bitWidth": 64}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_l", - "type": {"name": "int", "isSigned": true, "bitWidth": 64}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "i", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 3, 4, 5, 6] - }, - { - "name": "a_l", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, -1, 2, -2, 9223372036854775807, -9223372036854775808] - }, - { - "name": "b_l", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1, -1, 2, -2, 9223372036854775807, -9223372036854775808] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/lowercase-strings.json b/sql/core/src/test/resources/test-data/arrow/lowercase-strings.json deleted file mode 100644 index 356c431a671e..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/lowercase-strings.json +++ /dev/null @@ -1,52 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "n", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - }, - { - "name": "l", - "type": {"name": "utf8"}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "OFFSET", "typeBitWidth": 32}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - } - ] - }, - - "batches": [ - { - "count": 4, - "columns": [ - { - "name": "n", - "count": 4, - "VALIDITY": [1, 1, 1, 1], - "DATA": [1, 2, 3, 4] - }, - { - "name": "l", - "count": 4, - "VALIDITY": [1, 1, 1, 1], - "OFFSET": [0, 1, 2, 3, 4], - "DATA": ["a", "b", "c", "d"] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/mixedData-standard-nullable.json b/sql/core/src/test/resources/test-data/arrow/mixedData-standard-nullable.json deleted file mode 100644 index 2d7921001eb7..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/mixedData-standard-nullable.json +++ /dev/null @@ -1,212 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - }, - { - "name": "a_s", - "type": {"name": "int", "isSigned": true, "bitWidth": 16}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_s", - "type": {"name": "int", "isSigned": true, "bitWidth": 16}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "a_i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "a_l", - "type": {"name": "int", "isSigned": true, "bitWidth": 64}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_l", - "type": {"name": "int", "isSigned": true, "bitWidth": 64}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "a_f", - "type": {"name": "floatingpoint", "precision": "SINGLE"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_f", - "type": {"name": "floatingpoint", "precision": "SINGLE"}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "a_d", - "type": {"name": "floatingpoint", "precision": "DOUBLE"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_d", - "type": {"name": "floatingpoint", "precision": "DOUBLE"}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "i", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 3, 4, 5, 6] - }, - { - "name": "a_s", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, -1, 2, -2, 32767, -32768] - }, - { - "name": "b_s", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1, -1, 2, -2, 32767, -32768] - }, - { - "name": "a_i", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, -1, 2, -2, 2147483647, -2147483648] - }, - { - "name": "b_i", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1, -1, 2, -2, 2147483647, -2147483648] - }, - { - "name": "a_l", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, -1, 2, -2, 9223372036854775807, -9223372036854775808] - }, - { - "name": "b_l", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1, -1, 2, -2, 9223372036854775807, -9223372036854775808] - }, - { - "name": "a_f", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0] - }, - { - "name": "b_f", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1.1, 0, 0, 2.2, 0, 3.3] - }, - { - "name": "a_d", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0] - }, - { - "name": "b_d", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1.1, 0, 0, 2.2, 0, 3.3] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json b/sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json deleted file mode 100644 index 4a8407d45f37..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - }, - { - "name": "NaN_f", - "type": {"name": "floatingpoint", "precision": "SINGLE"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "NaN_d", - "type": {"name": "floatingpoint", "precision": "DOUBLE"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 2, - "columns": [ - { - "name": "i", - "count": 2, - "VALIDITY": [1, 1], - "DATA": [1, 2] - }, - { - "name": "NaN_f", - "count": 2, - "VALIDITY": [1, 1], - "DATA": [1.2, "NaN"] - }, - { - "name": "NaN_d", - "count": 2, - "VALIDITY": [1, 1], - "DATA": ["NaN", 1.23] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/null-ints-mixed.json b/sql/core/src/test/resources/test-data/arrow/null-ints-mixed.json deleted file mode 100644 index a82ba623f539..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/null-ints-mixed.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "a", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 2, - "columns": [ - { - "name": "a", - "count": 2, - "VALIDITY": [1, 1], - "DATA": [1, 2] - }, - { - "name": "b", - "count": 2, - "VALIDITY": [0, 1], - "DATA": [0, 2] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/null-ints.json b/sql/core/src/test/resources/test-data/arrow/null-ints.json deleted file mode 100644 index 1a2447abdc0b..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/null-ints.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "a", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 4, - "columns": [ - { - "name": "a", - "count": 4, - "VALIDITY": [1, 1, 1, 0], - "DATA": [1, 2, 3, 0] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/null-strings.json b/sql/core/src/test/resources/test-data/arrow/null-strings.json deleted file mode 100644 index c93e1e757bc5..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/null-strings.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "s", - "type": {"name": "utf8"}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "OFFSET", "typeBitWidth": 32}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - } - ] - }, - - "batches": [ - { - "count": 3, - "columns": [ - { - "name": "s", - "count": 3, - "VALIDITY": [1, 1, 0], - "OFFSET": [0, 3, 6, 6], - "DATA": ["abc", "ABC", ""] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/salary-doubles.json b/sql/core/src/test/resources/test-data/arrow/salary-doubles.json deleted file mode 100644 index 2cc42182a56c..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/salary-doubles.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "personId", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "salary", - "type": {"name": "floatingpoint", "precision": "DOUBLE"}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 2, - "columns": [ - { - "name": "personId", - "count": 2, - "VALIDITY": [1, 1], - "DATA": [0, 1] - }, - { - "name": "salary", - "count": 2, - "VALIDITY": [1, 1], - "DATA": [2000.0, 1000.0] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/shortData-16bit_ints-nullable.json b/sql/core/src/test/resources/test-data/arrow/shortData-16bit_ints-nullable.json deleted file mode 100644 index ca04de5b0ea3..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/shortData-16bit_ints-nullable.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "i", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - }, - { - "name": "a_s", - "type": {"name": "int", "isSigned": true, "bitWidth": 16}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b_s", - "type": {"name": "int", "isSigned": true, "bitWidth": 16}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "i", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 3, 4, 5, 6] - }, - { - "name": "a_s", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, -1, 2, -2, 32767, -32768] - }, - { - "name": "b_s", - "count": 6, - "VALIDITY": [1, 0, 0, 1, 0, 1], - "DATA": [1, -1, 2, -2, 32767, -32768] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json deleted file mode 100644 index bf6f0a38a332..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "a", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 3, - "columns": [ - { - "name": "a", - "count": 3, - "VALIDITY": [1, 1, 1], - "DATA": [1, 1, 2] - }, - { - "name": "b", - "count": 3, - "VALIDITY": [1, 1, 1], - "DATA": [1, 2, 1] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json deleted file mode 100644 index 5261d51ff218..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "a", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - }, - { - "name": "b", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 32} - ] - } - } - ] - }, - - "batches": [ - { - "count": 3, - "columns": [ - { - "name": "a", - "count": 3, - "VALIDITY": [1, 1, 1], - "DATA": [2, 3, 3] - }, - { - "name": "b", - "count": 3, - "VALIDITY": [1, 1, 1], - "DATA": [2, 1, 2] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/timestampData.json b/sql/core/src/test/resources/test-data/arrow/timestampData.json deleted file mode 100644 index 6fe59975954d..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/timestampData.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "c_timestamp", - "type": {"name": "timestamp", "unit": "MILLISECOND"}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 64} - ] - } - } - ] - }, - - "batches": [ - { - "count": 2, - "columns": [ - { - "name": "c_timestamp", - "count": 2, - "VALIDITY": [1, 1], - "DATA": [1365383415567, 1365426610789] - } - ] - } - ] -} diff --git a/sql/core/src/test/resources/test-data/arrow/uppercase-strings.json b/sql/core/src/test/resources/test-data/arrow/uppercase-strings.json deleted file mode 100644 index b6016022e314..000000000000 --- a/sql/core/src/test/resources/test-data/arrow/uppercase-strings.json +++ /dev/null @@ -1,52 +0,0 @@ -{ - "schema": { - "fields": [ - { - "name": "N", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": false, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - }, - { - "name": "L", - "type": {"name": "utf8"}, - "nullable": true, - "children": [], - "typeLayout": { - "vectors": [ - {"type": "VALIDITY", "typeBitWidth": 1}, - {"type": "OFFSET", "typeBitWidth": 32}, - {"type": "DATA", "typeBitWidth": 8} - ] - } - } - ] - }, - - "batches": [ - { - "count": 6, - "columns": [ - { - "name": "N", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 3, 4, 5, 6] - }, - { - "name": "L", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "OFFSET": [0, 1, 2, 3, 4, 5, 6], - "DATA": ["A", "B", "C", "D", "E", "F"] - } - ] - } - ] -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index e0497b855e03..51e2455d8bcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -17,30 +17,31 @@ package org.apache.spark.sql import java.io.File +import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale +import com.google.common.io.Files import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator -import org.apache.spark.SparkException +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils -// NOTE - nullable type can be declared as Option[*] or java.lang.* -private[sql] case class ShortData(i: Int, a_s: Short, b_s: Option[Short]) -private[sql] case class IntData(i: Int, a_i: Int, b_i: Option[Int]) -private[sql] case class LongData(i: Int, a_l: Long, b_l: java.lang.Long) -private[sql] case class FloatData(i: Int, a_f: Float, b_f: Option[Float]) -private[sql] case class DoubleData(i: Int, a_d: Double, b_d: Option[Double]) - - -class ArrowConvertersSuite extends SharedSQLContext { +class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { import testImplicits._ + private var tempDataPath: String = _ + private def collectAsArrow(df: DataFrame, converter: Option[ArrowConverters] = None): ArrowPayload = { val cnvtr = converter.getOrElse(new ArrowConverters) @@ -48,11 +49,13 @@ class ArrowConvertersSuite extends SharedSQLContext { cnvtr.readPayloadByteArrays(payloadByteArrays) } - private def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).getFile + override def beforeAll(): Unit = { + super.beforeAll() + tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath } test("collect to arrow record batch") { + val indexData = (1 to 6).toDF("i") val arrowPayload = collectAsArrow(indexData) assert(arrowPayload.nonEmpty) val arrowBatches = arrowPayload.toArray @@ -63,87 +66,76 @@ class ArrowConvertersSuite extends SharedSQLContext { arrowBatches.foreach(batch => batch.close()) } - test("standard type conversion") { - collectAndValidate(indexData, "test-data/arrow/indexData-ints.json") - collectAndValidate(largeAndSmallInts, "test-data/arrow/largeAndSmall-ints.json") - collectAndValidate(salary, "test-data/arrow/salary-doubles.json") + test("numeric type conversion") { + collectAndValidate(indexData) + collectAndValidate(shortData) + collectAndValidate(intData) + collectAndValidate(longData) + collectAndValidate(floatData) + collectAndValidate(doubleData) } - test("standard type nullable conversion") { - collectAndValidate(shortData, "test-data/arrow/shortData-16bit_ints-nullable.json") - collectAndValidate(intData, "test-data/arrow/intData-32bit_ints-nullable.json") - collectAndValidate(longData, "test-data/arrow/longData-64bit_ints-nullable.json") - collectAndValidate(floatData, "test-data/arrow/floatData-single_precision-nullable.json") - collectAndValidate(doubleData, "test-data/arrow/doubleData-double_precision-nullable.json") + test("mixed numeric type conversion") { + collectAndValidate(mixedNumericData) } test("boolean type conversion") { - val boolData = Seq(true, true, false, true).toDF("a_bool") - collectAndValidate(boolData, "test-data/arrow/boolData.json") + collectAndValidate(boolData) } - test("byte type conversion") { - val byteData = Seq(1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") - collectAndValidate(byteData, "test-data/arrow/byteData.json") - } - - test("mixed standard type nullable conversion") { - val mixedData = Seq(shortData, intData, longData, floatData, doubleData) - .reduce((a, b) => a.join(b, "i")).sort("i") - collectAndValidate(mixedData, "test-data/arrow/mixedData-standard-nullable.json") + test("string type conversion") { + collectAndValidate(stringData) } - test("partitioned DataFrame") { - val converter = new ArrowConverters - val schema = testData2.schema - val arrowPayload = collectAsArrow(testData2, Some(converter)) - val arrowBatches = arrowPayload.toArray - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowBatches.length === 2) - val pl1 = new ArrowStaticPayload(arrowBatches(0)) - val pl2 = new ArrowStaticPayload(arrowBatches(1)) - validateConversion(schema, pl1,"test-data/arrow/testData2-ints-part1.json", Some(converter)) - validateConversion(schema, pl2,"test-data/arrow/testData2-ints-part2.json", Some(converter)) + test("byte type conversion") { + collectAndValidate(byteData) } - test("string type conversion") { - collectAndValidate(upperCaseData, "test-data/arrow/uppercase-strings.json") - collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json") - val nullStringsColOnly = nullStrings.select(nullStrings.columns(1)) - collectAndValidate(nullStringsColOnly, "test-data/arrow/null-strings.json") + test("timestamp conversion") { + collectAndValidate(timestampData) } + // TODO: Not currently supported in Arrow JSON reader ignore("date conversion") { - collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json") - } - - test("timestamp conversion") { - collectAndValidate(dateTimeData.select($"c_timestamp"), "test-data/arrow/timestampData.json") + // collectAndValidate(dateTimeData) } - // Arrow json reader doesn't support binary data + // TODO: Not currently supported in Arrow JSON reader ignore("binary type conversion") { - collectAndValidate(binaryData, "test-data/arrow/binaryData.json") + // collectAndValidate(binaryData) } - // Type not yet supported - ignore("nested type conversion") { } - - // Type not yet supported - ignore("array type conversion") { } - - // Type not yet supported - ignore("mapped type conversion") { } - test("floating-point NaN") { - val nanData = Seq((1, 1.2F, Double.NaN), (2, Float.NaN, 1.23)).toDF("i", "NaN_f", "NaN_d") - collectAndValidate(nanData, "test-data/arrow/nanData-floating_point.json") + collectAndValidate(floatNaNData) } - test("convert int column with null to arrow") { - collectAndValidate(nullInts, "test-data/arrow/null-ints.json") - collectAndValidate(testData3, "test-data/arrow/null-ints-mixed.json") - collectAndValidate(allNulls, "test-data/arrow/allNulls-ints.json") + test("partitioned DataFrame") { + val converter = new ArrowConverters + val schema = testData2.schema + val arrowPayload = collectAsArrow(testData2, Some(converter)) + val arrowBatches = arrowPayload.toArray + // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + assert(arrowBatches.length === 2) + val pl1 = new ArrowStaticPayload(arrowBatches(0)) + val pl2 = new ArrowStaticPayload(arrowBatches(1)) + // Generate JSON files + val a = List[Int](1, 1, 2, 2, 3, 3) + val b = List[Int](1, 2, 1, 2, 1, 2) + val fields = Seq(new IntegerType("a", is_signed = true, 32, nullable = false), + new IntegerType("b", is_signed = true, 32, nullable = false)) + def getBatch(x: Seq[Int], y: Seq[Int]): JSONRecordBatch = { + val columns = Seq(new PrimitiveColumn("a", x.length, x.map(_ => true), x), + new PrimitiveColumn("b", y.length, y.map(_ => true), y)) + new JSONRecordBatch(x.length, columns) + } + val json1 = new JSONFile(new JSONSchema(fields), Seq(getBatch(a.take(3), b.take(3)))) + val json2 = new JSONFile(new JSONSchema(fields), Seq(getBatch(a.takeRight(3), b.takeRight(3)))) + val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") + val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") + json1.write(tempFile1) + json2.write(tempFile2) + validateConversion(schema, pl1, tempFile1, Some(converter)) + validateConversion(schema, pl2, tempFile2, Some(converter)) } test("empty frame collect") { @@ -176,38 +168,36 @@ class ArrowConvertersSuite extends SharedSQLContext { } test("test Arrow Validator") { - - // Missing test file - intercept[NullPointerException] { - collectAndValidate(indexData, "test-data/arrow/missing-file") - } + val sdata = shortData + val idata = intData // Different schema intercept[IllegalArgumentException] { - collectAndValidate(shortData, "test-data/arrow/intData-32bit_ints-nullable.json") + collectAndValidate(DataTuple(sdata.df, idata.json, idata.file)) } // Different values intercept[IllegalArgumentException] { - collectAndValidate(indexData.sort($"i".desc), "test-data/arrow/indexData-ints.json") + collectAndValidate(DataTuple(idata.df.sort($"a_i".desc), idata.json, idata.file)) } } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, arrowFile: String): Unit = { + private def collectAndValidate(data: DataTuple): Unit = { val converter = new ArrowConverters // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = collectAsArrow(df.coalesce(1), Some(converter)) - validateConversion(df.schema, arrowPayload, arrowFile, Some(converter)) + val arrowPayload = collectAsArrow(data.df.coalesce(1), Some(converter)) + val tempFile = new File(tempDataPath, data.file) + data.json.write(tempFile) + validateConversion(data.df.schema, arrowPayload, tempFile, Some(converter)) } private def validateConversion(sparkSchema: StructType, arrowPayload: ArrowPayload, - arrowFile: String, + jsonFile: File, converterOpt: Option[ArrowConverters] = None): Unit = { val converter = converterOpt.getOrElse(new ArrowConverters) - val jsonFilePath = testFile(arrowFile) - val jsonReader = new JsonFileReader(new File(jsonFilePath), converter.allocator) + val jsonReader = new JsonFileReader(jsonFile, converter.allocator) val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) val jsonSchema = jsonReader.start() @@ -220,65 +210,358 @@ class ArrowConvertersSuite extends SharedSQLContext { Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) } - protected lazy val indexData = Seq(1, 2, 3, 4, 5, 6).toDF("i") + // Create Spark DataFrame and matching Arrow JSON at same time for validation + private case class DataTuple(df: DataFrame, json: JSONFile, file: String) - protected lazy val shortData: DataFrame = { - spark.sparkContext.parallelize( - ShortData(1, 1, Some(1)) :: - ShortData(2, -1, None) :: - ShortData(3, 2, None) :: - ShortData(4, -2, Some(-2)) :: - ShortData(5, 32767, None) :: - ShortData(6, -32768, Some(-32768)) :: Nil).toDF() + private def indexData: DataTuple = { + val data = List[Int](1, 2, 3, 4, 5, 6) + val fields = Seq(new IntegerType("i", is_signed = true, 32, nullable = false)) + val schema = new JSONSchema(fields) + val columns = Seq(new PrimitiveColumn("i", data.length, data.map(_ => true), data)) + val batch = new JSONRecordBatch(data.length, columns) + DataTuple(data.toDF("i"), new JSONFile(schema, Seq(batch)), "indexData-ints.json") + } + + private def shortData: DataTuple = { + val a_s = List[Short](1, -1, 2, -2, 32767, -32768) + val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) + val fields = Seq(new IntegerType("a_s", is_signed = true, 16, nullable = false), + new IntegerType("b_s", is_signed = true, 16, nullable = true)) + val schema = new JSONSchema(fields) + val b_s_values = b_s.map(_.map(_.toInt).getOrElse(0)) + val columns = Seq( + new PrimitiveColumn("a_s", a_s.length, a_s.map(_ => true), a_s.map(_.toInt)), + new PrimitiveColumn("b_s", b_s.length, b_s.map(_.isDefined), b_s_values)) + val batch = new JSONRecordBatch(a_s.length, columns) + val df = a_s.zip(b_s).toDF("a_s", "b_s") + DataTuple(df, new JSONFile(schema, Seq(batch)), "integer-16bit.json") + } + + private def intData: DataTuple = { + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val fields = Seq(new IntegerType("a_i", is_signed = true, 32, nullable = false), + new IntegerType("b_i", is_signed = true, 32, nullable = true)) + val schema = new JSONSchema(fields) + val columns = Seq( + new PrimitiveColumn("a_i", a_i.length, a_i.map(_ => true), a_i), + new PrimitiveColumn("b_i", b_i.length, b_i.map(_.isDefined), b_i.map(_.getOrElse(0)))) + val batch = new JSONRecordBatch(a_i.length, columns) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + DataTuple(df, new JSONFile(schema, Seq(batch)), "integer-32bit.json") + } + + private def longData: DataTuple = { + val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) + val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) + val fields = Seq(new IntegerType("a_l", is_signed = true, 64, nullable = false), + new IntegerType("b_l", is_signed = true, 64, nullable = true)) + val schema = new JSONSchema(fields) + val columns = Seq( + new PrimitiveColumn("a_l", a_l.length, a_l.map(_ => true), a_l), + new PrimitiveColumn("b_l", b_l.length, b_l.map(_.isDefined), b_l.map(_.getOrElse(0L)))) + val batch = new JSONRecordBatch(a_l.length, columns) + val df = a_l.zip(b_l).toDF("a_l", "b_l") + DataTuple(df, new JSONFile(schema, Seq(batch)), "integer-64bit.json") + } + + private def floatData: DataTuple = { + val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) + val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) + val fields = Seq(new FloatingPointType("a_f", 32, nullable = false), + new FloatingPointType("b_f", 32, nullable = true)) + val schema = new JSONSchema(fields) + val columns = Seq(new PrimitiveColumn("a_f", a_f.length, a_f.map(_ => true), a_f), + new PrimitiveColumn("b_f", b_f.length, b_f.map(_.isDefined), b_f.map(_.getOrElse(0.0f)))) + val batch = new JSONRecordBatch(a_f.length, columns) + val df = a_f.zip(b_f).toDF("a_f", "b_f") + DataTuple(df, new JSONFile(schema, Seq(batch)), "floating_point-single_precision.json") + } + + private def doubleData: DataTuple = { + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) + val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) + val fields = Seq(new FloatingPointType("a_d", 64, nullable = false), + new FloatingPointType("b_d", 64, nullable = true)) + val schema = new JSONSchema(fields) + val columns = Seq(new PrimitiveColumn("a_d", a_d.length, a_d.map(_ => true), a_d), + new PrimitiveColumn("b_d", b_d.length, b_d.map(_.isDefined), b_d.map(_.getOrElse(0.0)))) + val batch = new JSONRecordBatch(a_d.length, columns) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + DataTuple(df, new JSONFile(schema, Seq(batch)), "floating_point-double_precision.json") + } + + private def mixedNumericData: DataTuple = { + val data = List(1, 2, 3, 4, 5, 6) + val fields = Seq(new IntegerType("a", is_signed = true, 16, nullable = false), + new FloatingPointType("b", 32, nullable = false), + new IntegerType("c", is_signed = true, 32, nullable = false), + new FloatingPointType("d", 64, nullable = false), + new IntegerType("e", is_signed = true, 64, nullable = false)) + val schema = new JSONSchema(fields) + val columns = Seq(new PrimitiveColumn("a", data.length, data.map(_ => true), data), + new PrimitiveColumn("b", data.length, data.map(_ => true), data.map(_.toFloat)), + new PrimitiveColumn("c", data.length, data.map(_ => true), data), + new PrimitiveColumn("d", data.length, data.map(_ => true), data.map(_.toDouble)), + new PrimitiveColumn("e", data.length, data.map(_ => true), data) + ) + val batch = new JSONRecordBatch(data.length, columns) + val data_tuples = for (d <- data) yield { + (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) + } + val df = data_tuples.toDF("a", "b", "c", "d", "e") + DataTuple(df, new JSONFile(schema, Seq(batch)), "mixed_numeric_types.json") } - protected lazy val intData: DataFrame = { - spark.sparkContext.parallelize( - IntData(1, 1, Some(1)) :: - IntData(2, -1, None) :: - IntData(3, 2, None) :: - IntData(4, -2, Some(-2)) :: - IntData(5, 2147483647, None) :: - IntData(6, -2147483648, Some(-2147483648)) :: Nil).toDF() + private def boolData: DataTuple = { + val data = Seq(true, true, false, true) + val fields = Seq(new BooleanType("a_bool", nullable = false)) + val schema = new JSONSchema(fields) + val columns = Seq(new PrimitiveColumn("a_bool", data.length, data.map(_ => true), data)) + val batch = new JSONRecordBatch(data.length, columns) + DataTuple(data.toDF("a_bool"), new JSONFile(schema, Seq(batch)), "boolData.json") } - protected lazy val longData: DataFrame = { - spark.sparkContext.parallelize( - LongData(1, 1L, 1L) :: - LongData(2, -1L, null) :: - LongData(3, 2L, null) :: - LongData(4, -2, -2L) :: - LongData(5, 9223372036854775807L, null) :: - LongData(6, -9223372036854775808L, -9223372036854775808L) :: Nil).toDF() + private def stringData: DataTuple = { + val upperCase = Seq("A", "B", "C") + val lowerCase = Seq("a", "b", "c") + val nullStr = Seq("ab", "CDE", null) + val fields = Seq(new StringType("upper_case", nullable = true), + new StringType("lower_case", nullable = true), + new StringType("null_str", nullable = true)) + val schema = new JSONSchema(fields) + val columns = Seq( + new StringColumn("upper_case", upperCase.length, upperCase.map(_ => true), upperCase), + new StringColumn("lower_case", lowerCase.length, lowerCase.map(_ => true), lowerCase), + new StringColumn("null_str", nullStr.length, nullStr.map(_ != null), + nullStr.map { s => if (s == null) "" else s} + )) + val batch = new JSONRecordBatch(upperCase.length, columns) + val df = (upperCase, lowerCase, nullStr).zipped.toList + .toDF("upper_case", "lower_case", "null_str") + DataTuple(df, new JSONFile(schema, Seq(batch)), "stringData.json") } - protected lazy val floatData: DataFrame = { - spark.sparkContext.parallelize( - FloatData(1, 1.0f, Some(1.1f)) :: - FloatData(2, 2.0f, None) :: - FloatData(3, 0.01f, None) :: - FloatData(4, 200.0f, Some(2.2f)) :: - FloatData(5, 0.0001f, None) :: - FloatData(6, 20000.0f, Some(3.3f)) :: Nil).toDF() + private def byteData: DataTuple = { + val data = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue) + val fields = Seq(new IntegerType("a_byte", is_signed = true, 8, nullable = false)) + val schema = new JSONSchema(fields) + val columns = Seq( + new PrimitiveColumn("a_byte", data.length, data.map(_ => true), data.map(_.toInt))) + val batch = new JSONRecordBatch(data.length, columns) + DataTuple(data.toDF("a_byte"), new JSONFile(schema, Seq(batch)), "byteData.json") } - protected lazy val doubleData: DataFrame = { - spark.sparkContext.parallelize( - DoubleData(1, 1.0, Some(1.1)) :: - DoubleData(2, 2.0, None) :: - DoubleData(3, 0.01, None) :: - DoubleData(4, 200.0, Some(2.2)) :: - DoubleData(5, 0.0001, None) :: - DoubleData(6, 20000.0, Some(3.3)) :: Nil).toDF() + private def floatNaNData: DataTuple = { + val fnan = Seq(1.2F, Float.NaN) + val dnan = Seq(Double.NaN, 1.2) + val fields = Seq(new FloatingPointType("NaN_f", 32, nullable = false), + new FloatingPointType("NaN_d", 64, nullable = false)) + val schema = new JSONSchema(fields) + val columns = Seq(new PrimitiveColumn("NaN_f", fnan.length, fnan.map(_ => true), fnan), + new PrimitiveColumn("NaN_d", dnan.length, dnan.map(_ => true), dnan)) + val batch = new JSONRecordBatch(fnan.length, columns) + val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") + DataTuple(df, new JSONFile(schema, Seq(batch)), "nanData-floating_point.json") } - protected lazy val dateTimeData: DataFrame = { + private def timestampData: DataTuple = { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + val data = Seq(ts1, ts2) + val schema = new JSONSchema(Seq(new TimestampType("c_timestamp"))) + val columns = Seq( + new PrimitiveColumn("c_timestamp", data.length, data.map(_ => true), data.map(_.getTime))) + val batch = new JSONRecordBatch(data.length, columns) + DataTuple(data.toDF("c_timestamp"), new JSONFile(schema, Seq(batch)), "timestampData.json") + } + + private def dateData: DataTuple = { val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) + val df = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) .toDF("a_date", "b_string", "c_timestamp") + val jsonFile = new JSONFile(new JSONSchema(Seq.empty[DataType]), Seq.empty[JSONRecordBatch]) + DataTuple(df, jsonFile, "dateData.json") + } + + /** + * Arrow JSON Format Data Generation + * Referenced from https://github.com/apache/arrow/blob/master/integration/integration_test.py + */ + + private abstract class DataType(name: String, nullable: Boolean) { + def _get_type: JObject + def _get_type_layout: JField + def _get_children: JArray + def get_json: JObject = { + JObject( + "name" -> name, + "type" -> _get_type, + "nullable" -> nullable, + "children" -> _get_children, + "typeLayout" -> _get_type_layout) + } + } + + private abstract class Column(name: String, count: Int) { + def _get_children: JArray + def _get_buffers: JObject + def get_json: JObject = { + val entries = JObject( + "name" -> name, + "count" -> count + ).merge(_get_buffers) + + val children = _get_children + if (children.arr.nonEmpty) entries.merge(JObject("children" -> children)) else entries + } + } + + private abstract class PrimitiveType(name: String, nullable: Boolean) + extends DataType(name, nullable) { + val bit_width: Int + override def _get_children: JArray = JArray(List.empty) + override def _get_type_layout: JField = { + JField("vectors", JArray(List( + JObject("type" -> "VALIDITY", "typeBitWidth" -> 1), + JObject("type" -> "DATA", "typeBitWidth" -> bit_width) + ))) + } + } + + private class PrimitiveColumn[T <% JValue](name: String, + count: Int, + is_valid: Seq[Boolean], + values: Seq[T]) + extends Column(name, count) { + override def _get_children: JArray = JArray(List.empty) + override def _get_buffers: JObject = { + JObject( + "VALIDITY" -> is_valid.map(b => if (b) 1 else 0), + "DATA" -> values) + } + } + + private class IntegerType(name: String, + is_signed: Boolean, + override val bit_width: Int, + nullable: Boolean) + extends PrimitiveType(name, nullable = nullable) { + override def _get_type: JObject = { + JObject( + "name" -> "int", + "isSigned" -> is_signed, + "bitWidth" -> bit_width) + } + } + + private class FloatingPointType(name: String, override val bit_width: Int, nullable: Boolean) + extends PrimitiveType(name, nullable = nullable) { + override def _get_type: JObject = { + val precision = bit_width match { + case 16 => "HALF" + case 32 => "SINGLE" + case 64 => "DOUBLE" + } + JObject( + "name" -> "floatingpoint", + "precision" -> precision) + } + } + + private class BooleanType(name: String, nullable: Boolean) + extends PrimitiveType(name, nullable = nullable) { + override val bit_width = 1 + override def _get_type: JObject = JObject("name" -> JString("bool")) + } + + private class BinaryType(name: String, nullable: Boolean) + extends PrimitiveType(name, nullable = nullable) { + override val bit_width = 8 + override def _get_type: JObject = JObject("name" -> JString("binary")) + override def _get_type_layout: JField = { + JField("vectors", JArray(List( + JObject("type" -> "VALIDITY", "typeBitWidth" -> 1), + JObject("type" -> "OFFSET", "typeBitWidth" -> 32), + JObject("type" -> "DATA", "typeBitWidth" -> bit_width) + ))) + } + } + + private class StringType(name: String, nullable: Boolean) + extends BinaryType(name, nullable = nullable) { + override def _get_type: JObject = JObject("name" -> JString("utf8")) + } + + private class TimestampType(name: String) extends PrimitiveType(name, nullable = true) { + override val bit_width = 64 + override def _get_type: JObject = { + JObject( + "name" -> "timestamp", + "unit" -> "MILLISECOND") + } + } + + private class JSONSchema(fields: Seq[DataType]) { + def get_json: JObject = { + JObject("fields" -> JArray(fields.map(_.get_json).toList)) + } + } + + private class BinaryColumn(name: String, + count: Int, + is_valid: Seq[Boolean], + values: Seq[String]) + extends PrimitiveColumn(name, count, is_valid, values) { + def _encode_value(v: String): String = { + v.map(c => String.format("%h", c.toString)).reduce(_ + _) + } + override def _get_buffers: JObject = { + var offset = 0 + val offsets = scala.collection.mutable.ArrayBuffer[Int](offset) + val data = values.zip(is_valid).map { case (value, isval) => + if (isval) offset += value.length + val element = _encode_value(if (isval) value else "") + offsets += offset + element + } + JObject( + "VALIDITY" -> is_valid.map(b => if (b) 1 else 0), + "OFFSET" -> offsets, + "DATA" -> data) + } + } + + private class StringColumn(name: String, + count: Int, + is_valid: Seq[Boolean], + values: Seq[String]) + extends BinaryColumn(name, count, is_valid, values) { + override def _encode_value(v: String): String = v + } + + private class JSONRecordBatch(count: Int, columns: Seq[Column]) { + def get_json: JObject = { + JObject( + "count" -> count, + "columns" -> columns.map(_.get_json)) + } + } + + private class JSONFile(schema: JSONSchema, batches: Seq[JSONRecordBatch]) { + def get_json: JObject = { + JObject( + "schema" -> schema.get_json, + "batches" -> batches.map(_.get_json)) + } + def write(file: File): Unit = { + val json = pretty(render(get_json)) + Files.write(json, file, StandardCharsets.UTF_8) + } } } From 54884ed502ad779edf212bb5652afe46f28bd921 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 22 Feb 2017 16:21:00 -0800 Subject: [PATCH 17/56] updated Arrow artifacts to 0.2.0 release --- pom.xml | 2 +- .../main/scala/org/apache/spark/sql/ArrowConverters.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index 17f6e5492871..47dfea20b61c 100644 --- a/pom.xml +++ b/pom.xml @@ -184,7 +184,7 @@ 2.6 1.8 1.0.0 - 0.1.1-SNAPSHOT + 0.2.0 ${java.home} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 47a2d966b0c7..04e3697f8d28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream import java.nio.ByteBuffer -import java.nio.channels.{SeekableByteChannel, Channels} +import java.nio.channels.{Channels, SeekableByteChannel} import scala.collection.JavaConverters._ @@ -174,8 +174,7 @@ private[sql] object ArrowConverters { val fieldNodes = fieldAndBuf._1 val buffers = fieldAndBuf._2.flatten - val rowLength = if(fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 - + val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 val recordBatch = new ArrowRecordBatch(rowLength, fieldNodes.toList.asJava, buffers.toList.asJava) From 42af1d59c6c8e6fbb0f38cd8d1b1afbe86076aa1 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 22 Feb 2017 17:03:24 -0800 Subject: [PATCH 18/56] fixed python style checks --- python/pyspark/serializers.py | 1 - python/pyspark/sql/dataframe.py | 5 +++-- python/pyspark/sql/tests.py | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 14af494c4cdd..386fab77429f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -183,7 +183,6 @@ def loads(self, obj): class ArrowSerializer(FramedSerializer): - """ Serializes an Arrow stream. """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 030717fc78b7..9ea0033157e5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -27,7 +27,8 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -391,7 +392,7 @@ def collect(self): return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix - @since(2.0) + @since(2.2) def collectAsArrow(self): """Returns all records as list of deserialized ArrowPayloads """ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 575dc3f20078..28ce47072b1e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2347,6 +2347,7 @@ def range_frame_match(): importlib.reload(window) + @unittest.skipIf(not _have_arrow, "Arrow not installed") class ArrowTests(ReusedPySparkTestCase): From 9c8ea63ccec4ceb9dd40834bca530a6265385579 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 22 Feb 2017 17:17:38 -0800 Subject: [PATCH 19/56] updated dependency manifest --- dev/deps/spark-deps-hadoop-2.6 | 10 ++++++++++ dev/deps/spark-deps-hadoop-2.7 | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 73dc1f9a1398..277bc6f44ef2 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.2.0.jar +arrow-memory-0.2.0.jar +arrow-vector-0.2.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar @@ -139,6 +144,11 @@ minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.9.9.Final.jar netty-all-4.0.43.Final.jar +netty-buffer-4.0.41.Final.jar +netty-codec-4.0.41.Final.jar +netty-common-4.0.41.Final.jar +netty-handler-4.0.41.Final.jar +netty-transport-4.0.41.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 6bf0923a1d75..341db5e78daf 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.2.0.jar +arrow-memory-0.2.0.jar +arrow-vector-0.2.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar @@ -140,6 +145,11 @@ minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.9.9.Final.jar netty-all-4.0.43.Final.jar +netty-buffer-4.0.41.Final.jar +netty-codec-4.0.41.Final.jar +netty-common-4.0.41.Final.jar +netty-handler-4.0.41.Final.jar +netty-transport-4.0.41.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar From b7c28ad19cc56553c8cdb53f98f92f189c9d7a27 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 23 Feb 2017 16:29:30 -0800 Subject: [PATCH 20/56] test format fix for python 2.6 --- python/pyspark/sql/tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 28ce47072b1e..da75714adaf9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2388,8 +2388,8 @@ def test_toPandas_arrow_toggle(self): def test_pandas_round_trip(self): import pandas as pd - data_dict = {name: [self.data[i][j] for i in range(len(self.data))] - for j, name in enumerate(self.schema.names)} + names = [(j, name) for j, name in enumerate(self.schema.names)] + data_dict = {name: [self.data[i][j] for i in range(len(self.data))] for j, name in names} pdf = pd.DataFrame(data=data_dict) pdf_arrow = self.spark.createDataFrame(pdf).toPandas(useArrow=True) self.assertFramesEqual(pdf_arrow, pdf) From 2851cd6ea9e37a9fa439bc1a7510ba62ea90b110 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 28 Feb 2017 14:03:57 -0800 Subject: [PATCH 21/56] fixed docstrings and added list of pyarrow supported types --- python/pyspark/sql/dataframe.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9ea0033157e5..438dbd78f55a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -391,10 +391,13 @@ def collect(self): port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) - @ignore_unicode_prefix @since(2.2) def collectAsArrow(self): - """Returns all records as list of deserialized ArrowPayloads + """ + Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed + and available. + + .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: port = self._jdf.collectAsArrowToPython() @@ -1608,16 +1611,21 @@ def toDF(self, *cols): @since(1.3) def toPandas(self, useArrow=False): - """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + """ + Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. :param useArrow: Make use of Apache Arrow for conversion, pyarrow must be installed - on the calling Python process. + and available on the calling Python process (Experimental). .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. + .. note:: Using pyarrow is experimental and currently supports the following data types: + StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, + LongType, ShortType + >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice From f8f24abeb4fd3394a8aa4136b1dda6b6ae4cd323 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 3 Mar 2017 14:38:03 -0800 Subject: [PATCH 22/56] fixed memory leak of ArrowRecordBatch iterator getting consumed and batches not closed properly --- .../org/apache/spark/sql/ArrowConverters.scala | 15 +++++++-------- .../main/scala/org/apache/spark/sql/Dataset.scala | 1 - 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 04e3697f8d28..bfce48a485aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -199,15 +199,14 @@ private[sql] object ArrowConverters { val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) val out = new ByteArrayOutputStream() val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) - try { - payload.foreach(writer.writeRecordBatch) - } catch { - case e: Exception => - throw e - } finally { - writer.close() - payload.foreach(_.close()) + payload.foreach { batch => + try { + writer.writeRecordBatch(batch) + } finally { + batch.close() + } } + writer.close() out.toByteArray } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index afceec6f8f74..18abd99ec578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2847,7 +2847,6 @@ class Dataset[T] private[sql]( val converter = new ArrowConverters val payload = converter.interalRowIterToPayload(iter, schema_captured) val payloadBytes = ArrowConverters.payloadToByteArray(payload, schema_captured) - payload.foreach(_.close()) Iterator(payloadBytes) } } From b6c752b0fcc19f2c5795593d9b4a1511f574c83d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 3 Mar 2017 14:46:50 -0800 Subject: [PATCH 23/56] changed _collectAsArrow to private method --- python/pyspark/sql/dataframe.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 438dbd78f55a..5f733f4d96db 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -391,18 +391,6 @@ def collect(self): port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) - @since(2.2) - def collectAsArrow(self): - """ - Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed - and available. - - .. note:: Experimental. - """ - with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer())) - @ignore_unicode_prefix @since(2.0) def toLocalIterator(self): @@ -1633,13 +1621,24 @@ def toPandas(self, useArrow=False): """ if useArrow: from pyarrow.table import concat_tables - tables = self.collectAsArrow() + tables = self._collectAsArrow() table = concat_tables(tables) return table.to_pandas() else: import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) + def _collectAsArrow(self): + """ + Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed + and available. + + .. note:: Experimental. + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(port, ArrowSerializer())) + ########################################################################################## # Pandas compatibility ########################################################################################## From cbab294cbf291bd0082139e6ca988909b06bdd1a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 3 Mar 2017 15:07:03 -0800 Subject: [PATCH 24/56] added netty to exclusion list for arrow dependency --- dev/deps/spark-deps-hadoop-2.6 | 5 ----- dev/deps/spark-deps-hadoop-2.7 | 5 ----- pom.xml | 4 ++++ 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 277bc6f44ef2..5b88a9433a32 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -144,11 +144,6 @@ minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.9.9.Final.jar netty-all-4.0.43.Final.jar -netty-buffer-4.0.41.Final.jar -netty-codec-4.0.41.Final.jar -netty-common-4.0.41.Final.jar -netty-handler-4.0.41.Final.jar -netty-transport-4.0.41.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 341db5e78daf..7959eefde172 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -145,11 +145,6 @@ minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.9.9.Final.jar netty-all-4.0.43.Final.jar -netty-buffer-4.0.41.Final.jar -netty-codec-4.0.41.Final.jar -netty-common-4.0.41.Final.jar -netty-handler-4.0.41.Final.jar -netty-transport-4.0.41.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/pom.xml b/pom.xml index 47dfea20b61c..b48cf5d6e33e 100644 --- a/pom.xml +++ b/pom.xml @@ -1889,6 +1889,10 @@ org.slf4j log4j-over-slf4j + + io.netty + netty-handler + From 44ca3ffc37711225a64fd57dca2ba747debf8c3d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 7 Mar 2017 10:21:46 -0800 Subject: [PATCH 25/56] dict comprehensions not supported in python 2.6 --- python/pyspark/sql/tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index da75714adaf9..c59cc019c020 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2388,8 +2388,9 @@ def test_toPandas_arrow_toggle(self): def test_pandas_round_trip(self): import pandas as pd - names = [(j, name) for j, name in enumerate(self.schema.names)] - data_dict = {name: [self.data[i][j] for i in range(len(self.data))] for j, name in names} + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] pdf = pd.DataFrame(data=data_dict) pdf_arrow = self.spark.createDataFrame(pdf).toPandas(useArrow=True) self.assertFramesEqual(pdf_arrow, pdf) From 33b75b99b375292e1f4691983e37919d0a620725 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 10 Mar 2017 13:12:45 -0800 Subject: [PATCH 26/56] ensure payload batches are closed if any exception is thrown, some minor cleanup --- .../apache/spark/sql/ArrowConverters.scala | 58 +++++++++++-------- .../scala/org/apache/spark/sql/Dataset.scala | 1 + .../spark/sql/ArrowConvertersSuite.scala | 1 + 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index bfce48a485aa..07ca26d3f88f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -34,11 +34,12 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * ArrowReader requires a seekable byte channel. - * NOTE - this is taken from test org.apache.vector.file, see about moving to public util pkg + * NOTE - this is taken from test org.apache.arrow.vector.file, see about moving to public util pkg */ private[sql] class ByteArrayReadableSeekableByteChannel(var byteArray: Array[Byte]) extends SeekableByteChannel { @@ -57,7 +58,7 @@ private[sql] class ByteArrayReadableSeekableByteChannel(var byteArray: Array[Byt val length = Math.min(dst.remaining(), remainingBuf).toInt dst.put(byteArray, _position.toInt, length) _position += length - length.toInt + length } override def position(): Long = _position @@ -121,6 +122,10 @@ private[sql] class ArrowConverters { } new ArrowStaticPayload(batches: _*) } + + def close(): Unit = { + _allocator.close() + } } private[sql] object ArrowConverters { @@ -168,11 +173,8 @@ private[sql] object ArrowConverters { } } - val fieldAndBuf = columnWriters.map { writer => - writer.finish() - }.unzip - val fieldNodes = fieldAndBuf._1 - val buffers = fieldAndBuf._2.flatten + val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip + val buffers = bufferArrays.flatten val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 val recordBatch = new ArrowRecordBatch(rowLength, @@ -199,14 +201,20 @@ private[sql] object ArrowConverters { val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) val out = new ByteArrayOutputStream() val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) - payload.foreach { batch => - try { + + // Iterate over payload batches to write each one, ensure all batches get closed + var batch: ArrowRecordBatch = null + Utils.tryWithSafeFinallyAndFailureCallbacks { + while (payload.hasNext) { + batch = payload.next() writer.writeRecordBatch(batch) - } finally { batch.close() } - } - writer.close() + }(catchBlock = { + Option(batch).foreach(_.close()) + payload.foreach(_.close()) + }, finallyBlock = writer.close()) + out.toByteArray } } @@ -233,7 +241,7 @@ private[sql] abstract class PrimitiveColumnWriter( def valueMutator: BaseMutator def setNull(): Unit - def setValue(row: InternalRow, ordinal: Int): Unit + def setValue(row: InternalRow): Unit protected var count = 0 protected var nullCount = 0 @@ -248,7 +256,7 @@ private[sql] abstract class PrimitiveColumnWriter( setNull() nullCount += 1 } else { - setValue(row, ordinal) + setValue(row) } count += 1 } @@ -270,7 +278,7 @@ private[sql] class BooleanColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit + override def setValue(row: InternalRow): Unit = valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal))) } @@ -281,7 +289,7 @@ private[sql] class ShortColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit + override def setValue(row: InternalRow): Unit = valueMutator.setSafe(count, row.getShort(ordinal)) } @@ -292,7 +300,7 @@ private[sql] class IntegerColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit + override def setValue(row: InternalRow): Unit = valueMutator.setSafe(count, row.getInt(ordinal)) } @@ -303,7 +311,7 @@ private[sql] class LongColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit + override def setValue(row: InternalRow): Unit = valueMutator.setSafe(count, row.getLong(ordinal)) } @@ -314,7 +322,7 @@ private[sql] class FloatColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit + override def setValue(row: InternalRow): Unit = valueMutator.setSafe(count, row.getFloat(ordinal)) } @@ -325,7 +333,7 @@ private[sql] class DoubleColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit + override def setValue(row: InternalRow): Unit = valueMutator.setSafe(count, row.getDouble(ordinal)) } @@ -336,7 +344,7 @@ private[sql] class ByteColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit + override def setValue(row: InternalRow): Unit = valueMutator.setSafe(count, row.getByte(ordinal)) } @@ -347,7 +355,7 @@ private[sql] class UTF8StringColumnWriter(ordinal: Int, allocator: BaseAllocator override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit = { + override def setValue(row: InternalRow): Unit = { val bytes = row.getUTF8String(ordinal).getBytes valueMutator.setSafe(count, bytes, 0, bytes.length) } @@ -360,7 +368,7 @@ private[sql] class BinaryColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit = { + override def setValue(row: InternalRow): Unit = { val bytes = row.getBinary(ordinal) valueMutator.setSafe(count, bytes, 0, bytes.length) } @@ -373,7 +381,7 @@ private[sql] class DateColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableDateVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit = { + override def setValue(row: InternalRow): Unit = { // TODO: comment on diff btw value representations of date/timestamp valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000) } @@ -386,7 +394,7 @@ private[sql] class TimeStampColumnWriter(ordinal: Int, allocator: BaseAllocator) override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow, ordinal: Int): Unit = { + override def setValue(row: InternalRow): Unit = { valueMutator.setSafe(count, row.getLong(ordinal) / 1000) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 18abd99ec578..97bb9eedf2fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2847,6 +2847,7 @@ class Dataset[T] private[sql]( val converter = new ArrowConverters val payload = converter.interalRowIterToPayload(iter, schema_captured) val payloadBytes = ArrowConverters.payloadToByteArray(payload, schema_captured) + converter.close() Iterator(payloadBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index 51e2455d8bcc..e64911d0964e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -392,6 +392,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { /** * Arrow JSON Format Data Generation * Referenced from https://github.com/apache/arrow/blob/master/integration/integration_test.py + * TODO: Look into using JSON generation from parquet-vector.jar */ private abstract class DataType(name: String, nullable: Boolean) { From 97742b8bc39aa735af098bba16d5355651b87025 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 13 Mar 2017 11:40:56 -0700 Subject: [PATCH 27/56] changed comment for readable seekable byte channel class --- .../src/main/scala/org/apache/spark/sql/ArrowConverters.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 07ca26d3f88f..74ffd1525f1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils /** * ArrowReader requires a seekable byte channel. - * NOTE - this is taken from test org.apache.arrow.vector.file, see about moving to public util pkg + * TODO: This is available in arrow-vector now with ARROW-615, to be included in 0.2.1 release */ private[sql] class ByteArrayReadableSeekableByteChannel(var byteArray: Array[Byte]) extends SeekableByteChannel { From b821077b29daffffc5a931a89989fbd73fc90d68 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 27 Mar 2017 21:09:30 -0400 Subject: [PATCH 28/56] Remove Date and Timestamp from supported types closes #24 --- .../scala/org/apache/spark/sql/ArrowConverters.scala | 10 ++++++---- .../org/apache/spark/sql/ArrowConvertersSuite.scala | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 74ffd1525f1f..be8676f5f0b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -144,8 +144,9 @@ private[sql] object ArrowConverters { case ByteType => new ArrowType.Int(8, true) case StringType => ArrowType.Utf8.INSTANCE case BinaryType => ArrowType.Binary.INSTANCE - case DateType => ArrowType.Date.INSTANCE - case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND) + // TODO: Enable Date and Timestamp type with Arrow 0.3 + // case DateType => ArrowType.Date.INSTANCE + // case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND) case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } @@ -411,8 +412,9 @@ private[sql] object ColumnWriter { case ByteType => new ByteColumnWriter(ordinal, allocator) case StringType => new UTF8StringColumnWriter(ordinal, allocator) case BinaryType => new BinaryColumnWriter(ordinal, allocator) - case DateType => new DateColumnWriter(ordinal, allocator) - case TimestampType => new TimeStampColumnWriter(ordinal, allocator) + // TODO: Enable Date and Timestamp type with Arrow 0.3 + // case DateType => new DateColumnWriter(ordinal, allocator) + // case TimestampType => new TimeStampColumnWriter(ordinal, allocator) case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index e64911d0964e..0a61ece23626 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -91,7 +91,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(byteData) } - test("timestamp conversion") { + ignore("timestamp conversion") { collectAndValidate(timestampData) } From 3d786a2e1b01a697678d1f4868b9e0354fe13333 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 3 Apr 2017 14:22:44 -0700 Subject: [PATCH 29/56] Added scaladocs to methods that did not have it --- .../org/apache/spark/sql/ArrowConverters.scala | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index be8676f5f0b0..1d1b0266a45b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -103,11 +103,18 @@ private[sql] class ArrowConverters { private[sql] def allocator: RootAllocator = _allocator + /** + * Iterate over the rows and convert to an ArrowPayload, using RootAllocator from this class + */ def interalRowIterToPayload(rowIter: Iterator[InternalRow], schema: StructType): ArrowPayload = { - val batch = ArrowConverters.internalRowIterToArrowBatch(rowIter, schema, allocator) + val batch = ArrowConverters.internalRowIterToArrowBatch(rowIter, schema, _allocator) new ArrowStaticPayload(batch) } + /** + * Read an Array of Arrow Record batches as byte Arrays into an ArrowPayload, using + * RootAllocator from this class + */ def readPayloadByteArrays(payloadByteArrays: Array[Array[Byte]]): ArrowPayload = { val batches = scala.collection.mutable.ArrayBuffer.empty[ArrowRecordBatch] var i = 0 @@ -123,6 +130,10 @@ private[sql] class ArrowConverters { new ArrowStaticPayload(batches: _*) } + /** + * Call when done using this converter, will close RootAllocator so any ArrowBuffers should be + * closed first + */ def close(): Unit = { _allocator.close() } @@ -152,7 +163,7 @@ private[sql] object ArrowConverters { } /** - * Iterate over InternalRows and write to an ArrowRecordBatch. + * Iterate over InternalRows and convert to an ArrowRecordBatch. */ private def internalRowIterToArrowBatch( rowIter: Iterator[InternalRow], From cb4c510fc5c44ffe8bff0b4d57e0ca4e2af48a8a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 3 Apr 2017 15:01:19 -0700 Subject: [PATCH 30/56] added check for pyarrow import error --- python/pyspark/sql/dataframe.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5f733f4d96db..11d3bdd0a3fc 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1620,10 +1620,13 @@ def toPandas(self, useArrow=False): 1 5 Bob """ if useArrow: - from pyarrow.table import concat_tables - tables = self._collectAsArrow() - table = concat_tables(tables) - return table.to_pandas() + try: + import pyarrow + tables = self._collectAsArrow() + table = pyarrow.table.concat_tables(tables) + return table.to_pandas() + except ImportError as e: + raise ImportError("%s\n%s" % (e.message, self.toPandas.__doc__)) else: import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) From 7260217327ccccb5f418a8f7da9d864c13ddcc55 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 13 Apr 2017 14:00:45 -0700 Subject: [PATCH 31/56] changed pyspark script to accept all args when testing --- bin/pyspark | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/pyspark b/bin/pyspark index 98387c2ec5b8..8eeea7716cc9 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$1" + exec "$PYSPARK_DRIVER_PYTHON" -m "$@" exit fi From a0483b8990a8113305a63c3f486df6e835a8ed32 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 13 Apr 2017 14:23:12 -0700 Subject: [PATCH 32/56] added pyarrow tests to be launched during run-pip-tests when using conda --- dev/run-pip-tests | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dev/run-pip-tests b/dev/run-pip-tests index d51dde12a03c..edfc29352d26 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -83,6 +83,8 @@ for python in "${PYTHON_EXECS[@]}"; do if [ -n "$USE_CONDA" ]; then conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools source activate "$VIRTUALENV_PATH" + conda install -y -c conda-forge pyarrow=0.2 + TEST_PYARROW=1 else mkdir -p "$VIRTUALENV_PATH" virtualenv --python=$python "$VIRTUALENV_PATH" @@ -120,6 +122,10 @@ for python in "${PYTHON_EXECS[@]}"; do python "$FWDIR"/dev/pip-sanity-check.py echo "Run the tests for context.py" python "$FWDIR"/python/pyspark/context.py + if [ -n "$TEST_PYARROW" ]; then + echo "Run tests for pyarrow" + SPARK_TESTING=1 "$FWDIR"/bin/pyspark pyspark.sql.tests ArrowTests + fi cd "$FWDIR" From c144667857cd27f94319ebf7faf3bcccd1d291c3 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 25 Apr 2017 12:10:56 -0700 Subject: [PATCH 33/56] pre-update for using Arrow 0.3, cleanup of converter functions, timestamp test not working --- pom.xml | 2 +- python/pyspark/serializers.py | 4 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/tests.py | 46 ++- .../apache/spark/sql/ArrowConverters.scala | 343 ++++++++---------- .../scala/org/apache/spark/sql/Dataset.scala | 13 +- .../spark/sql/ArrowConvertersSuite.scala | 80 ++-- 7 files changed, 232 insertions(+), 258 deletions(-) diff --git a/pom.xml b/pom.xml index 18c993a7fea5..3e283d330ad4 100644 --- a/pom.xml +++ b/pom.xml @@ -184,7 +184,7 @@ 2.6 1.8 1.0.0 - 0.2.0 + 0.2.1-SNAPSHOT ${java.home} diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 386fab77429f..07d043f88811 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -191,8 +191,8 @@ def dumps(self, obj): raise NotImplementedError def loads(self, obj): - from pyarrow import FileReader, BufferReader - reader = FileReader(BufferReader(obj)) + import pyarrow as pa + reader = pa.FileReader(pa.BufferReader(obj)) return reader.read_all() def __repr__(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2ced35dcd702..0fa7994c3c9e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1661,7 +1661,7 @@ def toPandas(self, useArrow=False): try: import pyarrow tables = self._collectAsArrow() - table = pyarrow.table.concat_tables(tables) + table = pyarrow.concat_tables(tables) return table.to_pandas() except ImportError as e: raise ImportError("%s\n%s" % (e.message, self.toPandas.__doc__)) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dced1a78b68b..178b93814aeb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2488,17 +2488,34 @@ class ArrowTests(ReusedPySparkTestCase): @classmethod def setUpClass(cls): + from datetime import datetime, tzinfo ReusedPySparkTestCase.setUpClass() cls.spark = SparkSession(cls.sc) cls.schema = StructType([ - StructField("str_t", StringType(), True), - StructField("int_t", IntegerType(), True), - StructField("long_t", LongType(), True), - StructField("float_t", FloatType(), True), - StructField("double_t", DoubleType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0), - ("b", 2, 20, 0.4, 4.0), - ("c", 3, 30, 0.8, 6.0)] + StructField("1_str_t", StringType(), True), + StructField("2_int_t", IntegerType(), True), + StructField("3_long_t", LongType(), True), + StructField("4_float_t", FloatType(), True), + StructField("5_double_t", DoubleType(), True), + StructField("6_date_t", DateType(), True), + StructField("7_timestamp_t", TimestampType(), True)]) + + def mkdt(*args): + class NaiveTZ(tzinfo): + """ + Force Spark to store internal value as offset to UTC, not local time + """ + def utcoffset(self, date_time): + return None + + def dst(self, date_time): + return None + + return datetime(*args, tzinfo=NaiveTZ()) + + cls.data = [("a", 1, 10, 0.2, 2.0, datetime(2011, 1, 1), mkdt(2011, 1, 1, 1, 1, 1)), + ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), mkdt(2012, 2, 2, 2, 2, 2)), + ("c", 3, 30, 0.8, 6.0, datetime(2013, 3, 3), mkdt(2013, 3, 3, 3, 3, 3))] def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -2515,19 +2532,26 @@ def test_null_conversion(self): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - # NOTE - toPandas(useArrow=False) will infer standard data types - df_sel = df.select("str_t", "long_t", "double_t") + # NOTE - toPandas(useArrow=False) will infer standard python data types + df_sel = df.select("1_str_t", "3_long_t", "5_double_t") pdf = df_sel.toPandas(useArrow=False) pdf_arrow = df_sel.toPandas(useArrow=True) self.assertFramesEqual(pdf_arrow, pdf) def test_pandas_round_trip(self): import pandas as pd + import numpy as np data_dict = {} for j, name in enumerate(self.schema.names): data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + # Pandas will store the datetime as 'object' if has tzinfo + data_dict["7_timestamp_t"] = [dt.replace(tzinfo=None) for dt in data_dict["7_timestamp_t"]] pdf = pd.DataFrame(data=data_dict) - pdf_arrow = self.spark.createDataFrame(pdf).toPandas(useArrow=True) + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf_arrow = df.toPandas(useArrow=True) self.assertFramesEqual(pdf_arrow, pdf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 1d1b0266a45b..5a6b14f37c4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -18,19 +18,19 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream -import java.nio.ByteBuffer -import java.nio.channels.{Channels, SeekableByteChannel} +import java.nio.channels.Channels import scala.collection.JavaConverters._ import io.netty.buffer.ArrowBuf -import org.apache.arrow.memory.{BaseAllocator, RootAllocator} +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator -import org.apache.arrow.vector.file.{ArrowReader, ArrowWriter} +import org.apache.arrow.vector.file._ import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit} -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -38,104 +38,16 @@ import org.apache.spark.util.Utils /** - * ArrowReader requires a seekable byte channel. - * TODO: This is available in arrow-vector now with ARROW-615, to be included in 0.2.1 release + * Store Arrow data in a form that can be serde by Spark */ -private[sql] class ByteArrayReadableSeekableByteChannel(var byteArray: Array[Byte]) - extends SeekableByteChannel { - var _position: Long = 0L +private[sql] class ArrowPayload(val batchBytes: Array[Byte]) extends Serializable { - override def isOpen: Boolean = { - byteArray != null + def this(batch: ArrowRecordBatch, schema: StructType, allocator: BufferAllocator) = { + this(ArrowConverters.batchToByteArray(batch, schema, allocator)) } - override def close(): Unit = { - byteArray = null - } - - override def read(dst: ByteBuffer): Int = { - val remainingBuf = byteArray.length - _position - val length = Math.min(dst.remaining(), remainingBuf).toInt - dst.put(byteArray, _position.toInt, length) - _position += length - length - } - - override def position(): Long = _position - - override def position(newPosition: Long): SeekableByteChannel = { - _position = newPosition.toLong - this - } - - override def size: Long = { - byteArray.length.toLong - } - - override def write(src: ByteBuffer): Int = { - throw new UnsupportedOperationException("Read Only") - } - - override def truncate(size: Long): SeekableByteChannel = { - throw new UnsupportedOperationException("Read Only") - } -} - -/** - * Intermediate data structure returned from Arrow conversions - */ -private[sql] abstract class ArrowPayload extends Iterator[ArrowRecordBatch] - -/** - * Build a payload from existing ArrowRecordBatches - */ -private[sql] class ArrowStaticPayload(batches: ArrowRecordBatch*) extends ArrowPayload { - private val iter = batches.iterator - override def next(): ArrowRecordBatch = iter.next() - override def hasNext: Boolean = iter.hasNext -} - -/** - * Class that wraps an Arrow RootAllocator used in conversion - */ -private[sql] class ArrowConverters { - private val _allocator = new RootAllocator(Long.MaxValue) - - private[sql] def allocator: RootAllocator = _allocator - - /** - * Iterate over the rows and convert to an ArrowPayload, using RootAllocator from this class - */ - def interalRowIterToPayload(rowIter: Iterator[InternalRow], schema: StructType): ArrowPayload = { - val batch = ArrowConverters.internalRowIterToArrowBatch(rowIter, schema, _allocator) - new ArrowStaticPayload(batch) - } - - /** - * Read an Array of Arrow Record batches as byte Arrays into an ArrowPayload, using - * RootAllocator from this class - */ - def readPayloadByteArrays(payloadByteArrays: Array[Array[Byte]]): ArrowPayload = { - val batches = scala.collection.mutable.ArrayBuffer.empty[ArrowRecordBatch] - var i = 0 - while (i < payloadByteArrays.length) { - val payloadBytes = payloadByteArrays(i) - val in = new ByteArrayReadableSeekableByteChannel(payloadBytes) - val reader = new ArrowReader(in, _allocator) - val footer = reader.readFooter() - val batchBlocks = footer.getRecordBatches.asScala.toArray - batchBlocks.foreach(block => batches += reader.readRecordBatch(block)) - i += 1 - } - new ArrowStaticPayload(batches: _*) - } - - /** - * Call when done using this converter, will close RootAllocator so any ArrowBuffers should be - * closed first - */ - def close(): Unit = { - _allocator.close() + def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { + ArrowConverters.byteArrayToBatch(batchBytes, allocator) } } @@ -155,24 +67,64 @@ private[sql] object ArrowConverters { case ByteType => new ArrowType.Int(8, true) case StringType => ArrowType.Utf8.INSTANCE case BinaryType => ArrowType.Binary.INSTANCE - // TODO: Enable Date and Timestamp type with Arrow 0.3 - // case DateType => ArrowType.Date.INSTANCE - // case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } /** - * Iterate over InternalRows and convert to an ArrowRecordBatch. + * Convert a Spark Dataset schema to Arrow schema. + */ + private[sql] def schemaToArrowSchema(schema: StructType): Schema = { + val arrowFields = schema.fields.map { f => + new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) + } + new Schema(arrowFields.toList.asJava) + } + + /** + * Maps Iterator from InternalRow to ArrowPayload + */ + private[sql] def toPayloadIterator( + rowIter: Iterator[InternalRow], + schema: StructType): Iterator[ArrowPayload] = { + new Iterator[ArrowPayload] { + private val _allocator = new RootAllocator(Long.MaxValue) + private var _nextPayload = if (rowIter.nonEmpty) convert() else null + + override def hasNext: Boolean = _nextPayload != null + + override def next(): ArrowPayload = { + val obj = _nextPayload + if (hasNext) { + if (rowIter.hasNext) { + _nextPayload = convert() + } else { + _allocator.close() + _nextPayload = null + } + } + obj + } + + private def convert(): ArrowPayload = { + val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator) + new ArrowPayload(batch, schema, _allocator) + } + } + } + + /** + * Iterate over InternalRows and write to an ArrowRecordBatch. */ private def internalRowIterToArrowBatch( rowIter: Iterator[InternalRow], schema: StructType, - allocator: RootAllocator): ArrowRecordBatch = { + allocator: BufferAllocator): ArrowRecordBatch = { val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => - ColumnWriter(ordinal, allocator, field.dataType) - .init() + ColumnWriter(ordinal, allocator, field.dataType).init() } val writerLength = columnWriters.length @@ -197,40 +149,50 @@ private[sql] object ArrowConverters { } /** - * Convert a Spark Dataset schema to Arrow schema. + * Convert an ArrowRecordBatch to a byte array and close batch */ - private[sql] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map { f => - new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) - } - new Schema(arrowFields.toList.asJava) - } - - /** - * Write an ArrowPayload to a byte array - */ - private[sql] def payloadToByteArray(payload: ArrowPayload, schema: StructType): Array[Byte] = { + private[sql] def batchToByteArray( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): Array[Byte] = { val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, allocator) val out = new ByteArrayOutputStream() - val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) + val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) // Iterate over payload batches to write each one, ensure all batches get closed - var batch: ArrowRecordBatch = null - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (payload.hasNext) { - batch = payload.next() - writer.writeRecordBatch(batch) - batch.close() - } - }(catchBlock = { - Option(batch).foreach(_.close()) - payload.foreach(_.close()) - }, finallyBlock = writer.close()) - + Utils.tryWithSafeFinally { + val loader = new VectorLoader(root) + loader.load(batch) + writer.writeBatch() + } { + batch.close() + root.close() + writer.close() + } out.toByteArray } + + /** + * Convert a byte array to an ArrowRecordBatch + */ + private[sql] def byteArrayToBatch( + batchBytes: Array[Byte], + allocator: BufferAllocator): ArrowRecordBatch = { + val in = new ByteArrayReadableSeekableByteChannel(batchBytes) + val reader = new ArrowFileReader(in, allocator) + val root = reader.getVectorSchemaRoot + val unloader = new VectorUnloader(root) + reader.loadNextBatch() + val batch = unloader.getRecordBatch + reader.close() + batch + } } +/** + * Interface for writing InternalRows to Arrow Buffers + */ private[sql] trait ColumnWriter { def init(): this.type def write(row: InternalRow): Unit @@ -245,10 +207,11 @@ private[sql] trait ColumnWriter { /** * Base class for flat arrow column writer, i.e., column without children. */ -private[sql] abstract class PrimitiveColumnWriter( - val ordinal: Int, - val allocator: BaseAllocator) - extends ColumnWriter { +private[sql] abstract class PrimitiveColumnWriter(val ordinal: Int) + extends ColumnWriter { + + def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) + def valueVector: BaseDataValueVector def valueMutator: BaseMutator @@ -281,23 +244,21 @@ private[sql] abstract class PrimitiveColumnWriter( } } -private[sql] class BooleanColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { - private def bool2int(b: Boolean): Int = if (b) 1 else 0 - +private[sql] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableBitVector - = new NullableBitVector("BooleanValue", allocator) + = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal))) + = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) } -private[sql] class ShortColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { +private[sql] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableSmallIntVector - = new NullableSmallIntVector("ShortValue", allocator) + = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -305,10 +266,10 @@ private[sql] class ShortColumnWriter(ordinal: Int, allocator: BaseAllocator) = valueMutator.setSafe(count, row.getShort(ordinal)) } -private[sql] class IntegerColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { +private[sql] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableIntVector - = new NullableIntVector("IntValue", allocator) + = new NullableIntVector("IntValue", getFieldType(dtype), allocator) override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -316,10 +277,10 @@ private[sql] class IntegerColumnWriter(ordinal: Int, allocator: BaseAllocator) = valueMutator.setSafe(count, row.getInt(ordinal)) } -private[sql] class LongColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { +private[sql] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableBigIntVector - = new NullableBigIntVector("LongValue", allocator) + = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -327,10 +288,10 @@ private[sql] class LongColumnWriter(ordinal: Int, allocator: BaseAllocator) = valueMutator.setSafe(count, row.getLong(ordinal)) } -private[sql] class FloatColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { +private[sql] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableFloat4Vector - = new NullableFloat4Vector("FloatValue", allocator) + = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -338,10 +299,10 @@ private[sql] class FloatColumnWriter(ordinal: Int, allocator: BaseAllocator) = valueMutator.setSafe(count, row.getFloat(ordinal)) } -private[sql] class DoubleColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { +private[sql] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableFloat8Vector - = new NullableFloat8Vector("DoubleValue", allocator) + = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -349,10 +310,10 @@ private[sql] class DoubleColumnWriter(ordinal: Int, allocator: BaseAllocator) = valueMutator.setSafe(count, row.getDouble(ordinal)) } -private[sql] class ByteColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { +private[sql] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableUInt1Vector - = new NullableUInt1Vector("ByteValue", allocator) + = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -360,10 +321,13 @@ private[sql] class ByteColumnWriter(ordinal: Int, allocator: BaseAllocator) = valueMutator.setSafe(count, row.getByte(ordinal)) } -private[sql] class UTF8StringColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { +private[sql] class UTF8StringColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("UTF8StringValue", allocator) + = new NullableVarBinaryVector("UTF8StringValue", getFieldType(dtype), allocator) override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -373,10 +337,10 @@ private[sql] class UTF8StringColumnWriter(ordinal: Int, allocator: BaseAllocator } } -private[sql] class BinaryColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { +private[sql] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("BinaryValue", allocator) + = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -386,46 +350,45 @@ private[sql] class BinaryColumnWriter(ordinal: Int, allocator: BaseAllocator) } } -private[sql] class DateColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { - override val valueVector: NullableDateVector - = new NullableDateVector("DateValue", allocator) - override val valueMutator: NullableDateVector#Mutator = valueVector.getMutator +private[sql] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableDateDayVector + = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) + override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow): Unit = { - // TODO: comment on diff btw value representations of date/timestamp - valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000) + valueMutator.setSafe(count, row.getInt(ordinal)) } } -private[sql] class TimeStampColumnWriter(ordinal: Int, allocator: BaseAllocator) - extends PrimitiveColumnWriter(ordinal, allocator) { +private[sql] class TimeStampColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableTimeStampMicroVector - = new NullableTimeStampMicroVector("TimeStampValue", allocator) + = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getLong(ordinal) / 1000) + valueMutator.setSafe(count, row.getLong(ordinal)) } } private[sql] object ColumnWriter { - def apply(ordinal: Int, allocator: BaseAllocator, dataType: DataType): ColumnWriter = { + def apply(ordinal: Int, allocator: BufferAllocator, dataType: DataType): ColumnWriter = { + val dtype = ArrowConverters.sparkTypeToArrowType(dataType) dataType match { - case BooleanType => new BooleanColumnWriter(ordinal, allocator) - case ShortType => new ShortColumnWriter(ordinal, allocator) - case IntegerType => new IntegerColumnWriter(ordinal, allocator) - case LongType => new LongColumnWriter(ordinal, allocator) - case FloatType => new FloatColumnWriter(ordinal, allocator) - case DoubleType => new DoubleColumnWriter(ordinal, allocator) - case ByteType => new ByteColumnWriter(ordinal, allocator) - case StringType => new UTF8StringColumnWriter(ordinal, allocator) - case BinaryType => new BinaryColumnWriter(ordinal, allocator) - // TODO: Enable Date and Timestamp type with Arrow 0.3 - // case DateType => new DateColumnWriter(ordinal, allocator) - // case TimestampType => new TimeStampColumnWriter(ordinal, allocator) + case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) + case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) + case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) + case LongType => new LongColumnWriter(dtype, ordinal, allocator) + case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) + case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) + case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) + case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) + case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) + case DateType => new DateColumnWriter(dtype, ordinal, allocator) + case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a358166dc25c..31d7e751f988 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2776,10 +2776,9 @@ class Dataset[T] private[sql]( * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - val payloadRdd = toArrowPayloadBytes() - val payloadByteArrays = payloadRdd.collect() withNewExecutionId { - PythonRDD.serveIterator(payloadByteArrays.iterator, "serve-Arrow") + val iter = toArrowPayload.collect().iterator.map(_.batchBytes) + PythonRDD.serveIterator(iter, "serve-Arrow") } } @@ -2866,14 +2865,10 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayloadBytes(): RDD[Array[Byte]] = { + private[sql] def toArrowPayload: RDD[ArrowPayload] = { val schema_captured = this.schema queryExecution.toRdd.mapPartitionsInternal { iter => - val converter = new ArrowConverters - val payload = converter.interalRowIterToPayload(iter, schema_captured) - val payloadBytes = ArrowConverters.payloadToByteArray(payload, schema_captured) - converter.close() - Iterator(payloadBytes) + ArrowConverters.toPayloadIterator(iter, schema_captured) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index 0a61ece23626..3d37960bba1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -23,6 +23,7 @@ import java.text.SimpleDateFormat import java.util.Locale import com.google.common.io.Files +import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator @@ -42,13 +43,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { private var tempDataPath: String = _ - private def collectAsArrow(df: DataFrame, - converter: Option[ArrowConverters] = None): ArrowPayload = { - val cnvtr = converter.getOrElse(new ArrowConverters) - val payloadByteArrays = df.toArrowPayloadBytes().collect() - cnvtr.readPayloadByteArrays(payloadByteArrays) - } - override def beforeAll(): Unit = { super.beforeAll() tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath @@ -56,14 +50,16 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("collect to arrow record batch") { val indexData = (1 to 6).toDF("i") - val arrowPayload = collectAsArrow(indexData) - assert(arrowPayload.nonEmpty) - val arrowBatches = arrowPayload.toArray - assert(arrowBatches.length == indexData.rdd.getNumPartitions) - val rowCount = arrowBatches.map(batch => batch.getLength).sum + val arrowPayloads = indexData.toArrowPayload.collect() + assert(arrowPayloads.nonEmpty) + assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val rowCount = arrowRecordBatches.map(_.getLength).sum assert(rowCount === indexData.count()) - arrowBatches.foreach(batch => assert(batch.getNodes.size() > 0)) - arrowBatches.foreach(batch => batch.close()) + arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) + arrowRecordBatches.foreach(_.close()) + allocator.close() } test("numeric type conversion") { @@ -91,7 +87,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(byteData) } - ignore("timestamp conversion") { + test("timestamp conversion") { collectAndValidate(timestampData) } @@ -110,14 +106,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("partitioned DataFrame") { - val converter = new ArrowConverters - val schema = testData2.schema - val arrowPayload = collectAsArrow(testData2, Some(converter)) - val arrowBatches = arrowPayload.toArray + val arrowPayloads = testData2.toArrowPayload.collect() // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowBatches.length === 2) - val pl1 = new ArrowStaticPayload(arrowBatches(0)) - val pl2 = new ArrowStaticPayload(arrowBatches(1)) + assert(arrowPayloads.length === 2) // Generate JSON files val a = List[Int](1, 1, 2, 2, 3, 3) val b = List[Int](1, 2, 1, 2, 1, 2) @@ -134,22 +125,25 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") json1.write(tempFile1) json2.write(tempFile2) - validateConversion(schema, pl1, tempFile1, Some(converter)) - validateConversion(schema, pl2, tempFile2, Some(converter)) + val schema = testData2.schema + validateConversion(schema, arrowPayloads(0), tempFile1) + validateConversion(schema, arrowPayloads(1), tempFile2) } test("empty frame collect") { - val arrowPayload = collectAsArrow(spark.emptyDataFrame) + val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() assert(arrowPayload.isEmpty) } test("empty partition collect") { val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayload = collectAsArrow(emptyPart) - val arrowBatches = arrowPayload.toArray - assert(arrowBatches.length === 2) - assert(arrowBatches.count(_.getLength == 0) === 1) - assert(arrowBatches.count(_.getLength == 1) === 1) + val arrowPayloads = emptyPart.toArrowPayload.collect() + assert(arrowPayloads.length === 1) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + assert(arrowRecordBatches.head.getLength == 1) + arrowRecordBatches.foreach(_.close()) + allocator.close() } test("unsupported types") { @@ -161,10 +155,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { collectAsArrow(decimalData) } - runUnsupported { collectAsArrow(arrayData.toDF()) } - runUnsupported { collectAsArrow(mapData.toDF()) } - runUnsupported { collectAsArrow(complexData) } + runUnsupported { decimalData.toArrowPayload.collect() } + runUnsupported { arrayData.toDF().toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowPayload.collect() } + runUnsupported { complexData.toArrowPayload.collect() } } test("test Arrow Validator") { @@ -184,28 +178,26 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate(data: DataTuple): Unit = { - val converter = new ArrowConverters // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = collectAsArrow(data.df.coalesce(1), Some(converter)) + val arrowPayload = data.df.coalesce(1).toArrowPayload.collect().head val tempFile = new File(tempDataPath, data.file) data.json.write(tempFile) - validateConversion(data.df.schema, arrowPayload, tempFile, Some(converter)) + validateConversion(data.df.schema, arrowPayload, tempFile) } private def validateConversion(sparkSchema: StructType, arrowPayload: ArrowPayload, - jsonFile: File, - converterOpt: Option[ArrowConverters] = None): Unit = { - val converter = converterOpt.getOrElse(new ArrowConverters) - val jsonReader = new JsonFileReader(jsonFile, converter.allocator) + jsonFile: File): Unit = { + val allocator = new RootAllocator(Long.MaxValue) + val jsonReader = new JsonFileReader(jsonFile, allocator) val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) - val arrowRoot = new VectorSchemaRoot(arrowSchema, converter.allocator) + val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) val vectorLoader = new VectorLoader(arrowRoot) - arrowPayload.foreach(vectorLoader.load) + vectorLoader.load(arrowPayload.loadBatch(allocator)) val jsonRoot = jsonReader.read() Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) } @@ -504,7 +496,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { override def _get_type: JObject = { JObject( "name" -> "timestamp", - "unit" -> "MILLISECOND") + "unit" -> "MICROSECOND") } } From 250b5810d4c9152066e84ed34653227956f5297b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 26 Apr 2017 10:07:36 -0700 Subject: [PATCH 34/56] added DateType to tests --- python/pyspark/sql/tests.py | 2 +- .../spark/sql/ArrowConvertersSuite.scala | 36 ++++++++++++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 178b93814aeb..55003c85406e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2503,7 +2503,7 @@ def setUpClass(cls): def mkdt(*args): class NaiveTZ(tzinfo): """ - Force Spark to store internal value as offset to UTC, not local time + This will have Spark store internal value as UTC, not local time """ def utcoffset(self, date_time): return None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index 3d37960bba1b..7457b287147b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -27,6 +27,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.json4s.jackson.JsonMethods._ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ @@ -92,8 +93,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } // TODO: Not currently supported in Arrow JSON reader - ignore("date conversion") { - // collectAndValidate(dateTimeData) + test("date conversion") { + collectAndValidate(dateData) } // TODO: Not currently supported in Arrow JSON reader @@ -362,23 +363,25 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) val data = Seq(ts1, ts2) - val schema = new JSONSchema(Seq(new TimestampType("c_timestamp"))) + val schema = new JSONSchema(Seq(new TimestampType("timestamp"))) + val us_data = data.map(_.getTime * 1000) // convert to microseconds val columns = Seq( - new PrimitiveColumn("c_timestamp", data.length, data.map(_ => true), data.map(_.getTime))) + new PrimitiveColumn("timestamp", data.length, data.map(_ => true), us_data)) val batch = new JSONRecordBatch(data.length, columns) - DataTuple(data.toDF("c_timestamp"), new JSONFile(schema, Seq(batch)), "timestampData.json") + DataTuple(data.toDF("timestamp"), new JSONFile(schema, Seq(batch)), "timestampData.json") } private def dateData: DataTuple = { val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - val df = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) - .toDF("a_date", "b_string", "c_timestamp") - val jsonFile = new JSONFile(new JSONSchema(Seq.empty[DataType]), Seq.empty[JSONRecordBatch]) - DataTuple(df, jsonFile, "dateData.json") + val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) + val data = Seq(d1, d2) + val day_data = data.map(d => DateTimeUtils.millisToDays(d.getTime)) + val schema = new JSONSchema(Seq(new DateType("date"))) + val columns = Seq( + new PrimitiveColumn("date", data.length, data.map(_ => true), day_data)) + val batch = new JSONRecordBatch(data.length, columns) + DataTuple(data.toDF("date"), new JSONFile(schema, Seq(batch)), "dateData.json") } /** @@ -491,6 +494,15 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { override def _get_type: JObject = JObject("name" -> JString("utf8")) } + private class DateType(name: String) extends PrimitiveType(name, nullable = true) { + override val bit_width = 32 + override def _get_type: JObject = { + JObject( + "name" -> "date", + "unit" -> "DAY") + } + } + private class TimestampType(name: String) extends PrimitiveType(name, nullable = true) { override val bit_width = 64 override def _get_type: JObject = { From f667a7add608a57fbe876018c4154af5b70bd7d7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 26 Apr 2017 11:01:49 -0700 Subject: [PATCH 35/56] removed support for DateType and TimestampType for now --- python/pyspark/sql/tests.py | 12 +++------ .../apache/spark/sql/ArrowConverters.scala | 4 +-- .../spark/sql/ArrowConvertersSuite.scala | 25 ++++++++++--------- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 55003c85406e..a77c0d877a93 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2496,9 +2496,7 @@ def setUpClass(cls): StructField("2_int_t", IntegerType(), True), StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True), - StructField("6_date_t", DateType(), True), - StructField("7_timestamp_t", TimestampType(), True)]) + StructField("5_double_t", DoubleType(), True)]) def mkdt(*args): class NaiveTZ(tzinfo): @@ -2513,9 +2511,9 @@ def dst(self, date_time): return datetime(*args, tzinfo=NaiveTZ()) - cls.data = [("a", 1, 10, 0.2, 2.0, datetime(2011, 1, 1), mkdt(2011, 1, 1, 1, 1, 1)), - ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), mkdt(2012, 2, 2, 2, 2, 2)), - ("c", 3, 30, 0.8, 6.0, datetime(2013, 3, 3), mkdt(2013, 3, 3, 3, 3, 3))] + cls.data = [("a", 1, 10, 0.2, 2.0), + ("b", 2, 20, 0.4, 4.0), + ("c", 3, 30, 0.8, 6.0)] def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -2547,8 +2545,6 @@ def test_pandas_round_trip(self): # need to convert these to numpy types first data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) - # Pandas will store the datetime as 'object' if has tzinfo - data_dict["7_timestamp_t"] = [dt.replace(tzinfo=None) for dt in data_dict["7_timestamp_t"]] pdf = pd.DataFrame(data=data_dict) df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas(useArrow=True) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 5a6b14f37c4c..9384a72f313e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -67,8 +67,6 @@ private[sql] object ArrowConverters { case ByteType => new ArrowType.Int(8, true) case StringType => ArrowType.Utf8.INSTANCE case BinaryType => ArrowType.Binary.INSTANCE - case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } @@ -160,7 +158,7 @@ private[sql] object ArrowConverters { val out = new ByteArrayOutputStream() val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) - // Iterate over payload batches to write each one, ensure all batches get closed + // Write batch to a byte array, ensure the batch, allocator and writer are closed Utils.tryWithSafeFinally { val loader = new VectorLoader(root) loader.load(batch) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index 7457b287147b..c4a58959388d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -76,30 +76,29 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(mixedNumericData) } - test("boolean type conversion") { - collectAndValidate(boolData) - } - test("string type conversion") { collectAndValidate(stringData) } + test("boolean type conversion") { + collectAndValidate(boolData) + } + test("byte type conversion") { collectAndValidate(byteData) } - test("timestamp conversion") { - collectAndValidate(timestampData) + // TODO: Not currently supported in Arrow JSON reader + ignore("binary type conversion") { + // collectAndValidate(binaryData) } - // TODO: Not currently supported in Arrow JSON reader - test("date conversion") { - collectAndValidate(dateData) + ignore("timestamp conversion") { + collectAndValidate(timestampData) } - // TODO: Not currently supported in Arrow JSON reader - ignore("binary type conversion") { - // collectAndValidate(binaryData) + ignore("date conversion") { + collectAndValidate(dateData) } test("floating-point NaN") { @@ -160,6 +159,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { runUnsupported { arrayData.toDF().toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } + runUnsupported { dateData.df.toArrowPayload.collect() } + runUnsupported { timestampData.df.toArrowPayload.collect() } } test("test Arrow Validator") { From 76f7ddb4b21bce901a4511f98e0a46541fe8f770 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 26 Apr 2017 14:06:24 -0700 Subject: [PATCH 36/56] moved ArrowConverters to o.a.s.sql.execution.arrow --- .../scala/org/apache/spark/sql/Dataset.scala | 1 + .../arrow}/ArrowConverters.scala | 42 +++++++++---------- .../arrow}/ArrowConvertersSuite.scala | 3 +- 3 files changed, 24 insertions(+), 22 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/arrow}/ArrowConverters.scala (90%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/arrow}/ArrowConvertersSuite.scala (99%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 31d7e751f988..b7a6b29fcc29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -48,6 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala similarity index 90% rename from sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 9384a72f313e..870ac8f597ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution.arrow import java.io.ByteArrayOutputStream import java.nio.channels.Channels @@ -38,7 +38,7 @@ import org.apache.spark.util.Utils /** - * Store Arrow data in a form that can be serde by Spark + * Store Arrow data in a form that can be serialized by Spark */ private[sql] class ArrowPayload(val batchBytes: Array[Byte]) extends Serializable { @@ -54,9 +54,9 @@ private[sql] class ArrowPayload(val batchBytes: Array[Byte]) extends Serializabl private[sql] object ArrowConverters { /** - * Map a Spark Dataset type to ArrowType. + * Map a Spark DataType to ArrowType. */ - private[sql] def sparkTypeToArrowType(dataType: DataType): ArrowType = { + private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { dataType match { case BooleanType => ArrowType.Bool.INSTANCE case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) @@ -74,7 +74,7 @@ private[sql] object ArrowConverters { /** * Convert a Spark Dataset schema to Arrow schema. */ - private[sql] def schemaToArrowSchema(schema: StructType): Schema = { + private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { val arrowFields = schema.fields.map { f => new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) } @@ -149,7 +149,7 @@ private[sql] object ArrowConverters { /** * Convert an ArrowRecordBatch to a byte array and close batch */ - private[sql] def batchToByteArray( + private[arrow] def batchToByteArray( batch: ArrowRecordBatch, schema: StructType, allocator: BufferAllocator): Array[Byte] = { @@ -174,7 +174,7 @@ private[sql] object ArrowConverters { /** * Convert a byte array to an ArrowRecordBatch */ - private[sql] def byteArrayToBatch( + private[arrow] def byteArrayToBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { val in = new ByteArrayReadableSeekableByteChannel(batchBytes) @@ -191,7 +191,7 @@ private[sql] object ArrowConverters { /** * Interface for writing InternalRows to Arrow Buffers */ -private[sql] trait ColumnWriter { +private[arrow] trait ColumnWriter { def init(): this.type def write(row: InternalRow): Unit @@ -205,7 +205,7 @@ private[sql] trait ColumnWriter { /** * Base class for flat arrow column writer, i.e., column without children. */ -private[sql] abstract class PrimitiveColumnWriter(val ordinal: Int) +private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) extends ColumnWriter { def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) @@ -242,7 +242,7 @@ private[sql] abstract class PrimitiveColumnWriter(val ordinal: Int) } } -private[sql] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableBitVector = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) @@ -253,7 +253,7 @@ private[sql] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) } -private[sql] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableSmallIntVector = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) @@ -264,7 +264,7 @@ private[sql] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: = valueMutator.setSafe(count, row.getShort(ordinal)) } -private[sql] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableIntVector = new NullableIntVector("IntValue", getFieldType(dtype), allocator) @@ -275,7 +275,7 @@ private[sql] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator = valueMutator.setSafe(count, row.getInt(ordinal)) } -private[sql] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableBigIntVector = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) @@ -286,7 +286,7 @@ private[sql] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: B = valueMutator.setSafe(count, row.getLong(ordinal)) } -private[sql] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableFloat4Vector = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) @@ -297,7 +297,7 @@ private[sql] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: = valueMutator.setSafe(count, row.getFloat(ordinal)) } -private[sql] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableFloat8Vector = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) @@ -308,7 +308,7 @@ private[sql] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: = valueMutator.setSafe(count, row.getDouble(ordinal)) } -private[sql] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableUInt1Vector = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) @@ -319,7 +319,7 @@ private[sql] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: B = valueMutator.setSafe(count, row.getByte(ordinal)) } -private[sql] class UTF8StringColumnWriter( +private[arrow] class UTF8StringColumnWriter( dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) @@ -335,7 +335,7 @@ private[sql] class UTF8StringColumnWriter( } } -private[sql] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableVarBinaryVector = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) @@ -348,7 +348,7 @@ private[sql] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: } } -private[sql] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableDateDayVector = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) @@ -360,7 +360,7 @@ private[sql] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: B } } -private[sql] class TimeStampColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class TimeStampColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableTimeStampMicroVector = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) @@ -372,7 +372,7 @@ private[sql] class TimeStampColumnWriter(dtype: ArrowType, ordinal: Int, allocat } } -private[sql] object ColumnWriter { +private[arrow] object ColumnWriter { def apply(ordinal: Int, allocator: BufferAllocator, dataType: DataType): ColumnWriter = { val dtype = ArrowConverters.sparkTypeToArrowType(dataType) dataType match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index c4a58959388d..179d35366c0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution.arrow import java.io.File import java.nio.charset.StandardCharsets @@ -34,6 +34,7 @@ import org.json4s.JsonDSL._ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils From 89dd0f46a7c5e8c11368912647338a2f96518a09 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 26 Apr 2017 17:13:31 -0700 Subject: [PATCH 37/56] changed useArrow flag to SQLConf spark.sql.execution.arrow.enable --- python/pyspark/sql/dataframe.py | 15 +++----- python/pyspark/sql/tests.py | 34 ++++++++----------- .../sql/execution/arrow/ArrowConverters.scala | 5 ++- .../arrow/ArrowConvertersSuite.scala | 4 +-- 4 files changed, 25 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0fa7994c3c9e..57e538d6a669 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1636,35 +1636,30 @@ def toDF(self, *cols): return DataFrame(jdf, self.sql_ctx) @since(1.3) - def toPandas(self, useArrow=False): + def toPandas(self): """ Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. - :param useArrow: Make use of Apache Arrow for conversion, pyarrow must be installed - and available on the calling Python process (Experimental). - .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. - .. note:: Using pyarrow is experimental and currently supports the following data types: - StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, - LongType, ShortType - >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice 1 5 Bob """ - if useArrow: + if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": try: import pyarrow tables = self._collectAsArrow() table = pyarrow.concat_tables(tables) return table.to_pandas() except ImportError as e: - raise ImportError("%s\n%s" % (e.message, self.toPandas.__doc__)) + msg = "note: pyarrow must be installed and available on calling Python process " \ + "if using spark.sql.execution.arrow.enable=true" + raise ImportError("%s\n%s" % (e.message, msg)) else: import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a77c0d877a93..cc6dd49f9b13 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -50,7 +50,7 @@ from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type -from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -2488,29 +2488,15 @@ class ArrowTests(ReusedPySparkTestCase): @classmethod def setUpClass(cls): - from datetime import datetime, tzinfo ReusedPySparkTestCase.setUpClass() cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), StructField("5_double_t", DoubleType(), True)]) - - def mkdt(*args): - class NaiveTZ(tzinfo): - """ - This will have Spark store internal value as UTC, not local time - """ - def utcoffset(self, date_time): - return None - - def dst(self, date_time): - return None - - return datetime(*args, tzinfo=NaiveTZ()) - cls.data = [("a", 1, 10, 0.2, 2.0), ("b", 2, 20, 0.4, 4.0), ("c", 3, 30, 0.8, 6.0)] @@ -2521,10 +2507,16 @@ def assertFramesEqual(self, df_with_arrow, df_without): ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + def test_unsupported_datatype(self): + schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) + df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: df.toPandas()) + def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) - pdf = df_null.toPandas(useArrow=True) + pdf = df_null.toPandas() null_counts = pdf.isnull().sum().tolist() self.assertTrue(all([c == 1 for c in null_counts])) @@ -2532,8 +2524,10 @@ def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) # NOTE - toPandas(useArrow=False) will infer standard python data types df_sel = df.select("1_str_t", "3_long_t", "5_double_t") - pdf = df_sel.toPandas(useArrow=False) - pdf_arrow = df_sel.toPandas(useArrow=True) + self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + pdf = df_sel.toPandas() + self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + pdf_arrow = df_sel.toPandas() self.assertFramesEqual(pdf_arrow, pdf) def test_pandas_round_trip(self): @@ -2547,7 +2541,7 @@ def test_pandas_round_trip(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) pdf = pd.DataFrame(data=data_dict) df = self.spark.createDataFrame(self.data, schema=self.schema) - pdf_arrow = df.toPandas(useArrow=True) + pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 870ac8f597ba..03f91b3f9211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -360,7 +360,10 @@ private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: } } -private[arrow] class TimeStampColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) +private[arrow] class TimeStampColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { override val valueVector: NullableTimeStampMicroVector = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 179d35366c0d..ef34598e1913 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -27,7 +27,6 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.json4s.jackson.JsonMethods._ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ @@ -35,6 +34,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -147,7 +147,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { allocator.close() } - test("unsupported types") { + testQuietly("unsupported types") { def runUnsupported(block: => Unit): Unit = { val msg = intercept[SparkException] { block From d7cb4ab823ca1dd6fea11082f23a01c9b5de95ed Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 26 Apr 2017 18:02:43 -0700 Subject: [PATCH 38/56] separated numeric tests, moved data to test scope --- .../sql/execution/arrow/ArrowConverters.scala | 4 +- .../arrow/ArrowConvertersSuite.scala | 423 ++++++++++-------- 2 files changed, 231 insertions(+), 196 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 03f91b3f9211..8a0f9c39e1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -158,11 +158,11 @@ private[sql] object ArrowConverters { val out = new ByteArrayOutputStream() val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) - // Write batch to a byte array, ensure the batch, allocator and writer are closed + // Write a batch to byte stream, ensure the batch, allocator and writer are closed Utils.tryWithSafeFinally { val loader = new VectorLoader(root) loader.load(batch) - writer.writeBatch() + writer.writeBatch() // writeBatch can throw IOException } { batch.close() root.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index ef34598e1913..8f5a6e47eaf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -64,162 +64,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { allocator.close() } - test("numeric type conversion") { - collectAndValidate(indexData) - collectAndValidate(shortData) - collectAndValidate(intData) - collectAndValidate(longData) - collectAndValidate(floatData) - collectAndValidate(doubleData) - } - - test("mixed numeric type conversion") { - collectAndValidate(mixedNumericData) - } - - test("string type conversion") { - collectAndValidate(stringData) - } - - test("boolean type conversion") { - collectAndValidate(boolData) - } - - test("byte type conversion") { - collectAndValidate(byteData) - } - - // TODO: Not currently supported in Arrow JSON reader - ignore("binary type conversion") { - // collectAndValidate(binaryData) - } - - ignore("timestamp conversion") { - collectAndValidate(timestampData) - } - - ignore("date conversion") { - collectAndValidate(dateData) - } - - test("floating-point NaN") { - collectAndValidate(floatNaNData) - } - - test("partitioned DataFrame") { - val arrowPayloads = testData2.toArrowPayload.collect() - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowPayloads.length === 2) - // Generate JSON files - val a = List[Int](1, 1, 2, 2, 3, 3) - val b = List[Int](1, 2, 1, 2, 1, 2) - val fields = Seq(new IntegerType("a", is_signed = true, 32, nullable = false), - new IntegerType("b", is_signed = true, 32, nullable = false)) - def getBatch(x: Seq[Int], y: Seq[Int]): JSONRecordBatch = { - val columns = Seq(new PrimitiveColumn("a", x.length, x.map(_ => true), x), - new PrimitiveColumn("b", y.length, y.map(_ => true), y)) - new JSONRecordBatch(x.length, columns) - } - val json1 = new JSONFile(new JSONSchema(fields), Seq(getBatch(a.take(3), b.take(3)))) - val json2 = new JSONFile(new JSONSchema(fields), Seq(getBatch(a.takeRight(3), b.takeRight(3)))) - val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") - val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") - json1.write(tempFile1) - json2.write(tempFile2) - val schema = testData2.schema - validateConversion(schema, arrowPayloads(0), tempFile1) - validateConversion(schema, arrowPayloads(1), tempFile2) - } - - test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() - assert(arrowPayload.isEmpty) - } - - test("empty partition collect") { - val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowPayload.collect() - assert(arrowPayloads.length === 1) - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - assert(arrowRecordBatches.head.getLength == 1) - arrowRecordBatches.foreach(_.close()) - allocator.close() - } - - testQuietly("unsupported types") { - def runUnsupported(block: => Unit): Unit = { - val msg = intercept[SparkException] { - block - } - assert(msg.getMessage.contains("Unsupported data type")) - assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) - } - - runUnsupported { decimalData.toArrowPayload.collect() } - runUnsupported { arrayData.toDF().toArrowPayload.collect() } - runUnsupported { mapData.toDF().toArrowPayload.collect() } - runUnsupported { complexData.toArrowPayload.collect() } - runUnsupported { dateData.df.toArrowPayload.collect() } - runUnsupported { timestampData.df.toArrowPayload.collect() } - } - - test("test Arrow Validator") { - val sdata = shortData - val idata = intData - - // Different schema - intercept[IllegalArgumentException] { - collectAndValidate(DataTuple(sdata.df, idata.json, idata.file)) - } - - // Different values - intercept[IllegalArgumentException] { - collectAndValidate(DataTuple(idata.df.sort($"a_i".desc), idata.json, idata.file)) - } - } - - /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(data: DataTuple): Unit = { - // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = data.df.coalesce(1).toArrowPayload.collect().head - val tempFile = new File(tempDataPath, data.file) - data.json.write(tempFile) - validateConversion(data.df.schema, arrowPayload, tempFile) - } - - private def validateConversion(sparkSchema: StructType, - arrowPayload: ArrowPayload, - jsonFile: File): Unit = { - val allocator = new RootAllocator(Long.MaxValue) - val jsonReader = new JsonFileReader(jsonFile, allocator) - - val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) - val jsonSchema = jsonReader.start() - Validator.compareSchemas(arrowSchema, jsonSchema) - - val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) - val vectorLoader = new VectorLoader(arrowRoot) - vectorLoader.load(arrowPayload.loadBatch(allocator)) - val jsonRoot = jsonReader.read() - Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) - } - - // Create Spark DataFrame and matching Arrow JSON at same time for validation - private case class DataTuple(df: DataFrame, json: JSONFile, file: String) - - private def indexData: DataTuple = { - val data = List[Int](1, 2, 3, 4, 5, 6) - val fields = Seq(new IntegerType("i", is_signed = true, 32, nullable = false)) - val schema = new JSONSchema(fields) - val columns = Seq(new PrimitiveColumn("i", data.length, data.map(_ => true), data)) - val batch = new JSONRecordBatch(data.length, columns) - DataTuple(data.toDF("i"), new JSONFile(schema, Seq(batch)), "indexData-ints.json") - } - - private def shortData: DataTuple = { + test("short conversion") { val a_s = List[Short](1, -1, 2, -2, 32767, -32768) val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) + val fields = Seq(new IntegerType("a_s", is_signed = true, 16, nullable = false), new IntegerType("b_s", is_signed = true, 16, nullable = true)) val schema = new JSONSchema(fields) @@ -228,13 +76,17 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { new PrimitiveColumn("a_s", a_s.length, a_s.map(_ => true), a_s.map(_.toInt)), new PrimitiveColumn("b_s", b_s.length, b_s.map(_.isDefined), b_s_values)) val batch = new JSONRecordBatch(a_s.length, columns) + val json = new JSONFile(schema, Seq(batch)) + val df = a_s.zip(b_s).toDF("a_s", "b_s") - DataTuple(df, new JSONFile(schema, Seq(batch)), "integer-16bit.json") + + collectAndValidate(df, json, "integer-16bit.json") } - private def intData: DataTuple = { + test("int conversion") { val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val fields = Seq(new IntegerType("a_i", is_signed = true, 32, nullable = false), new IntegerType("b_i", is_signed = true, 32, nullable = true)) val schema = new JSONSchema(fields) @@ -242,13 +94,17 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { new PrimitiveColumn("a_i", a_i.length, a_i.map(_ => true), a_i), new PrimitiveColumn("b_i", b_i.length, b_i.map(_.isDefined), b_i.map(_.getOrElse(0)))) val batch = new JSONRecordBatch(a_i.length, columns) + val json = new JSONFile(schema, Seq(batch)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") - DataTuple(df, new JSONFile(schema, Seq(batch)), "integer-32bit.json") + + collectAndValidate(df, json, "integer-32bit.json") } - private def longData: DataTuple = { + test("long conversion") { val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) + val fields = Seq(new IntegerType("a_l", is_signed = true, 64, nullable = false), new IntegerType("b_l", is_signed = true, 64, nullable = true)) val schema = new JSONSchema(fields) @@ -256,26 +112,34 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { new PrimitiveColumn("a_l", a_l.length, a_l.map(_ => true), a_l), new PrimitiveColumn("b_l", b_l.length, b_l.map(_.isDefined), b_l.map(_.getOrElse(0L)))) val batch = new JSONRecordBatch(a_l.length, columns) + val json = new JSONFile(schema, Seq(batch)) + val df = a_l.zip(b_l).toDF("a_l", "b_l") - DataTuple(df, new JSONFile(schema, Seq(batch)), "integer-64bit.json") + + collectAndValidate(df, json, "integer-64bit.json") } - private def floatData: DataTuple = { + test("float conversion") { val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) + val fields = Seq(new FloatingPointType("a_f", 32, nullable = false), new FloatingPointType("b_f", 32, nullable = true)) val schema = new JSONSchema(fields) val columns = Seq(new PrimitiveColumn("a_f", a_f.length, a_f.map(_ => true), a_f), new PrimitiveColumn("b_f", b_f.length, b_f.map(_.isDefined), b_f.map(_.getOrElse(0.0f)))) val batch = new JSONRecordBatch(a_f.length, columns) + val json = new JSONFile(schema, Seq(batch)) + val df = a_f.zip(b_f).toDF("a_f", "b_f") - DataTuple(df, new JSONFile(schema, Seq(batch)), "floating_point-single_precision.json") + + collectAndValidate(df, json, "floating_point-single_precision.json") } - private def doubleData: DataTuple = { + test("double conversion") { val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) + val fields = Seq(new FloatingPointType("a_d", 64, nullable = false), new FloatingPointType("b_d", 64, nullable = true)) val schema = new JSONSchema(fields) @@ -283,11 +147,29 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { new PrimitiveColumn("b_d", b_d.length, b_d.map(_.isDefined), b_d.map(_.getOrElse(0.0)))) val batch = new JSONRecordBatch(a_d.length, columns) val df = a_d.zip(b_d).toDF("a_d", "b_d") - DataTuple(df, new JSONFile(schema, Seq(batch)), "floating_point-double_precision.json") + + val json = new JSONFile(schema, Seq(batch)) + + collectAndValidate(df, json, "floating_point-double_precision.json") } - private def mixedNumericData: DataTuple = { + test("index conversion") { + val data = List[Int](1, 2, 3, 4, 5, 6) + + val fields = Seq(new IntegerType("i", is_signed = true, 32, nullable = false)) + val schema = new JSONSchema(fields) + val columns = Seq(new PrimitiveColumn("i", data.length, data.map(_ => true), data)) + val batch = new JSONRecordBatch(data.length, columns) + val json = new JSONFile(schema, Seq(batch)) + + val df = data.toDF("i") + + collectAndValidate(df, json, "indexData-ints.json") + } + + test("mixed numeric type conversion") { val data = List(1, 2, 3, 4, 5, 6) + val fields = Seq(new IntegerType("a", is_signed = true, 16, nullable = false), new FloatingPointType("b", 32, nullable = false), new IntegerType("c", is_signed = true, 32, nullable = false), @@ -301,26 +183,21 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { new PrimitiveColumn("e", data.length, data.map(_ => true), data) ) val batch = new JSONRecordBatch(data.length, columns) + val json = new JSONFile(schema, Seq(batch)) + val data_tuples = for (d <- data) yield { (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) } val df = data_tuples.toDF("a", "b", "c", "d", "e") - DataTuple(df, new JSONFile(schema, Seq(batch)), "mixed_numeric_types.json") - } - private def boolData: DataTuple = { - val data = Seq(true, true, false, true) - val fields = Seq(new BooleanType("a_bool", nullable = false)) - val schema = new JSONSchema(fields) - val columns = Seq(new PrimitiveColumn("a_bool", data.length, data.map(_ => true), data)) - val batch = new JSONRecordBatch(data.length, columns) - DataTuple(data.toDF("a_bool"), new JSONFile(schema, Seq(batch)), "boolData.json") + collectAndValidate(df, json, "mixed_numeric_types.json") } - private def stringData: DataTuple = { + test("string type conversion") { val upperCase = Seq("A", "B", "C") val lowerCase = Seq("a", "b", "c") val nullStr = Seq("ab", "CDE", null) + val fields = Seq(new StringType("upper_case", nullable = true), new StringType("lower_case", nullable = true), new StringType("null_str", nullable = true)) @@ -332,60 +209,218 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { nullStr.map { s => if (s == null) "" else s} )) val batch = new JSONRecordBatch(upperCase.length, columns) + val json = new JSONFile(schema, Seq(batch)) + val df = (upperCase, lowerCase, nullStr).zipped.toList .toDF("upper_case", "lower_case", "null_str") - DataTuple(df, new JSONFile(schema, Seq(batch)), "stringData.json") + + collectAndValidate(df, json, "stringData.json") + } + + test("boolean type conversion") { + val data = Seq(true, true, false, true) + + val fields = Seq(new BooleanType("a_bool", nullable = false)) + val schema = new JSONSchema(fields) + val columns = Seq(new PrimitiveColumn("a_bool", data.length, data.map(_ => true), data)) + val batch = new JSONRecordBatch(data.length, columns) + val json = new JSONFile(schema, Seq(batch)) + + val df = data.toDF("a_bool") + + collectAndValidate(df, json, "boolData.json") } - private def byteData: DataTuple = { + test("byte type conversion") { val data = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue) + val fields = Seq(new IntegerType("a_byte", is_signed = true, 8, nullable = false)) val schema = new JSONSchema(fields) val columns = Seq( new PrimitiveColumn("a_byte", data.length, data.map(_ => true), data.map(_.toInt))) val batch = new JSONRecordBatch(data.length, columns) - DataTuple(data.toDF("a_byte"), new JSONFile(schema, Seq(batch)), "byteData.json") + val json = new JSONFile(schema, Seq(batch)) + + val df = data.toDF("a_byte") + + collectAndValidate(df, json, "byteData.json") } - private def floatNaNData: DataTuple = { - val fnan = Seq(1.2F, Float.NaN) - val dnan = Seq(Double.NaN, 1.2) - val fields = Seq(new FloatingPointType("NaN_f", 32, nullable = false), - new FloatingPointType("NaN_d", 64, nullable = false)) - val schema = new JSONSchema(fields) - val columns = Seq(new PrimitiveColumn("NaN_f", fnan.length, fnan.map(_ => true), fnan), - new PrimitiveColumn("NaN_d", dnan.length, dnan.map(_ => true), dnan)) - val batch = new JSONRecordBatch(fnan.length, columns) - val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") - DataTuple(df, new JSONFile(schema, Seq(batch)), "nanData-floating_point.json") + // TODO: Not currently supported in Arrow JSON reader + ignore("binary type conversion") { + // collectAndValidate(binaryData) } - private def timestampData: DataTuple = { + ignore("timestamp conversion") { val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) val data = Seq(ts1, ts2) + val schema = new JSONSchema(Seq(new TimestampType("timestamp"))) val us_data = data.map(_.getTime * 1000) // convert to microseconds val columns = Seq( - new PrimitiveColumn("timestamp", data.length, data.map(_ => true), us_data)) + new PrimitiveColumn("timestamp", data.length, data.map(_ => true), us_data)) val batch = new JSONRecordBatch(data.length, columns) - DataTuple(data.toDF("timestamp"), new JSONFile(schema, Seq(batch)), "timestampData.json") + val json = new JSONFile(schema, Seq(batch)) + + val df = data.toDF("timestamp") + + collectAndValidate(df, json, "timestampData.json") } - private def dateData: DataTuple = { + ignore("date conversion") { val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) val data = Seq(d1, d2) + val day_data = data.map(d => DateTimeUtils.millisToDays(d.getTime)) val schema = new JSONSchema(Seq(new DateType("date"))) val columns = Seq( - new PrimitiveColumn("date", data.length, data.map(_ => true), day_data)) + new PrimitiveColumn("date", data.length, data.map(_ => true), day_data)) val batch = new JSONRecordBatch(data.length, columns) - DataTuple(data.toDF("date"), new JSONFile(schema, Seq(batch)), "dateData.json") + val json = new JSONFile(schema, Seq(batch)) + + val df = data.toDF("date") + + collectAndValidate(df, json, "dateData.json") + } + + test("floating-point NaN") { + val fnan = Seq(1.2F, Float.NaN) + val dnan = Seq(Double.NaN, 1.2) + + val fields = Seq(new FloatingPointType("NaN_f", 32, nullable = false), + new FloatingPointType("NaN_d", 64, nullable = false)) + val schema = new JSONSchema(fields) + val columns = Seq(new PrimitiveColumn("NaN_f", fnan.length, fnan.map(_ => true), fnan), + new PrimitiveColumn("NaN_d", dnan.length, dnan.map(_ => true), dnan)) + val batch = new JSONRecordBatch(fnan.length, columns) + val json = new JSONFile(schema, Seq(batch)) + + val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") + + collectAndValidate(df, json, "nanData-floating_point.json") + } + + test("partitioned DataFrame") { + val arrowPayloads = testData2.toArrowPayload.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + assert(arrowPayloads.length === 2) + // Generate JSON files + val a = List[Int](1, 1, 2, 2, 3, 3) + val b = List[Int](1, 2, 1, 2, 1, 2) + val fields = Seq(new IntegerType("a", is_signed = true, 32, nullable = false), + new IntegerType("b", is_signed = true, 32, nullable = false)) + def getBatch(x: Seq[Int], y: Seq[Int]): JSONRecordBatch = { + val columns = Seq(new PrimitiveColumn("a", x.length, x.map(_ => true), x), + new PrimitiveColumn("b", y.length, y.map(_ => true), y)) + new JSONRecordBatch(x.length, columns) + } + val json1 = new JSONFile(new JSONSchema(fields), Seq(getBatch(a.take(3), b.take(3)))) + val json2 = new JSONFile(new JSONSchema(fields), Seq(getBatch(a.takeRight(3), b.takeRight(3)))) + val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") + val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") + json1.write(tempFile1) + json2.write(tempFile2) + val schema = testData2.schema + validateConversion(schema, arrowPayloads(0), tempFile1) + validateConversion(schema, arrowPayloads(1), tempFile2) + } + + test("empty frame collect") { + val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() + assert(arrowPayload.isEmpty) + } + + test("empty partition collect") { + val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") + val arrowPayloads = emptyPart.toArrowPayload.collect() + assert(arrowPayloads.length === 1) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + assert(arrowRecordBatches.head.getLength == 1) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + testQuietly("unsupported types") { + def runUnsupported(block: => Unit): Unit = { + val msg = intercept[SparkException] { + block + } + assert(msg.getMessage.contains("Unsupported data type")) + assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) + } + + runUnsupported { decimalData.toArrowPayload.collect() } + runUnsupported { arrayData.toDF().toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowPayload.collect() } + runUnsupported { complexData.toArrowPayload.collect() } + // runUnsupported { dateData.df.toArrowPayload.collect() } + // runUnsupported { timestampData.df.toArrowPayload.collect() } + } + + test("test Arrow Validator") { + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + + val fields = Seq(new IntegerType("a_i", is_signed = true, 32, nullable = false), + new IntegerType("b_i", is_signed = true, 32, nullable = true)) + val schema = new JSONSchema(fields) + val columns = Seq( + new PrimitiveColumn("a_i", a_i.length, a_i.map(_ => true), a_i), + new PrimitiveColumn("b_i", b_i.length, b_i.map(_.isDefined), b_i.map(_.getOrElse(0)))) + val batch = new JSONRecordBatch(a_i.length, columns) + val json = new JSONFile(schema, Seq(batch)) + + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + // Different schema + val schema_other = new JSONSchema(fields.reverse) + val json_other = new JSONFile(schema_other, Seq(batch)) + intercept[IllegalArgumentException] { + collectAndValidate(df, json_other, "validador_diff_schema.json") + } + + // Different values + intercept[IllegalArgumentException] { + collectAndValidate(df.sort($"a_i".desc), json, "validador_diff_values.json") + } } + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ + private def collectAndValidate(df: DataFrame, json: JSONFile, file: String): Unit = { + // NOTE: coalesce to single partition because can only load 1 batch in validator + val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val tempFile = new File(tempDataPath, file) + json.write(tempFile) + validateConversion(df.schema, arrowPayload, tempFile) + } + + private def validateConversion( + sparkSchema: StructType, + arrowPayload: ArrowPayload, + jsonFile: File): Unit = { + val allocator = new RootAllocator(Long.MaxValue) + val jsonReader = new JsonFileReader(jsonFile, allocator) + + val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) + val jsonSchema = jsonReader.start() + Validator.compareSchemas(arrowSchema, jsonSchema) + + val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) + val vectorLoader = new VectorLoader(arrowRoot) + vectorLoader.load(arrowPayload.loadBatch(allocator)) + val jsonRoot = jsonReader.read() + Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) + } + + // Create Spark DataFrame and matching Arrow JSON at same time for validation + private case class DataTuple(df: DataFrame, json: JSONFile, file: String) + + /** * Arrow JSON Format Data Generation * Referenced from https://github.com/apache/arrow/blob/master/integration/integration_test.py From b6fe733955d6e153722b1945c09ed663d8ed9be2 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 26 Apr 2017 18:10:31 -0700 Subject: [PATCH 39/56] removed timestamp and date test until fully supported --- .../arrow/ArrowConvertersSuite.scala | 81 +++++++------------ 1 file changed, 27 insertions(+), 54 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 8f5a6e47eaf1..976f56be4da1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -251,42 +251,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { // collectAndValidate(binaryData) } - ignore("timestamp conversion") { - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - val data = Seq(ts1, ts2) - - val schema = new JSONSchema(Seq(new TimestampType("timestamp"))) - val us_data = data.map(_.getTime * 1000) // convert to microseconds - val columns = Seq( - new PrimitiveColumn("timestamp", data.length, data.map(_ => true), us_data)) - val batch = new JSONRecordBatch(data.length, columns) - val json = new JSONFile(schema, Seq(batch)) - - val df = data.toDF("timestamp") - - collectAndValidate(df, json, "timestampData.json") - } - - ignore("date conversion") { - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) - val data = Seq(d1, d2) - - val day_data = data.map(d => DateTimeUtils.millisToDays(d.getTime)) - val schema = new JSONSchema(Seq(new DateType("date"))) - val columns = Seq( - new PrimitiveColumn("date", data.length, data.map(_ => true), day_data)) - val batch = new JSONRecordBatch(data.length, columns) - val json = new JSONFile(schema, Seq(batch)) - - val df = data.toDF("date") - - collectAndValidate(df, json, "dateData.json") - } - test("floating-point NaN") { val fnan = Seq(1.2F, Float.NaN) val dnan = Seq(Double.NaN, 1.2) @@ -358,8 +322,13 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { runUnsupported { arrayData.toDF().toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } - // runUnsupported { dateData.df.toArrowPayload.collect() } - // runUnsupported { timestampData.df.toArrowPayload.collect() } + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) + val data = Seq(d1, d2) + runUnsupported { data.toDF("date").toArrowPayload.collect() } + runUnsupported { data.toDF("timestamp").toArrowPayload.collect() } } test("test Arrow Validator") { @@ -467,10 +436,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - private class PrimitiveColumn[T <% JValue](name: String, - count: Int, - is_valid: Seq[Boolean], - values: Seq[T]) + private class PrimitiveColumn[T <% JValue]( + name: String, + count: Int, + is_valid: Seq[Boolean], + values: Seq[T]) extends Column(name, count) { override def _get_children: JArray = JArray(List.empty) override def _get_buffers: JObject = { @@ -480,10 +450,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - private class IntegerType(name: String, - is_signed: Boolean, - override val bit_width: Int, - nullable: Boolean) + private class IntegerType( + name: String, + is_signed: Boolean, + override val bit_width: Int, + nullable: Boolean) extends PrimitiveType(name, nullable = nullable) { override def _get_type: JObject = { JObject( @@ -555,10 +526,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - private class BinaryColumn(name: String, - count: Int, - is_valid: Seq[Boolean], - values: Seq[String]) + private class BinaryColumn( + name: String, + count: Int, + is_valid: Seq[Boolean], + values: Seq[String]) extends PrimitiveColumn(name, count, is_valid, values) { def _encode_value(v: String): String = { v.map(c => String.format("%h", c.toString)).reduce(_ + _) @@ -579,10 +551,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - private class StringColumn(name: String, - count: Int, - is_valid: Seq[Boolean], - values: Seq[String]) + private class StringColumn( + name: String, + count: Int, + is_valid: Seq[Boolean], + values: Seq[String]) extends BinaryColumn(name, count, is_valid, values) { override def _encode_value(v: String): String = v } From 36f8127b0c705153a5c20364e66d141702377abb Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 1 May 2017 11:11:15 -0700 Subject: [PATCH 40/56] added exception handling in byteArrayToBatch conversion, changed ArrowPayload storage name to payload --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b7a6b29fcc29..a62de4a08893 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2778,7 +2778,7 @@ class Dataset[T] private[sql]( */ private[sql] def collectAsArrowToPython(): Int = { withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.batchBytes) + val iter = toArrowPayload.collect().iterator.map(_.payload) PythonRDD.serveIterator(iter, "serve-Arrow") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 8a0f9c39e1eb..70dae9748933 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -40,14 +40,14 @@ import org.apache.spark.util.Utils /** * Store Arrow data in a form that can be serialized by Spark */ -private[sql] class ArrowPayload(val batchBytes: Array[Byte]) extends Serializable { +private[sql] class ArrowPayload(val payload: Array[Byte]) extends Serializable { def this(batch: ArrowRecordBatch, schema: StructType, allocator: BufferAllocator) = { this(ArrowConverters.batchToByteArray(batch, schema, allocator)) } def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { - ArrowConverters.byteArrayToBatch(batchBytes, allocator) + ArrowConverters.byteArrayToBatch(payload, allocator) } } @@ -179,12 +179,16 @@ private[sql] object ArrowConverters { allocator: BufferAllocator): ArrowRecordBatch = { val in = new ByteArrayReadableSeekableByteChannel(batchBytes) val reader = new ArrowFileReader(in, allocator) - val root = reader.getVectorSchemaRoot - val unloader = new VectorUnloader(root) - reader.loadNextBatch() - val batch = unloader.getRecordBatch - reader.close() - batch + + // Read a batch from a byte stream, ensure the reader is closed + Utils.tryWithSafeFinally { + val root = reader.getVectorSchemaRoot // throws IOException + val unloader = new VectorUnloader(root) + reader.loadNextBatch() // throws IOException + unloader.getRecordBatch + } { + reader.close() + } } } From 088f79e4d0043b2e0d48b3d1f2d41134e2ba0b72 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 1 May 2017 13:52:34 -0700 Subject: [PATCH 41/56] added binary conversion test --- .../arrow/ArrowConvertersSuite.scala | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 976f56be4da1..23e9d79fadc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -33,10 +33,10 @@ import org.json4s.JsonDSL._ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{BinaryType, StructField, StructType} import org.apache.spark.util.Utils @@ -245,10 +245,21 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "byteData.json") } + + test("binary type conversion") { + val data = Seq("abc", "d", "ef") - // TODO: Not currently supported in Arrow JSON reader - ignore("binary type conversion") { - // collectAndValidate(binaryData) + val fields = Seq(new BinaryType("a_binary", nullable = true)) + val schema = new JSONSchema(fields) + val columns = Seq( + new BinaryColumn("a_binary", data.length, data.map(_ => true), data)) + val batch = new JSONRecordBatch(data.length, columns) + val json = new JSONFile(schema, Seq(batch)) + + val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) + val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) + + collectAndValidate(df, json, "binaryData.json") } test("floating-point NaN") { @@ -350,12 +361,12 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema_other = new JSONSchema(fields.reverse) val json_other = new JSONFile(schema_other, Seq(batch)) intercept[IllegalArgumentException] { - collectAndValidate(df, json_other, "validador_diff_schema.json") + collectAndValidate(df, json_other, "validator_diff_schema.json") } // Different values intercept[IllegalArgumentException] { - collectAndValidate(df.sort($"a_i".desc), json, "validador_diff_values.json") + collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json") } } @@ -386,9 +397,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) } - // Create Spark DataFrame and matching Arrow JSON at same time for validation - private case class DataTuple(df: DataFrame, json: JSONFile, file: String) - /** * Arrow JSON Format Data Generation From e0449ebb45ded6fc9ee14027cf2ae2b9ced8537f Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 1 May 2017 14:01:46 -0700 Subject: [PATCH 42/56] fixed up unsupported test for timestamp --- .../sql/execution/arrow/ArrowConvertersSuite.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 23e9d79fadc6..acaa507e9a00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -33,8 +33,7 @@ import org.json4s.JsonDSL._ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException -import org.apache.spark.sql.{Row, DataFrame} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, StructField, StructType} import org.apache.spark.util.Utils @@ -245,7 +244,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "byteData.json") } - + test("binary type conversion") { val data = Seq("abc", "d", "ef") @@ -337,9 +336,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) - val data = Seq(d1, d2) - runUnsupported { data.toDF("date").toArrowPayload.collect() } - runUnsupported { data.toDF("timestamp").toArrowPayload.collect() } + runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } + + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } } test("test Arrow Validator") { From b6bfcd7c8ec0300aa7d0931c717dde00e4c56672 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 2 May 2017 15:42:09 -0700 Subject: [PATCH 43/56] Updated Arrow version to 0.3.0 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 3e283d330ad4..dc967e224f98 100644 --- a/pom.xml +++ b/pom.xml @@ -184,7 +184,7 @@ 2.6 1.8 1.0.0 - 0.2.1-SNAPSHOT + 0.3.0 ${java.home} From 2c1af59afe4f590bdea5897139d1d771478c3d50 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 9 May 2017 11:50:15 -0700 Subject: [PATCH 44/56] added ArrowPayload method toByteArray --- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/arrow/ArrowConverters.scala | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a62de4a08893..147f32038ff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2778,7 +2778,7 @@ class Dataset[T] private[sql]( */ private[sql] def collectAsArrowToPython(): Int = { withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.payload) + val iter = toArrowPayload.collect().iterator.map(_.toByteArray) PythonRDD.serveIterator(iter, "serve-Arrow") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 70dae9748933..a0db9ad51974 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -40,15 +40,26 @@ import org.apache.spark.util.Utils /** * Store Arrow data in a form that can be serialized by Spark */ -private[sql] class ArrowPayload(val payload: Array[Byte]) extends Serializable { +private[sql] class ArrowPayload(payload: Array[Byte]) extends Serializable { + /** + * Create an ArrowPayload from an ArrowRecordBatch and Spark schema + */ def this(batch: ArrowRecordBatch, schema: StructType, allocator: BufferAllocator) = { this(ArrowConverters.batchToByteArray(batch, schema, allocator)) } + /** + * Convert the ArrowPayload to an ArrowRecordBatch + */ def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { ArrowConverters.byteArrayToBatch(payload, allocator) } + + /** + * Get the ArrowPayload as an Array[Byte] + */ + def toByteArray: Array[Byte] = payload } private[sql] object ArrowConverters { From 1d471ac1612e29e9486fab465bd5b9db0c20d2a8 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 9 May 2017 13:29:53 -0700 Subject: [PATCH 45/56] removed unused imports, arrow.vector.DateUnit and TimeUnit --- .../org/apache/spark/sql/execution/arrow/ArrowConverters.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index a0db9ad51974..4e822b05726e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -28,7 +28,7 @@ import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.file._ import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} +import org.apache.arrow.vector.types.FloatingPointPrecision import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel From a4d6057642a922c4beb5b396591ba9f1b5e3f883 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 9 May 2017 15:40:18 -0700 Subject: [PATCH 46/56] Added conf spark.sql.execution.arrow.maxRecordsPerBatch to limit num records in a single batch, added unit test for conf, minor cleanup --- .../apache/spark/sql/internal/SQLConf.scala | 22 ++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 3 +- .../sql/execution/arrow/ArrowConverters.scala | 42 ++++++++++++------- .../arrow/ArrowConvertersSuite.scala | 20 +++++++++ 4 files changed, 71 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2e1798e22b9f..c4620e670f14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -781,6 +781,24 @@ object SQLConf { .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + val ARROW_EXECUTION_ENABLE = + buildConf("spark.sql.execution.arrow.enable") + .internal() + .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + + "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + + "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + + "LongType, ShortType") + .booleanConf + .createWithDefault(false) + + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = + buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") + .internal() + .doc("When using Apache Arrow, limit the maximum number of records that can be written " + + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") + .intConf + .createWithDefault(10000) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1033,6 +1051,10 @@ class SQLConf extends Serializable with Logging { def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + + def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 147f32038ff3..4f409b21c780 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2868,8 +2868,9 @@ class Dataset[T] private[sql]( /** Convert to an RDD of ArrowPayload byte arrays */ private[sql] def toArrowPayload: RDD[ArrowPayload] = { val schema_captured = this.schema + val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch queryExecution.toRdd.mapPartitionsInternal { iter => - ArrowConverters.toPayloadIterator(iter, schema_captured) + ArrowConverters.toPayloadIterator(iter, schema_captured, maxRecordsPerBatch) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 4e822b05726e..17c2a7b725c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -38,26 +38,26 @@ import org.apache.spark.util.Utils /** - * Store Arrow data in a form that can be serialized by Spark + * Store Arrow data in a form that can be serialized by Spark. */ private[sql] class ArrowPayload(payload: Array[Byte]) extends Serializable { /** - * Create an ArrowPayload from an ArrowRecordBatch and Spark schema + * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. */ def this(batch: ArrowRecordBatch, schema: StructType, allocator: BufferAllocator) = { this(ArrowConverters.batchToByteArray(batch, schema, allocator)) } /** - * Convert the ArrowPayload to an ArrowRecordBatch + * Convert the ArrowPayload to an ArrowRecordBatch. */ def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { ArrowConverters.byteArrayToBatch(payload, allocator) } /** - * Get the ArrowPayload as an Array[Byte] + * Get the ArrowPayload as an Array[Byte]. */ def toByteArray: Array[Byte] = payload } @@ -93,11 +93,13 @@ private[sql] object ArrowConverters { } /** - * Maps Iterator from InternalRow to ArrowPayload + * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload + * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ private[sql] def toPayloadIterator( rowIter: Iterator[InternalRow], - schema: StructType): Iterator[ArrowPayload] = { + schema: StructType, + maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { new Iterator[ArrowPayload] { private val _allocator = new RootAllocator(Long.MaxValue) private var _nextPayload = if (rowIter.nonEmpty) convert() else null @@ -118,32 +120,37 @@ private[sql] object ArrowConverters { } private def convert(): ArrowPayload = { - val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator) + val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) new ArrowPayload(batch, schema, _allocator) } } } /** - * Iterate over InternalRows and write to an ArrowRecordBatch. + * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed + * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, + * then rowIter will be fully consumed. */ private def internalRowIterToArrowBatch( rowIter: Iterator[InternalRow], schema: StructType, - allocator: BufferAllocator): ArrowRecordBatch = { + allocator: BufferAllocator, + maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => - ColumnWriter(ordinal, allocator, field.dataType).init() + ColumnWriter(field.dataType, ordinal, allocator).init() } val writerLength = columnWriters.length - while (rowIter.hasNext) { + var recordsInBatch = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { val row = rowIter.next() var i = 0 while (i < writerLength) { columnWriters(i).write(row) i += 1 } + recordsInBatch += 1 } val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip @@ -158,7 +165,8 @@ private[sql] object ArrowConverters { } /** - * Convert an ArrowRecordBatch to a byte array and close batch + * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, + * the batch can no longer be used. */ private[arrow] def batchToByteArray( batch: ArrowRecordBatch, @@ -183,7 +191,7 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch + * Convert a byte array to an ArrowRecordBatch. */ private[arrow] def byteArrayToBatch( batchBytes: Array[Byte], @@ -204,7 +212,7 @@ private[sql] object ArrowConverters { } /** - * Interface for writing InternalRows to Arrow Buffers + * Interface for writing InternalRows to Arrow Buffers. */ private[arrow] trait ColumnWriter { def init(): this.type @@ -391,7 +399,11 @@ private[arrow] class TimeStampColumnWriter( } private[arrow] object ColumnWriter { - def apply(ordinal: Int, allocator: BufferAllocator, dataType: DataType): ColumnWriter = { + + /** + * Create an Arrow ColumnWriter given the type and ordinal of row. + */ + def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { val dtype = ArrowConverters.sparkTypeToArrowType(dataType) dataType match { case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index acaa507e9a00..6d6df4307c24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -319,6 +319,26 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { allocator.close() } + test("max records in batch conf") { + val totalRecords = 10 + val maxRecordsPerBatch = 3 + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) + val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") + val arrowPayloads = df.toArrowPayload.collect() + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + var recordCount = 0 + arrowRecordBatches.foreach { batch => + assert(batch.getLength > 0) + assert(batch.getLength <= maxRecordsPerBatch) + recordCount += batch.getLength + batch.close() + } + assert(recordCount == totalRecords) + allocator.close() + spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + } + testQuietly("unsupported types") { def runUnsupported(block: => Unit): Unit = { val msg = intercept[SparkException] { From 934c147cf41752d382ee6ae304ed18ca5bed73e4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 9 May 2017 16:18:56 -0700 Subject: [PATCH 47/56] update dependency manifests for Arrow 0.3.0 --- dev/deps/spark-deps-hadoop-2.6 | 6 +++--- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- dev/run-pip-tests | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index af0bb8fb4e90..bb0cb28f40bf 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,9 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.2.0.jar -arrow-memory-0.2.0.jar -arrow-vector-0.2.0.jar +arrow-format-0.3.0.jar +arrow-memory-0.3.0.jar +arrow-vector-0.3.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 242aa7a65050..4e40b4dbfb10 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,9 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.2.0.jar -arrow-memory-0.2.0.jar -arrow-vector-0.2.0.jar +arrow-format-0.3.0.jar +arrow-memory-0.3.0.jar +arrow-vector-0.3.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index edfc29352d26..c92fa33a116b 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -83,7 +83,7 @@ for python in "${PYTHON_EXECS[@]}"; do if [ -n "$USE_CONDA" ]; then conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools source activate "$VIRTUALENV_PATH" - conda install -y -c conda-forge pyarrow=0.2 + conda install -y -c conda-forge pyarrow=0.3 TEST_PYARROW=1 else mkdir -p "$VIRTUALENV_PATH" From 2e4747b9ade6a38360d77078679db5731d45bb8e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 May 2017 11:59:32 -0700 Subject: [PATCH 48/56] changed tests to close resources properly --- .../spark/sql/execution/arrow/ArrowConvertersSuite.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 6d6df4307c24..4f4be87187da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -413,9 +413,16 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) val vectorLoader = new VectorLoader(arrowRoot) - vectorLoader.load(arrowPayload.loadBatch(allocator)) + val arrowRecordBatch = arrowPayload.loadBatch(allocator) + vectorLoader.load(arrowRecordBatch) val jsonRoot = jsonReader.read() Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) + + jsonRoot.close() + jsonReader.close() + arrowRecordBatch.close() + arrowRoot.close() + allocator.close() } From b4eebc27e261eddb4d8b0b829245fa3c187dade1 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 17 May 2017 15:41:38 -0700 Subject: [PATCH 49/56] Made JSON test data local string for each test, removed JSON generation --- .../arrow/ArrowConvertersSuite.scala | 1296 ++++++++++++----- 1 file changed, 948 insertions(+), 348 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 4f4be87187da..8400de93a01d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -27,9 +27,6 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator -import org.json4s.jackson.JsonMethods._ -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException @@ -64,126 +61,506 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("short conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_s", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ] + | }, { + | "name" : "b_s", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -32768 ] + | } ] + | } ] + |} + """.stripMargin + val a_s = List[Short](1, -1, 2, -2, 32767, -32768) val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) - - val fields = Seq(new IntegerType("a_s", is_signed = true, 16, nullable = false), - new IntegerType("b_s", is_signed = true, 16, nullable = true)) - val schema = new JSONSchema(fields) - val b_s_values = b_s.map(_.map(_.toInt).getOrElse(0)) - val columns = Seq( - new PrimitiveColumn("a_s", a_s.length, a_s.map(_ => true), a_s.map(_.toInt)), - new PrimitiveColumn("b_s", b_s.length, b_s.map(_.isDefined), b_s_values)) - val batch = new JSONRecordBatch(a_s.length, columns) - val json = new JSONFile(schema, Seq(batch)) - val df = a_s.zip(b_s).toDF("a_s", "b_s") collectAndValidate(df, json, "integer-16bit.json") } test("int conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) - - val fields = Seq(new IntegerType("a_i", is_signed = true, 32, nullable = false), - new IntegerType("b_i", is_signed = true, 32, nullable = true)) - val schema = new JSONSchema(fields) - val columns = Seq( - new PrimitiveColumn("a_i", a_i.length, a_i.map(_ => true), a_i), - new PrimitiveColumn("b_i", b_i.length, b_i.map(_.isDefined), b_i.map(_.getOrElse(0)))) - val batch = new JSONRecordBatch(a_i.length, columns) - val json = new JSONFile(schema, Seq(batch)) - val df = a_i.zip(b_i).toDF("a_i", "b_i") collectAndValidate(df, json, "integer-32bit.json") } test("long conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_l", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ] + | }, { + | "name" : "b_l", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ] + | } ] + | } ] + |} + """.stripMargin + val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) - - val fields = Seq(new IntegerType("a_l", is_signed = true, 64, nullable = false), - new IntegerType("b_l", is_signed = true, 64, nullable = true)) - val schema = new JSONSchema(fields) - val columns = Seq( - new PrimitiveColumn("a_l", a_l.length, a_l.map(_ => true), a_l), - new PrimitiveColumn("b_l", b_l.length, b_l.map(_.isDefined), b_l.map(_.getOrElse(0L)))) - val batch = new JSONRecordBatch(a_l.length, columns) - val json = new JSONFile(schema, Seq(batch)) - val df = a_l.zip(b_l).toDF("a_l", "b_l") collectAndValidate(df, json, "integer-64bit.json") } test("float conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_f", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ] + | }, { + | "name" : "b_f", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) - - val fields = Seq(new FloatingPointType("a_f", 32, nullable = false), - new FloatingPointType("b_f", 32, nullable = true)) - val schema = new JSONSchema(fields) - val columns = Seq(new PrimitiveColumn("a_f", a_f.length, a_f.map(_ => true), a_f), - new PrimitiveColumn("b_f", b_f.length, b_f.map(_.isDefined), b_f.map(_.getOrElse(0.0f)))) - val batch = new JSONRecordBatch(a_f.length, columns) - val json = new JSONFile(schema, Seq(batch)) - val df = a_f.zip(b_f).toDF("a_f", "b_f") collectAndValidate(df, json, "floating_point-single_precision.json") } test("double conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ] + | }, { + | "name" : "b_d", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) - - val fields = Seq(new FloatingPointType("a_d", 64, nullable = false), - new FloatingPointType("b_d", 64, nullable = true)) - val schema = new JSONSchema(fields) - val columns = Seq(new PrimitiveColumn("a_d", a_d.length, a_d.map(_ => true), a_d), - new PrimitiveColumn("b_d", b_d.length, b_d.map(_.isDefined), b_d.map(_.getOrElse(0.0)))) - val batch = new JSONRecordBatch(a_d.length, columns) val df = a_d.zip(b_d).toDF("a_d", "b_d") - val json = new JSONFile(schema, Seq(batch)) - collectAndValidate(df, json, "floating_point-double_precision.json") } test("index conversion") { val data = List[Int](1, 2, 3, 4, 5, 6) - - val fields = Seq(new IntegerType("i", is_signed = true, 32, nullable = false)) - val schema = new JSONSchema(fields) - val columns = Seq(new PrimitiveColumn("i", data.length, data.map(_ => true), data)) - val batch = new JSONRecordBatch(data.length, columns) - val json = new JSONFile(schema, Seq(batch)) - + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin val df = data.toDF("i") collectAndValidate(df, json, "indexData-ints.json") } test("mixed numeric type conversion") { - val data = List(1, 2, 3, 4, 5, 6) - - val fields = Seq(new IntegerType("a", is_signed = true, 16, nullable = false), - new FloatingPointType("b", 32, nullable = false), - new IntegerType("c", is_signed = true, 32, nullable = false), - new FloatingPointType("d", 64, nullable = false), - new IntegerType("e", is_signed = true, 64, nullable = false)) - val schema = new JSONSchema(fields) - val columns = Seq(new PrimitiveColumn("a", data.length, data.map(_ => true), data), - new PrimitiveColumn("b", data.length, data.map(_ => true), data.map(_.toFloat)), - new PrimitiveColumn("c", data.length, data.map(_ => true), data), - new PrimitiveColumn("d", data.length, data.map(_ => true), data.map(_.toDouble)), - new PrimitiveColumn("e", data.length, data.map(_ => true), data) - ) - val batch = new JSONRecordBatch(data.length, columns) - val json = new JSONFile(schema, Seq(batch)) + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "e", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "b", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "c", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "e", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + val data = List(1, 2, 3, 4, 5, 6) val data_tuples = for (d <- data) yield { (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) } @@ -193,23 +570,97 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("string type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "upper_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "lower_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "null_str", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "upper_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "A", "B", "C" ] + | }, { + | "name" : "lower_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "a", "b", "c" ] + | }, { + | "name" : "null_str", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 0 ], + | "OFFSET" : [ 0, 2, 5, 5 ], + | "DATA" : [ "ab", "CDE", "" ] + | } ] + | } ] + |} + """.stripMargin + val upperCase = Seq("A", "B", "C") val lowerCase = Seq("a", "b", "c") val nullStr = Seq("ab", "CDE", null) - - val fields = Seq(new StringType("upper_case", nullable = true), - new StringType("lower_case", nullable = true), - new StringType("null_str", nullable = true)) - val schema = new JSONSchema(fields) - val columns = Seq( - new StringColumn("upper_case", upperCase.length, upperCase.map(_ => true), upperCase), - new StringColumn("lower_case", lowerCase.length, lowerCase.map(_ => true), lowerCase), - new StringColumn("null_str", nullStr.length, nullStr.map(_ != null), - nullStr.map { s => if (s == null) "" else s} - )) - val batch = new JSONRecordBatch(upperCase.length, columns) - val json = new JSONFile(schema, Seq(batch)) - val df = (upperCase, lowerCase, nullStr).zipped.toList .toDF("upper_case", "lower_case", "null_str") @@ -217,44 +668,124 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("boolean type conversion") { - val data = Seq(true, true, false, true) - - val fields = Seq(new BooleanType("a_bool", nullable = false)) - val schema = new JSONSchema(fields) - val columns = Seq(new PrimitiveColumn("a_bool", data.length, data.map(_ => true), data)) - val batch = new JSONRecordBatch(data.length, columns) - val json = new JSONFile(schema, Seq(batch)) - - val df = data.toDF("a_bool") - + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_bool", + | "type" : { + | "name" : "bool" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_bool", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ true, true, false, true ] + | } ] + | } ] + |} + """.stripMargin + val df = Seq(true, true, false, true).toDF("a_bool") collectAndValidate(df, json, "boolData.json") } test("byte type conversion") { - val data = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue) - - val fields = Seq(new IntegerType("a_byte", is_signed = true, 8, nullable = false)) - val schema = new JSONSchema(fields) - val columns = Seq( - new PrimitiveColumn("a_byte", data.length, data.map(_ => true), data.map(_.toInt))) - val batch = new JSONRecordBatch(data.length, columns) - val json = new JSONFile(schema, Seq(batch)) - - val df = data.toDF("a_byte") - + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_byte", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 8 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_byte", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 64, 127 ] + | } ] + | } ] + |} + | + """.stripMargin + val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") collectAndValidate(df, json, "byteData.json") } test("binary type conversion") { - val data = Seq("abc", "d", "ef") - - val fields = Seq(new BinaryType("a_binary", nullable = true)) - val schema = new JSONSchema(fields) - val columns = Seq( - new BinaryColumn("a_binary", data.length, data.map(_ => true), data)) - val batch = new JSONRecordBatch(data.length, columns) - val json = new JSONFile(schema, Seq(batch)) + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_binary", + | "type" : { + | "name" : "binary" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_binary", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 3, 4, 6 ], + | "DATA" : [ "616263", "64", "6566" ] + | } ] + | } ] + |} + """.stripMargin + val data = Seq("abc", "d", "ef") val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) @@ -262,43 +793,198 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("floating-point NaN") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "NaN_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "NaN_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 2, + | "columns" : [ { + | "name" : "NaN_f", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1.2000000476837158, "NaN" ] + | }, { + | "name" : "NaN_d", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ "NaN", 1.2 ] + | } ] + | } ] + |} + """.stripMargin + val fnan = Seq(1.2F, Float.NaN) val dnan = Seq(Double.NaN, 1.2) - - val fields = Seq(new FloatingPointType("NaN_f", 32, nullable = false), - new FloatingPointType("NaN_d", 64, nullable = false)) - val schema = new JSONSchema(fields) - val columns = Seq(new PrimitiveColumn("NaN_f", fnan.length, fnan.map(_ => true), fnan), - new PrimitiveColumn("NaN_d", dnan.length, dnan.map(_ => true), dnan)) - val batch = new JSONRecordBatch(fnan.length, columns) - val json = new JSONFile(schema, Seq(batch)) - val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") collectAndValidate(df, json, "nanData-floating_point.json") } test("partitioned DataFrame") { + val json1 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 1, 2 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 1 ] + | } ] + | } ] + |} + """.stripMargin + val json2 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 3, 3 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 1, 2 ] + | } ] + | } ] + |} + """.stripMargin + val arrowPayloads = testData2.toArrowPayload.collect() // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload assert(arrowPayloads.length === 2) - // Generate JSON files - val a = List[Int](1, 1, 2, 2, 3, 3) - val b = List[Int](1, 2, 1, 2, 1, 2) - val fields = Seq(new IntegerType("a", is_signed = true, 32, nullable = false), - new IntegerType("b", is_signed = true, 32, nullable = false)) - def getBatch(x: Seq[Int], y: Seq[Int]): JSONRecordBatch = { - val columns = Seq(new PrimitiveColumn("a", x.length, x.map(_ => true), x), - new PrimitiveColumn("b", y.length, y.map(_ => true), y)) - new JSONRecordBatch(x.length, columns) - } - val json1 = new JSONFile(new JSONSchema(fields), Seq(getBatch(a.take(3), b.take(3)))) - val json2 = new JSONFile(new JSONSchema(fields), Seq(getBatch(a.takeRight(3), b.takeRight(3)))) + val schema = testData2.schema + val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") - json1.write(tempFile1) - json2.write(tempFile2) - val schema = testData2.schema + Files.write(json1, tempFile1, StandardCharsets.UTF_8) + Files.write(json2, tempFile2, StandardCharsets.UTF_8) + validateConversion(schema, arrowPayloads(0), tempFile1) validateConversion(schema, arrowPayloads(1), tempFile2) } @@ -364,25 +1050,130 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("test Arrow Validator") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + val json_diff_col_order = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) - - val fields = Seq(new IntegerType("a_i", is_signed = true, 32, nullable = false), - new IntegerType("b_i", is_signed = true, 32, nullable = true)) - val schema = new JSONSchema(fields) - val columns = Seq( - new PrimitiveColumn("a_i", a_i.length, a_i.map(_ => true), a_i), - new PrimitiveColumn("b_i", b_i.length, b_i.map(_.isDefined), b_i.map(_.getOrElse(0)))) - val batch = new JSONRecordBatch(a_i.length, columns) - val json = new JSONFile(schema, Seq(batch)) - val df = a_i.zip(b_i).toDF("a_i", "b_i") // Different schema - val schema_other = new JSONSchema(fields.reverse) - val json_other = new JSONFile(schema_other, Seq(batch)) intercept[IllegalArgumentException] { - collectAndValidate(df, json_other, "validator_diff_schema.json") + collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json") } // Different values @@ -392,11 +1183,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, json: JSONFile, file: String): Unit = { + private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator val arrowPayload = df.coalesce(1).toArrowPayload.collect().head val tempFile = new File(tempDataPath, file) - json.write(tempFile) + Files.write(json, tempFile, StandardCharsets.UTF_8) validateConversion(df.schema, arrowPayload, tempFile) } @@ -424,195 +1215,4 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { arrowRoot.close() allocator.close() } - - - /** - * Arrow JSON Format Data Generation - * Referenced from https://github.com/apache/arrow/blob/master/integration/integration_test.py - * TODO: Look into using JSON generation from parquet-vector.jar - */ - - private abstract class DataType(name: String, nullable: Boolean) { - def _get_type: JObject - def _get_type_layout: JField - def _get_children: JArray - def get_json: JObject = { - JObject( - "name" -> name, - "type" -> _get_type, - "nullable" -> nullable, - "children" -> _get_children, - "typeLayout" -> _get_type_layout) - } - } - - private abstract class Column(name: String, count: Int) { - def _get_children: JArray - def _get_buffers: JObject - def get_json: JObject = { - val entries = JObject( - "name" -> name, - "count" -> count - ).merge(_get_buffers) - - val children = _get_children - if (children.arr.nonEmpty) entries.merge(JObject("children" -> children)) else entries - } - } - - private abstract class PrimitiveType(name: String, nullable: Boolean) - extends DataType(name, nullable) { - val bit_width: Int - override def _get_children: JArray = JArray(List.empty) - override def _get_type_layout: JField = { - JField("vectors", JArray(List( - JObject("type" -> "VALIDITY", "typeBitWidth" -> 1), - JObject("type" -> "DATA", "typeBitWidth" -> bit_width) - ))) - } - } - - private class PrimitiveColumn[T <% JValue]( - name: String, - count: Int, - is_valid: Seq[Boolean], - values: Seq[T]) - extends Column(name, count) { - override def _get_children: JArray = JArray(List.empty) - override def _get_buffers: JObject = { - JObject( - "VALIDITY" -> is_valid.map(b => if (b) 1 else 0), - "DATA" -> values) - } - } - - private class IntegerType( - name: String, - is_signed: Boolean, - override val bit_width: Int, - nullable: Boolean) - extends PrimitiveType(name, nullable = nullable) { - override def _get_type: JObject = { - JObject( - "name" -> "int", - "isSigned" -> is_signed, - "bitWidth" -> bit_width) - } - } - - private class FloatingPointType(name: String, override val bit_width: Int, nullable: Boolean) - extends PrimitiveType(name, nullable = nullable) { - override def _get_type: JObject = { - val precision = bit_width match { - case 16 => "HALF" - case 32 => "SINGLE" - case 64 => "DOUBLE" - } - JObject( - "name" -> "floatingpoint", - "precision" -> precision) - } - } - - private class BooleanType(name: String, nullable: Boolean) - extends PrimitiveType(name, nullable = nullable) { - override val bit_width = 1 - override def _get_type: JObject = JObject("name" -> JString("bool")) - } - - private class BinaryType(name: String, nullable: Boolean) - extends PrimitiveType(name, nullable = nullable) { - override val bit_width = 8 - override def _get_type: JObject = JObject("name" -> JString("binary")) - override def _get_type_layout: JField = { - JField("vectors", JArray(List( - JObject("type" -> "VALIDITY", "typeBitWidth" -> 1), - JObject("type" -> "OFFSET", "typeBitWidth" -> 32), - JObject("type" -> "DATA", "typeBitWidth" -> bit_width) - ))) - } - } - - private class StringType(name: String, nullable: Boolean) - extends BinaryType(name, nullable = nullable) { - override def _get_type: JObject = JObject("name" -> JString("utf8")) - } - - private class DateType(name: String) extends PrimitiveType(name, nullable = true) { - override val bit_width = 32 - override def _get_type: JObject = { - JObject( - "name" -> "date", - "unit" -> "DAY") - } - } - - private class TimestampType(name: String) extends PrimitiveType(name, nullable = true) { - override val bit_width = 64 - override def _get_type: JObject = { - JObject( - "name" -> "timestamp", - "unit" -> "MICROSECOND") - } - } - - private class JSONSchema(fields: Seq[DataType]) { - def get_json: JObject = { - JObject("fields" -> JArray(fields.map(_.get_json).toList)) - } - } - - private class BinaryColumn( - name: String, - count: Int, - is_valid: Seq[Boolean], - values: Seq[String]) - extends PrimitiveColumn(name, count, is_valid, values) { - def _encode_value(v: String): String = { - v.map(c => String.format("%h", c.toString)).reduce(_ + _) - } - override def _get_buffers: JObject = { - var offset = 0 - val offsets = scala.collection.mutable.ArrayBuffer[Int](offset) - val data = values.zip(is_valid).map { case (value, isval) => - if (isval) offset += value.length - val element = _encode_value(if (isval) value else "") - offsets += offset - element - } - JObject( - "VALIDITY" -> is_valid.map(b => if (b) 1 else 0), - "OFFSET" -> offsets, - "DATA" -> data) - } - } - - private class StringColumn( - name: String, - count: Int, - is_valid: Seq[Boolean], - values: Seq[String]) - extends BinaryColumn(name, count, is_valid, values) { - override def _encode_value(v: String): String = v - } - - private class JSONRecordBatch(count: Int, columns: Seq[Column]) { - def get_json: JObject = { - JObject( - "count" -> count, - "columns" -> columns.map(_.get_json)) - } - } - - private class JSONFile(schema: JSONSchema, batches: Seq[JSONRecordBatch]) { - def get_json: JObject = { - JObject( - "schema" -> schema.get_json, - "batches" -> batches.map(_.get_json)) - } - def write(file: File): Unit = { - val json = pretty(render(get_json)) - Files.write(json, file, StandardCharsets.UTF_8) - } - } } From d49a14daea3a5e92c2cfdf579373ca13b96c20e5 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 19 May 2017 14:40:05 -0700 Subject: [PATCH 50/56] upgrade to use Arrow 0.4 removed exclusion for log4j-over-slf4j as it had been moved to test scope in Arrow --- dev/run-pip-tests | 2 +- pom.xml | 6 +----- python/pyspark/serializers.py | 2 +- python/pyspark/sql/tests.py | 2 +- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/dev/run-pip-tests b/dev/run-pip-tests index c92fa33a116b..225e9209536f 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -83,7 +83,7 @@ for python in "${PYTHON_EXECS[@]}"; do if [ -n "$USE_CONDA" ]; then conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools source activate "$VIRTUALENV_PATH" - conda install -y -c conda-forge pyarrow=0.3 + conda install -y -c conda-forge pyarrow=0.4.0 TEST_PYARROW=1 else mkdir -p "$VIRTUALENV_PATH" diff --git a/pom.xml b/pom.xml index dc967e224f98..8811ffcae0c9 100644 --- a/pom.xml +++ b/pom.xml @@ -184,7 +184,7 @@ 2.6 1.8 1.0.0 - 0.3.0 + 0.4.0 ${java.home} @@ -1885,10 +1885,6 @@ com.fasterxml.jackson.core jackson-databind - - org.slf4j - log4j-over-slf4j - io.netty netty-handler diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 07d043f88811..d5c2a7518b18 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -192,7 +192,7 @@ def dumps(self, obj): def loads(self, obj): import pyarrow as pa - reader = pa.FileReader(pa.BufferReader(obj)) + reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) return reader.read_all() def __repr__(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cc6dd49f9b13..12706b7a4783 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2522,7 +2522,7 @@ def test_null_conversion(self): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - # NOTE - toPandas(useArrow=False) will infer standard python data types + # NOTE - toPandas() without pyarrow will infer standard python data types df_sel = df.select("1_str_t", "3_long_t", "5_double_t") self.spark.conf.set("spark.sql.execution.arrow.enable", "false") pdf = df_sel.toPandas() From a630bf0d867c31be10660f25ae0d9b185dfa00e2 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 26 May 2017 10:21:14 -0700 Subject: [PATCH 51/56] forgot to update arrow version in dependency manifests --- dev/deps/spark-deps-hadoop-2.6 | 6 +++--- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index bb0cb28f40bf..9868c1ab7c2a 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,9 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.3.0.jar -arrow-memory-0.3.0.jar -arrow-vector-0.3.0.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 4e40b4dbfb10..d6001ee71ee7 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,9 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.3.0.jar -arrow-memory-0.3.0.jar -arrow-vector-0.3.0.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar From 748e6fbc9c257e4af4ba920c3af2eebe97ed9c7d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 15 Jun 2017 15:27:56 -0700 Subject: [PATCH 52/56] Changed UTF8StringColumnWriter to use VarCharVector --- .../spark/sql/execution/arrow/ArrowConverters.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 17c2a7b725c1..c38d78327512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -347,14 +347,14 @@ private[arrow] class UTF8StringColumnWriter( ordinal: Int, allocator: BufferAllocator) extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("UTF8StringValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + override val valueVector: NullableVarCharVector + = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) + override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow): Unit = { - val bytes = row.getUTF8String(ordinal).getBytes - valueMutator.setSafe(count, bytes, 0, bytes.length) + val str = row.getUTF8String(ordinal) + valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) } } From b361bdcef27bd3d5a864c4fd6f60e0b6a2cf5d4e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 19 Jun 2017 11:36:37 -0700 Subject: [PATCH 53/56] Added check for DataFrame that is filtered out completely and converted to Pandas --- python/pyspark/sql/dataframe.py | 9 ++++++--- python/pyspark/sql/tests.py | 7 +++++++ .../spark/sql/execution/arrow/ArrowConvertersSuite.scala | 4 ++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 57e538d6a669..9b781c0a2e83 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1650,18 +1650,21 @@ def toPandas(self): 0 2 Alice 1 5 Bob """ + import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": try: import pyarrow tables = self._collectAsArrow() - table = pyarrow.concat_tables(tables) - return table.to_pandas() + if tables: + table = pyarrow.concat_tables(tables) + return table.to_pandas() + else: + return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: msg = "note: pyarrow must be installed and available on calling Python process " \ "if using spark.sql.execution.arrow.enable=true" raise ImportError("%s\n%s" % (e.message, msg)) else: - import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) def _collectAsArrow(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 12706b7a4783..0aab151fc43e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2544,6 +2544,13 @@ def test_pandas_round_trip(self): pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) + def test_filtered_frame(self): + df = self.spark.range(3).toDF("i") + pdf = df.filter("i < 0").toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "i") + self.assertTrue(pdf.empty) + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 8400de93a01d..159328cc0d95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -992,6 +992,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("empty frame collect") { val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() assert(arrowPayload.isEmpty) + + val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") + val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() + assert(filteredArrowPayload.isEmpty) } test("empty partition collect") { From 8bff966b637ee35a8c1cb051c7eb700f017e4d71 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 19 Jun 2017 14:20:41 -0700 Subject: [PATCH 54/56] Moved all work out of ArrowPayload construction to companion object --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 30 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4f409b21c780..72b7fb13fe3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2778,7 +2778,7 @@ class Dataset[T] private[sql]( */ private[sql] def collectAsArrowToPython(): Int = { withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.toByteArray) + val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) PythonRDD.serveIterator(iter, "serve-Arrow") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index c38d78327512..6af5c7342237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -38,16 +38,9 @@ import org.apache.spark.util.Utils /** - * Store Arrow data in a form that can be serialized by Spark. + * Store Arrow data in a form that can be serialized by Spark and served to a Python process. */ -private[sql] class ArrowPayload(payload: Array[Byte]) extends Serializable { - - /** - * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. - */ - def this(batch: ArrowRecordBatch, schema: StructType, allocator: BufferAllocator) = { - this(ArrowConverters.batchToByteArray(batch, schema, allocator)) - } +private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { /** * Convert the ArrowPayload to an ArrowRecordBatch. @@ -57,9 +50,22 @@ private[sql] class ArrowPayload(payload: Array[Byte]) extends Serializable { } /** - * Get the ArrowPayload as an Array[Byte]. + * Get the ArrowPayload as a type that can be served to Python. + */ + def asPythonSerializable: Array[Byte] = payload +} + +private[sql] object ArrowPayload { + + /** + * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. */ - def toByteArray: Array[Byte] = payload + def apply( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): ArrowPayload = { + new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) + } } private[sql] object ArrowConverters { @@ -121,7 +127,7 @@ private[sql] object ArrowConverters { private def convert(): ArrowPayload = { val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) - new ArrowPayload(batch, schema, _allocator) + ArrowPayload(batch, schema, _allocator) } } } From f96f555e1a3b8aabc7949d4b355f3af3b0e78b5a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 20 Jun 2017 16:49:09 -0700 Subject: [PATCH 55/56] Renamed variable to schemaCaptured --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 72b7fb13fe3c..a0cbba04e9b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2867,10 +2867,10 @@ class Dataset[T] private[sql]( /** Convert to an RDD of ArrowPayload byte arrays */ private[sql] def toArrowPayload: RDD[ArrowPayload] = { - val schema_captured = this.schema + val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch queryExecution.toRdd.mapPartitionsInternal { iter => - ArrowConverters.toPayloadIterator(iter, schema_captured, maxRecordsPerBatch) + ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) } } } From 44d7a2a3fedb4f4bec167d763d0df3d6448bbe49 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 22 Jun 2017 11:22:51 -0700 Subject: [PATCH 56/56] cleanup up test now that toPandas without Arrow will have correct dtypes --- python/pyspark/sql/tests.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c16a0d5a85f6..326e8548a617 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2669,12 +2669,10 @@ def test_null_conversion(self): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - # NOTE - toPandas() without pyarrow will infer standard python data types - df_sel = df.select("1_str_t", "3_long_t", "5_double_t") self.spark.conf.set("spark.sql.execution.arrow.enable", "false") - pdf = df_sel.toPandas() + pdf = df.toPandas() self.spark.conf.set("spark.sql.execution.arrow.enable", "true") - pdf_arrow = df_sel.toPandas() + pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) def test_pandas_round_trip(self):