From 09af5a5851786b918f45c6f997b1c357745fe883 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 7 Jul 2016 19:36:14 +0900 Subject: [PATCH 1/4] support codegen for an array in CachedBatch --- .../vectorized/ByteBufferColumnVector.java | 381 ++++++++++++++++++ .../execution/vectorized/ColumnVector.java | 26 +- .../columnar/GenerateColumnAccessor.scala | 4 + .../execution/columnar/InMemoryRelation.scala | 26 +- .../columnar/InMemoryTableScanExec.scala | 88 ++++ 5 files changed, 522 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java new file mode 100644 index 000000000000..56ec159299a8 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.nio.ByteBuffer; + +import org.apache.commons.lang.NotImplementedException; + +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.execution.columnar.ByteBufferHelper; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; + +/** + * A column backed by an in memory JVM byte array. + * This stores the NULLs as a byte per value and a java array for the values. + * Currently, this column vector is read-only + */ +public final class ByteBufferColumnVector extends ColumnVector { + // The data stored in these arrays need to maintain binary compatible. We can + // directly pass this buffer to external components. + + // This is faster than a boolean array and we optimize this over memory footprint. + private byte[] nulls; + + // Array stored in byte array + private byte[] data; + private long offset; + + // Only set if type is Array. + private int[] arrayOffsets; + private UnsafeArrayData unsafeArray; + + protected ByteBufferColumnVector(int capacity, DataType type, + boolean isConstant, ByteBuffer buffer, ByteBuffer nullsBuffer) { + super(capacity, type); + if (this.resultArray != null) { + data = buffer.array(); + offset = Platform.BYTE_ARRAY_OFFSET + buffer.position(); + + unsafeArray = new UnsafeArrayData(); + arrayOffsets = new int[capacity]; + + byte[] dataNulls = nullsBuffer.array(); + int posNulls = Platform.BYTE_ARRAY_OFFSET + nullsBuffer.position(); + int numNulls = Platform.getInt(dataNulls, posNulls); + for (int i = 0; i < numNulls; i++) { + int cordinal = Platform.getInt(dataNulls, posNulls + 4 + i * 4); + arrayOffsets[cordinal] = -1; + } + + int pos = 0; + for (int i = 0; i < capacity; i++) { + if (arrayOffsets[i] < 0) continue; + arrayOffsets[i] = pos; + pos += Platform.getInt(data, offset + pos) + 4; + } + } else if (DecimalType.isByteArrayDecimalType(type)) { + throw new NotImplementedException(); + } else if ((type instanceof FloatType) || (type instanceof DoubleType)) { + data = buffer.array(); + offset = Platform.BYTE_ARRAY_OFFSET + buffer.position(); + } else if (resultStruct != null) { + // Nothing to store. + } else { + throw new RuntimeException("Unhandled " + type); + } + nulls = new byte[capacity]; + reset(); + + int numNulls = ByteBufferHelper.getInt(nullsBuffer); + for (int i = 0; i < numNulls; i++) { + int cordinal = ByteBufferHelper.getInt(nullsBuffer); + putNull(cordinal); + } + if (isConstant) { + setIsConstant(); + } + } + + @Override + public final long valuesNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + @Override + public final long nullsNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + + @Override + public final void close() { + } + + // + // APIs dealing with nulls + // + + @Override + public final void putNotNull(int rowId) { + nulls[rowId] = (byte)0; + } + + @Override + public final void putNull(int rowId) { + nulls[rowId] = (byte)1; + ++numNulls; + anyNullsSet = true; + } + + @Override + public final void putNulls(int rowId, int count) { + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)1; + } + anyNullsSet = true; + numNulls += count; + } + + @Override + public final void putNotNulls(int rowId, int count) { + if (!anyNullsSet) return; + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)0; + } + } + + @Override + public final boolean isNullAt(int rowId) { + return nulls[rowId] == 1; + } + + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + throw new NotImplementedException(); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + throw new NotImplementedException(); + } + + @Override + public final boolean getBoolean(int rowId) { + assert(dictionary == null); + return Platform.getBoolean(data, offset + rowId); + } + + // + // APIs dealing with Bytes + // + + @Override + public final void putByte(int rowId, byte value) { + throw new NotImplementedException(); + } + + @Override + public final void putBytes(int rowId, int count, byte value) { + throw new NotImplementedException(); + } + + @Override + public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final byte getByte(int rowId) { + assert(dictionary == null); + return Platform.getByte(data, offset + rowId); + } + + // + // APIs dealing with Shorts + // + + @Override + public final void putShort(int rowId, short value) { + throw new NotImplementedException(); + } + + @Override + public final void putShorts(int rowId, int count, short value) { + throw new NotImplementedException(); + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final short getShort(int rowId) { + assert(dictionary == null); + return Platform.getShort(data, offset + rowId * 2); + } + + // + // APIs dealing with Ints + // + + @Override + public final void putInt(int rowId, int value) { + throw new NotImplementedException(); + } + + @Override + public final void putInts(int rowId, int count, int value) { + throw new NotImplementedException(); + } + + @Override + public final void putInts(int rowId, int count, int[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final int getInt(int rowId) { + assert(dictionary == null); + return Platform.getInt(data, offset + rowId * 4); + } + + // + // APIs dealing with Longs + // + + @Override + public final void putLong(int rowId, long value) { + throw new NotImplementedException(); + } + + @Override + public final void putLongs(int rowId, int count, long value) { + throw new NotImplementedException(); + } + + @Override + public final void putLongs(int rowId, int count, long[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final long getLong(int rowId) { + throw new NotImplementedException(); + } + + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { + throw new NotImplementedException(); + } + + @Override + public final void putFloats(int rowId, int count, float value) { + throw new NotImplementedException(); + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final float getFloat(int rowId) { + assert(dictionary == null); + return Platform.getFloat(data, offset + rowId * 4); + } + + // + // APIs dealing with doubles + // + + @Override + public final void putDouble(int rowId, double value) { + throw new NotImplementedException(); + } + + @Override + public final void putDoubles(int rowId, int count, double value) { + throw new NotImplementedException(); + } + + @Override + public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public final double getDouble(int rowId) { + assert(dictionary == null); + return Platform.getDouble(data, offset + rowId * 8); + } + + // + // APIs dealing with Arrays + // + + @Override + public final ArrayData getArray(int rowId) { + int length = getArrayLength(rowId); + unsafeArray.pointTo(data, offset + arrayOffsets[rowId] + 4, length); + return unsafeArray; + } + + @Override + public final int getArrayLength(int rowId) { return Platform.getInt(data, offset + arrayOffsets[rowId]); } + @Override + public final int getArrayOffset(int rowId) { return arrayOffsets[rowId]; } + + @Override + public final void putArray(int rowId, int offset, int length) { + throw new NotImplementedException(); + } + + @Override + public final void loadBytes(ColumnVector.Array array) { + throw new NotImplementedException(); + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public final int putByteArray(int rowId, byte[] value, int offset, int length) { + throw new NotImplementedException(); + } + + @Override + public final void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); + } + + // Spilt this function out since it is the slow path. + @Override + protected void reserveInternal(int newCapacity) { + throw new NotImplementedException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index bbbb796aca0d..16e877c8e12a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -19,6 +19,8 @@ import java.math.BigDecimal; import java.math.BigInteger; +import org.apache.commons.lang.NotImplementedException; + import com.google.common.annotations.VisibleForTesting; import org.apache.parquet.column.Dictionary; import org.apache.parquet.io.api.Binary; @@ -529,7 +531,7 @@ public ColumnarBatch.Row getStruct(int rowId, int size) { /** * Returns the array at rowid. */ - public final Array getArray(int rowId) { + public ArrayData getArray(int rowId) { resultArray.length = getArrayLength(rowId); resultArray.offset = getArrayOffset(rowId); return resultArray; @@ -552,7 +554,7 @@ public final int putByteArray(int rowId, byte[] value) { * Returns the value for rowId. */ private Array getByteArray(int rowId) { - Array array = getArray(rowId); + Array array = (Array)getArray(rowId); array.data.loadBytes(array); return array; } @@ -1002,4 +1004,24 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.resultStruct = null; } } + + protected ColumnVector(int capacity, DataType type) { + this.capacity = capacity; + this.type = type; + + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType + || DecimalType.isByteArrayDecimalType(type)) { + this.childColumns = null; + this.resultArray = new Array(null); + this.resultStruct = null; + } else if (type instanceof StructType) { + throw new NotImplementedException(); + } else if (type instanceof CalendarIntervalType) { + throw new NotImplementedException(); + } else { + this.childColumns = null; + this.resultArray = null; + this.resultStruct = null; + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 7a14879b8b9d..1378ea6819a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -95,6 +95,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + case ArrayType(_, _) => + _isSupportColumnarCodeGen = true + s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), + (${dt.getClass.getName}) columnTypes[$index]);""" case other => s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), (${dt.getClass.getName}) columnTypes[$index]);""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 079e122a5a85..7df04a6a15ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -53,7 +53,31 @@ private[sql] object InMemoryRelation { * @param stats The stat of columns */ private[columnar] -case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) +case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) { + def column(columnarIterator: ColumnarIterator, index: Int): ColumnVector = { + val ordinal = columnarIterator.getColumnIndexes(index) + val dataType = columnarIterator.getColumnTypes(index) + val buffer = ByteBuffer.wrap(buffers(ordinal)).order(nativeOrder) + val accessor: BasicColumnAccessor[_] = dataType match { + case FloatType => new FloatColumnAccessor(buffer) + case DoubleType => new DoubleColumnAccessor(buffer) + case arrayType: ArrayType => new ArrayColumnAccessor(buffer, arrayType) + case _ => throw new UnsupportedOperationException(s"CachedBatch.column(): $dataType") + } + + val (out, nullsBuffer) = if (accessor.isInstanceOf[NativeColumnAccessor[_]]) { + val nativeAccessor = accessor.asInstanceOf[NativeColumnAccessor[_]] + nativeAccessor.decompress(numRows); + } else { + val buffer = accessor.getByteBuffer + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + (buffer, nullsBuffer) + } + + ColumnVector.allocate(numRows, dataType, true, out, nullsBuffer) + } +} private[sql] case class InMemoryRelation( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 67a410f539b6..274e100fe7a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -180,3 +180,91 @@ private[sql] case class InMemoryTableScanExec( } } } + +private[sql] object InMemoryTableScanExec { + private val columnarItrName = "columnar_itr" + private val columnarBatchIdxName = "columnar_batchIdx" + + def enableColumnCodeGen( + sqlContext: SQLContext, ctx: CodegenContext, child: SparkPlan): Boolean = { + ctx.enableColumnCodeGen && + sqlContext.getConf(SQLConf.COLUMN_VECTOR_CODEGEN.key).toBoolean && + child.find(c => c.isInstanceOf[InMemoryTableScanExec]).isDefined && + child.find(c => c.isInstanceOf[CodegenSupport] && + c.asInstanceOf[CodegenSupport].useUnsafeRow).isEmpty + } + + def produceColumnLoop( + ctx: CodegenContext, codegen: CodegenSupport, output: Seq[Attribute]): String = { + val idx = columnarBatchIdxName + val numRows = "columnar_numRows" + ctx.addMutableState("int", idx, s"$idx = 0;") + ctx.addMutableState("int", numRows, s"$numRows = 0;") + val rowidx = ctx.freshName("rowIdx") + + val colVars = output.indices.map(i => ctx.freshName("col" + i)) + val columnAssigns = colVars.zipWithIndex.map { case (name, i) => + ctx.addMutableState("org.apache.spark.sql.execution.vectorized.ColumnVector", + name, s"$name = null;", s"$name = null;") + s"$name = ${columnarItrName}.getColumn($i);" + } + val columns = (output zip colVars).map { case (attr, colVar) => + new ColumnVectorReference(colVar, rowidx, attr.dataType, attr.nullable).genCode(ctx) } + + s""" + | while (true) { + | if ($idx == 0) { + | $numRows = ${columnarItrName}.initForColumnar(); + | if ($numRows < 0) { + | cleanup(); + | break; + | } + | ${columnAssigns.mkString("", "\n", "")} + | } + | + | while ($idx < $numRows) { + | int $rowidx = $idx++; + | System.out.println("rowIdx="+$rowidx); + | ${codegen.consume(ctx, columns, null).trim} + | if (shouldStop()) return; + | } + | $idx = 0; + | } + """.stripMargin + } + + def produceProcessNext( + ctx: CodegenContext, codegen: CodegenSupport, child: SparkPlan, codeRow: String): String = { + ctx.isRow = false + val codeCol = child.asInstanceOf[CodegenSupport].produce(ctx, codegen) + val columnarItrClz = "org.apache.spark.sql.execution.columnar.ColumnarIterator" + val colItr = columnarItrName + ctx.addMutableState(s"$columnarItrClz", colItr, s"$colItr = null;", s"$colItr = null;") + + s""" + private void processBatch() throws java.io.IOException { + System.out.println("*** processBatch() ***"); + ${codeCol.trim} + } + + private void processRow() throws java.io.IOException { + System.out.println("*** processRow() ***"); + ${codeRow.trim} + } + + private void cleanup() { + ${ctx.cleanupMutableStates()} + } + + protected void processNext() throws java.io.IOException { + if ((${columnarBatchIdxName} != 0) || + (${ctx.iteratorInput} instanceof $columnarItrClz && + ($colItr = ($columnarItrClz)${ctx.iteratorInput}).isSupportColumnarCodeGen())) { + processBatch(); + } else { + processRow(); + } + } + """.trim + } +} From 8e218e38d5acb6c04db221fcd3cd6d2483926552 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 7 Jul 2016 19:36:34 +0900 Subject: [PATCH 2/4] update test suites --- .../sql/execution/DataFrameCacheSuite.scala | 116 ++++++++++++++++++ .../vectorized/ColumnarBatchSuite.scala | 23 ++-- 2 files changed, 130 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/DataFrameCacheSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataFrameCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataFrameCacheSuite.scala new file mode 100644 index 000000000000..62b56bf88218 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataFrameCacheSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameCacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("range/filter should be combined with column codegen") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => i.toFloat).toDF().cache() + .filter("value = 1").selectExpr("value + 1") + assert(df.collect() === Array(Row(2.0))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen).isDefined) + } + + test("filters should be combined with column codegen") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => i.toFloat).toDF().cache() + .filter("value % 2.0 == 0").filter("value % 3.0 == 0") + assert(df.collect() === Array(Row(0), Row(6.0))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen).isDefined) + } + + test("filter with null should be included in WholeStageCodegen with column codegen") { + val toFloat = udf[java.lang.Float, String] { s => if (s == "2") null else s.toFloat } + val df0 = sparkContext.parallelize(0 to 4, 1).map(i => i.toString).toDF() + val df = df0.withColumn("i", toFloat(df0("value"))).select("i").toDF().cache() + .filter("i % 2.0 == 0") + assert(df.collect() === Array(Row(0), Row(4.0))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen).isDefined) + } + + test("Aggregate should be included in WholeStageCodegen with column codegen") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => i.toFloat).toDF().cache() + .groupBy().agg(max(col("value")), avg(col("value"))) + assert(df.collect() === Array(Row(9, 4.5))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + } + + test("Aggregate with grouping keys should be included in WholeStageCodegen with column codegen") { + val df = sparkContext.parallelize(0 to 2, 1).map(i => i.toFloat).toDF().cache() + .groupBy("value").count().orderBy("value") + assert(df.collect() === Array(Row(0.0, 1), Row(1.0, 1), Row(2.0, 1))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + } + + test("Aggregate with columns should be included in WholeStageCodegen with column codegen") { + val df = sparkContext.parallelize(0 to 10, 1).map(i => (i, (i * 2).toDouble)).toDF("i", "d") + .cache().agg(sum("d")) + assert(df.collect() === Array(Row(110.0))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + } + + test("Sort should be included in WholeStageCodegen without column codegen") { + val df = sparkContext.parallelize(Seq(3.toFloat, 2.toFloat, 1.toFloat), 1).toDF() + .sort(col("value")) + val plan = df.queryExecution.executedPlan + assert(df.collect() === Array(Row(1.0), Row(2.0), Row(3.0))) + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + !p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined) + } + + test("filter/selectExpr should be combined with column codegen for int array") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => Array(i, i + 1)).toDF().cache() + .filter("value[0] > 4").selectExpr("value[1] + 1") + assert(df.collect() === Array(Row(7), Row(8), Row(9), Row(10), Row(11))) + } + + test("filter/selectExpr should be combined with column codegen for double array") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => Array(i.toDouble, (i + 1).toDouble)) + .toDF().cache().filter("value[0] > 4.0").selectExpr("value[1] + 1") + assert(df.collect() === Array(Row(7.0), Row(8.0), Row(9.0), Row(10.0), Row(11.0))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 100cc4daca87..ebcc33713bb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -401,26 +401,30 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(2, 2, 0) column.putArray(3, 3, 3) - val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] - val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] - val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + val a1 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(0).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a2 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(1).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a3 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(2).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a4 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(3).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) // Verify the ArrayData APIs - assert(column.getArray(0).length == 1) + assert(column.getArray(0).numElements() == 1) assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).length == 2) + assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).length == 0) + assert(column.getArray(2).numElements() == 0) - assert(column.getArray(3).length == 3) + assert(column.getArray(3).numElements() == 3) assert(column.getArray(3).getInt(0) == 3) assert(column.getArray(3).getInt(1) == 4) assert(column.getArray(3).getInt(2) == 5) @@ -433,7 +437,8 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + assert(ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(0).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] === array) }} } From 54df41c8691f02dd9eac3eef3d816a130b87a5c9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 7 Jul 2016 09:18:58 -0400 Subject: [PATCH 3/4] remove debug print --- .../spark/sql/execution/columnar/InMemoryTableScanExec.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 274e100fe7a2..d0e6696a3081 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -224,7 +224,6 @@ private[sql] object InMemoryTableScanExec { | | while ($idx < $numRows) { | int $rowidx = $idx++; - | System.out.println("rowIdx="+$rowidx); | ${codegen.consume(ctx, columns, null).trim} | if (shouldStop()) return; | } @@ -243,12 +242,10 @@ private[sql] object InMemoryTableScanExec { s""" private void processBatch() throws java.io.IOException { - System.out.println("*** processBatch() ***"); ${codeCol.trim} } private void processRow() throws java.io.IOException { - System.out.println("*** processRow() ***"); ${codeRow.trim} } From 61a4754b898755e293a741fa74518d6e76c5c538 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 8 Aug 2016 15:55:49 +0900 Subject: [PATCH 4/4] alleviate performance overhead in constructor remove duplicated final attribute at method declaration --- .../vectorized/ByteBufferColumnVector.java | 181 +++++++++++------- 1 file changed, 110 insertions(+), 71 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java index 56ec159299a8..f5732debc96f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java @@ -43,33 +43,22 @@ public final class ByteBufferColumnVector extends ColumnVector { private long offset; // Only set if type is Array. - private int[] arrayOffsets; + private int lastArrayRow; + private int lastArrayPos; private UnsafeArrayData unsafeArray; protected ByteBufferColumnVector(int capacity, DataType type, boolean isConstant, ByteBuffer buffer, ByteBuffer nullsBuffer) { super(capacity, type); + boolean containsNull = true; if (this.resultArray != null) { + containsNull = ((ArrayType)type).containsNull(); data = buffer.array(); offset = Platform.BYTE_ARRAY_OFFSET + buffer.position(); unsafeArray = new UnsafeArrayData(); - arrayOffsets = new int[capacity]; - - byte[] dataNulls = nullsBuffer.array(); - int posNulls = Platform.BYTE_ARRAY_OFFSET + nullsBuffer.position(); - int numNulls = Platform.getInt(dataNulls, posNulls); - for (int i = 0; i < numNulls; i++) { - int cordinal = Platform.getInt(dataNulls, posNulls + 4 + i * 4); - arrayOffsets[cordinal] = -1; - } - - int pos = 0; - for (int i = 0; i < capacity; i++) { - if (arrayOffsets[i] < 0) continue; - arrayOffsets[i] = pos; - pos += Platform.getInt(data, offset + pos) + 4; - } + lastArrayPos = 0; + lastArrayRow = 0; } else if (DecimalType.isByteArrayDecimalType(type)) { throw new NotImplementedException(); } else if ((type instanceof FloatType) || (type instanceof DoubleType)) { @@ -80,13 +69,17 @@ protected ByteBufferColumnVector(int capacity, DataType type, } else { throw new RuntimeException("Unhandled " + type); } - nulls = new byte[capacity]; + if (containsNull) { + nulls = new byte[capacity]; + } reset(); - int numNulls = ByteBufferHelper.getInt(nullsBuffer); - for (int i = 0; i < numNulls; i++) { - int cordinal = ByteBufferHelper.getInt(nullsBuffer); - putNull(cordinal); + if (containsNull) { + int numNulls = ByteBufferHelper.getInt(nullsBuffer); + for (int i = 0; i < numNulls; i++) { + int cordinal = ByteBufferHelper.getInt(nullsBuffer); + putNull(cordinal); + } } if (isConstant) { setIsConstant(); @@ -94,16 +87,16 @@ protected ByteBufferColumnVector(int capacity, DataType type, } @Override - public final long valuesNativeAddress() { + public long valuesNativeAddress() { throw new RuntimeException("Cannot get native address for on heap column"); } @Override - public final long nullsNativeAddress() { + public long nullsNativeAddress() { throw new RuntimeException("Cannot get native address for on heap column"); } @Override - public final void close() { + public void close() { } // @@ -111,19 +104,19 @@ public final void close() { // @Override - public final void putNotNull(int rowId) { + public void putNotNull(int rowId) { nulls[rowId] = (byte)0; } @Override - public final void putNull(int rowId) { + public void putNull(int rowId) { nulls[rowId] = (byte)1; ++numNulls; anyNullsSet = true; } @Override - public final void putNulls(int rowId, int count) { + public void putNulls(int rowId, int count) { for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)1; } @@ -132,7 +125,7 @@ public final void putNulls(int rowId, int count) { } @Override - public final void putNotNulls(int rowId, int count) { + public void putNotNulls(int rowId, int count) { if (!anyNullsSet) return; for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)0; @@ -140,7 +133,8 @@ public final void putNotNulls(int rowId, int count) { } @Override - public final boolean isNullAt(int rowId) { + public boolean isNullAt(int rowId) { + if (nulls == null) return false; return nulls[rowId] == 1; } @@ -149,17 +143,17 @@ public final boolean isNullAt(int rowId) { // @Override - public final void putBoolean(int rowId, boolean value) { + public void putBoolean(int rowId, boolean value) { throw new NotImplementedException(); } @Override - public final void putBooleans(int rowId, int count, boolean value) { + public void putBooleans(int rowId, int count, boolean value) { throw new NotImplementedException(); } @Override - public final boolean getBoolean(int rowId) { + public boolean getBoolean(int rowId) { assert(dictionary == null); return Platform.getBoolean(data, offset + rowId); } @@ -169,22 +163,22 @@ public final boolean getBoolean(int rowId) { // @Override - public final void putByte(int rowId, byte value) { + public void putByte(int rowId, byte value) { throw new NotImplementedException(); } @Override - public final void putBytes(int rowId, int count, byte value) { + public void putBytes(int rowId, int count, byte value) { throw new NotImplementedException(); } @Override - public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final byte getByte(int rowId) { + public byte getByte(int rowId) { assert(dictionary == null); return Platform.getByte(data, offset + rowId); } @@ -194,22 +188,22 @@ public final byte getByte(int rowId) { // @Override - public final void putShort(int rowId, short value) { + public void putShort(int rowId, short value) { throw new NotImplementedException(); } @Override - public final void putShorts(int rowId, int count, short value) { + public void putShorts(int rowId, int count, short value) { throw new NotImplementedException(); } @Override - public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + public void putShorts(int rowId, int count, short[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final short getShort(int rowId) { + public short getShort(int rowId) { assert(dictionary == null); return Platform.getShort(data, offset + rowId * 2); } @@ -219,27 +213,27 @@ public final short getShort(int rowId) { // @Override - public final void putInt(int rowId, int value) { + public void putInt(int rowId, int value) { throw new NotImplementedException(); } @Override - public final void putInts(int rowId, int count, int value) { + public void putInts(int rowId, int count, int value) { throw new NotImplementedException(); } @Override - public final void putInts(int rowId, int count, int[] src, int srcIndex) { + public void putInts(int rowId, int count, int[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final int getInt(int rowId) { + public int getInt(int rowId) { assert(dictionary == null); return Platform.getInt(data, offset + rowId * 4); } @@ -249,27 +243,27 @@ public final int getInt(int rowId) { // @Override - public final void putLong(int rowId, long value) { + public void putLong(int rowId, long value) { throw new NotImplementedException(); } @Override - public final void putLongs(int rowId, int count, long value) { + public void putLongs(int rowId, int count, long value) { throw new NotImplementedException(); } @Override - public final void putLongs(int rowId, int count, long[] src, int srcIndex) { + public void putLongs(int rowId, int count, long[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final long getLong(int rowId) { + public long getLong(int rowId) { throw new NotImplementedException(); } @@ -278,27 +272,27 @@ public final long getLong(int rowId) { // @Override - public final void putFloat(int rowId, float value) { + public void putFloat(int rowId, float value) { throw new NotImplementedException(); } @Override - public final void putFloats(int rowId, int count, float value) { + public void putFloats(int rowId, int count, float value) { throw new NotImplementedException(); } @Override - public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + public void putFloats(int rowId, int count, float[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final float getFloat(int rowId) { + public float getFloat(int rowId) { assert(dictionary == null); return Platform.getFloat(data, offset + rowId * 4); } @@ -308,27 +302,27 @@ public final float getFloat(int rowId) { // @Override - public final void putDouble(int rowId, double value) { + public void putDouble(int rowId, double value) { throw new NotImplementedException(); } @Override - public final void putDoubles(int rowId, int count, double value) { + public void putDoubles(int rowId, int count, double value) { throw new NotImplementedException(); } @Override - public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { throw new NotImplementedException(); } @Override - public final double getDouble(int rowId) { + public double getDouble(int rowId) { assert(dictionary == null); return Platform.getDouble(data, offset + rowId * 8); } @@ -337,25 +331,70 @@ public final double getDouble(int rowId) { // APIs dealing with Arrays // + private void updateLastArrayPos(int rowId) { + int relative = rowId - lastArrayRow; + lastArrayRow = rowId; + + if (relative == 1) { + int totalBytesLastArray = Platform.getInt(data, offset + lastArrayPos); + lastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + } else if (relative == 0) { + // return the same position + return; + } else if (relative > 0) { + for (int i = 0; i < relative; i++) { + int totalBytesLastArray = Platform.getInt(data, offset + lastArrayPos); + lastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + } + } else { + // recalculate pos from the first Array entry + lastArrayPos = 0; + for (int i = 0; i < rowId; i++) { + int totalBytesLastArray = Platform.getInt(data, offset + lastArrayPos); + lastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + } + } + } + @Override - public final ArrayData getArray(int rowId) { - int length = getArrayLength(rowId); - unsafeArray.pointTo(data, offset + arrayOffsets[rowId] + 4, length); - return unsafeArray; + public ArrayData getArray(int rowId) { + if (rowId - lastArrayRow == 1) { + lastArrayRow = rowId; + long localOffset = offset; + int localLastArrayPos = lastArrayPos; + int totalBytesLastArray = Platform.getInt(data, localOffset + localLastArrayPos); + localLastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + int length = Platform.getInt(data, localOffset + localLastArrayPos); + unsafeArray.pointTo(data, localOffset + localLastArrayPos + 4, length); + lastArrayPos = localLastArrayPos; + return unsafeArray; + } else { + updateLastArrayPos(rowId); + int length = Platform.getInt(data, offset + lastArrayPos); // inline getArrayLength() + unsafeArray.pointTo(data, offset + lastArrayPos + 4, length); + return unsafeArray; + } } @Override - public final int getArrayLength(int rowId) { return Platform.getInt(data, offset + arrayOffsets[rowId]); } + public int getArrayLength(int rowId) { + updateLastArrayPos(rowId); + return Platform.getInt(data, offset + lastArrayPos); + } + @Override - public final int getArrayOffset(int rowId) { return arrayOffsets[rowId]; } + public int getArrayOffset(int rowId) { + updateLastArrayPos(rowId); + return lastArrayPos; + } @Override - public final void putArray(int rowId, int offset, int length) { + public void putArray(int rowId, int offset, int length) { throw new NotImplementedException(); } @Override - public final void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnVector.Array array) { throw new NotImplementedException(); } @@ -364,12 +403,12 @@ public final void loadBytes(ColumnVector.Array array) { // @Override - public final int putByteArray(int rowId, byte[] value, int offset, int length) { + public int putByteArray(int rowId, byte[] value, int offset, int length) { throw new NotImplementedException(); } @Override - public final void reserve(int requiredCapacity) { + public void reserve(int requiredCapacity) { if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); }