diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java new file mode 100644 index 000000000000..f5732debc96f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ByteBufferColumnVector.java @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.nio.ByteBuffer; + +import org.apache.commons.lang.NotImplementedException; + +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.execution.columnar.ByteBufferHelper; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; + +/** + * A column backed by an in memory JVM byte array. + * This stores the NULLs as a byte per value and a java array for the values. + * Currently, this column vector is read-only + */ +public final class ByteBufferColumnVector extends ColumnVector { + // The data stored in these arrays need to maintain binary compatible. We can + // directly pass this buffer to external components. + + // This is faster than a boolean array and we optimize this over memory footprint. + private byte[] nulls; + + // Array stored in byte array + private byte[] data; + private long offset; + + // Only set if type is Array. + private int lastArrayRow; + private int lastArrayPos; + private UnsafeArrayData unsafeArray; + + protected ByteBufferColumnVector(int capacity, DataType type, + boolean isConstant, ByteBuffer buffer, ByteBuffer nullsBuffer) { + super(capacity, type); + boolean containsNull = true; + if (this.resultArray != null) { + containsNull = ((ArrayType)type).containsNull(); + data = buffer.array(); + offset = Platform.BYTE_ARRAY_OFFSET + buffer.position(); + + unsafeArray = new UnsafeArrayData(); + lastArrayPos = 0; + lastArrayRow = 0; + } else if (DecimalType.isByteArrayDecimalType(type)) { + throw new NotImplementedException(); + } else if ((type instanceof FloatType) || (type instanceof DoubleType)) { + data = buffer.array(); + offset = Platform.BYTE_ARRAY_OFFSET + buffer.position(); + } else if (resultStruct != null) { + // Nothing to store. + } else { + throw new RuntimeException("Unhandled " + type); + } + if (containsNull) { + nulls = new byte[capacity]; + } + reset(); + + if (containsNull) { + int numNulls = ByteBufferHelper.getInt(nullsBuffer); + for (int i = 0; i < numNulls; i++) { + int cordinal = ByteBufferHelper.getInt(nullsBuffer); + putNull(cordinal); + } + } + if (isConstant) { + setIsConstant(); + } + } + + @Override + public long valuesNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + @Override + public long nullsNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + + @Override + public void close() { + } + + // + // APIs dealing with nulls + // + + @Override + public void putNotNull(int rowId) { + nulls[rowId] = (byte)0; + } + + @Override + public void putNull(int rowId) { + nulls[rowId] = (byte)1; + ++numNulls; + anyNullsSet = true; + } + + @Override + public void putNulls(int rowId, int count) { + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)1; + } + anyNullsSet = true; + numNulls += count; + } + + @Override + public void putNotNulls(int rowId, int count) { + if (!anyNullsSet) return; + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)0; + } + } + + @Override + public boolean isNullAt(int rowId) { + if (nulls == null) return false; + return nulls[rowId] == 1; + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + throw new NotImplementedException(); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + throw new NotImplementedException(); + } + + @Override + public boolean getBoolean(int rowId) { + assert(dictionary == null); + return Platform.getBoolean(data, offset + rowId); + } + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + throw new NotImplementedException(); + } + + @Override + public void putBytes(int rowId, int count, byte value) { + throw new NotImplementedException(); + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public byte getByte(int rowId) { + assert(dictionary == null); + return Platform.getByte(data, offset + rowId); + } + + // + // APIs dealing with Shorts + // + + @Override + public void putShort(int rowId, short value) { + throw new NotImplementedException(); + } + + @Override + public void putShorts(int rowId, int count, short value) { + throw new NotImplementedException(); + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public short getShort(int rowId) { + assert(dictionary == null); + return Platform.getShort(data, offset + rowId * 2); + } + + // + // APIs dealing with Ints + // + + @Override + public void putInt(int rowId, int value) { + throw new NotImplementedException(); + } + + @Override + public void putInts(int rowId, int count, int value) { + throw new NotImplementedException(); + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public int getInt(int rowId) { + assert(dictionary == null); + return Platform.getInt(data, offset + rowId * 4); + } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + throw new NotImplementedException(); + } + + @Override + public void putLongs(int rowId, int count, long value) { + throw new NotImplementedException(); + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public long getLong(int rowId) { + throw new NotImplementedException(); + } + + // + // APIs dealing with floats + // + + @Override + public void putFloat(int rowId, float value) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, float value) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public float getFloat(int rowId) { + assert(dictionary == null); + return Platform.getFloat(data, offset + rowId * 4); + } + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, double value) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public double getDouble(int rowId) { + assert(dictionary == null); + return Platform.getDouble(data, offset + rowId * 8); + } + + // + // APIs dealing with Arrays + // + + private void updateLastArrayPos(int rowId) { + int relative = rowId - lastArrayRow; + lastArrayRow = rowId; + + if (relative == 1) { + int totalBytesLastArray = Platform.getInt(data, offset + lastArrayPos); + lastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + } else if (relative == 0) { + // return the same position + return; + } else if (relative > 0) { + for (int i = 0; i < relative; i++) { + int totalBytesLastArray = Platform.getInt(data, offset + lastArrayPos); + lastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + } + } else { + // recalculate pos from the first Array entry + lastArrayPos = 0; + for (int i = 0; i < rowId; i++) { + int totalBytesLastArray = Platform.getInt(data, offset + lastArrayPos); + lastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + } + } + } + + @Override + public ArrayData getArray(int rowId) { + if (rowId - lastArrayRow == 1) { + lastArrayRow = rowId; + long localOffset = offset; + int localLastArrayPos = lastArrayPos; + int totalBytesLastArray = Platform.getInt(data, localOffset + localLastArrayPos); + localLastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + int length = Platform.getInt(data, localOffset + localLastArrayPos); + unsafeArray.pointTo(data, localOffset + localLastArrayPos + 4, length); + lastArrayPos = localLastArrayPos; + return unsafeArray; + } else { + updateLastArrayPos(rowId); + int length = Platform.getInt(data, offset + lastArrayPos); // inline getArrayLength() + unsafeArray.pointTo(data, offset + lastArrayPos + 4, length); + return unsafeArray; + } + } + + @Override + public int getArrayLength(int rowId) { + updateLastArrayPos(rowId); + return Platform.getInt(data, offset + lastArrayPos); + } + + @Override + public int getArrayOffset(int rowId) { + updateLastArrayPos(rowId); + return lastArrayPos; + } + + @Override + public void putArray(int rowId, int offset, int length) { + throw new NotImplementedException(); + } + + @Override + public void loadBytes(ColumnVector.Array array) { + throw new NotImplementedException(); + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + throw new NotImplementedException(); + } + + @Override + public void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); + } + + // Spilt this function out since it is the slow path. + @Override + protected void reserveInternal(int newCapacity) { + throw new NotImplementedException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index bbbb796aca0d..16e877c8e12a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -19,6 +19,8 @@ import java.math.BigDecimal; import java.math.BigInteger; +import org.apache.commons.lang.NotImplementedException; + import com.google.common.annotations.VisibleForTesting; import org.apache.parquet.column.Dictionary; import org.apache.parquet.io.api.Binary; @@ -529,7 +531,7 @@ public ColumnarBatch.Row getStruct(int rowId, int size) { /** * Returns the array at rowid. */ - public final Array getArray(int rowId) { + public ArrayData getArray(int rowId) { resultArray.length = getArrayLength(rowId); resultArray.offset = getArrayOffset(rowId); return resultArray; @@ -552,7 +554,7 @@ public final int putByteArray(int rowId, byte[] value) { * Returns the value for rowId. */ private Array getByteArray(int rowId) { - Array array = getArray(rowId); + Array array = (Array)getArray(rowId); array.data.loadBytes(array); return array; } @@ -1002,4 +1004,24 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.resultStruct = null; } } + + protected ColumnVector(int capacity, DataType type) { + this.capacity = capacity; + this.type = type; + + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType + || DecimalType.isByteArrayDecimalType(type)) { + this.childColumns = null; + this.resultArray = new Array(null); + this.resultStruct = null; + } else if (type instanceof StructType) { + throw new NotImplementedException(); + } else if (type instanceof CalendarIntervalType) { + throw new NotImplementedException(); + } else { + this.childColumns = null; + this.resultArray = null; + this.resultStruct = null; + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 7a14879b8b9d..1378ea6819a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -95,6 +95,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + case ArrayType(_, _) => + _isSupportColumnarCodeGen = true + s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), + (${dt.getClass.getName}) columnTypes[$index]);""" case other => s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), (${dt.getClass.getName}) columnTypes[$index]);""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 079e122a5a85..7df04a6a15ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -53,7 +53,31 @@ private[sql] object InMemoryRelation { * @param stats The stat of columns */ private[columnar] -case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) +case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) { + def column(columnarIterator: ColumnarIterator, index: Int): ColumnVector = { + val ordinal = columnarIterator.getColumnIndexes(index) + val dataType = columnarIterator.getColumnTypes(index) + val buffer = ByteBuffer.wrap(buffers(ordinal)).order(nativeOrder) + val accessor: BasicColumnAccessor[_] = dataType match { + case FloatType => new FloatColumnAccessor(buffer) + case DoubleType => new DoubleColumnAccessor(buffer) + case arrayType: ArrayType => new ArrayColumnAccessor(buffer, arrayType) + case _ => throw new UnsupportedOperationException(s"CachedBatch.column(): $dataType") + } + + val (out, nullsBuffer) = if (accessor.isInstanceOf[NativeColumnAccessor[_]]) { + val nativeAccessor = accessor.asInstanceOf[NativeColumnAccessor[_]] + nativeAccessor.decompress(numRows); + } else { + val buffer = accessor.getByteBuffer + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + (buffer, nullsBuffer) + } + + ColumnVector.allocate(numRows, dataType, true, out, nullsBuffer) + } +} private[sql] case class InMemoryRelation( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 67a410f539b6..d0e6696a3081 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -180,3 +180,88 @@ private[sql] case class InMemoryTableScanExec( } } } + +private[sql] object InMemoryTableScanExec { + private val columnarItrName = "columnar_itr" + private val columnarBatchIdxName = "columnar_batchIdx" + + def enableColumnCodeGen( + sqlContext: SQLContext, ctx: CodegenContext, child: SparkPlan): Boolean = { + ctx.enableColumnCodeGen && + sqlContext.getConf(SQLConf.COLUMN_VECTOR_CODEGEN.key).toBoolean && + child.find(c => c.isInstanceOf[InMemoryTableScanExec]).isDefined && + child.find(c => c.isInstanceOf[CodegenSupport] && + c.asInstanceOf[CodegenSupport].useUnsafeRow).isEmpty + } + + def produceColumnLoop( + ctx: CodegenContext, codegen: CodegenSupport, output: Seq[Attribute]): String = { + val idx = columnarBatchIdxName + val numRows = "columnar_numRows" + ctx.addMutableState("int", idx, s"$idx = 0;") + ctx.addMutableState("int", numRows, s"$numRows = 0;") + val rowidx = ctx.freshName("rowIdx") + + val colVars = output.indices.map(i => ctx.freshName("col" + i)) + val columnAssigns = colVars.zipWithIndex.map { case (name, i) => + ctx.addMutableState("org.apache.spark.sql.execution.vectorized.ColumnVector", + name, s"$name = null;", s"$name = null;") + s"$name = ${columnarItrName}.getColumn($i);" + } + val columns = (output zip colVars).map { case (attr, colVar) => + new ColumnVectorReference(colVar, rowidx, attr.dataType, attr.nullable).genCode(ctx) } + + s""" + | while (true) { + | if ($idx == 0) { + | $numRows = ${columnarItrName}.initForColumnar(); + | if ($numRows < 0) { + | cleanup(); + | break; + | } + | ${columnAssigns.mkString("", "\n", "")} + | } + | + | while ($idx < $numRows) { + | int $rowidx = $idx++; + | ${codegen.consume(ctx, columns, null).trim} + | if (shouldStop()) return; + | } + | $idx = 0; + | } + """.stripMargin + } + + def produceProcessNext( + ctx: CodegenContext, codegen: CodegenSupport, child: SparkPlan, codeRow: String): String = { + ctx.isRow = false + val codeCol = child.asInstanceOf[CodegenSupport].produce(ctx, codegen) + val columnarItrClz = "org.apache.spark.sql.execution.columnar.ColumnarIterator" + val colItr = columnarItrName + ctx.addMutableState(s"$columnarItrClz", colItr, s"$colItr = null;", s"$colItr = null;") + + s""" + private void processBatch() throws java.io.IOException { + ${codeCol.trim} + } + + private void processRow() throws java.io.IOException { + ${codeRow.trim} + } + + private void cleanup() { + ${ctx.cleanupMutableStates()} + } + + protected void processNext() throws java.io.IOException { + if ((${columnarBatchIdxName} != 0) || + (${ctx.iteratorInput} instanceof $columnarItrClz && + ($colItr = ($columnarItrClz)${ctx.iteratorInput}).isSupportColumnarCodeGen())) { + processBatch(); + } else { + processRow(); + } + } + """.trim + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataFrameCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataFrameCacheSuite.scala new file mode 100644 index 000000000000..62b56bf88218 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataFrameCacheSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameCacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("range/filter should be combined with column codegen") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => i.toFloat).toDF().cache() + .filter("value = 1").selectExpr("value + 1") + assert(df.collect() === Array(Row(2.0))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen).isDefined) + } + + test("filters should be combined with column codegen") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => i.toFloat).toDF().cache() + .filter("value % 2.0 == 0").filter("value % 3.0 == 0") + assert(df.collect() === Array(Row(0), Row(6.0))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen).isDefined) + } + + test("filter with null should be included in WholeStageCodegen with column codegen") { + val toFloat = udf[java.lang.Float, String] { s => if (s == "2") null else s.toFloat } + val df0 = sparkContext.parallelize(0 to 4, 1).map(i => i.toString).toDF() + val df = df0.withColumn("i", toFloat(df0("value"))).select("i").toDF().cache() + .filter("i % 2.0 == 0") + assert(df.collect() === Array(Row(0), Row(4.0))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen).isDefined) + } + + test("Aggregate should be included in WholeStageCodegen with column codegen") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => i.toFloat).toDF().cache() + .groupBy().agg(max(col("value")), avg(col("value"))) + assert(df.collect() === Array(Row(9, 4.5))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + } + + test("Aggregate with grouping keys should be included in WholeStageCodegen with column codegen") { + val df = sparkContext.parallelize(0 to 2, 1).map(i => i.toFloat).toDF().cache() + .groupBy("value").count().orderBy("value") + assert(df.collect() === Array(Row(0.0, 1), Row(1.0, 1), Row(2.0, 1))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + } + + test("Aggregate with columns should be included in WholeStageCodegen with column codegen") { + val df = sparkContext.parallelize(0 to 10, 1).map(i => (i, (i * 2).toDouble)).toDF("i", "d") + .cache().agg(sum("d")) + assert(df.collect() === Array(Row(110.0))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + } + + test("Sort should be included in WholeStageCodegen without column codegen") { + val df = sparkContext.parallelize(Seq(3.toFloat, 2.toFloat, 1.toFloat), 1).toDF() + .sort(col("value")) + val plan = df.queryExecution.executedPlan + assert(df.collect() === Array(Row(1.0), Row(2.0), Row(3.0))) + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + !p.asInstanceOf[WholeStageCodegenExec].enableColumnCodeGen && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined) + } + + test("filter/selectExpr should be combined with column codegen for int array") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => Array(i, i + 1)).toDF().cache() + .filter("value[0] > 4").selectExpr("value[1] + 1") + assert(df.collect() === Array(Row(7), Row(8), Row(9), Row(10), Row(11))) + } + + test("filter/selectExpr should be combined with column codegen for double array") { + val df = sparkContext.parallelize(0 to 9, 1).map(i => Array(i.toDouble, (i + 1).toDouble)) + .toDF().cache().filter("value[0] > 4.0").selectExpr("value[1] + 1") + assert(df.collect() === Array(Row(7.0), Row(8.0), Row(9.0), Row(10.0), Row(11.0))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 100cc4daca87..ebcc33713bb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -401,26 +401,30 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(2, 2, 0) column.putArray(3, 3, 3) - val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] - val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] - val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + val a1 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(0).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a2 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(1).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a3 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(2).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a4 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(3).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) // Verify the ArrayData APIs - assert(column.getArray(0).length == 1) + assert(column.getArray(0).numElements() == 1) assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).length == 2) + assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).length == 0) + assert(column.getArray(2).numElements() == 0) - assert(column.getArray(3).length == 3) + assert(column.getArray(3).numElements() == 3) assert(column.getArray(3).getInt(0) == 3) assert(column.getArray(3).getInt(1) == 4) assert(column.getArray(3).getInt(2) == 5) @@ -433,7 +437,8 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + assert(ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(0).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] === array) }} }