From 1439fe6d320208ccb565d18fc4b7485210068330 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 13 Jul 2017 17:45:33 +0900 Subject: [PATCH 01/15] Introduce ArrowWriter and ArrowColumnVector. --- .../vectorized/ArrowColumnVector.java | 502 ++++++++++++++++++ .../execution/vectorized/ColumnVector.java | 16 +- .../vectorized/OnHeapColumnVector.java | 2 +- .../sql/execution/vectorized/ArrowUtils.scala | 107 ++++ .../execution/vectorized/ArrowWriter.scala | 405 ++++++++++++++ .../vectorized/ArrowUtilsSuite.scala | 65 +++ .../vectorized/ArrowWriterSuite.scala | 250 +++++++++ 7 files changed, 1338 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowUtils.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowWriter.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowUtilsSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowWriterSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java new file mode 100644 index 000000000000..abc8a8d23d44 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -0,0 +1,502 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.complex.*; +import org.apache.arrow.vector.holders.NullableVarCharHolder; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column backed by Apache Arrow. + */ +public final class ArrowColumnVector extends ColumnVector { + + private ValueVector vector; + private ValueVector.Accessor nulls; + + private NullableBitVector boolData; + private NullableTinyIntVector byteData; + private NullableSmallIntVector shortData; + private NullableIntVector intData; + private NullableBigIntVector longData; + + private NullableFloat4Vector floatData; + private NullableFloat8Vector doubleData; + private NullableDecimalVector decimalData; + + private NullableVarCharVector stringData; + + private NullableVarBinaryVector binaryData; + + private UInt4Vector listOffsetData; + + public ArrowColumnVector(ValueVector vector) { + super(vector.getValueCapacity(), DataTypes.NullType, MemoryMode.OFF_HEAP); + initialize(vector); + } + + @Override + public long nullsNativeAddress() { + throw new RuntimeException("Cannot get native address for arrow column"); + } + + @Override + public long valuesNativeAddress() { + throw new RuntimeException("Cannot get native address for arrow column"); + } + + @Override + public void close() { + if (childColumns != null) { + for (int i = 0; i < childColumns.length; i++) { + childColumns[i].close(); + } + } + vector.close(); + } + + // + // APIs dealing with nulls + // + + @Override + public void putNotNull(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public void putNull(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public void putNulls(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + @Override + public void putNotNulls(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int rowId) { + return nulls.isNull(rowId); + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getBoolean(int rowId) { + return boolData.getAccessor().get(rowId) == 1; + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + assert(dictionary == null); + boolean[] array = new boolean[count]; + for (int i = 0; i < count; ++i) { + array[i] = (boolData.getAccessor().get(rowId + i) == 1); + } + return array; + } + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putBytes(int rowId, int count, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + return byteData.getAccessor().get(rowId); + } + + @Override + public byte[] getBytes(int rowId, int count) { + assert(dictionary == null); + byte[] array = new byte[count]; + for (int i = 0; i < count; ++i) { + array[i] = byteData.getAccessor().get(rowId + i); + } + return array; + } + + // + // APIs dealing with Shorts + // + + @Override + public void putShort(int rowId, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putShorts(int rowId, int count, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + return shortData.getAccessor().get(rowId); + } + + @Override + public short[] getShorts(int rowId, int count) { + assert(dictionary == null); + short[] array = new short[count]; + for (int i = 0; i < count; ++i) { + array[i] = shortData.getAccessor().get(rowId + i); + } + return array; + } + + // + // APIs dealing with Ints + // + + @Override + public void putInt(int rowId, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putInts(int rowId, int count, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + return intData.getAccessor().get(rowId); + } + + @Override + public int[] getInts(int rowId, int count) { + assert(dictionary == null); + int[] array = new int[count]; + for (int i = 0; i < count; ++i) { + array[i] = intData.getAccessor().get(rowId + i); + } + return array; + } + + @Override + public int getDictId(int rowId) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putLongs(int rowId, int count, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + return longData.getAccessor().get(rowId); + } + + @Override + public long[] getLongs(int rowId, int count) { + assert(dictionary == null); + long[] array = new long[count]; + for (int i = 0; i < count; ++i) { + array[i] = longData.getAccessor().get(rowId + i); + } + return array; + } + + // + // APIs dealing with floats + // + + @Override + public void putFloat(int rowId, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putFloats(int rowId, int count, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + return floatData.getAccessor().get(rowId); + } + + @Override + public float[] getFloats(int rowId, int count) { + assert(dictionary == null); + float[] array = new float[count]; + for (int i = 0; i < count; ++i) { + array[i] = floatData.getAccessor().get(rowId + i); + } + return array; + } + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putDoubles(int rowId, int count, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + return doubleData.getAccessor().get(rowId); + } + + @Override + public double[] getDoubles(int rowId, int count) { + assert(dictionary == null); + double[] array = new double[count]; + for (int i = 0; i < count; ++i) { + array[i] = doubleData.getAccessor().get(rowId + i); + } + return array; + } + + // + // APIs dealing with Arrays + // + + @Override + public int getArrayLength(int rowId) { + return listOffsetData.getAccessor().get(rowId + 1) - listOffsetData.getAccessor().get(rowId); + } + + @Override + public int getArrayOffset(int rowId) { + return listOffsetData.getAccessor().get(rowId); + } + + @Override + public void putArray(int rowId, int offset, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public void loadBytes(Array array) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public int putByteArray(int rowId, byte[] value, int offset, int count) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Decimals + // + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + return Decimal.apply(decimalData.getAccessor().getObject(rowId), precision, scale); + } + + @Override + public final void putDecimal(int rowId, Decimal value, int precision) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with UTF8Strings + // + + private NullableVarCharHolder stringResult = new NullableVarCharHolder(); + + @Override + public UTF8String getUTF8String(int rowId) { + stringData.getAccessor().get(rowId, stringResult); + if (stringResult.isSet == 0) { + return null; + } else { + return UTF8String.fromAddress(null, + stringResult.buffer.memoryAddress() + stringResult.start, + stringResult.end - stringResult.start); + } + } + + // + // APIs dealing with Binaries + // + + @Override + public byte[] getBinary(int rowId) { + return binaryData.getAccessor().getObject(rowId); + } + + @Override + protected void reserveInternal(int newCapacity) { + while (vector.getValueCapacity() <= newCapacity) { + vector.reAlloc(); + } + capacity = vector.getValueCapacity(); + } + + private void initialize(ValueVector vector) { + this.vector = vector; + this.type = ArrowUtils.fromArrowField(vector.getField()); + if (vector instanceof NullableBitVector) { + boolData = (NullableBitVector) vector; + nulls = boolData.getAccessor(); + } else if (vector instanceof NullableTinyIntVector) { + byteData = (NullableTinyIntVector) vector; + nulls = byteData.getAccessor(); + } else if (vector instanceof NullableSmallIntVector) { + shortData = (NullableSmallIntVector) vector; + nulls = shortData.getAccessor(); + } else if (vector instanceof NullableIntVector) { + intData = (NullableIntVector) vector; + nulls = intData.getAccessor(); + } else if (vector instanceof NullableBigIntVector) { + longData = (NullableBigIntVector) vector; + nulls = longData.getAccessor(); + } else if (vector instanceof NullableFloat4Vector) { + floatData = (NullableFloat4Vector) vector; + nulls = floatData.getAccessor(); + } else if (vector instanceof NullableFloat8Vector) { + doubleData = (NullableFloat8Vector) vector; + nulls = doubleData.getAccessor(); + } else if (vector instanceof NullableDecimalVector) { + decimalData = (NullableDecimalVector) vector; + nulls = decimalData.getAccessor(); + } else if (vector instanceof NullableVarCharVector) { + stringData = (NullableVarCharVector) vector; + nulls = stringData.getAccessor(); + } else if (vector instanceof NullableVarBinaryVector) { + binaryData = (NullableVarBinaryVector) vector; + nulls = binaryData.getAccessor(); + } else if (vector instanceof ListVector) { + ListVector listVector = (ListVector) vector; + listOffsetData = listVector.getOffsetVector(); + nulls = listVector.getAccessor(); + + childColumns = new ColumnVector[1]; + childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); + resultArray = new Array(childColumns[0]); + } else if (vector instanceof MapVector) { + MapVector mapVector = (MapVector) vector; + nulls = mapVector.getAccessor(); + + childColumns = new ArrowColumnVector[mapVector.size()]; + for (int i = 0; i < childColumns.length; ++i) { + childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); + } + resultStruct = new ColumnarBatch.Row(childColumns); + } + numNulls = nulls.getNullCount(); + anyNullsSet = numNulls > 0; + isConstant = true; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 0c027f80d48c..77966382881b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -646,7 +646,7 @@ public MapData getMap(int ordinal) { /** * Returns the decimal for rowId. */ - public final Decimal getDecimal(int rowId, int precision, int scale) { + public Decimal getDecimal(int rowId, int precision, int scale) { if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -661,7 +661,7 @@ public final Decimal getDecimal(int rowId, int precision, int scale) { } - public final void putDecimal(int rowId, Decimal value, int precision) { + public void putDecimal(int rowId, Decimal value, int precision) { if (precision <= Decimal.MAX_INT_DIGITS()) { putInt(rowId, (int) value.toUnscaledLong()); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -675,7 +675,7 @@ public final void putDecimal(int rowId, Decimal value, int precision) { /** * Returns the UTF8String for rowId. */ - public final UTF8String getUTF8String(int rowId) { + public UTF8String getUTF8String(int rowId) { if (dictionary == null) { ColumnVector.Array a = getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); @@ -688,7 +688,7 @@ public final UTF8String getUTF8String(int rowId) { /** * Returns the byte array for rowId. */ - public final byte[] getBinary(int rowId) { + public byte[] getBinary(int rowId) { if (dictionary == null) { ColumnVector.Array array = getByteArray(rowId); byte[] bytes = new byte[array.length]; @@ -956,7 +956,7 @@ public final int appendStruct(boolean isNull) { /** * Data type for this column. */ - protected final DataType type; + protected DataType type; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. @@ -988,17 +988,17 @@ public final int appendStruct(boolean isNull) { /** * If this is a nested type (array or struct), the column for the child data. */ - protected final ColumnVector[] childColumns; + protected ColumnVector[] childColumns; /** * Reusable Array holder for getArray(). */ - protected final Array resultArray; + protected Array resultArray; /** * Reusable Struct holder for getStruct(). */ - protected final ColumnarBatch.Row resultStruct; + protected ColumnarBatch.Row resultStruct; /** * The Dictionary for this column. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 506434364be4..0775d381ef3b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -137,7 +137,7 @@ public boolean[] getBooleans(int rowId, int count) { for (int i = 0; i < count; ++i) { array[i] = (byteData[rowId + i] == 1); } - return array; + return array; } // diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowUtils.scala new file mode 100644 index 000000000000..9ded74b496a5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowUtils.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import scala.collection.JavaConverters._ + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} + +import org.apache.spark.sql.types._ + +object ArrowUtils { + + val rootAllocator = new RootAllocator(Long.MaxValue) + + // todo: support more types. + + def toArrowType(dt: DataType): ArrowType = dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + } + + def fromArrowType(dt: ArrowType): DataType = dt match { + case ArrowType.Bool.INSTANCE => BooleanType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType + case ArrowType.Utf8.INSTANCE => StringType + case ArrowType.Binary.INSTANCE => BinaryType + case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") + } + + def toArrowField(name: String, dt: DataType, nullable: Boolean): Field = { + dt match { + case ArrayType(elementType, containsNull) => + new Field(name, nullable, ArrowType.List.INSTANCE, + Seq(toArrowField("element", elementType, containsNull)).asJava) + case StructType(fields) => + new Field(name, nullable, ArrowType.Struct.INSTANCE, + fields.map { field => + toArrowField(field.name, field.dataType, field.nullable) + }.toSeq.asJava) + case dataType => + new Field(name, nullable, toArrowType(dataType), Seq.empty[Field].asJava) + } + } + + def fromArrowField(field: Field): DataType = { + field.getType match { + case ArrowType.List.INSTANCE => + val elementField = field.getChildren().get(0) + val elementType = fromArrowField(elementField) + ArrayType(elementType, containsNull = elementField.isNullable) + case ArrowType.Struct.INSTANCE => + val fields = field.getChildren().asScala.map { child => + val dt = fromArrowField(child) + StructField(child.getName, dt, child.isNullable) + } + StructType(fields) + case arrowType => fromArrowType(arrowType) + } + } + + def toArrowSchema(schema: StructType): Schema = { + new Schema(schema.map { field => + toArrowField(field.name, field.dataType, field.nullable) + }.asJava) + } + + def fromArrowSchema(schema: Schema): StructType = { + StructType(schema.getFields.asScala.map { field => + val dt = fromArrowField(field) + StructField(field.getName, dt, field.isNullable) + }) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowWriter.scala new file mode 100644 index 000000000000..0547e9585664 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowWriter.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex._ +import org.apache.arrow.vector.util.DecimalUtility + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.types._ + +object ArrowWriter { + + def create(schema: StructType): ArrowWriter = { + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + create(root) + } + + def create(root: VectorSchemaRoot): ArrowWriter = { + val children = root.getFieldVectors().asScala.map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + new ArrowWriter(root, children.toArray) + } + + private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { + val field = vector.getField() + ArrowUtils.fromArrowField(field) match { + case BooleanType => + new BooleanWriter(vector.asInstanceOf[NullableBitVector]) + case ByteType => + new ByteWriter(vector.asInstanceOf[NullableTinyIntVector]) + case ShortType => + new ShortWriter(vector.asInstanceOf[NullableSmallIntVector]) + case IntegerType => + new IntegerWriter(vector.asInstanceOf[NullableIntVector]) + case LongType => + new LongWriter(vector.asInstanceOf[NullableBigIntVector]) + case FloatType => + new FloatWriter(vector.asInstanceOf[NullableFloat4Vector]) + case DoubleType => + new DoubleWriter(vector.asInstanceOf[NullableFloat8Vector]) + case DecimalType.Fixed(precision, scale) => + new DecimalWriter(vector.asInstanceOf[NullableDecimalVector], precision, scale) + case StringType => + new StringWriter(vector.asInstanceOf[NullableVarCharVector]) + case BinaryType => + new BinaryWriter(vector.asInstanceOf[NullableVarBinaryVector]) + case ArrayType(_, _) => + val v = vector.asInstanceOf[ListVector] + val elementVector = createFieldWriter(v.getDataVector()) + new ArrayWriter(v, elementVector) + case StructType(_) => + val v = vector.asInstanceOf[NullableMapVector] + val children = (0 until v.size()).map { ordinal => + createFieldWriter(v.getChildByOrdinal(ordinal)) + } + new StructWriter(v, children.toArray) + } + } +} + +class ArrowWriter( + val root: VectorSchemaRoot, + fields: Array[ArrowFieldWriter]) { + + def schema: StructType = StructType(fields.map { f => + StructField(f.name, f.dataType, f.nullable) + }) + + private var count: Int = 0 + + def write(row: InternalRow): Unit = { + var i = 0 + while (i < fields.size) { + fields(i).write(row, i) + i += 1 + } + count += 1 + } + + def finish(): Unit = { + root.setRowCount(count) + fields.foreach(_.finish()) + } + + def reset(): Unit = { + root.setRowCount(0) + count = 0 + fields.foreach(_.reset()) + } +} + +private[sql] abstract class ArrowFieldWriter { + + def valueVector: ValueVector + def valueMutator: ValueVector.Mutator + + def name: String = valueVector.getField().getName() + def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField()) + def nullable: Boolean = valueVector.getField().isNullable() + + def setNull(): Unit + def setValue(input: SpecializedGetters, ordinal: Int): Unit + def skip(): Unit + + protected var count: Int = 0 + + def write(input: SpecializedGetters, ordinal: Int): Unit = { + if (input.isNullAt(ordinal)) { + setNull() + } else { + setValue(input, ordinal) + } + count += 1 + } + + def writeSkip(): Unit = { + skip() + count += 1 + } + + def finish(): Unit = { + valueMutator.setValueCount(count) + } + + def reset(): Unit = { + valueMutator.reset() + count = 0 + } +} + +private[sql] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter { + + override def valueMutator: NullableBitVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableTinyIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getByte(ordinal)) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getShort(ordinal)) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getInt(ordinal)) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getLong(ordinal)) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter { + + override def valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getFloat(ordinal)) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter { + + override def valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getDouble(ordinal)) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class DecimalWriter( + val valueVector: NullableDecimalVector, + precision: Int, + scale: Int) extends ArrowFieldWriter { + + override def valueMutator: NullableDecimalVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setIndexDefined(count) + val decimal = input.getDecimal(ordinal, precision, scale) + decimal.changePrecision(precision, scale) + DecimalUtility.writeBigDecimalToArrowBuf(decimal.toJavaBigDecimal, valueVector.getBuffer, count) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { + + override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val utf8 = input.getUTF8String(ordinal) + // todo: for off-heap UTF8String, how to pass in to arrow without copy? + valueMutator.setSafe(count, utf8.getByteBuffer, 0, utf8.numBytes()) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class BinaryWriter(val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter { + + override def valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val bytes = input.getBinary(ordinal) + valueMutator.setSafe(count, bytes, 0, bytes.length) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } +} + +private[sql] class ArrayWriter( + val valueVector: ListVector, + val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { + + override def valueMutator: ListVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val array = input.getArray(ordinal) + var i = 0 + valueMutator.startNewValue(count) + while (i < array.numElements()) { + elementWriter.write(array, i) + i += 1 + } + valueMutator.endValue(count, array.numElements()) + } + + override def skip(): Unit = { + valueMutator.setNotNull(count) + } + + override def finish(): Unit = { + super.finish() + elementWriter.finish() + } + + override def reset(): Unit = { + super.reset() + elementWriter.reset() + } +} + +private[sql] class StructWriter( + val valueVector: NullableMapVector, + children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { + + override def valueMutator: NullableMapVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + var i = 0 + while (i < children.length) { + children(i).writeSkip() + i += 1 + } + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val struct = input.getStruct(ordinal, children.length) + var i = 0 + while (i < struct.numFields) { + children(i).write(struct, i) + i += 1 + } + valueMutator.setIndexDefined(count) + } + + override def skip(): Unit = { + valueMutator.setIndexDefined(count) + } + + override def finish(): Unit = { + super.finish() + children.foreach(_.finish()) + } + + override def reset(): Unit = { + super.reset() + children.foreach(_.reset()) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowUtilsSuite.scala new file mode 100644 index 000000000000..f318d887d57a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowUtilsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class ArrowUtilsSuite extends SparkFunSuite { + + def roundtrip(dt: DataType): Unit = { + dt match { + case schema: StructType => + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema)) === schema) + case _ => + roundtrip(new StructType().add("value", dt)) + } + } + + test("simple") { + roundtrip(BooleanType) + roundtrip(ByteType) + roundtrip(ShortType) + roundtrip(IntegerType) + roundtrip(LongType) + roundtrip(FloatType) + roundtrip(DoubleType) + roundtrip(StringType) + roundtrip(BinaryType) + roundtrip(DecimalType.SYSTEM_DEFAULT) + } + + test("array") { + roundtrip(ArrayType(IntegerType, containsNull = true)) + roundtrip(ArrayType(IntegerType, containsNull = false)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = true)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = false)) + } + + test("struct") { + roundtrip(new StructType()) + roundtrip(new StructType().add("i", IntegerType)) + roundtrip(new StructType().add("arr", ArrayType(IntegerType))) + roundtrip(new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType))) + roundtrip(new StructType().add( + "struct", + new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType)))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowWriterSuite.scala new file mode 100644 index 000000000000..e1ca05a14180 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowWriterSuite.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ArrowWriterSuite extends SparkFunSuite { + + test("simple") { + def check(dt: DataType, data: Seq[Any], get: (ArrowColumnVector, Int) => Any): Unit = { + val schema = new StructType().add("value", dt, nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + data.zipWithIndex.foreach { + case (null, rowId) => assert(reader.isNullAt(rowId)) + case (datum, rowId) => assert(get(reader, rowId) === datum) + } + + writer.root.close() + } + check(BooleanType, Seq(true, null, false), (reader, rowId) => reader.getBoolean(rowId)) + check(ByteType, + Seq(1.toByte, 2.toByte, null, 4.toByte), (reader, rowId) => reader.getByte(rowId)) + check(ShortType, + Seq(1.toShort, 2.toShort, null, 4.toShort), (reader, rowId) => reader.getShort(rowId)) + check(IntegerType, Seq(1, 2, null, 4), (reader, rowId) => reader.getInt(rowId)) + check(LongType, Seq(1L, 2L, null, 4L), (reader, rowId) => reader.getLong(rowId)) + check(FloatType, Seq(1.0f, 2.0f, null, 4.0f), (reader, rowId) => reader.getFloat(rowId)) + check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d), (reader, rowId) => reader.getDouble(rowId)) + + check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4)), + (reader, rowId) => reader.getDecimal( + rowId, DecimalType.SYSTEM_DEFAULT.precision, DecimalType.SYSTEM_DEFAULT.scale)) + + check(StringType, + Seq("a", "b", null, "d").map(UTF8String.fromString), + (reader, rowId) => reader.getUTF8String(rowId)) + + check(BinaryType, + Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()), + (reader, rowId) => reader.getBinary(rowId)) + } + + test("get multiple") { + def check[A](dt: DataType, data: Seq[A], get: (ArrowColumnVector, Int) => Seq[A]): Unit = { + val schema = new StructType().add("value", dt, nullable = false) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + assert(get(reader, data.size) === data) + + writer.root.close() + } + check(BooleanType, Seq(true, false), (reader, count) => reader.getBooleans(0, count)) + check(ByteType, (0 until 10).map(_.toByte), (reader, count) => reader.getBytes(0, count)) + check(ShortType, (0 until 10).map(_.toShort), (reader, count) => reader.getShorts(0, count)) + check(IntegerType, (0 until 10), (reader, count) => reader.getInts(0, count)) + check(LongType, (0 until 10).map(_.toLong), (reader, count) => reader.getLongs(0, count)) + check(FloatType, (0 until 10).map(_.toFloat), (reader, count) => reader.getFloats(0, count)) + check(DoubleType, (0 until 10).map(_.toDouble), (reader, count) => reader.getDoubles(0, count)) + } + + test("array") { + val schema = new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) + writer.write(InternalRow(ArrayData.toArrayData(Array(4, 5)))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty[Int]))) + writer.write(InternalRow(ArrayData.toArrayData(Array(6, null, 8)))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 3) + assert(array0.getInt(0) === 1) + assert(array0.getInt(1) === 2) + assert(array0.getInt(2) === 3) + + val array1 = reader.getArray(1) + assert(array1.numElements() === 2) + assert(array1.getInt(0) === 4) + assert(array1.getInt(1) === 5) + + assert(reader.isNullAt(2)) + + val array3 = reader.getArray(3) + assert(array3.numElements() === 0) + + val array4 = reader.getArray(4) + assert(array4.numElements() === 3) + assert(array4.getInt(0) === 6) + assert(array4.isNullAt(1)) + assert(array4.getInt(2) === 8) + + writer.root.close() + } + + test("nested array") { + val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array( + ArrayData.toArrayData(Array(1, 2, 3)), + ArrayData.toArrayData(Array(4, 5)), + null, + ArrayData.toArrayData(Array.empty[Int]), + ArrayData.toArrayData(Array(6, null, 8)))))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 5) + + val array00 = array0.getArray(0) + assert(array00.numElements() === 3) + assert(array00.getInt(0) === 1) + assert(array00.getInt(1) === 2) + assert(array00.getInt(2) === 3) + + val array01 = array0.getArray(1) + assert(array01.numElements() === 2) + assert(array01.getInt(0) === 4) + assert(array01.getInt(1) === 5) + + assert(array0.isNullAt(2)) + + val array03 = array0.getArray(3) + assert(array03.numElements() === 0) + + val array04 = array0.getArray(4) + assert(array04.numElements() === 3) + assert(array04.getInt(0) === 6) + assert(array04.isNullAt(1)) + assert(array04.getInt(2) === 8) + + assert(reader.isNullAt(1)) + + val array2 = reader.getArray(2) + assert(array2.numElements() === 0) + + writer.root.close() + } + + test("struct") { + val schema = new StructType() + .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) + writer.write(InternalRow(InternalRow(null, null))) + writer.write(InternalRow(null)) + writer.write(InternalRow(InternalRow(4, null))) + writer.write(InternalRow(InternalRow(null, UTF8String.fromString("str5")))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct0 = reader.getStruct(0, 2) + assert(struct0.getInt(0) === 1) + assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct1 = reader.getStruct(1, 2) + assert(struct1.isNullAt(0)) + assert(struct1.isNullAt(1)) + + assert(reader.isNullAt(2)) + + val struct3 = reader.getStruct(3, 2) + assert(struct3.getInt(0) === 4) + assert(struct3.isNullAt(1)) + + val struct4 = reader.getStruct(4, 2) + assert(struct4.isNullAt(0)) + assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) + + writer.root.close() + } + + test("nested struct") { + val schema = new StructType().add("struct", + new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) + writer.write(InternalRow(InternalRow(InternalRow(null, null)))) + writer.write(InternalRow(InternalRow(null))) + writer.write(InternalRow(null)) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + assert(struct00.getInt(0) === 1) + assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + assert(struct10.isNullAt(0)) + assert(struct10.isNullAt(1)) + + val struct2 = reader.getStruct(2, 1) + assert(struct2.isNullAt(0)) + + assert(reader.isNullAt(3)) + + writer.root.close() + } +} From 6fcf700d381a18704b3ee485083ee51c82bbd2f8 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 14 Jul 2017 14:48:33 +0900 Subject: [PATCH 02/15] Use ArrowWriter for ArrowConverters. --- .../sql/execution/arrow/ArrowConverters.scala | 279 +-------- .../arrow/ArrowConvertersSuite.scala | 529 +++++++++++++++++- 2 files changed, 539 insertions(+), 269 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 6af5c7342237..66b0a597995d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -20,19 +20,14 @@ package org.apache.spark.sql.execution.arrow import java.io.ByteArrayOutputStream import java.nio.channels.Channels -import scala.collection.JavaConverters._ - -import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector._ -import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.file._ -import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.vectorized.{ArrowUtils, ArrowWriter} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -70,34 +65,6 @@ private[sql] object ArrowPayload { private[sql] object ArrowConverters { - /** - * Map a Spark DataType to ArrowType. - */ - private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { - dataType match { - case BooleanType => ArrowType.Bool.INSTANCE - case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) - case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) - case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case ByteType => new ArrowType.Int(8, true) - case StringType => ArrowType.Utf8.INSTANCE - case BinaryType => ArrowType.Binary.INSTANCE - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } - - /** - * Convert a Spark Dataset schema to Arrow schema. - */ - private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map { f => - new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) - } - new Schema(arrowFields.toList.asJava) - } - /** * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. @@ -143,31 +110,24 @@ private[sql] object ArrowConverters { allocator: BufferAllocator, maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { - val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => - ColumnWriter(field.dataType, ordinal, allocator).init() - } + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) - val writerLength = columnWriters.length var recordsInBatch = 0 while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { val row = rowIter.next() - var i = 0 - while (i < writerLength) { - columnWriters(i).write(row) - i += 1 - } + arrowWriter.write(row) recordsInBatch += 1 } + arrowWriter.finish() - val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip - val buffers = bufferArrays.flatten - - val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 - val recordBatch = new ArrowRecordBatch(rowLength, - fieldNodes.toList.asJava, buffers.toList.asJava) - - buffers.foreach(_.release()) - recordBatch + Utils.tryWithSafeFinally { + val unloader = new VectorUnloader(arrowWriter.root) + unloader.getRecordBatch() + } { + root.close() + } } /** @@ -178,7 +138,7 @@ private[sql] object ArrowConverters { batch: ArrowRecordBatch, schema: StructType, allocator: BufferAllocator): Array[Byte] = { - val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val arrowSchema = ArrowUtils.toArrowSchema(schema) val root = VectorSchemaRoot.create(arrowSchema, allocator) val out = new ByteArrayOutputStream() val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) @@ -216,214 +176,3 @@ private[sql] object ArrowConverters { } } } - -/** - * Interface for writing InternalRows to Arrow Buffers. - */ -private[arrow] trait ColumnWriter { - def init(): this.type - def write(row: InternalRow): Unit - - /** - * Clear the column writer and return the ArrowFieldNode and ArrowBuf. - * This should be called only once after all the data is written. - */ - def finish(): (ArrowFieldNode, Array[ArrowBuf]) -} - -/** - * Base class for flat arrow column writer, i.e., column without children. - */ -private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) - extends ColumnWriter { - - def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) - - def valueVector: BaseDataValueVector - def valueMutator: BaseMutator - - def setNull(): Unit - def setValue(row: InternalRow): Unit - - protected var count = 0 - protected var nullCount = 0 - - override def init(): this.type = { - valueVector.allocateNew() - this - } - - override def write(row: InternalRow): Unit = { - if (row.isNullAt(ordinal)) { - setNull() - nullCount += 1 - } else { - setValue(row) - } - count += 1 - } - - override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { - valueMutator.setValueCount(count) - val fieldNode = new ArrowFieldNode(count, nullCount) - val valueBuffers = valueVector.getBuffers(true) - (fieldNode, valueBuffers) - } -} - -private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBitVector - = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) -} - -private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableSmallIntVector - = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) - override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getShort(ordinal)) -} - -private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableIntVector - = new NullableIntVector("IntValue", getFieldType(dtype), allocator) - override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getInt(ordinal)) -} - -private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBigIntVector - = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getLong(ordinal)) -} - -private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat4Vector - = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getFloat(ordinal)) -} - -private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat8Vector - = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getDouble(ordinal)) -} - -private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableUInt1Vector - = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) - override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getByte(ordinal)) -} - -private[arrow] class UTF8StringColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarCharVector - = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val str = row.getUTF8String(ordinal) - valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) - } -} - -private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val bytes = row.getBinary(ordinal) - valueMutator.setSafe(count, bytes, 0, bytes.length) - } -} - -private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableDateDayVector - = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) - override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getInt(ordinal)) - } -} - -private[arrow] class TimeStampColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableTimeStampMicroVector - = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) - override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getLong(ordinal)) - } -} - -private[arrow] object ColumnWriter { - - /** - * Create an Arrow ColumnWriter given the type and ordinal of row. - */ - def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { - val dtype = ArrowConverters.sparkTypeToArrowType(dataType) - dataType match { - case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) - case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) - case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) - case LongType => new LongColumnWriter(dtype, ordinal, allocator) - case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) - case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) - case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) - case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) - case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) - case DateType => new DateColumnWriter(dtype, ordinal, allocator) - case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 159328cc0d95..38ac88a21acc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -31,8 +31,9 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.execution.vectorized.ArrowUtils import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -391,6 +392,85 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "floating_point-double_precision.json") } + ignore("decimal conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "nullable" : true, + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_d", + | "nullable" : true, + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ + | 1.000000000000000000, + | 2.000000000000000000, + | 0.010000000000000000, + | 200.000000000000000000, + | 0.000100000000000000, + | 20000.000000000000000000 ] + | }, { + | "name" : "b_d", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ + | 1.100000000000000000, + | 0E-18, + | 0E-18, + | 2.200000000000000000, + | 0E-18, + | 3.300000000000000000 ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0).map(Decimal(_)) + val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3))) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "decimalData.json") + } + test("index conversion") { val data = List[Int](1, 2, 3, 4, 5, 6) val json = @@ -857,6 +937,449 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "nanData-floating_point.json") } + test("array type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 4, 4, 5 ], + | "children" : [ { + | "name" : "element", + | "count" : 5, + | "VALIDITY" : [ 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5 ] + | } ] + | }, { + | "name" : "b_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 0, 1, 0 ], + | "OFFSET" : [ 0, 2, 2, 2, 2 ], + | "children" : [ { + | "name" : "element", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1, 2 ] + | } ] + | }, { + | "name" : "c_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 4, 4, 5 ], + | "children" : [ { + | "name" : "element", + | "count" : 5, + | "VALIDITY" : [ 1, 1, 1, 0, 1 ], + | "DATA" : [ 1, 2, 3, 0, 5 ] + | } ] + | }, { + | "name" : "d_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 3, 3, 4 ], + | "children" : [ { + | "name" : "element", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 3, 3, 4 ], + | "children" : [ { + | "name" : "element", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 5 ] + | } ] + | } ] + | } ] + | } ] + |} + """.stripMargin + + val a_arr = Seq(Seq(1, 2), Seq(3, 4), Seq(), Seq(5)) + val b_arr = Seq(Some(Seq(1, 2)), None, Some(Seq()), None) + val c_arr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5))) + val d_arr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5))) + + val df = a_arr.zip(b_arr).zip(c_arr).zip(d_arr).map { + case (((a, b), c), d) => (a, b, c, d) + }.toDF("a_arr", "b_arr", "c_arr", "d_arr") + + collectAndValidate(df, json, "arrayData.json") + } + + test("struct type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_struct", + | "nullable" : false, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "b_struct", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "c_struct", + | "nullable" : false, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "d_struct", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "nested", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "b_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "c_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "d_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "children" : [ { + | "name" : "nested", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 0 ], + | "children" : [ { + | "name" : "i", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1, 2 ] + | } ] + | } ] + | } ] + | } ] + |} + """.stripMargin + + val a_struct = Seq(Row(1), Row(2), Row(3)) + val b_struct = Seq(Row(1), null, Row(3)) + val c_struct = Seq(Row(1), Row(null), Row(3)) + val d_struct = Seq(Row(Row(1)), null, Row(null)) + val data = a_struct.zip(b_struct).zip(c_struct).zip(d_struct).map { + case (((a, b), c), d) => Row(a, b, c, d) + } + + val rdd = sparkContext.parallelize(data) + val schema = new StructType() + .add("a_struct", new StructType().add("i", IntegerType, nullable = false), nullable = false) + .add("b_struct", new StructType().add("i", IntegerType, nullable = false), nullable = true) + .add("c_struct", new StructType().add("i", IntegerType, nullable = true), nullable = false) + .add("d_struct", new StructType().add("nested", new StructType().add("i", IntegerType))) + val df = spark.createDataFrame(rdd, schema) + + collectAndValidate(df, json, "structData.json") + } + test("partitioned DataFrame") { val json1 = s""" @@ -1038,8 +1561,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { decimalData.toArrowPayload.collect() } - runUnsupported { arrayData.toDF().toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } @@ -1202,7 +1723,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val allocator = new RootAllocator(Long.MaxValue) val jsonReader = new JsonFileReader(jsonFile, allocator) - val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) From 579def2db0a0f015760a458032d3bd916669201c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 14 Jul 2017 18:42:32 +0900 Subject: [PATCH 03/15] Refactor ArrowConverters. --- python/pyspark/sql/tests.py | 4 +- .../sql/execution/arrow/ArrowConverters.scala | 115 +++++------------- 2 files changed, 31 insertions(+), 88 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index be5495ca019a..fba13cb0722f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2923,8 +2923,8 @@ def assertFramesEqual(self, df_with_arrow, df_without): self.assertTrue(df_without.equals(df_with_arrow), msg=msg) def test_unsupported_datatype(self): - schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) - df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) with QuietTest(self.sc): self.assertRaises(Exception, lambda: df.toPandas()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 66b0a597995d..86a85550ef57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow import java.io.ByteArrayOutputStream import java.nio.channels.Channels -import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.file._ import org.apache.arrow.vector.schema.ArrowRecordBatch @@ -50,19 +50,6 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se def asPythonSerializable: Array[Byte] = payload } -private[sql] object ArrowPayload { - - /** - * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. - */ - def apply( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): ArrowPayload = { - new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) - } -} - private[sql] object ArrowConverters { /** @@ -73,89 +60,45 @@ private[sql] object ArrowConverters { rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { + new Iterator[ArrowPayload] { - private val _allocator = new RootAllocator(Long.MaxValue) - private var _nextPayload = if (rowIter.nonEmpty) convert() else null - override def hasNext: Boolean = _nextPayload != null + private val arrowSchema = ArrowUtils.toArrowSchema(schema) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) + + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + private val arrowWriter = ArrowWriter.create(root) + + override def hasNext: Boolean = rowIter.hasNext || { + root.close() + allocator.close() + false + } override def next(): ArrowPayload = { - val obj = _nextPayload - if (hasNext) { - if (rowIter.hasNext) { - _nextPayload = convert() - } else { - _allocator.close() - _nextPayload = null + val out = new ByteArrayOutputStream() + val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + + Utils.tryWithSafeFinally { + var rowId = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowId < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + rowId += 1 } + arrowWriter.finish() + writer.writeBatch() + } { + arrowWriter.reset() + writer.close() } - obj - } - private def convert(): ArrowPayload = { - val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) - ArrowPayload(batch, schema, _allocator) + new ArrowPayload(out.toByteArray) } } } - /** - * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed - * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, - * then rowIter will be fully consumed. - */ - private def internalRowIterToArrowBatch( - rowIter: Iterator[InternalRow], - schema: StructType, - allocator: BufferAllocator, - maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { - - val arrowSchema = ArrowUtils.toArrowSchema(schema) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val arrowWriter = ArrowWriter.create(root) - - var recordsInBatch = 0 - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { - val row = rowIter.next() - arrowWriter.write(row) - recordsInBatch += 1 - } - arrowWriter.finish() - - Utils.tryWithSafeFinally { - val unloader = new VectorUnloader(arrowWriter.root) - unloader.getRecordBatch() - } { - root.close() - } - } - - /** - * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, - * the batch can no longer be used. - */ - private[arrow] def batchToByteArray( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): Array[Byte] = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) - - // Write a batch to byte stream, ensure the batch, allocator and writer are closed - Utils.tryWithSafeFinally { - val loader = new VectorLoader(root) - loader.load(batch) - writer.writeBatch() // writeBatch can throw IOException - } { - batch.close() - root.close() - writer.close() - } - out.toByteArray - } - /** * Convert a byte array to an ArrowRecordBatch. */ From 58cd46506b02800269380f7c8acb5f9825664cad Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 17 Jul 2017 15:30:17 +0900 Subject: [PATCH 04/15] Move releasing memory into task completion listener. --- .../scala/org/apache/spark/sql/Dataset.scala | 4 ++- .../sql/execution/arrow/ArrowConverters.scala | 27 ++++++++++--------- .../arrow/ArrowConvertersSuite.scala | 1 + 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b825b6cd6160..82c8699e4e8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils +import org.apache.spark.TaskContext import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ @@ -3086,7 +3087,8 @@ class Dataset[T] private[sql]( val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch queryExecution.toRdd.mapPartitionsInternal { iter => - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) + val context = TaskContext.get() + ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 86a85550ef57..95dce98c198c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -26,6 +26,7 @@ import org.apache.arrow.vector.file._ import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.{ArrowUtils, ArrowWriter} import org.apache.spark.sql.types._ @@ -59,22 +60,24 @@ private[sql] object ArrowConverters { private[sql] def toPayloadIterator( rowIter: Iterator[InternalRow], schema: StructType, - maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { + maxRecordsPerBatch: Int, + context: TaskContext): Iterator[ArrowPayload] = { - new Iterator[ArrowPayload] { + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val allocator = + ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) - private val arrowSchema = ArrowUtils.toArrowSchema(schema) - private val allocator = - ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) - private val root = VectorSchemaRoot.create(arrowSchema, allocator) - private val arrowWriter = ArrowWriter.create(root) + context.addTaskCompletionListener { _ => + root.close() + allocator.close() + } - override def hasNext: Boolean = rowIter.hasNext || { - root.close() - allocator.close() - false - } + new Iterator[ArrowPayload] { + + override def hasNext: Boolean = rowIter.hasNext override def next(): ArrowPayload = { val out = new ByteArrayOutputStream() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 38ac88a21acc..9b3210ca62ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1538,6 +1538,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") val arrowPayloads = df.toArrowPayload.collect() + assert(arrowPayloads.length >= 4) val allocator = new RootAllocator(Long.MaxValue) val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) var recordCount = 0 From 8ffedda9f05d379d700aef95dca049a751374f87 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 18 Jul 2017 15:31:59 +0900 Subject: [PATCH 05/15] Revert ArrowColumnVector related implementations to split a pr into multiple prs. --- .../vectorized/ArrowColumnVector.java | 502 ------------------ .../execution/vectorized/ColumnVector.java | 16 +- .../vectorized/OnHeapColumnVector.java | 2 +- .../vectorized/ArrowWriterSuite.scala | 250 --------- 4 files changed, 9 insertions(+), 761 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowWriterSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java deleted file mode 100644 index abc8a8d23d44..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ /dev/null @@ -1,502 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.vectorized; - -import org.apache.arrow.vector.*; -import org.apache.arrow.vector.complex.*; -import org.apache.arrow.vector.holders.NullableVarCharHolder; - -import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * A column backed by Apache Arrow. - */ -public final class ArrowColumnVector extends ColumnVector { - - private ValueVector vector; - private ValueVector.Accessor nulls; - - private NullableBitVector boolData; - private NullableTinyIntVector byteData; - private NullableSmallIntVector shortData; - private NullableIntVector intData; - private NullableBigIntVector longData; - - private NullableFloat4Vector floatData; - private NullableFloat8Vector doubleData; - private NullableDecimalVector decimalData; - - private NullableVarCharVector stringData; - - private NullableVarBinaryVector binaryData; - - private UInt4Vector listOffsetData; - - public ArrowColumnVector(ValueVector vector) { - super(vector.getValueCapacity(), DataTypes.NullType, MemoryMode.OFF_HEAP); - initialize(vector); - } - - @Override - public long nullsNativeAddress() { - throw new RuntimeException("Cannot get native address for arrow column"); - } - - @Override - public long valuesNativeAddress() { - throw new RuntimeException("Cannot get native address for arrow column"); - } - - @Override - public void close() { - if (childColumns != null) { - for (int i = 0; i < childColumns.length; i++) { - childColumns[i].close(); - } - } - vector.close(); - } - - // - // APIs dealing with nulls - // - - @Override - public void putNotNull(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public void putNull(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public void putNulls(int rowId, int count) { - throw new UnsupportedOperationException(); - } - - @Override - public void putNotNulls(int rowId, int count) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isNullAt(int rowId) { - return nulls.isNull(rowId); - } - - // - // APIs dealing with Booleans - // - - @Override - public void putBoolean(int rowId, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putBooleans(int rowId, int count, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(int rowId) { - return boolData.getAccessor().get(rowId) == 1; - } - - @Override - public boolean[] getBooleans(int rowId, int count) { - assert(dictionary == null); - boolean[] array = new boolean[count]; - for (int i = 0; i < count; ++i) { - array[i] = (boolData.getAccessor().get(rowId + i) == 1); - } - return array; - } - - // - // APIs dealing with Bytes - // - - @Override - public void putByte(int rowId, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putBytes(int rowId, int count, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putBytes(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(int rowId) { - return byteData.getAccessor().get(rowId); - } - - @Override - public byte[] getBytes(int rowId, int count) { - assert(dictionary == null); - byte[] array = new byte[count]; - for (int i = 0; i < count; ++i) { - array[i] = byteData.getAccessor().get(rowId + i); - } - return array; - } - - // - // APIs dealing with Shorts - // - - @Override - public void putShort(int rowId, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putShorts(int rowId, int count, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putShorts(int rowId, int count, short[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(int rowId) { - return shortData.getAccessor().get(rowId); - } - - @Override - public short[] getShorts(int rowId, int count) { - assert(dictionary == null); - short[] array = new short[count]; - for (int i = 0; i < count; ++i) { - array[i] = shortData.getAccessor().get(rowId + i); - } - return array; - } - - // - // APIs dealing with Ints - // - - @Override - public void putInt(int rowId, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putInts(int rowId, int count, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putInts(int rowId, int count, int[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public int getInt(int rowId) { - return intData.getAccessor().get(rowId); - } - - @Override - public int[] getInts(int rowId, int count) { - assert(dictionary == null); - int[] array = new int[count]; - for (int i = 0; i < count; ++i) { - array[i] = intData.getAccessor().get(rowId + i); - } - return array; - } - - @Override - public int getDictId(int rowId) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Longs - // - - @Override - public void putLong(int rowId, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putLongs(int rowId, int count, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putLongs(int rowId, int count, long[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(int rowId) { - return longData.getAccessor().get(rowId); - } - - @Override - public long[] getLongs(int rowId, int count) { - assert(dictionary == null); - long[] array = new long[count]; - for (int i = 0; i < count; ++i) { - array[i] = longData.getAccessor().get(rowId + i); - } - return array; - } - - // - // APIs dealing with floats - // - - @Override - public void putFloat(int rowId, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putFloats(int rowId, int count, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putFloats(int rowId, int count, float[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public void putFloats(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(int rowId) { - return floatData.getAccessor().get(rowId); - } - - @Override - public float[] getFloats(int rowId, int count) { - assert(dictionary == null); - float[] array = new float[count]; - for (int i = 0; i < count; ++i) { - array[i] = floatData.getAccessor().get(rowId + i); - } - return array; - } - - // - // APIs dealing with doubles - // - - @Override - public void putDouble(int rowId, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putDoubles(int rowId, int count, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(int rowId) { - return doubleData.getAccessor().get(rowId); - } - - @Override - public double[] getDoubles(int rowId, int count) { - assert(dictionary == null); - double[] array = new double[count]; - for (int i = 0; i < count; ++i) { - array[i] = doubleData.getAccessor().get(rowId + i); - } - return array; - } - - // - // APIs dealing with Arrays - // - - @Override - public int getArrayLength(int rowId) { - return listOffsetData.getAccessor().get(rowId + 1) - listOffsetData.getAccessor().get(rowId); - } - - @Override - public int getArrayOffset(int rowId) { - return listOffsetData.getAccessor().get(rowId); - } - - @Override - public void putArray(int rowId, int offset, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public void loadBytes(Array array) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Byte Arrays - // - - @Override - public int putByteArray(int rowId, byte[] value, int offset, int count) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Decimals - // - - @Override - public Decimal getDecimal(int rowId, int precision, int scale) { - if (isNullAt(rowId)) return null; - return Decimal.apply(decimalData.getAccessor().getObject(rowId), precision, scale); - } - - @Override - public final void putDecimal(int rowId, Decimal value, int precision) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with UTF8Strings - // - - private NullableVarCharHolder stringResult = new NullableVarCharHolder(); - - @Override - public UTF8String getUTF8String(int rowId) { - stringData.getAccessor().get(rowId, stringResult); - if (stringResult.isSet == 0) { - return null; - } else { - return UTF8String.fromAddress(null, - stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); - } - } - - // - // APIs dealing with Binaries - // - - @Override - public byte[] getBinary(int rowId) { - return binaryData.getAccessor().getObject(rowId); - } - - @Override - protected void reserveInternal(int newCapacity) { - while (vector.getValueCapacity() <= newCapacity) { - vector.reAlloc(); - } - capacity = vector.getValueCapacity(); - } - - private void initialize(ValueVector vector) { - this.vector = vector; - this.type = ArrowUtils.fromArrowField(vector.getField()); - if (vector instanceof NullableBitVector) { - boolData = (NullableBitVector) vector; - nulls = boolData.getAccessor(); - } else if (vector instanceof NullableTinyIntVector) { - byteData = (NullableTinyIntVector) vector; - nulls = byteData.getAccessor(); - } else if (vector instanceof NullableSmallIntVector) { - shortData = (NullableSmallIntVector) vector; - nulls = shortData.getAccessor(); - } else if (vector instanceof NullableIntVector) { - intData = (NullableIntVector) vector; - nulls = intData.getAccessor(); - } else if (vector instanceof NullableBigIntVector) { - longData = (NullableBigIntVector) vector; - nulls = longData.getAccessor(); - } else if (vector instanceof NullableFloat4Vector) { - floatData = (NullableFloat4Vector) vector; - nulls = floatData.getAccessor(); - } else if (vector instanceof NullableFloat8Vector) { - doubleData = (NullableFloat8Vector) vector; - nulls = doubleData.getAccessor(); - } else if (vector instanceof NullableDecimalVector) { - decimalData = (NullableDecimalVector) vector; - nulls = decimalData.getAccessor(); - } else if (vector instanceof NullableVarCharVector) { - stringData = (NullableVarCharVector) vector; - nulls = stringData.getAccessor(); - } else if (vector instanceof NullableVarBinaryVector) { - binaryData = (NullableVarBinaryVector) vector; - nulls = binaryData.getAccessor(); - } else if (vector instanceof ListVector) { - ListVector listVector = (ListVector) vector; - listOffsetData = listVector.getOffsetVector(); - nulls = listVector.getAccessor(); - - childColumns = new ColumnVector[1]; - childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - resultArray = new Array(childColumns[0]); - } else if (vector instanceof MapVector) { - MapVector mapVector = (MapVector) vector; - nulls = mapVector.getAccessor(); - - childColumns = new ArrowColumnVector[mapVector.size()]; - for (int i = 0; i < childColumns.length; ++i) { - childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); - } - resultStruct = new ColumnarBatch.Row(childColumns); - } - numNulls = nulls.getNullCount(); - anyNullsSet = numNulls > 0; - isConstant = true; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 77966382881b..0c027f80d48c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -646,7 +646,7 @@ public MapData getMap(int ordinal) { /** * Returns the decimal for rowId. */ - public Decimal getDecimal(int rowId, int precision, int scale) { + public final Decimal getDecimal(int rowId, int precision, int scale) { if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -661,7 +661,7 @@ public Decimal getDecimal(int rowId, int precision, int scale) { } - public void putDecimal(int rowId, Decimal value, int precision) { + public final void putDecimal(int rowId, Decimal value, int precision) { if (precision <= Decimal.MAX_INT_DIGITS()) { putInt(rowId, (int) value.toUnscaledLong()); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -675,7 +675,7 @@ public void putDecimal(int rowId, Decimal value, int precision) { /** * Returns the UTF8String for rowId. */ - public UTF8String getUTF8String(int rowId) { + public final UTF8String getUTF8String(int rowId) { if (dictionary == null) { ColumnVector.Array a = getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); @@ -688,7 +688,7 @@ public UTF8String getUTF8String(int rowId) { /** * Returns the byte array for rowId. */ - public byte[] getBinary(int rowId) { + public final byte[] getBinary(int rowId) { if (dictionary == null) { ColumnVector.Array array = getByteArray(rowId); byte[] bytes = new byte[array.length]; @@ -956,7 +956,7 @@ public final int appendStruct(boolean isNull) { /** * Data type for this column. */ - protected DataType type; + protected final DataType type; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. @@ -988,17 +988,17 @@ public final int appendStruct(boolean isNull) { /** * If this is a nested type (array or struct), the column for the child data. */ - protected ColumnVector[] childColumns; + protected final ColumnVector[] childColumns; /** * Reusable Array holder for getArray(). */ - protected Array resultArray; + protected final Array resultArray; /** * Reusable Struct holder for getStruct(). */ - protected ColumnarBatch.Row resultStruct; + protected final ColumnarBatch.Row resultStruct; /** * The Dictionary for this column. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 0775d381ef3b..506434364be4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -137,7 +137,7 @@ public boolean[] getBooleans(int rowId, int count) { for (int i = 0; i < count; ++i) { array[i] = (byteData[rowId + i] == 1); } - return array; + return array; } // diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowWriterSuite.scala deleted file mode 100644 index e1ca05a14180..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowWriterSuite.scala +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.vectorized - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -class ArrowWriterSuite extends SparkFunSuite { - - test("simple") { - def check(dt: DataType, data: Seq[Any], get: (ArrowColumnVector, Int) => Any): Unit = { - val schema = new StructType().add("value", dt, nullable = true) - val writer = ArrowWriter.create(schema) - assert(writer.schema === schema) - - data.foreach { datum => - writer.write(InternalRow(datum)) - } - writer.finish() - - val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - data.zipWithIndex.foreach { - case (null, rowId) => assert(reader.isNullAt(rowId)) - case (datum, rowId) => assert(get(reader, rowId) === datum) - } - - writer.root.close() - } - check(BooleanType, Seq(true, null, false), (reader, rowId) => reader.getBoolean(rowId)) - check(ByteType, - Seq(1.toByte, 2.toByte, null, 4.toByte), (reader, rowId) => reader.getByte(rowId)) - check(ShortType, - Seq(1.toShort, 2.toShort, null, 4.toShort), (reader, rowId) => reader.getShort(rowId)) - check(IntegerType, Seq(1, 2, null, 4), (reader, rowId) => reader.getInt(rowId)) - check(LongType, Seq(1L, 2L, null, 4L), (reader, rowId) => reader.getLong(rowId)) - check(FloatType, Seq(1.0f, 2.0f, null, 4.0f), (reader, rowId) => reader.getFloat(rowId)) - check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d), (reader, rowId) => reader.getDouble(rowId)) - - check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4)), - (reader, rowId) => reader.getDecimal( - rowId, DecimalType.SYSTEM_DEFAULT.precision, DecimalType.SYSTEM_DEFAULT.scale)) - - check(StringType, - Seq("a", "b", null, "d").map(UTF8String.fromString), - (reader, rowId) => reader.getUTF8String(rowId)) - - check(BinaryType, - Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()), - (reader, rowId) => reader.getBinary(rowId)) - } - - test("get multiple") { - def check[A](dt: DataType, data: Seq[A], get: (ArrowColumnVector, Int) => Seq[A]): Unit = { - val schema = new StructType().add("value", dt, nullable = false) - val writer = ArrowWriter.create(schema) - assert(writer.schema === schema) - - data.foreach { datum => - writer.write(InternalRow(datum)) - } - writer.finish() - - val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - assert(get(reader, data.size) === data) - - writer.root.close() - } - check(BooleanType, Seq(true, false), (reader, count) => reader.getBooleans(0, count)) - check(ByteType, (0 until 10).map(_.toByte), (reader, count) => reader.getBytes(0, count)) - check(ShortType, (0 until 10).map(_.toShort), (reader, count) => reader.getShorts(0, count)) - check(IntegerType, (0 until 10), (reader, count) => reader.getInts(0, count)) - check(LongType, (0 until 10).map(_.toLong), (reader, count) => reader.getLongs(0, count)) - check(FloatType, (0 until 10).map(_.toFloat), (reader, count) => reader.getFloats(0, count)) - check(DoubleType, (0 until 10).map(_.toDouble), (reader, count) => reader.getDoubles(0, count)) - } - - test("array") { - val schema = new StructType() - .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) - val writer = ArrowWriter.create(schema) - assert(writer.schema === schema) - - writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) - writer.write(InternalRow(ArrayData.toArrayData(Array(4, 5)))) - writer.write(InternalRow(null)) - writer.write(InternalRow(ArrayData.toArrayData(Array.empty[Int]))) - writer.write(InternalRow(ArrayData.toArrayData(Array(6, null, 8)))) - writer.finish() - - val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - - val array0 = reader.getArray(0) - assert(array0.numElements() === 3) - assert(array0.getInt(0) === 1) - assert(array0.getInt(1) === 2) - assert(array0.getInt(2) === 3) - - val array1 = reader.getArray(1) - assert(array1.numElements() === 2) - assert(array1.getInt(0) === 4) - assert(array1.getInt(1) === 5) - - assert(reader.isNullAt(2)) - - val array3 = reader.getArray(3) - assert(array3.numElements() === 0) - - val array4 = reader.getArray(4) - assert(array4.numElements() === 3) - assert(array4.getInt(0) === 6) - assert(array4.isNullAt(1)) - assert(array4.getInt(2) === 8) - - writer.root.close() - } - - test("nested array") { - val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) - val writer = ArrowWriter.create(schema) - assert(writer.schema === schema) - - writer.write(InternalRow(ArrayData.toArrayData(Array( - ArrayData.toArrayData(Array(1, 2, 3)), - ArrayData.toArrayData(Array(4, 5)), - null, - ArrayData.toArrayData(Array.empty[Int]), - ArrayData.toArrayData(Array(6, null, 8)))))) - writer.write(InternalRow(null)) - writer.write(InternalRow(ArrayData.toArrayData(Array.empty))) - writer.finish() - - val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - - val array0 = reader.getArray(0) - assert(array0.numElements() === 5) - - val array00 = array0.getArray(0) - assert(array00.numElements() === 3) - assert(array00.getInt(0) === 1) - assert(array00.getInt(1) === 2) - assert(array00.getInt(2) === 3) - - val array01 = array0.getArray(1) - assert(array01.numElements() === 2) - assert(array01.getInt(0) === 4) - assert(array01.getInt(1) === 5) - - assert(array0.isNullAt(2)) - - val array03 = array0.getArray(3) - assert(array03.numElements() === 0) - - val array04 = array0.getArray(4) - assert(array04.numElements() === 3) - assert(array04.getInt(0) === 6) - assert(array04.isNullAt(1)) - assert(array04.getInt(2) === 8) - - assert(reader.isNullAt(1)) - - val array2 = reader.getArray(2) - assert(array2.numElements() === 0) - - writer.root.close() - } - - test("struct") { - val schema = new StructType() - .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) - val writer = ArrowWriter.create(schema) - assert(writer.schema === schema) - - writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) - writer.write(InternalRow(InternalRow(null, null))) - writer.write(InternalRow(null)) - writer.write(InternalRow(InternalRow(4, null))) - writer.write(InternalRow(InternalRow(null, UTF8String.fromString("str5")))) - writer.finish() - - val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - - val struct0 = reader.getStruct(0, 2) - assert(struct0.getInt(0) === 1) - assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) - - val struct1 = reader.getStruct(1, 2) - assert(struct1.isNullAt(0)) - assert(struct1.isNullAt(1)) - - assert(reader.isNullAt(2)) - - val struct3 = reader.getStruct(3, 2) - assert(struct3.getInt(0) === 4) - assert(struct3.isNullAt(1)) - - val struct4 = reader.getStruct(4, 2) - assert(struct4.isNullAt(0)) - assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) - - writer.root.close() - } - - test("nested struct") { - val schema = new StructType().add("struct", - new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) - val writer = ArrowWriter.create(schema) - assert(writer.schema === schema) - - writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) - writer.write(InternalRow(InternalRow(InternalRow(null, null)))) - writer.write(InternalRow(InternalRow(null))) - writer.write(InternalRow(null)) - writer.finish() - - val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - - val struct00 = reader.getStruct(0, 1).getStruct(0, 2) - assert(struct00.getInt(0) === 1) - assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) - - val struct10 = reader.getStruct(1, 1).getStruct(0, 2) - assert(struct10.isNullAt(0)) - assert(struct10.isNullAt(1)) - - val struct2 = reader.getStruct(2, 1) - assert(struct2.isNullAt(0)) - - assert(reader.isNullAt(3)) - - writer.root.close() - } -} From e3a4fc03b9bdded64cf209d05d1ef6d6c3d926c1 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 19 Jul 2017 12:50:03 +0900 Subject: [PATCH 06/15] Use rowCount instead of rowId. --- .../apache/spark/sql/execution/arrow/ArrowConverters.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 95dce98c198c..c4eb869fd986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -84,11 +84,11 @@ private[sql] object ArrowConverters { val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) Utils.tryWithSafeFinally { - var rowId = 0 - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowId < maxRecordsPerBatch)) { + var rowCount = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { val row = rowIter.next() arrowWriter.write(row) - rowId += 1 + rowCount += 1 } arrowWriter.finish() writer.writeBatch() From b5988f9a223de407b7709f239fca672bb02b60aa Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 19 Jul 2017 12:57:35 +0900 Subject: [PATCH 07/15] Move files back to arrow package. --- .../sql/execution/arrow/ArrowConverters.scala | 1 - .../{vectorized => arrow}/ArrowUtils.scala | 2 +- .../{vectorized => arrow}/ArrowWriter.scala | 29 ++++++++++--------- .../arrow/ArrowConvertersSuite.scala | 1 - .../ArrowUtilsSuite.scala | 2 +- 5 files changed, 17 insertions(+), 18 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{vectorized => arrow}/ArrowUtils.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{vectorized => arrow}/ArrowWriter.scala (91%) rename sql/core/src/test/scala/org/apache/spark/sql/execution/{vectorized => arrow}/ArrowUtilsSuite.scala (97%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index c4eb869fd986..85bed216c631 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -28,7 +28,6 @@ import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ArrowUtils, ArrowWriter} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 9ded74b496a5..d57b1a5eecaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized +package org.apache.spark.sql.execution.arrow import scala.collection.JavaConverters._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowWriter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 0547e9585664..5484fc77a49e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/vectorized/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized +package org.apache.spark.sql.execution.arrow import scala.collection.JavaConverters._ @@ -111,7 +111,7 @@ class ArrowWriter( } } -private[sql] abstract class ArrowFieldWriter { +private[arrow] abstract class ArrowFieldWriter { def valueVector: ValueVector def valueMutator: ValueVector.Mutator @@ -150,7 +150,7 @@ private[sql] abstract class ArrowFieldWriter { } } -private[sql] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter { +private[arrow] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter { override def valueMutator: NullableBitVector#Mutator = valueVector.getMutator() @@ -167,7 +167,7 @@ private[sql] class BooleanWriter(val valueVector: NullableBitVector) extends Arr } } -private[sql] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter { +private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter { override def valueMutator: NullableTinyIntVector#Mutator = valueVector.getMutator() @@ -184,7 +184,7 @@ private[sql] class ByteWriter(val valueVector: NullableTinyIntVector) extends Ar } } -private[sql] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter { +private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter { override def valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator() @@ -201,7 +201,7 @@ private[sql] class ShortWriter(val valueVector: NullableSmallIntVector) extends } } -private[sql] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter { +private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter { override def valueMutator: NullableIntVector#Mutator = valueVector.getMutator() @@ -218,7 +218,7 @@ private[sql] class IntegerWriter(val valueVector: NullableIntVector) extends Arr } } -private[sql] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter { +private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter { override def valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator() @@ -235,7 +235,7 @@ private[sql] class LongWriter(val valueVector: NullableBigIntVector) extends Arr } } -private[sql] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter { +private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter { override def valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator() @@ -252,7 +252,7 @@ private[sql] class FloatWriter(val valueVector: NullableFloat4Vector) extends Ar } } -private[sql] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter { +private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter { override def valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator() @@ -269,7 +269,7 @@ private[sql] class DoubleWriter(val valueVector: NullableFloat8Vector) extends A } } -private[sql] class DecimalWriter( +private[arrow] class DecimalWriter( val valueVector: NullableDecimalVector, precision: Int, scale: Int) extends ArrowFieldWriter { @@ -292,7 +292,7 @@ private[sql] class DecimalWriter( } } -private[sql] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { +private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator() @@ -311,7 +311,8 @@ private[sql] class StringWriter(val valueVector: NullableVarCharVector) extends } } -private[sql] class BinaryWriter(val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter { +private[arrow] class BinaryWriter( + val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter { override def valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator() @@ -329,7 +330,7 @@ private[sql] class BinaryWriter(val valueVector: NullableVarBinaryVector) extend } } -private[sql] class ArrayWriter( +private[arrow] class ArrayWriter( val valueVector: ListVector, val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { @@ -364,7 +365,7 @@ private[sql] class ArrayWriter( } } -private[sql] class StructWriter( +private[arrow] class StructWriter( val valueVector: NullableMapVector, children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 9b3210ca62ba..f5ab9044c272 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -31,7 +31,6 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.execution.vectorized.ArrowUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowUtilsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala index f318d887d57a..638619fd39d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized +package org.apache.spark.sql.execution.arrow import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ From a50a271b2738cf25748f2376935d5b30bf4bc3aa Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 19 Jul 2017 14:59:50 +0900 Subject: [PATCH 08/15] Modify ArrowUtils to avoid deprecated APIs. --- .../spark/sql/execution/arrow/ArrowUtils.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index d57b1a5eecaf..2caf1ef02909 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.types.FloatingPointPrecision -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.types._ @@ -64,15 +64,17 @@ object ArrowUtils { def toArrowField(name: String, dt: DataType, nullable: Boolean): Field = { dt match { case ArrayType(elementType, containsNull) => - new Field(name, nullable, ArrowType.List.INSTANCE, - Seq(toArrowField("element", elementType, containsNull)).asJava) + val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) + new Field(name, fieldType, Seq(toArrowField("element", elementType, containsNull)).asJava) case StructType(fields) => - new Field(name, nullable, ArrowType.Struct.INSTANCE, + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + new Field(name, fieldType, fields.map { field => toArrowField(field.name, field.dataType, field.nullable) }.toSeq.asJava) case dataType => - new Field(name, nullable, toArrowType(dataType), Seq.empty[Field].asJava) + val fieldType = new FieldType(nullable, toArrowType(dataType), null) + new Field(name, fieldType, Seq.empty[Field].asJava) } } From 7084b388d87c8347b79898827658d7827bf5649d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 20 Jul 2017 12:33:20 +0900 Subject: [PATCH 09/15] Modify to close resources also immediately after row iterator is consumed. --- .../sql/execution/arrow/ArrowConverters.scala | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 85bed216c631..240f38f5bfeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -69,14 +69,23 @@ private[sql] object ArrowConverters { val root = VectorSchemaRoot.create(arrowSchema, allocator) val arrowWriter = ArrowWriter.create(root) + var closed = false + context.addTaskCompletionListener { _ => - root.close() - allocator.close() + if (!closed) { + root.close() + allocator.close() + } } new Iterator[ArrowPayload] { - override def hasNext: Boolean = rowIter.hasNext + override def hasNext: Boolean = rowIter.hasNext || { + root.close() + allocator.close() + closed = true + false + } override def next(): ArrowPayload = { val out = new ByteArrayOutputStream() From 6fc4da05a84ee55ec8fd98c078c16d5671a303cc Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 20 Jul 2017 22:30:23 +0900 Subject: [PATCH 10/15] Add ArrowWriterSuite. --- .../execution/arrow/ArrowWriterSuite.scala | 251 ++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala new file mode 100644 index 000000000000..57db5dc3b735 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution.vectorized.ArrowColumnVector +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ArrowWriterSuite extends SparkFunSuite { + + test("simple") { + def check(dt: DataType, data: Seq[Any], get: (ArrowColumnVector, Int) => Any): Unit = { + val schema = new StructType().add("value", dt, nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + data.zipWithIndex.foreach { + case (null, rowId) => assert(reader.isNullAt(rowId)) + case (datum, rowId) => assert(get(reader, rowId) === datum) + } + + writer.root.close() + } + check(BooleanType, Seq(true, null, false), (reader, rowId) => reader.getBoolean(rowId)) + check(ByteType, + Seq(1.toByte, 2.toByte, null, 4.toByte), (reader, rowId) => reader.getByte(rowId)) + check(ShortType, + Seq(1.toShort, 2.toShort, null, 4.toShort), (reader, rowId) => reader.getShort(rowId)) + check(IntegerType, Seq(1, 2, null, 4), (reader, rowId) => reader.getInt(rowId)) + check(LongType, Seq(1L, 2L, null, 4L), (reader, rowId) => reader.getLong(rowId)) + check(FloatType, Seq(1.0f, 2.0f, null, 4.0f), (reader, rowId) => reader.getFloat(rowId)) + check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d), (reader, rowId) => reader.getDouble(rowId)) + + check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4)), + (reader, rowId) => reader.getDecimal( + rowId, DecimalType.SYSTEM_DEFAULT.precision, DecimalType.SYSTEM_DEFAULT.scale)) + + check(StringType, + Seq("a", "b", null, "d").map(UTF8String.fromString), + (reader, rowId) => reader.getUTF8String(rowId)) + + check(BinaryType, + Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()), + (reader, rowId) => reader.getBinary(rowId)) + } + + test("get multiple") { + def check[A](dt: DataType, data: Seq[A], get: (ArrowColumnVector, Int) => Seq[A]): Unit = { + val schema = new StructType().add("value", dt, nullable = false) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + assert(get(reader, data.size) === data) + + writer.root.close() + } + check(BooleanType, Seq(true, false), (reader, count) => reader.getBooleans(0, count)) + check(ByteType, (0 until 10).map(_.toByte), (reader, count) => reader.getBytes(0, count)) + check(ShortType, (0 until 10).map(_.toShort), (reader, count) => reader.getShorts(0, count)) + check(IntegerType, (0 until 10), (reader, count) => reader.getInts(0, count)) + check(LongType, (0 until 10).map(_.toLong), (reader, count) => reader.getLongs(0, count)) + check(FloatType, (0 until 10).map(_.toFloat), (reader, count) => reader.getFloats(0, count)) + check(DoubleType, (0 until 10).map(_.toDouble), (reader, count) => reader.getDoubles(0, count)) + } + + test("array") { + val schema = new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) + writer.write(InternalRow(ArrayData.toArrayData(Array(4, 5)))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty[Int]))) + writer.write(InternalRow(ArrayData.toArrayData(Array(6, null, 8)))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 3) + assert(array0.getInt(0) === 1) + assert(array0.getInt(1) === 2) + assert(array0.getInt(2) === 3) + + val array1 = reader.getArray(1) + assert(array1.numElements() === 2) + assert(array1.getInt(0) === 4) + assert(array1.getInt(1) === 5) + + assert(reader.isNullAt(2)) + + val array3 = reader.getArray(3) + assert(array3.numElements() === 0) + + val array4 = reader.getArray(4) + assert(array4.numElements() === 3) + assert(array4.getInt(0) === 6) + assert(array4.isNullAt(1)) + assert(array4.getInt(2) === 8) + + writer.root.close() + } + + test("nested array") { + val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array( + ArrayData.toArrayData(Array(1, 2, 3)), + ArrayData.toArrayData(Array(4, 5)), + null, + ArrayData.toArrayData(Array.empty[Int]), + ArrayData.toArrayData(Array(6, null, 8)))))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 5) + + val array00 = array0.getArray(0) + assert(array00.numElements() === 3) + assert(array00.getInt(0) === 1) + assert(array00.getInt(1) === 2) + assert(array00.getInt(2) === 3) + + val array01 = array0.getArray(1) + assert(array01.numElements() === 2) + assert(array01.getInt(0) === 4) + assert(array01.getInt(1) === 5) + + assert(array0.isNullAt(2)) + + val array03 = array0.getArray(3) + assert(array03.numElements() === 0) + + val array04 = array0.getArray(4) + assert(array04.numElements() === 3) + assert(array04.getInt(0) === 6) + assert(array04.isNullAt(1)) + assert(array04.getInt(2) === 8) + + assert(reader.isNullAt(1)) + + val array2 = reader.getArray(2) + assert(array2.numElements() === 0) + + writer.root.close() + } + + test("struct") { + val schema = new StructType() + .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) + writer.write(InternalRow(InternalRow(null, null))) + writer.write(InternalRow(null)) + writer.write(InternalRow(InternalRow(4, null))) + writer.write(InternalRow(InternalRow(null, UTF8String.fromString("str5")))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct0 = reader.getStruct(0, 2) + assert(struct0.getInt(0) === 1) + assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct1 = reader.getStruct(1, 2) + assert(struct1.isNullAt(0)) + assert(struct1.isNullAt(1)) + + assert(reader.isNullAt(2)) + + val struct3 = reader.getStruct(3, 2) + assert(struct3.getInt(0) === 4) + assert(struct3.isNullAt(1)) + + val struct4 = reader.getStruct(4, 2) + assert(struct4.isNullAt(0)) + assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) + + writer.root.close() + } + + test("nested struct") { + val schema = new StructType().add("struct", + new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) + writer.write(InternalRow(InternalRow(InternalRow(null, null)))) + writer.write(InternalRow(InternalRow(null))) + writer.write(InternalRow(null)) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + assert(struct00.getInt(0) === 1) + assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + assert(struct10.isNullAt(0)) + assert(struct10.isNullAt(1)) + + val struct2 = reader.getStruct(2, 1) + assert(struct2.isNullAt(0)) + + assert(reader.isNullAt(3)) + + writer.root.close() + } +} From 5bbb46f55a90bae3008ee374ef2c0349a3489c09 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 26 Jul 2017 13:14:38 +0900 Subject: [PATCH 11/15] Remove DecimalType support for now. --- .../sql/execution/arrow/ArrowWriter.scala | 27 +----- .../arrow/ArrowConvertersSuite.scala | 82 +------------------ .../execution/arrow/ArrowWriterSuite.scala | 4 - 3 files changed, 4 insertions(+), 109 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 5484fc77a49e..b7eb6ebc372f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -60,8 +60,6 @@ object ArrowWriter { new FloatWriter(vector.asInstanceOf[NullableFloat4Vector]) case DoubleType => new DoubleWriter(vector.asInstanceOf[NullableFloat8Vector]) - case DecimalType.Fixed(precision, scale) => - new DecimalWriter(vector.asInstanceOf[NullableDecimalVector], precision, scale) case StringType => new StringWriter(vector.asInstanceOf[NullableVarCharVector]) case BinaryType => @@ -76,6 +74,8 @@ object ArrowWriter { createFieldWriter(v.getChildByOrdinal(ordinal)) } new StructWriter(v, children.toArray) + case dt => + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } } } @@ -269,29 +269,6 @@ private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends } } -private[arrow] class DecimalWriter( - val valueVector: NullableDecimalVector, - precision: Int, - scale: Int) extends ArrowFieldWriter { - - override def valueMutator: NullableDecimalVector#Mutator = valueVector.getMutator() - - override def setNull(): Unit = { - valueMutator.setNull(count) - } - - override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setIndexDefined(count) - val decimal = input.getDecimal(ordinal, precision, scale) - decimal.changePrecision(precision, scale) - DecimalUtility.writeBigDecimalToArrowBuf(decimal.toJavaBigDecimal, valueVector.getBuffer, count) - } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } -} - private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index f5ab9044c272..f8af89044b3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -391,85 +391,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "floating_point-double_precision.json") } - ignore("decimal conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_d", - | "nullable" : true, - | "type" : { - | "name" : "decimal", - | "precision" : 38, - | "scale" : 18 - | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "b_d", - | "nullable" : true, - | "type" : { - | "name" : "decimal", - | "precision" : 38, - | "scale" : 18 - | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_d", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ - | 1.000000000000000000, - | 2.000000000000000000, - | 0.010000000000000000, - | 200.000000000000000000, - | 0.000100000000000000, - | 20000.000000000000000000 ] - | }, { - | "name" : "b_d", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ - | 1.100000000000000000, - | 0E-18, - | 0E-18, - | 2.200000000000000000, - | 0E-18, - | 3.300000000000000000 ] - | } ] - | } ] - |} - """.stripMargin - - val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0).map(Decimal(_)) - val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3))) - val df = a_d.zip(b_d).toDF("a_d", "b_d") - - collectAndValidate(df, json, "decimalData.json") - } - test("index conversion") { val data = List[Int](1, 2, 3, 4, 5, 6) val json = @@ -1561,6 +1482,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } + runUnsupported { decimalData.toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 57db5dc3b735..7bf841a16166 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -55,10 +55,6 @@ class ArrowWriterSuite extends SparkFunSuite { check(FloatType, Seq(1.0f, 2.0f, null, 4.0f), (reader, rowId) => reader.getFloat(rowId)) check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d), (reader, rowId) => reader.getDouble(rowId)) - check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4)), - (reader, rowId) => reader.getDecimal( - rowId, DecimalType.SYSTEM_DEFAULT.precision, DecimalType.SYSTEM_DEFAULT.scale)) - check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString), (reader, rowId) => reader.getUTF8String(rowId)) From beff6ef8d44a695b3c6d2730e1dbd4b46d97589b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 26 Jul 2017 15:00:10 +0900 Subject: [PATCH 12/15] Fix skip() for StructType. --- .../org/apache/spark/sql/execution/arrow/ArrowWriter.scala | 5 +++++ .../spark/sql/execution/arrow/ArrowConvertersSuite.scala | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index b7eb6ebc372f..30b43c898c31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -368,6 +368,11 @@ private[arrow] class StructWriter( } override def skip(): Unit = { + var i = 0 + while (i < children.length) { + children(i).writeSkip() + i += 1 + } valueMutator.setIndexDefined(count) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index f8af89044b3b..4f32c374dce0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1271,9 +1271,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "VALIDITY" : [ 1, 1, 0 ], | "children" : [ { | "name" : "i", - | "count" : 2, - | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ 1, 2 ] + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 0 ] | } ] | } ] | } ] From 19f3973c4acf1b05ae51c338481d975cebf66a98 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 26 Jul 2017 15:11:39 +0900 Subject: [PATCH 13/15] Address comments. --- .../sql/execution/arrow/ArrowWriter.scala | 47 +++++-------- .../arrow/ArrowConvertersSuite.scala | 20 +++--- .../execution/arrow/ArrowWriterSuite.scala | 69 +++++++++++-------- 3 files changed, 69 insertions(+), 67 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 30b43c898c31..63a1e0cf5634 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -45,36 +45,25 @@ object ArrowWriter { private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() - ArrowUtils.fromArrowField(field) match { - case BooleanType => - new BooleanWriter(vector.asInstanceOf[NullableBitVector]) - case ByteType => - new ByteWriter(vector.asInstanceOf[NullableTinyIntVector]) - case ShortType => - new ShortWriter(vector.asInstanceOf[NullableSmallIntVector]) - case IntegerType => - new IntegerWriter(vector.asInstanceOf[NullableIntVector]) - case LongType => - new LongWriter(vector.asInstanceOf[NullableBigIntVector]) - case FloatType => - new FloatWriter(vector.asInstanceOf[NullableFloat4Vector]) - case DoubleType => - new DoubleWriter(vector.asInstanceOf[NullableFloat8Vector]) - case StringType => - new StringWriter(vector.asInstanceOf[NullableVarCharVector]) - case BinaryType => - new BinaryWriter(vector.asInstanceOf[NullableVarBinaryVector]) - case ArrayType(_, _) => - val v = vector.asInstanceOf[ListVector] - val elementVector = createFieldWriter(v.getDataVector()) - new ArrayWriter(v, elementVector) - case StructType(_) => - val v = vector.asInstanceOf[NullableMapVector] - val children = (0 until v.size()).map { ordinal => - createFieldWriter(v.getChildByOrdinal(ordinal)) + (ArrowUtils.fromArrowField(field), vector) match { + case (BooleanType, vector: NullableBitVector) => new BooleanWriter(vector) + case (ByteType, vector: NullableTinyIntVector) => new ByteWriter(vector) + case (ShortType, vector: NullableSmallIntVector) => new ShortWriter(vector) + case (IntegerType, vector: NullableIntVector) => new IntegerWriter(vector) + case (LongType, vector: NullableBigIntVector) => new LongWriter(vector) + case (FloatType, vector: NullableFloat4Vector) => new FloatWriter(vector) + case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) + case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) + case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) + case (ArrayType(_, _), vector: ListVector) => + val elementVector = createFieldWriter(vector.getDataVector()) + new ArrayWriter(vector, elementVector) + case (StructType(_), vector: NullableMapVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) } - new StructWriter(v, children.toArray) - case dt => + new StructWriter(vector, children.toArray) + case (dt, _) => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 4f32c374dce0..74235637549c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1073,12 +1073,12 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val a_arr = Seq(Seq(1, 2), Seq(3, 4), Seq(), Seq(5)) - val b_arr = Seq(Some(Seq(1, 2)), None, Some(Seq()), None) - val c_arr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5))) - val d_arr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5))) + val aArr = Seq(Seq(1, 2), Seq(3, 4), Seq(), Seq(5)) + val bArr = Seq(Some(Seq(1, 2)), None, Some(Seq()), None) + val cArr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5))) + val dArr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5))) - val df = a_arr.zip(b_arr).zip(c_arr).zip(d_arr).map { + val df = aArr.zip(bArr).zip(cArr).zip(dArr).map { case (((a, b), c), d) => (a, b, c, d) }.toDF("a_arr", "b_arr", "c_arr", "d_arr") @@ -1281,11 +1281,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val a_struct = Seq(Row(1), Row(2), Row(3)) - val b_struct = Seq(Row(1), null, Row(3)) - val c_struct = Seq(Row(1), Row(null), Row(3)) - val d_struct = Seq(Row(Row(1)), null, Row(null)) - val data = a_struct.zip(b_struct).zip(c_struct).zip(d_struct).map { + val aStruct = Seq(Row(1), Row(2), Row(3)) + val bStruct = Seq(Row(1), null, Row(3)) + val cStruct = Seq(Row(1), Row(null), Row(3)) + val dStruct = Seq(Row(Row(1)), null, Row(null)) + val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map { case (((a, b), c), d) => Row(a, b, c, d) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 7bf841a16166..e9a629315f5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { test("simple") { - def check(dt: DataType, data: Seq[Any], get: (ArrowColumnVector, Int) => Any): Unit = { + def check(dt: DataType, data: Seq[Any]): Unit = { val schema = new StructType().add("value", dt, nullable = true) val writer = ArrowWriter.create(schema) assert(writer.schema === schema) @@ -40,32 +40,36 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) data.zipWithIndex.foreach { case (null, rowId) => assert(reader.isNullAt(rowId)) - case (datum, rowId) => assert(get(reader, rowId) === datum) + case (datum, rowId) => + val value = dt match { + case BooleanType => reader.getBoolean(rowId) + case ByteType => reader.getByte(rowId) + case ShortType => reader.getShort(rowId) + case IntegerType => reader.getInt(rowId) + case LongType => reader.getLong(rowId) + case FloatType => reader.getFloat(rowId) + case DoubleType => reader.getDouble(rowId) + case StringType => reader.getUTF8String(rowId) + case BinaryType => reader.getBinary(rowId) + } + assert(value === datum) } writer.root.close() } - check(BooleanType, Seq(true, null, false), (reader, rowId) => reader.getBoolean(rowId)) - check(ByteType, - Seq(1.toByte, 2.toByte, null, 4.toByte), (reader, rowId) => reader.getByte(rowId)) - check(ShortType, - Seq(1.toShort, 2.toShort, null, 4.toShort), (reader, rowId) => reader.getShort(rowId)) - check(IntegerType, Seq(1, 2, null, 4), (reader, rowId) => reader.getInt(rowId)) - check(LongType, Seq(1L, 2L, null, 4L), (reader, rowId) => reader.getLong(rowId)) - check(FloatType, Seq(1.0f, 2.0f, null, 4.0f), (reader, rowId) => reader.getFloat(rowId)) - check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d), (reader, rowId) => reader.getDouble(rowId)) - - check(StringType, - Seq("a", "b", null, "d").map(UTF8String.fromString), - (reader, rowId) => reader.getUTF8String(rowId)) - - check(BinaryType, - Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()), - (reader, rowId) => reader.getBinary(rowId)) + check(BooleanType, Seq(true, null, false)) + check(ByteType, Seq(1.toByte, 2.toByte, null, 4.toByte)) + check(ShortType, Seq(1.toShort, 2.toShort, null, 4.toShort)) + check(IntegerType, Seq(1, 2, null, 4)) + check(LongType, Seq(1L, 2L, null, 4L)) + check(FloatType, Seq(1.0f, 2.0f, null, 4.0f)) + check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) + check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) + check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) } test("get multiple") { - def check[A](dt: DataType, data: Seq[A], get: (ArrowColumnVector, Int) => Seq[A]): Unit = { + def check(dt: DataType, data: Seq[Any]): Unit = { val schema = new StructType().add("value", dt, nullable = false) val writer = ArrowWriter.create(schema) assert(writer.schema === schema) @@ -76,17 +80,26 @@ class ArrowWriterSuite extends SparkFunSuite { writer.finish() val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - assert(get(reader, data.size) === data) + val values = dt match { + case BooleanType => reader.getBooleans(0, data.size) + case ByteType => reader.getBytes(0, data.size) + case ShortType => reader.getShorts(0, data.size) + case IntegerType => reader.getInts(0, data.size) + case LongType => reader.getLongs(0, data.size) + case FloatType => reader.getFloats(0, data.size) + case DoubleType => reader.getDoubles(0, data.size) + } + assert(values === data) writer.root.close() } - check(BooleanType, Seq(true, false), (reader, count) => reader.getBooleans(0, count)) - check(ByteType, (0 until 10).map(_.toByte), (reader, count) => reader.getBytes(0, count)) - check(ShortType, (0 until 10).map(_.toShort), (reader, count) => reader.getShorts(0, count)) - check(IntegerType, (0 until 10), (reader, count) => reader.getInts(0, count)) - check(LongType, (0 until 10).map(_.toLong), (reader, count) => reader.getLongs(0, count)) - check(FloatType, (0 until 10).map(_.toFloat), (reader, count) => reader.getFloats(0, count)) - check(DoubleType, (0 until 10).map(_.toDouble), (reader, count) => reader.getDoubles(0, count)) + check(BooleanType, Seq(true, false)) + check(ByteType, (0 until 10).map(_.toByte)) + check(ShortType, (0 until 10).map(_.toShort)) + check(IntegerType, (0 until 10)) + check(LongType, (0 until 10).map(_.toLong)) + check(FloatType, (0 until 10).map(_.toFloat)) + check(DoubleType, (0 until 10).map(_.toDouble)) } test("array") { From 0bac10d95637c1afa632210b5feca079a61a35d2 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 27 Jul 2017 16:32:03 +0900 Subject: [PATCH 14/15] Modify skip semantic. --- .../sql/execution/arrow/ArrowWriter.scala | 56 +------------------ .../arrow/ArrowConvertersSuite.scala | 6 +- 2 files changed, 6 insertions(+), 56 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 63a1e0cf5634..5649b3d4003f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -111,7 +111,6 @@ private[arrow] abstract class ArrowFieldWriter { def setNull(): Unit def setValue(input: SpecializedGetters, ordinal: Int): Unit - def skip(): Unit protected var count: Int = 0 @@ -124,8 +123,8 @@ private[arrow] abstract class ArrowFieldWriter { count += 1 } - def writeSkip(): Unit = { - skip() + def writeNull(): Unit = { + setNull() count += 1 } @@ -150,10 +149,6 @@ private[arrow] class BooleanWriter(val valueVector: NullableBitVector) extends A override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { valueMutator.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } } private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter { @@ -167,10 +162,6 @@ private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { valueMutator.setSafe(count, input.getByte(ordinal)) } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } } private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter { @@ -184,10 +175,6 @@ private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extend override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { valueMutator.setSafe(count, input.getShort(ordinal)) } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } } private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter { @@ -201,10 +188,6 @@ private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends A override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { valueMutator.setSafe(count, input.getInt(ordinal)) } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } } private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter { @@ -218,10 +201,6 @@ private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends A override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { valueMutator.setSafe(count, input.getLong(ordinal)) } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } } private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter { @@ -235,10 +214,6 @@ private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { valueMutator.setSafe(count, input.getFloat(ordinal)) } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } } private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter { @@ -252,10 +227,6 @@ private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { valueMutator.setSafe(count, input.getDouble(ordinal)) } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } } private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { @@ -271,10 +242,6 @@ private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extend // todo: for off-heap UTF8String, how to pass in to arrow without copy? valueMutator.setSafe(count, utf8.getByteBuffer, 0, utf8.numBytes()) } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } } private[arrow] class BinaryWriter( @@ -290,10 +257,6 @@ private[arrow] class BinaryWriter( val bytes = input.getBinary(ordinal) valueMutator.setSafe(count, bytes, 0, bytes.length) } - - override def skip(): Unit = { - valueMutator.setIndexDefined(count) - } } private[arrow] class ArrayWriter( @@ -316,10 +279,6 @@ private[arrow] class ArrayWriter( valueMutator.endValue(count, array.numElements()) } - override def skip(): Unit = { - valueMutator.setNotNull(count) - } - override def finish(): Unit = { super.finish() elementWriter.finish() @@ -340,7 +299,7 @@ private[arrow] class StructWriter( override def setNull(): Unit = { var i = 0 while (i < children.length) { - children(i).writeSkip() + children(i).writeNull() i += 1 } valueMutator.setNull(count) @@ -356,15 +315,6 @@ private[arrow] class StructWriter( valueMutator.setIndexDefined(count) } - override def skip(): Unit = { - var i = 0 - while (i < children.length) { - children(i).writeSkip() - i += 1 - } - valueMutator.setIndexDefined(count) - } - override def finish(): Unit = { super.finish() children.foreach(_.finish()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 74235637549c..4893b52f240e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1248,7 +1248,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "children" : [ { | "name" : "i", | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], + | "VALIDITY" : [ 1, 0, 1 ], | "DATA" : [ 1, 2, 3 ] | } ] | }, { @@ -1268,11 +1268,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "children" : [ { | "name" : "nested", | "count" : 3, - | "VALIDITY" : [ 1, 1, 0 ], + | "VALIDITY" : [ 1, 0, 0 ], | "children" : [ { | "name" : "i", | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], + | "VALIDITY" : [ 1, 0, 0 ], | "DATA" : [ 1, 2, 0 ] | } ] | } ] From b85dc231d05f5e1a1a3d8b0bcbc778b85d83c533 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 27 Jul 2017 16:54:20 +0900 Subject: [PATCH 15/15] Inline writeNull(). --- .../apache/spark/sql/execution/arrow/ArrowWriter.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 5649b3d4003f..11ba04d2ce9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -112,7 +112,7 @@ private[arrow] abstract class ArrowFieldWriter { def setNull(): Unit def setValue(input: SpecializedGetters, ordinal: Int): Unit - protected var count: Int = 0 + private[arrow] var count: Int = 0 def write(input: SpecializedGetters, ordinal: Int): Unit = { if (input.isNullAt(ordinal)) { @@ -123,11 +123,6 @@ private[arrow] abstract class ArrowFieldWriter { count += 1 } - def writeNull(): Unit = { - setNull() - count += 1 - } - def finish(): Unit = { valueMutator.setValueCount(count) } @@ -299,7 +294,8 @@ private[arrow] class StructWriter( override def setNull(): Unit = { var i = 0 while (i < children.length) { - children(i).writeNull() + children(i).setNull() + children(i).count += 1 i += 1 } valueMutator.setNull(count)