diff --git a/arrow-data-source/common/src/main/java/com/intel/oap/vectorized/ArrowWritableColumnVector.java b/arrow-data-source/common/src/main/java/com/intel/oap/vectorized/ArrowWritableColumnVector.java index b89e74fb6..acc59e6d7 100644 --- a/arrow-data-source/common/src/main/java/com/intel/oap/vectorized/ArrowWritableColumnVector.java +++ b/arrow-data-source/common/src/main/java/com/intel/oap/vectorized/ArrowWritableColumnVector.java @@ -1571,6 +1571,12 @@ final void setLongs(int rowId, int count, byte[] src, int srcIndex) { } } + @Override + final void setDouble(int rowId, double value) { + long val = (long)value; + writer.setSafe(rowId, val); + } + @Override void setLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarExpandExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarExpandExec.scala index 04801d040..55f7eb664 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarExpandExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarExpandExec.scala @@ -90,19 +90,20 @@ case class ColumnarExpandExec( private[this] val numGroups = columnarGroups.length private[this] val resultStructType = StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + private[this] var input_cb: ColumnarBatch = _ override def hasNext: Boolean = (-1 < idx && idx < numGroups) || iter.hasNext override def next(): ColumnarBatch = { if (idx <= 0) { // in the initial (-1) or beginning(0) of a new input row, fetch the next input tuple - val input_cb = iter.next() - input = (0 until input_cb.numCols).toList - .map(input_cb.column(_).asInstanceOf[ArrowWritableColumnVector].getValueVector) + input_cb = iter.next() numRows = input_cb.numRows numInputBatches += 1 idx = 0 } + input = columnarGroups(idx).ordinalList + .map(input_cb.column(_).asInstanceOf[ArrowWritableColumnVector].getValueVector) if (numRows == 0) { idx = -1 diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala index 61f71b859..2e724af39 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala @@ -107,6 +107,11 @@ case class ColumnarHashAggregateExec( buildCheck() + val onlyResultExpressions: Boolean = + if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty && + child.output.isEmpty && resultExpressions.nonEmpty) true + else false + override def doExecuteColumnar(): RDD[ColumnarBatch] = { var eval_elapse: Long = 0 child.executeColumnar().mapPartitions { iter => @@ -138,10 +143,16 @@ case class ColumnarHashAggregateExec( } var numRowsInput = 0 + var hasNextCount = 0 // now we can return this wholestagecodegen iter val res = new Iterator[ColumnarBatch] { var processed = false + /** Three special cases need to be handled in scala side: + * (1) count_literal (2) only result expressions (3) empty input + */ var skip_native = false + var onlyResExpr = false + var emptyInput = false var count_num_row = 0 def process: Unit = { while (iter.hasNext) { @@ -150,7 +161,9 @@ case class ColumnarHashAggregateExec( if (cb.numRows != 0) { numRowsInput += cb.numRows val beforeEval = System.nanoTime() - if (hash_aggr_input_schema.getFields.size == 0) { + if (hash_aggr_input_schema.getFields.size == 0 && + aggregateExpressions.nonEmpty && + aggregateExpressions.head.aggregateFunction.isInstanceOf[Count]) { // This is a special case used by only do count literal count_num_row += cb.numRows skip_native = true @@ -166,9 +179,17 @@ case class ColumnarHashAggregateExec( processed = true } override def hasNext: Boolean = { + hasNextCount += 1 if (!processed) process if (skip_native) { count_num_row > 0 + } else if (onlyResultExpressions && hasNextCount == 1) { + onlyResExpr = true + true + } else if (!onlyResultExpressions && groupingExpressions.isEmpty && + numRowsInput == 0 && hasNextCount == 1) { + emptyInput = true + true } else { nativeIterator.hasNext } @@ -179,28 +200,19 @@ case class ColumnarHashAggregateExec( val beforeEval = System.nanoTime() if (skip_native) { // special handling for only count literal in this operator - val out_res = count_num_row - count_num_row = 0 - val resultColumnVectors = - ArrowWritableColumnVector.allocateColumns(0, resultStructType).toArray - resultColumnVectors.foreach { v => - { - val numRows = v.dataType match { - case t: IntegerType => - out_res.asInstanceOf[Number].intValue - case t: LongType => - out_res.asInstanceOf[Number].longValue - } - v.put(0, numRows) - } - } - return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 1) + getResForCountLiteral + } else if (onlyResExpr) { + // special handling for only result expressions + getResForOnlyResExpr + } else if (emptyInput) { + // special handling for empty input batch + getResForEmptyInput } else { val output_rb = nativeIterator.next if (output_rb == null) { eval_elapse += System.nanoTime() - beforeEval val resultColumnVectors = - ArrowWritableColumnVector.allocateColumns(0, resultStructType).toArray + ArrowWritableColumnVector.allocateColumns(0, resultStructType) return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0) } val outputNumRows = output_rb.getLength @@ -212,6 +224,123 @@ case class ColumnarHashAggregateExec( new ColumnarBatch(output.map(v => v.asInstanceOf[ColumnVector]), outputNumRows) } } + def getResForCountLiteral: ColumnarBatch = { + val resultColumnVectors = + ArrowWritableColumnVector.allocateColumns(0, resultStructType) + if (count_num_row == 0) { + new ColumnarBatch( + resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0) + } else { + val out_res = count_num_row + count_num_row = 0 + for (idx <- resultColumnVectors.indices) { + resultColumnVectors(idx).dataType match { + case t: IntegerType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].intValue) + case t: LongType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].longValue) + case t: DoubleType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].doubleValue()) + case t: FloatType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].floatValue()) + case t: ByteType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].byteValue()) + case t: ShortType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].shortValue()) + case t: StringType => + val values = (out_res :: Nil).map(_.toByte).toArray + resultColumnVectors(idx) + .putBytes(0, 1, values, 0) + } + } + new ColumnarBatch( + resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 1) + } + } + def getResForOnlyResExpr: ColumnarBatch = { + // This function has limited support for only-result-expression case. + // Fake input for projection: + val inputColumnVectors = + ArrowWritableColumnVector.allocateColumns(0, resultStructType) + val valueVectors = + inputColumnVectors.map(columnVector => columnVector.getValueVector).toList + val projector = ColumnarProjection.create(child.output, resultExpressions) + val resultColumnVectorList = projector.evaluate(1, valueVectors) + new ColumnarBatch( + resultColumnVectorList.map(v => v.asInstanceOf[ColumnVector]).toArray, + 1) + } + def getResForEmptyInput: ColumnarBatch = { + val resultColumnVectors = + ArrowWritableColumnVector.allocateColumns(0, resultStructType) + if (aggregateExpressions.isEmpty) { + // To align with spark, in this case, one empty row is returned. + return new ColumnarBatch( + resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 1) + } + // If groupby is not required, for Final mode, a default value will be + // returned if input is empty. + var idx = 0 + for (expr <- aggregateExpressions) { + expr.aggregateFunction match { + case Average(_) | StddevSamp(_) | Sum(_) | Max(_) | Min(_) => + expr.mode match { + case Final => + resultColumnVectors(idx).putNull(0) + idx += 1 + case _ => + } + case Count(_) => + expr.mode match { + case Final => + val out_res = 0 + resultColumnVectors(idx).dataType match { + case t: IntegerType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].intValue) + case t: LongType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].longValue) + case t: DoubleType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].doubleValue()) + case t: FloatType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].floatValue()) + case t: ByteType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].byteValue()) + case t: ShortType => + resultColumnVectors(idx) + .put(0, out_res.asInstanceOf[Number].shortValue()) + case t: StringType => + val values = (out_res :: Nil).map(_.toByte).toArray + resultColumnVectors(idx) + .putBytes(0, 1, values, 0) + } + idx += 1 + case _ => + } + case other => + throw new UnsupportedOperationException(s"not currently supported: $other.") + } + } + // will only put default value for Final mode + aggregateExpressions.head.mode match { + case Final => + new ColumnarBatch( + resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 1) + case _ => + new ColumnarBatch( + resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0) + } + } } SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { close diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala index 91639ebd6..3295be474 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala @@ -344,6 +344,29 @@ class ColumnarHashAggregation( aggregateAttr.toList } + def existsAttrNotFound(allAggregateResultAttributes: List[Attribute]): Unit = { + if (resultExpressions.size == allAggregateResultAttributes.size) { + var resAllAttr = true + breakable { + for (expr <- resultExpressions) { + if (!expr.isInstanceOf[AttributeReference]) { + resAllAttr = false + break + } + } + } + if (resAllAttr) { + for (attr <- resultExpressions) { + if (allAggregateResultAttributes + .indexOf(attr.asInstanceOf[AttributeReference]) == -1) { + throw new IllegalArgumentException( + s"$attr in resultExpressions is not found in allAggregateResultAttributes!") + } + } + } + } + } + def prepareKernelFunction: TreeNode = { // build gandiva projection here. ColumnarPluginConfig.getConf @@ -420,6 +443,11 @@ class ColumnarHashAggregation( s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType)) }) + + // If some Attributes in result expressions (contain attributes only) are not found + // in allAggregateResultAttributes, an exception will be thrown. + existsAttrNotFound(allAggregateResultAttributes) + val nativeFuncNodes = groupingNativeFuncNodes ::: aggrNativeFuncNodes // 4. prepare after aggregate result expressions diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala index 413a4a0e7..4fe641e61 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala @@ -623,6 +623,31 @@ class ColumnarMakeDecimal( } } +class ColumnarNormalizeNaNAndZero(child: Expression, original: NormalizeNaNAndZero) + extends NormalizeNaNAndZero(child: Expression) + with ColumnarExpression + with Logging { + + buildCheck() + + def buildCheck(): Unit = { + val supportedTypes = List(FloatType, DoubleType) + if (supportedTypes.indexOf(child.dataType) == -1) { + throw new UnsupportedOperationException( + s"${child.dataType} is not supported in ColumnarNormalizeNaNAndZero") + } + } + + override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { + val (child_node, childType): (TreeNode, ArrowType) = + child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + + val normalizeNode = TreeBuilder.makeFunction( + "normalize", Lists.newArrayList(child_node), childType) + (normalizeNode, childType) + } +} + object ColumnarUnaryOperator { def create(child: Expression, original: Expression): Expression = original match { @@ -652,8 +677,8 @@ object ColumnarUnaryOperator { new ColumnarBitwiseNot(child, n) case a: KnownFloatingPointNormalized => child - case a: NormalizeNaNAndZero => - child + case n: NormalizeNaNAndZero => + new ColumnarNormalizeNaNAndZero(child, n) case a: PromotePrecision => child case a: CheckOverflow => diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 597877ed1..6b785fff3 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -742,7 +742,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } } - ignore("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") { + test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") { withTable("t") { withTempPath { path => Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) @@ -824,7 +824,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } } - ignore("SPARK-19993 subquery with cached underlying relation") { + test("SPARK-19993 subquery with cached underlying relation") { withTempView("t1") { Seq(1).toDF("c1").createOrReplaceTempView("t1") spark.catalog.cacheTable("t1") @@ -1029,7 +1029,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils SHUFFLE_HASH) } - ignore("analyzes column statistics in cached query") { + test("analyzes column statistics in cached query") { def query(): DataFrame = { spark.range(100) .selectExpr("id % 3 AS c0", "id % 5 AS c1", "2 AS c2") diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index f3b1d7fe3..f094e48d0 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -744,26 +744,28 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } } - ignore("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") { - withTempPath { dir => - val data = sparkContext.parallelize(0 to 10).toDF("id") - data.write.parquet(dir.getCanonicalPath) - - // Test the 3 expressions when reading from files - val q = spark.read.parquet(dir.getCanonicalPath).select( - input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) - val firstRow = q.head() - assert(firstRow.getString(0).contains(dir.toURI.getPath)) - assert(firstRow.getLong(1) == 0) - assert(firstRow.getLong(2) > 0) - - // Now read directly from the original RDD without going through any files to make sure - // we are returning empty string, -1, and -1. - checkAnswer( - data.select( - input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()") - ).limit(1), - Row("", -1L, -1L)) + test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") { + withSQLConf(("spark.oap.sql.columnar.batchscan", "true")) { + withTempPath { dir => + val data = sparkContext.parallelize(0 to 10).toDF("id") + data.write.parquet(dir.getCanonicalPath) + + // Test the 3 expressions when reading from files + val q = spark.read.parquet(dir.getCanonicalPath).select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.toURI.getPath)) + assert(firstRow.getLong(1) == 0) + assert(firstRow.getLong(2) > 0) + + // Now read directly from the original RDD without going through any files to make sure + // we are returning empty string, -1, and -1. + checkAnswer( + data.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()") + ).limit(1), + Row("", -1L, -1L)) + } } } diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index 2eb959a4c..1943ca80a 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -80,7 +80,7 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession { checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) } - ignore("named_struct is used in the top Project") { + test("named_struct is used in the top Project") { val df = spark.table("tab").selectExpr( "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") .selectExpr("col1.a", "col1") diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7d91919d8..0448e2cdd 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -149,7 +149,7 @@ class DataFrameAggregateSuite extends QueryTest ) } - ignore("cube") { + test("cube") { checkAnswer( courseSales.cube("course", "year").sum("earnings"), Row("Java", 2012, 20000.0) :: @@ -173,7 +173,7 @@ class DataFrameAggregateSuite extends QueryTest assert(cube0.where("date IS NULL").count > 0) } - ignore("grouping and grouping_id") { + test("grouping and grouping_id") { checkAnswer( courseSales.cube("course", "year") .agg(grouping("course"), grouping("year"), grouping_id("course", "year")), @@ -211,7 +211,7 @@ class DataFrameAggregateSuite extends QueryTest } } - ignore("grouping/grouping_id inside window function") { + test("grouping/grouping_id inside window function") { val w = Window.orderBy(sum("earnings")) checkAnswer( @@ -231,7 +231,7 @@ class DataFrameAggregateSuite extends QueryTest ) } - ignore("SPARK-21980: References in grouping functions should be indexed with semanticEquals") { + test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") { checkAnswer( courseSales.cube("course", "year") .agg(grouping("CouRse"), grouping("year")), @@ -302,7 +302,7 @@ class DataFrameAggregateSuite extends QueryTest ) } - ignore("agg without groups and functions") { + test("agg without groups and functions") { checkAnswer( testData2.agg(lit(1)), Row(1) @@ -350,7 +350,7 @@ class DataFrameAggregateSuite extends QueryTest Row(2.0, 2.0)) } - ignore("zero average") { + test("zero average") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(avg($"a")), @@ -369,7 +369,7 @@ class DataFrameAggregateSuite extends QueryTest Row(6, 6.0)) } - ignore("null count") { + test("null count") { checkAnswer( testData3.groupBy($"a").agg(count($"b")), Seq(Row(1, 0), Row(2, 1)) @@ -392,7 +392,7 @@ class DataFrameAggregateSuite extends QueryTest ) } - ignore("multiple column distinct count") { + test("multiple column distinct count") { val df1 = Seq( ("a", "b", "c"), ("a", "b", "c"), @@ -417,7 +417,7 @@ class DataFrameAggregateSuite extends QueryTest ) } - ignore("zero count") { + test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(count($"a"), sumDistinct($"a")), // non-partial @@ -441,14 +441,14 @@ class DataFrameAggregateSuite extends QueryTest Row(null, null, null)) } - ignore("zero sum") { + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(sum($"a")), Row(null)) } - ignore("zero sum distinct") { + test("zero sum distinct") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(sumDistinct($"a")), @@ -593,7 +593,7 @@ class DataFrameAggregateSuite extends QueryTest Seq(Row(Seq(1.0, 2.0)))) } - ignore("SPARK-14664: Decimal sum/avg over window should work.") { + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) @@ -732,8 +732,6 @@ class DataFrameAggregateSuite extends QueryTest } } - //TODO: failed ut - /* testWithWholeStageCodegenOnAndOff("SPARK-22951: dropDuplicates on empty dataFrames " + "should produce correct aggregate") { _ => // explicit global aggregations @@ -748,7 +746,6 @@ class DataFrameAggregateSuite extends QueryTest // global aggregation is converted to grouping aggregation: assert(spark.emptyDataFrame.dropDuplicates().count() == 0) } - */ test("SPARK-21896: Window functions inside aggregate functions") { def checkWindowError(df: => DataFrame): Unit = { @@ -790,7 +787,7 @@ class DataFrameAggregateSuite extends QueryTest "type: GroupBy]")) } - ignore("SPARK-26021: NaN and -0.0 in grouping expressions") { + test("SPARK-26021: NaN and -0.0 in grouping expressions") { checkAnswer( Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy("f").count(), Row(0.0f, 2) :: Row(Float.NaN, 2) :: Nil) @@ -954,7 +951,7 @@ class DataFrameAggregateSuite extends QueryTest } } - ignore("count_if") { + test("count_if") { withTempView("tempView") { Seq(("a", None), ("a", Some(1)), ("a", Some(2)), ("a", Some(3)), ("b", None), ("b", Some(4)), ("b", Some(5)), ("b", Some(6))) diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 7aa19abbd..d97649f09 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -73,7 +73,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { spark.sparkContext.parallelize(data), schema) } - ignore("drop") { + test("drop") { val input = createDF() val rows = input.collect() diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index c45eea7db..595a2c221 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -47,7 +47,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSparkSession with Eventua .set("spark.oap.sql.columnar.preferColumnar", "true") .set("spark.oap.sql.columnar.sortmergejoin", "true") - ignore("SPARK-7150 range api") { + test("SPARK-7150 range api") { // numSlice is greater than length val res1 = spark.range(0, 10, 1, 15).select("id") assert(res1.count == 10) diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a4b265365..77b9164c0 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -112,7 +112,7 @@ class DataFrameSuite extends QueryTest testData.collect().toSeq) } - ignore("empty data frame") { + test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) } diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 665f52deb..42a1a2856 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -134,7 +134,7 @@ class DatasetSuite extends QueryTest 1, 1, 1) } - ignore("emptyDataset") { + test("emptyDataset") { val ds = spark.emptyDataset[Int] assert(ds.count() == 0L) assert(ds.collect() sameElements Array.empty[Int]) @@ -1542,7 +1542,7 @@ class DatasetSuite extends QueryTest checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) } - ignore("SPARK-20399: do not unescaped regex pattern when ESCAPED_STRING_LITERALS is enabled") { + test("SPARK-20399: do not unescaped regex pattern when ESCAPED_STRING_LITERALS is enabled") { withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { val data = Seq("\u0020\u0021\u0023", "abc") val df = data.toDF() diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 04ece5e78..77e74f2de 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3192,7 +3192,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } - ignore("reset command should not fail with cache") { + test("reset command should not fail with cache") { withTable("tbl") { val provider = spark.sessionState.conf.defaultDataSourceName sql(s"CREATE TABLE tbl(i INT, j STRING) USING $provider") diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index dcc628223..accd9996f 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -113,7 +113,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } - ignore("analyze empty table") { + test("analyze empty table") { val table = "emptyTable" withTable(table) { val df = Seq.empty[Int].toDF("key") @@ -413,7 +413,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } - ignore("invalidation of tableRelationCache after alter table add partition") { + test("invalidation of tableRelationCache after alter table add partition") { val table = "invalidate_catalog_cache_table" Seq(false, true).foreach { autoUpdate => withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 1f9ac3887..7d33ef6b5 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -980,7 +980,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark assert(optimizedPlan.resolved) } - ignore("SPARK-23316: AnalysisException after max iteration reached for IN query") { + test("SPARK-23316: AnalysisException after max iteration reached for IN query") { // before the fix this would throw AnalysisException spark.range(10).where("(id,id) in (select id, null from range(3))").count } diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index de5bb259f..edf6e5a76 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -151,7 +151,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } } - ignore("Ignore mode if table exists - session catalog") { + test("Ignore mode if table exists - session catalog") { sql(s"create table t1 (id bigint) using $format") val df = spark.range(10).withColumn("part", 'id % 5) val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") @@ -163,7 +163,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with assert(load("t1", None).count() === 0) } - ignore("Ignore mode if table exists - testcat catalog") { + test("Ignore mode if table exists - testcat catalog") { sql(s"create table $catalogName.t1 (id bigint) using $format") val df = spark.range(10).withColumn("part", 'id % 5) val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 09af5997d..e3938220b 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -154,7 +154,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert(dsStringFilter.collect() === Array("1")) } - ignore("SPARK-19512 codegen for comparing structs is incorrect") { + test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10) .selectExpr("named_struct('a', id) as col1", "named_struct('a', id+2) as col2") diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 7c4d39c0e..264a89a71 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -478,7 +478,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession { } } - ignore("SPARK-22249: IN should work also with cached DataFrame") { + test("SPARK-22249: IN should work also with cached DataFrame") { val df = spark.range(10).cache() // with an empty list assert(df.filter($"id".isin()).count() == 0) diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index ecadd0b02..402756913 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2244,7 +2244,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - ignore("show functions") { + test("show functions") { withUserDefinedFunction("add_one" -> true) { val numFunctions = FunctionRegistry.functionSet.size.toLong + FunctionsCommand.virtualOperators.size.toLong @@ -2286,7 +2286,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(rows.length > 0) } - ignore("SET LOCATION for managed table") { + test("SET LOCATION for managed table") { withTable("tbl") { withTempDir { dir => sql("CREATE TABLE tbl(i INT) USING parquet") @@ -2465,7 +2465,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - ignore("Partition table should load empty static partitions") { + test("Partition table should load empty static partitions") { // All static partitions withTable("t", "t1", "t2") { withTempPath { dir => diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 31768dab3..9e8b55ce0 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -392,7 +392,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa assert(result.schema.fieldNames.size === 1) } - ignore("DDL test with empty file") { + test("DDL test with empty file") { withView("carsTable") { spark.sql( s""" @@ -1376,7 +1376,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa } } - ignore("SPARK-21263: Invalid float and double are handled correctly in different modes") { + test("SPARK-21263: Invalid float and double are handled correctly in different modes") { val exception = intercept[SparkException] { spark.read.schema("a DOUBLE") .option("mode", "FAILFAST") @@ -1862,7 +1862,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa } } - ignore("count() for malformed input") { + test("count() for malformed input") { def countForMalformedCSV(expected: Long, input: Seq[String]): Unit = { val schema = new StructType().add("a", IntegerType) val strings = spark.createDataset(input) diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 869b90dd1..4e711eccd 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2509,7 +2509,7 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson } } - ignore("count() for malformed input") { + test("count() for malformed input") { def countForMalformedJSON(expected: Long, input: Seq[String]): Unit = { val schema = new StructType().add("a", StringType) val strings = spark.createDataset(input) diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 29ad8bb36..046b5b74d 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -390,7 +390,7 @@ abstract class OrcQueryTest extends OrcTest { } } - ignore("SPARK-10623 Enable ORC PPD") { + test("SPARK-10623 Enable ORC PPD") { withTempPath { dir => withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { @@ -460,7 +460,7 @@ abstract class OrcQueryTest extends OrcTest { } } - ignore("SPARK-15198 Support for pushing down filters for boolean types") { + test("SPARK-15198 Support for pushing down filters for boolean types") { withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { val data = (0 until 10).map(_ => (true, false)) withOrcFile(data) { file => diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 3fffedee1..5cc122232 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -97,7 +97,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { } if (joinType != FullOuter) { - ignore(s"$testName using ShuffledHashJoin") { + test(s"$testName using ShuffledHashJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft @@ -131,7 +131,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { } } - ignore(s"$testName using SortMergeJoin") { + test(s"$testName using SortMergeJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index d44326d4b..a7fc3cf18 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -147,7 +147,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { assert(conf.get("spark.sql.warehouse.dir") === warehouseDir) } - ignore("reset - public conf") { + test("reset - public conf") { spark.sessionState.conf.clear() val original = spark.conf.get(SQLConf.GROUP_BY_ORDINAL) try { @@ -163,7 +163,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } } - ignore("reset - internal conf") { + test("reset - internal conf") { spark.sessionState.conf.clear() val original = spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) try { diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e8cb0479e..10f51d955 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -208,7 +208,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { } } - ignore("Truncate") { + test("Truncate") { JdbcDialects.registerDialect(testH2Dialect) val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index fa4980876..7be15e9d8 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -86,7 +86,7 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with protected def testWithWholeStageCodegenOnAndOff(testName: String)(f: String => Unit): Unit = { Seq("false", "true").foreach { codegenEnabled => val isTurnOn = if (codegenEnabled == "true") "on" else "off" - ignore(s"$testName (whole-stage-codegen ${isTurnOn})") { + test(s"$testName (whole-stage-codegen ${isTurnOn})") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled) { f(codegenEnabled) } diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/travis/TravisDataFrameAggregateSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/travis/TravisDataFrameAggregateSuite.scala index 288d0e206..a99f95f28 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/travis/TravisDataFrameAggregateSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/travis/TravisDataFrameAggregateSuite.scala @@ -59,7 +59,7 @@ class TravisDataFrameAggregateSuite extends QueryTest val absTol = 1e-8 - ignore("groupBy") { + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(1, 3), Row(2, 3), Row(3, 3)) @@ -127,7 +127,7 @@ class TravisDataFrameAggregateSuite extends QueryTest ) } - ignore("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") { + test("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") { val df = Seq(("some[thing]", "random-string")).toDF("key", "val") checkAnswer( @@ -149,7 +149,7 @@ class TravisDataFrameAggregateSuite extends QueryTest ) } - ignore("cube") { + test("cube") { checkAnswer( courseSales.cube("course", "year").sum("earnings"), Row("Java", 2012, 20000.0) :: @@ -173,7 +173,7 @@ class TravisDataFrameAggregateSuite extends QueryTest assert(cube0.where("date IS NULL").count > 0) } - ignore("grouping and grouping_id") { + test("grouping and grouping_id") { checkAnswer( courseSales.cube("course", "year") .agg(grouping("course"), grouping("year"), grouping_id("course", "year")), @@ -211,7 +211,7 @@ class TravisDataFrameAggregateSuite extends QueryTest } } - ignore("grouping/grouping_id inside window function") { + test("grouping/grouping_id inside window function") { val w = Window.orderBy(sum("earnings")) checkAnswer( @@ -231,7 +231,7 @@ class TravisDataFrameAggregateSuite extends QueryTest ) } - ignore("SPARK-21980: References in grouping functions should be indexed with semanticEquals") { + test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") { checkAnswer( courseSales.cube("course", "year") .agg(grouping("CouRse"), grouping("year")), @@ -302,14 +302,14 @@ class TravisDataFrameAggregateSuite extends QueryTest ) } - ignore("agg without groups and functions") { + test("agg without groups and functions") { checkAnswer( testData2.agg(lit(1)), Row(1) ) } - ignore("average") { + test("average") { checkAnswer( testData2.agg(avg($"a"), mean($"a")), Row(2.0, 2.0)) @@ -350,7 +350,7 @@ class TravisDataFrameAggregateSuite extends QueryTest Row(2.0, 2.0)) } - ignore("zero average") { + test("zero average") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(avg($"a")), @@ -369,7 +369,7 @@ class TravisDataFrameAggregateSuite extends QueryTest Row(6, 6.0)) } - ignore("null count") { + test("null count") { checkAnswer( testData3.groupBy($"a").agg(count($"b")), Seq(Row(1, 0), Row(2, 1)) @@ -392,7 +392,7 @@ class TravisDataFrameAggregateSuite extends QueryTest ) } - ignore("multiple column distinct count") { + test("multiple column distinct count") { val df1 = Seq( ("a", "b", "c"), ("a", "b", "c"), @@ -417,14 +417,14 @@ class TravisDataFrameAggregateSuite extends QueryTest ) } - ignore("zero count") { + test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(count($"a"), sumDistinct($"a")), // non-partial Row(0, null)) } - ignore("stddev") { + test("stddev") { val testData2ADev = math.sqrt(4.0 / 5.0) checkAnswer( testData2.agg(stddev($"a"), stddev_pop($"a"), stddev_samp($"a")), @@ -434,28 +434,28 @@ class TravisDataFrameAggregateSuite extends QueryTest Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) } - ignore("zero stddev") { + test("zero stddev") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( - emptyTableData.agg(stddev($"a"), stddev_pop($"a"), stddev_samp($"a")), - Row(null, null, null)) + emptyTableData.agg(stddev($"a"), stddev_pop($"a"), stddev_samp($"a")), + Row(null, null, null)) } - ignore("zero sum") { + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(sum($"a")), Row(null)) } - ignore("zero sum distinct") { + test("zero sum distinct") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(sumDistinct($"a")), Row(null)) } - ignore("moments") { + test("moments") { val sparkVariance = testData2.agg(variance($"a")) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) @@ -473,7 +473,7 @@ class TravisDataFrameAggregateSuite extends QueryTest checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol) } - ignore("zero moments") { + test("zero moments") { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( input.agg(stddev($"a"), stddev_samp($"a"), stddev_pop($"a"), variance($"a"), @@ -495,7 +495,7 @@ class TravisDataFrameAggregateSuite extends QueryTest Double.NaN, Double.NaN)) } - ignore("null moments") { + test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer(emptyTableData.agg( variance($"a"), var_samp($"a"), var_pop($"a"), skewness($"a"), kurtosis($"a")), @@ -547,7 +547,7 @@ class TravisDataFrameAggregateSuite extends QueryTest ) } - ignore("SPARK-31500: collect_set() of BinaryType returns duplicate elements") { + test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") { val bytesTest1 = "test1".getBytes val bytesTest2 = "test2".getBytes val df = Seq(bytesTest1, bytesTest1, bytesTest2).toDF("a") @@ -593,7 +593,7 @@ class TravisDataFrameAggregateSuite extends QueryTest Seq(Row(Seq(1.0, 2.0)))) } - ignore("SPARK-14664: Decimal sum/avg over window should work.") { + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) @@ -602,7 +602,7 @@ class TravisDataFrameAggregateSuite extends QueryTest Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } - ignore("SQL decimal test (used for catching certain decimal handling bugs in aggregates)") { + test("SQL decimal test (used for catching certain decimal handling bugs in aggregates)") { checkAnswer( decimalData.groupBy($"a" cast DecimalType(10, 2)).agg(avg($"b" cast DecimalType(10, 2))), Seq(Row(new java.math.BigDecimal(1), new java.math.BigDecimal("1.5")), @@ -626,7 +626,7 @@ class TravisDataFrameAggregateSuite extends QueryTest limit2Df.select($"id")) } - ignore("SPARK-17237 remove backticks in a pivot result schema") { + test("SPARK-17237 remove backticks in a pivot result schema") { val df = Seq((2, 3, 4), (3, 4, 5)).toDF("a", "x", "y") withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { checkAnswer( @@ -645,7 +645,7 @@ class TravisDataFrameAggregateSuite extends QueryTest private def assertNoExceptions(c: Column): Unit = { for ((wholeStage, useObjectHashAgg) <- - Seq((true, true), (true, false), (false, true), (false, false))) { + Seq((true, true), (true, false), (false, true), (false, false))) { withSQLConf( (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { @@ -679,7 +679,7 @@ class TravisDataFrameAggregateSuite extends QueryTest } } - ignore("SPARK-19471: AggregationIterator does not initialize the generated result projection" + + test("SPARK-19471: AggregationIterator does not initialize the generated result projection" + " before using it") { Seq( monotonically_increasing_id(), spark_partition_id(), @@ -732,8 +732,6 @@ class TravisDataFrameAggregateSuite extends QueryTest } } - //TODO: failed ut - /* testWithWholeStageCodegenOnAndOff("SPARK-22951: dropDuplicates on empty dataFrames " + "should produce correct aggregate") { _ => // explicit global aggregations @@ -748,7 +746,6 @@ class TravisDataFrameAggregateSuite extends QueryTest // global aggregation is converted to grouping aggregation: assert(spark.emptyDataFrame.dropDuplicates().count() == 0) } - */ test("SPARK-21896: Window functions inside aggregate functions") { def checkWindowError(df: => DataFrame): Unit = { @@ -790,7 +787,7 @@ class TravisDataFrameAggregateSuite extends QueryTest "type: GroupBy]")) } - ignore("SPARK-26021: NaN and -0.0 in grouping expressions") { + test("SPARK-26021: NaN and -0.0 in grouping expressions") { checkAnswer( Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy("f").count(), Row(0.0f, 2) :: Row(Float.NaN, 2) :: Nil) @@ -842,7 +839,7 @@ class TravisDataFrameAggregateSuite extends QueryTest checkAnswer(countAndDistinct, Row(100000, 100)) } - ignore("max_by") { + test("max_by") { val yearOfMaxEarnings = sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course") checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil) @@ -898,7 +895,7 @@ class TravisDataFrameAggregateSuite extends QueryTest } } - ignore("min_by") { + test("min_by") { val yearOfMinEarnings = sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course") checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil) @@ -954,7 +951,7 @@ class TravisDataFrameAggregateSuite extends QueryTest } } - ignore("count_if") { + test("count_if") { withTempView("tempView") { Seq(("a", None), ("a", Some(1)), ("a", Some(2)), ("a", Some(3)), ("b", None), ("b", Some(4)), ("b", Some(5)), ("b", Some(6))) diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc index 8cac3a5cb..c930d5cce 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc @@ -137,8 +137,6 @@ class UniqueAction : public ActionBase { #endif } - int RequiredColNum() { return 1; } - arrow::Status Submit(ArrayList in_list, int max_group_id, std::function* on_valid, std::function* on_null) override { @@ -288,8 +286,6 @@ class CountAction : public ActionBase { #endif } - int RequiredColNum() { return 1; } - arrow::Status Submit(ArrayList in_list, int max_group_id, std::function* on_valid, std::function* on_null) override { @@ -299,12 +295,25 @@ class CountAction : public ActionBase { length_ = cache_.size(); } - in_ = in_list[0]; + in_list_ = in_list; row_id = 0; + bool has_null = false; + for (int i = 0; i < in_list.size(); i++) { + if (in_list_[i]->null_count()) { + has_null = true; + break; + } + } // prepare evaluate lambda - if (in_->null_count()) { + if (has_null) { *on_valid = [this](int dest_group_id) { - const bool is_null = in_->IsNull(row_id); + bool is_null = false; + for (int i = 0; i < in_list_.size(); i++) { + if (in_list_[i]->IsNull(row_id)) { + is_null = true; + break; + } + } if (!is_null) { cache_[dest_group_id] += 1; } @@ -341,12 +350,23 @@ class CountAction : public ActionBase { cache_.resize(1, 0); length_ = 1; } - arrow::Datum output; - arrow::compute::CountOptions option(arrow::compute::CountOptions::COUNT_NON_NULL); - auto maybe_output = arrow::compute::Count(*in[0].get(), option, ctx_); - output = *std::move(maybe_output); - auto typed_scalar = std::dynamic_pointer_cast(output.scalar()); - cache_[0] += typed_scalar->value; + int length = in[0]->length(); + int count_non_null = 0; + if (in.size() == 1) { + count_non_null = length - in[0]->null_count(); + } else { + int count_null = 0; + for (int id = 0; id < length; id++) { + for (int colId = 0; colId < in.size(); colId++) { + if (in[colId]->IsNull(id)) { + count_null++; + break; + } + } + } + count_non_null = length - count_null; + } + cache_[0] += count_non_null; return arrow::Status::OK(); } @@ -399,7 +419,7 @@ class CountAction : public ActionBase { using ScalarType = typename arrow::TypeTraits::ScalarType; // input arrow::compute::ExecContext* ctx_; - std::shared_ptr in_; + ArrayList in_list_; int32_t row_id; // result using CType = typename arrow::TypeTraits::CType; @@ -428,8 +448,6 @@ class CountLiteralAction : public ActionBase { #endif } - int RequiredColNum() { return 0; } - arrow::Status Submit(ArrayList in_list, int max_group_id, std::function* on_valid, std::function* on_null) override { @@ -553,8 +571,6 @@ class MinAction> #endif } - int RequiredColNum() { return 1; } - arrow::Status Submit(ArrayList in_list, int max_group_id, std::function* on_valid, std::function* on_null) override { @@ -715,8 +731,6 @@ class MinAction> #endif } - int RequiredColNum() { return 1; } - arrow::Status Submit(ArrayList in_list, int max_group_id, std::function* on_valid, std::function* on_null) override { @@ -877,8 +891,6 @@ class MaxAction> #endif } - int RequiredColNum() { return 1; } - arrow::Status Submit(ArrayList in_list, int max_group_id, std::function* on_valid, std::function* on_null) override { @@ -1039,8 +1051,6 @@ class MaxAction> #endif } - int RequiredColNum() { return 1; } - arrow::Status Submit(ArrayList in_list, int max_group_id, std::function* on_valid, std::function* on_null) override { @@ -1203,8 +1213,6 @@ class SumAction* on_valid, std::function* on_null) override { @@ -1350,8 +1358,6 @@ class SumAction* on_valid, std::function* on_null) override { @@ -1499,8 +1505,6 @@ class AvgAction : public ActionBase { #endif } - int RequiredColNum() { return 1; } - arrow::Status Submit(ArrayList in_list, int max_group_id, std::function* on_valid, std::function* on_null) override { @@ -1682,8 +1686,6 @@ class SumCountAction* on_valid, std::function* on_null) override { @@ -1863,8 +1865,6 @@ class SumCountAction* on_valid, std::function* on_null) override { @@ -2039,8 +2039,6 @@ class SumCountMergeAction* on_valid, std::function* on_null) override { @@ -2222,8 +2220,6 @@ class SumCountMergeAction* on_valid, std::function* on_null) override { @@ -2394,8 +2390,6 @@ class AvgByCountAction* on_valid, std::function* on_null) override { @@ -2574,8 +2568,6 @@ class AvgByCountAction* on_valid, std::function* on_null) override { @@ -2771,8 +2763,6 @@ class StddevSampPartialAction* on_valid, std::function* on_null) override { @@ -3007,8 +2997,6 @@ class StddevSampPartialAction* on_valid, std::function* on_null) override { @@ -3235,8 +3223,6 @@ class StddevSampFinalAction* on_valid, std::function* on_null) override { @@ -3434,8 +3420,6 @@ class StddevSampFinalAction* on_valid, std::function* on_null) override { diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.h b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.h index eb2bfa664..afae347ca 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.h @@ -40,8 +40,6 @@ class ActionBase { public: virtual ~ActionBase() {} - virtual int RequiredColNum() { return 1; } - virtual arrow::Status Submit(ArrayList in, int max_group_id, std::function* on_valid, std::function* on_null); diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index 4502ea474..1285522d0 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -802,6 +802,27 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } prepare_str_ += prepare_ss.str(); check_str_ = validity; + } else if (func_name.compare("normalize") == 0) { + codes_str_ = "normalize_" + std::to_string(cur_func_id); + auto validity = codes_str_ + "_validity"; + std::stringstream fix_ss; + fix_ss << "normalize_nan_zero(" << child_visitor_list[0]->GetResult() << ")"; + std::stringstream prepare_ss; + prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" + << std::endl; + prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck() + << ";" << std::endl; + prepare_ss << "if (" << validity << ") {" << std::endl; + prepare_ss << codes_str_ << " = (" << GetCTypeString(node.return_type()) << ")" + << fix_ss.str() << ";" << std::endl; + prepare_ss << "}" << std::endl; + + for (int i = 0; i < 1; i++) { + prepare_str_ += child_visitor_list[i]->GetPrepare(); + } + prepare_str_ += prepare_ss.str(); + check_str_ = validity; + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else { return arrow::Status::NotImplemented(func_name, " is currently not supported."); } diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc index 9b565a7d2..47172abb8 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc @@ -372,12 +372,33 @@ class HashAggregateKernel::Impl { action_codes_ss << project_output_list[i].first.second << std::endl; project_output_list[i].first.second = ""; } - if (idx_v.size() > 0) - action_codes_ss << "if (" << project_output_list[idx_v[0]].first.first - << "_validity) {" << std::endl; + if (idx_v.size() > 0) { + if (action_name_str_list[action_idx] != "\"action_count\"") { + action_codes_ss << "if (" << project_output_list[idx_v[0]].first.first + << "_validity) {" << std::endl; + } else { + // For action_count with mutiple-col input, will check the validity + // of all the input cols. + action_codes_ss << "if (" << project_output_list[idx_v[0]].first.first + << "_validity"; + for (int i = 1; i < idx_v.size() - 1; i++) { + action_codes_ss << " && " << project_output_list[idx_v[i]].first.first + << "_validity"; + } + action_codes_ss << " && " + << project_output_list[idx_v[idx_v.size() - 1]].first.first + << "_validity) {" << std::endl; + } + } std::vector parameter_list; - for (auto i : idx_v) { - parameter_list.push_back("(void*)&" + project_output_list[i].first.first); + if (action_name_str_list[action_idx] != "\"action_count\"") { + for (auto i : idx_v) { + parameter_list.push_back("(void*)&" + project_output_list[i].first.first); + } + } else { + // For action_count, only the first col will be used as input to Evaluate + // function, in which it will not be used. + parameter_list.push_back("(void*)&" + project_output_list[idx_v[0]].first.first); } action_codes_ss << "RETURN_NOT_OK(aggr_action_list_" << level << "[" << action_idx << "]->Evaluate(memo_index" << GetParameterList(parameter_list) diff --git a/native-sql-engine/cpp/src/precompile/gandiva.h b/native-sql-engine/cpp/src/precompile/gandiva.h index 6d1614684..b7500bee4 100644 --- a/native-sql-engine/cpp/src/precompile/gandiva.h +++ b/native-sql-engine/cpp/src/precompile/gandiva.h @@ -205,6 +205,16 @@ bool equal_with_nan(double left, double right) { return left == right; } +double normalize_nan_zero(double in) { + if (std::isnan(in)) { + return 0.0 / 0.0; + } else if (in < 0 && std::abs(in) < 0.0000001) { + return 0.0; + } else { + return in; + } +} + arrow::Decimal128 round(arrow::Decimal128 in, int32_t original_precision, int32_t original_scale, bool* overflow_, int32_t res_scale = 2) { bool overflow = false; diff --git a/native-sql-engine/cpp/src/tests/arrow_compute_test_aggregate.cc b/native-sql-engine/cpp/src/tests/arrow_compute_test_aggregate.cc index 28e85509b..231466bd7 100644 --- a/native-sql-engine/cpp/src/tests/arrow_compute_test_aggregate.cc +++ b/native-sql-engine/cpp/src/tests/arrow_compute_test_aggregate.cc @@ -346,6 +346,84 @@ TEST(TestArrowCompute, GroupByCountAll) { } } +TEST(TestArrowCompute, GroupByCountOnMutipleCols) { + auto f0 = field("f0", utf8()); + auto f1 = field("f1", utf8()); + auto f2 = field("f2", utf8()); + auto f_unique = field("unique", utf8()); + auto f_count = field("count", int64()); + auto f_res = field("res", uint32()); + + auto arg0 = TreeExprBuilder::MakeField(f0); + auto arg1 = TreeExprBuilder::MakeField(f1); + auto arg2 = TreeExprBuilder::MakeField(f2); + + auto n_groupby = TreeExprBuilder::MakeFunction("action_groupby", {arg0}, uint32()); + auto n_count = TreeExprBuilder::MakeFunction("action_count", {arg1, arg2}, uint32()); + auto n_proj = + TreeExprBuilder::MakeFunction("aggregateExpressions", {arg0, arg1, arg2}, uint32()); + auto n_action = + TreeExprBuilder::MakeFunction("aggregateActions", {n_groupby, n_count}, uint32()); + auto n_result = TreeExprBuilder::MakeFunction( + "resultSchema", + {TreeExprBuilder::MakeField(f_unique), TreeExprBuilder::MakeField(f_count)}, + uint32()); + auto n_result_expr = TreeExprBuilder::MakeFunction( + "resultExpressions", + {TreeExprBuilder::MakeField(f_unique), TreeExprBuilder::MakeField(f_count)}, + uint32()); + auto n_aggr = TreeExprBuilder::MakeFunction( + "hashAggregateArrays", {n_proj, n_action, n_result, n_result_expr}, uint32()); + auto n_child = TreeExprBuilder::MakeFunction("standalone", {n_aggr}, uint32()); + auto aggr_expr = TreeExprBuilder::MakeExpression(n_child, f_res); + + std::vector> expr_vector = {aggr_expr}; + + auto sch = arrow::schema({f0, f1, f2}); + std::vector> ret_types = {f_unique, f_count}; + + /////////////////////// Create Expression Evaluator //////////////////// + std::shared_ptr expr; + arrow::compute::ExecContext ctx; + ASSERT_NOT_OK( + CreateCodeGenerator(ctx.memory_pool(), sch, expr_vector, ret_types, &expr, true)) + + std::shared_ptr input_batch; + std::vector> output_batch_list; + + ////////////////////// calculation ///////////////////// + + std::shared_ptr> aggr_result_iterator; + std::shared_ptr aggr_result_iterator_base; + ASSERT_NOT_OK(expr->finish(&aggr_result_iterator_base)); + aggr_result_iterator = std::dynamic_pointer_cast>( + aggr_result_iterator_base); + + std::vector input_data = {R"(["a", "a", "a", "x", "x"])", + R"(["b", "b", "b", "y", "q"])", + R"([null, "c", "d", "z", null])"}; + MakeInputBatch(input_data, sch, &input_batch); + ASSERT_NOT_OK(aggr_result_iterator->ProcessAndCacheOne(input_batch->columns())); + + std::vector input_data_2 = {R"(["b", "a", "b", "a", "x"])", + R"(["b", "b", "b", null, "q"])", + R"(["c", null, "d", "z", null])"}; + MakeInputBatch(input_data_2, sch, &input_batch); + ASSERT_NOT_OK(aggr_result_iterator->ProcessAndCacheOne(input_batch->columns())); + + ////////////////////// Finish ////////////////////////// + + std::shared_ptr expected_result; + std::shared_ptr result_batch; + std::vector expected_result_string = {R"(["a", "x", "b"])", "[2, 1, 2]"}; + auto res_sch = arrow::schema(ret_types); + MakeInputBatch(expected_result_string, res_sch, &expected_result); + if (aggr_result_iterator->HasNext()) { + ASSERT_NOT_OK(aggr_result_iterator->Next(&result_batch)); + ASSERT_NOT_OK(Equals(*expected_result.get(), *result_batch.get())); + } +} + TEST(TestArrowCompute, GroupByTwoAggregateTest) { ////////////////////// prepare expr_vector /////////////////////// auto f0 = field("f0", int64()); diff --git a/native-sql-engine/cpp/src/tests/arrow_compute_test_wscg.cc b/native-sql-engine/cpp/src/tests/arrow_compute_test_wscg.cc index 0bdd5686a..c875ec263 100644 --- a/native-sql-engine/cpp/src/tests/arrow_compute_test_wscg.cc +++ b/native-sql-engine/cpp/src/tests/arrow_compute_test_wscg.cc @@ -3827,6 +3827,80 @@ TEST(TestArrowComputeWSCG, WSCGTestAggregate) { } } +TEST(TestArrowComputeWSCG, WSCGTestCountOnMutipleCols) { + auto f0 = field("f0", utf8()); + auto f1 = field("f1", utf8()); + auto f2 = field("f2", utf8()); + ; + auto f_unique = field("unique", utf8()); + auto f_count = field("count", int64()); + auto f_res = field("res", uint32()); + + auto arg0 = TreeExprBuilder::MakeField(f0); + auto arg1 = TreeExprBuilder::MakeField(f1); + auto arg2 = TreeExprBuilder::MakeField(f2); + + auto n_groupby = TreeExprBuilder::MakeFunction("action_groupby", {arg0}, uint32()); + auto n_count = TreeExprBuilder::MakeFunction("action_count", {arg1, arg2}, uint32()); + auto n_proj = + TreeExprBuilder::MakeFunction("aggregateExpressions", {arg0, arg1, arg2}, uint32()); + auto n_action = + TreeExprBuilder::MakeFunction("aggregateActions", {n_groupby, n_count}, uint32()); + auto n_result = TreeExprBuilder::MakeFunction( + "resultSchema", + {TreeExprBuilder::MakeField(f_unique), TreeExprBuilder::MakeField(f_count)}, + uint32()); + auto n_result_expr = TreeExprBuilder::MakeFunction( + "resultExpressions", + {TreeExprBuilder::MakeField(f_unique), TreeExprBuilder::MakeField(f_count)}, + uint32()); + auto n_aggr = TreeExprBuilder::MakeFunction( + "hashAggregateArrays", {n_proj, n_action, n_result, n_result_expr}, uint32()); + auto n_child = TreeExprBuilder::MakeFunction("child", {n_aggr}, uint32()); + auto n_wscg = TreeExprBuilder::MakeFunction("wholestagecodegen", {n_child}, uint32()); + auto aggr_expr = TreeExprBuilder::MakeExpression(n_wscg, f_res); + + std::vector> expr_vector = {aggr_expr}; + + auto sch = arrow::schema({f0, f1, f2}); + std::vector> ret_types = {f_unique, f_count}; + + /////////////////////// Create Expression Evaluator //////////////////// + std::shared_ptr expr; + arrow::compute::ExecContext ctx; + ASSERT_NOT_OK( + CreateCodeGenerator(ctx.memory_pool(), sch, expr_vector, ret_types, &expr, true)); + std::shared_ptr input_batch; + std::vector> output_batch_list; + + std::shared_ptr> aggr_result_iterator; + std::shared_ptr aggr_result_iterator_base; + ASSERT_NOT_OK(expr->finish(&aggr_result_iterator_base)); + aggr_result_iterator = std::dynamic_pointer_cast>( + aggr_result_iterator_base); + + std::vector input_data = {R"(["a", "a", "a", "x", "x"])", + R"(["b", "b", "b", "y", "q"])", + R"([null, "c", "d", "z", null])"}; + MakeInputBatch(input_data, sch, &input_batch); + ASSERT_NOT_OK(aggr_result_iterator->ProcessAndCacheOne(input_batch->columns())); + + std::vector input_data_2 = {R"(["b", "a", "b", "a", "x"])", + R"(["b", "b", "b", null, "q"])", + R"(["c", null, "d", "z", null])"}; + MakeInputBatch(input_data_2, sch, &input_batch); + ASSERT_NOT_OK(aggr_result_iterator->ProcessAndCacheOne(input_batch->columns())); + + std::shared_ptr expected_result; + std::shared_ptr result_batch; + std::vector expected_result_string = {R"(["a", "x", "b"])", "[2, 1, 2]"}; + MakeInputBatch(expected_result_string, arrow::schema(ret_types), &expected_result); + if (aggr_result_iterator->HasNext()) { + ASSERT_NOT_OK(aggr_result_iterator->Next(&result_batch)); + ASSERT_NOT_OK(Equals(*expected_result.get(), *result_batch.get())); + } +} + TEST(TestArrowComputeWSCG, WSCGTestGroupbyHashAggregateTwoKeys) { ////////////////////// prepare expr_vector /////////////////////// auto f0 = field("f0", int64()); diff --git a/native-sql-engine/cpp/src/third_party/sparsehash/sparse_hash_map.h b/native-sql-engine/cpp/src/third_party/sparsehash/sparse_hash_map.h index 885b5840a..e609c23d8 100644 --- a/native-sql-engine/cpp/src/third_party/sparsehash/sparse_hash_map.h +++ b/native-sql-engine/cpp/src/third_party/sparsehash/sparse_hash_map.h @@ -19,14 +19,19 @@ #include #include +#include + #include "sparsehash/dense_hash_map" using google::dense_hash_map; #define NOTFOUND -1 +template +class SparseHashMap {}; + template -class SparseHashMap { +class SparseHashMap::value>> { public: SparseHashMap() { dense_map_.set_empty_key(0); } SparseHashMap(arrow::MemoryPool* pool) { @@ -81,3 +86,78 @@ class SparseHashMap { bool null_index_set_ = false; int32_t null_index_; }; + +template +class SparseHashMap::value>> { + public: + SparseHashMap() { dense_map_.set_empty_key(0); } + SparseHashMap(arrow::MemoryPool* pool) { + dense_map_.set_empty_key(std::numeric_limits::max()); + } + template + arrow::Status GetOrInsert(const Scalar& value, Func1&& on_found, Func2&& on_not_found, + int32_t* out_memo_index) { + if (dense_map_.find(value) == dense_map_.end()) { + if (!nan_index_set_) { + auto index = size_++; + dense_map_[value] = index; + *out_memo_index = index; + on_not_found(index); + if (std::isnan(value)) { + nan_index_set_ = true; + nan_index_ = index; + } + } else { + if (std::isnan(value)) { + *out_memo_index = nan_index_; + on_found(nan_index_); + } else { + auto index = size_++; + dense_map_[value] = index; + *out_memo_index = index; + on_not_found(index); + } + } + } else { + auto index = dense_map_[value]; + *out_memo_index = index; + on_found(index); + } + return arrow::Status::OK(); + } + template + int32_t GetOrInsertNull(Func1&& on_found, Func2&& on_not_found) { + if (!null_index_set_) { + null_index_set_ = true; + null_index_ = size_++; + on_not_found(null_index_); + } else { + on_found(null_index_); + } + return null_index_; + } + int32_t Get(const Scalar& value) { + if (dense_map_.find(value) == dense_map_.end()) { + return NOTFOUND; + } else { + auto ret = dense_map_[value]; + return ret; + } + } + int32_t GetNull() { + if (!null_index_set_) { + return NOTFOUND; + } else { + auto ret = null_index_; + return ret; + } + } + + private: + dense_hash_map dense_map_; + int32_t size_ = 0; + bool null_index_set_ = false; + int32_t null_index_; + bool nan_index_set_ = false; + int32_t nan_index_; +};