Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-207] fix issues found from aggregate unit tests (#233)
Browse files Browse the repository at this point in the history
* fix incorrect input in Expand

* fix empty input for aggregate

* fix only result expressions

* fix empty aggregate expressions

* fix res attr not found issue

* refine

* fix count distinct with null

* fix groupby of NaN, -0.0 and 0.0

* fix count on mutiple cols with null in WSCG

* format code

* support normalize NaN and 0.0

* revert and update

* support normalize function in WSCG
  • Loading branch information
rui-mo authored Apr 30, 2021
1 parent 2591312 commit 1126320
Show file tree
Hide file tree
Showing 36 changed files with 635 additions and 184 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,12 @@ final void setLongs(int rowId, int count, byte[] src, int srcIndex) {
}
}

@Override
final void setDouble(int rowId, double value) {
long val = (long)value;
writer.setSafe(rowId, val);
}

@Override
void setLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,20 @@ case class ColumnarExpandExec(
private[this] val numGroups = columnarGroups.length
private[this] val resultStructType =
StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
private[this] var input_cb: ColumnarBatch = _

override def hasNext: Boolean = (-1 < idx && idx < numGroups) || iter.hasNext

override def next(): ColumnarBatch = {
if (idx <= 0) {
// in the initial (-1) or beginning(0) of a new input row, fetch the next input tuple
val input_cb = iter.next()
input = (0 until input_cb.numCols).toList
.map(input_cb.column(_).asInstanceOf[ArrowWritableColumnVector].getValueVector)
input_cb = iter.next()
numRows = input_cb.numRows
numInputBatches += 1
idx = 0
}
input = columnarGroups(idx).ordinalList
.map(input_cb.column(_).asInstanceOf[ArrowWritableColumnVector].getValueVector)

if (numRows == 0) {
idx = -1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ case class ColumnarHashAggregateExec(

buildCheck()

val onlyResultExpressions: Boolean =
if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty &&
child.output.isEmpty && resultExpressions.nonEmpty) true
else false

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
var eval_elapse: Long = 0
child.executeColumnar().mapPartitions { iter =>
Expand Down Expand Up @@ -138,10 +143,16 @@ case class ColumnarHashAggregateExec(
}

var numRowsInput = 0
var hasNextCount = 0
// now we can return this wholestagecodegen iter
val res = new Iterator[ColumnarBatch] {
var processed = false
/** Three special cases need to be handled in scala side:
* (1) count_literal (2) only result expressions (3) empty input
*/
var skip_native = false
var onlyResExpr = false
var emptyInput = false
var count_num_row = 0
def process: Unit = {
while (iter.hasNext) {
Expand All @@ -150,7 +161,9 @@ case class ColumnarHashAggregateExec(
if (cb.numRows != 0) {
numRowsInput += cb.numRows
val beforeEval = System.nanoTime()
if (hash_aggr_input_schema.getFields.size == 0) {
if (hash_aggr_input_schema.getFields.size == 0 &&
aggregateExpressions.nonEmpty &&
aggregateExpressions.head.aggregateFunction.isInstanceOf[Count]) {
// This is a special case used by only do count literal
count_num_row += cb.numRows
skip_native = true
Expand All @@ -166,9 +179,17 @@ case class ColumnarHashAggregateExec(
processed = true
}
override def hasNext: Boolean = {
hasNextCount += 1
if (!processed) process
if (skip_native) {
count_num_row > 0
} else if (onlyResultExpressions && hasNextCount == 1) {
onlyResExpr = true
true
} else if (!onlyResultExpressions && groupingExpressions.isEmpty &&
numRowsInput == 0 && hasNextCount == 1) {
emptyInput = true
true
} else {
nativeIterator.hasNext
}
Expand All @@ -179,28 +200,19 @@ case class ColumnarHashAggregateExec(
val beforeEval = System.nanoTime()
if (skip_native) {
// special handling for only count literal in this operator
val out_res = count_num_row
count_num_row = 0
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(0, resultStructType).toArray
resultColumnVectors.foreach { v =>
{
val numRows = v.dataType match {
case t: IntegerType =>
out_res.asInstanceOf[Number].intValue
case t: LongType =>
out_res.asInstanceOf[Number].longValue
}
v.put(0, numRows)
}
}
return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 1)
getResForCountLiteral
} else if (onlyResExpr) {
// special handling for only result expressions
getResForOnlyResExpr
} else if (emptyInput) {
// special handling for empty input batch
getResForEmptyInput
} else {
val output_rb = nativeIterator.next
if (output_rb == null) {
eval_elapse += System.nanoTime() - beforeEval
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(0, resultStructType).toArray
ArrowWritableColumnVector.allocateColumns(0, resultStructType)
return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
}
val outputNumRows = output_rb.getLength
Expand All @@ -212,6 +224,123 @@ case class ColumnarHashAggregateExec(
new ColumnarBatch(output.map(v => v.asInstanceOf[ColumnVector]), outputNumRows)
}
}
def getResForCountLiteral: ColumnarBatch = {
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(0, resultStructType)
if (count_num_row == 0) {
new ColumnarBatch(
resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
} else {
val out_res = count_num_row
count_num_row = 0
for (idx <- resultColumnVectors.indices) {
resultColumnVectors(idx).dataType match {
case t: IntegerType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].intValue)
case t: LongType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].longValue)
case t: DoubleType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].doubleValue())
case t: FloatType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].floatValue())
case t: ByteType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].byteValue())
case t: ShortType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].shortValue())
case t: StringType =>
val values = (out_res :: Nil).map(_.toByte).toArray
resultColumnVectors(idx)
.putBytes(0, 1, values, 0)
}
}
new ColumnarBatch(
resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 1)
}
}
def getResForOnlyResExpr: ColumnarBatch = {
// This function has limited support for only-result-expression case.
// Fake input for projection:
val inputColumnVectors =
ArrowWritableColumnVector.allocateColumns(0, resultStructType)
val valueVectors =
inputColumnVectors.map(columnVector => columnVector.getValueVector).toList
val projector = ColumnarProjection.create(child.output, resultExpressions)
val resultColumnVectorList = projector.evaluate(1, valueVectors)
new ColumnarBatch(
resultColumnVectorList.map(v => v.asInstanceOf[ColumnVector]).toArray,
1)
}
def getResForEmptyInput: ColumnarBatch = {
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(0, resultStructType)
if (aggregateExpressions.isEmpty) {
// To align with spark, in this case, one empty row is returned.
return new ColumnarBatch(
resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 1)
}
// If groupby is not required, for Final mode, a default value will be
// returned if input is empty.
var idx = 0
for (expr <- aggregateExpressions) {
expr.aggregateFunction match {
case Average(_) | StddevSamp(_) | Sum(_) | Max(_) | Min(_) =>
expr.mode match {
case Final =>
resultColumnVectors(idx).putNull(0)
idx += 1
case _ =>
}
case Count(_) =>
expr.mode match {
case Final =>
val out_res = 0
resultColumnVectors(idx).dataType match {
case t: IntegerType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].intValue)
case t: LongType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].longValue)
case t: DoubleType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].doubleValue())
case t: FloatType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].floatValue())
case t: ByteType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].byteValue())
case t: ShortType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].shortValue())
case t: StringType =>
val values = (out_res :: Nil).map(_.toByte).toArray
resultColumnVectors(idx)
.putBytes(0, 1, values, 0)
}
idx += 1
case _ =>
}
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
}
// will only put default value for Final mode
aggregateExpressions.head.mode match {
case Final =>
new ColumnarBatch(
resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 1)
case _ =>
new ColumnarBatch(
resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
}
}
}
SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => {
close
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,29 @@ class ColumnarHashAggregation(
aggregateAttr.toList
}

def existsAttrNotFound(allAggregateResultAttributes: List[Attribute]): Unit = {
if (resultExpressions.size == allAggregateResultAttributes.size) {
var resAllAttr = true
breakable {
for (expr <- resultExpressions) {
if (!expr.isInstanceOf[AttributeReference]) {
resAllAttr = false
break
}
}
}
if (resAllAttr) {
for (attr <- resultExpressions) {
if (allAggregateResultAttributes
.indexOf(attr.asInstanceOf[AttributeReference]) == -1) {
throw new IllegalArgumentException(
s"$attr in resultExpressions is not found in allAggregateResultAttributes!")
}
}
}
}
}

def prepareKernelFunction: TreeNode = {
// build gandiva projection here.
ColumnarPluginConfig.getConf
Expand Down Expand Up @@ -420,6 +443,11 @@ class ColumnarHashAggregation(
s"${attr.name}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
})

// If some Attributes in result expressions (contain attributes only) are not found
// in allAggregateResultAttributes, an exception will be thrown.
existsAttrNotFound(allAggregateResultAttributes)

val nativeFuncNodes = groupingNativeFuncNodes ::: aggrNativeFuncNodes

// 4. prepare after aggregate result expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,31 @@ class ColumnarMakeDecimal(
}
}

class ColumnarNormalizeNaNAndZero(child: Expression, original: NormalizeNaNAndZero)
extends NormalizeNaNAndZero(child: Expression)
with ColumnarExpression
with Logging {

buildCheck()

def buildCheck(): Unit = {
val supportedTypes = List(FloatType, DoubleType)
if (supportedTypes.indexOf(child.dataType) == -1) {
throw new UnsupportedOperationException(
s"${child.dataType} is not supported in ColumnarNormalizeNaNAndZero")
}
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val (child_node, childType): (TreeNode, ArrowType) =
child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val normalizeNode = TreeBuilder.makeFunction(
"normalize", Lists.newArrayList(child_node), childType)
(normalizeNode, childType)
}
}

object ColumnarUnaryOperator {

def create(child: Expression, original: Expression): Expression = original match {
Expand Down Expand Up @@ -652,8 +677,8 @@ object ColumnarUnaryOperator {
new ColumnarBitwiseNot(child, n)
case a: KnownFloatingPointNormalized =>
child
case a: NormalizeNaNAndZero =>
child
case n: NormalizeNaNAndZero =>
new ColumnarNormalizeNaNAndZero(child, n)
case a: PromotePrecision =>
child
case a: CheckOverflow =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
}
}

ignore("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") {
test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") {
withTable("t") {
withTempPath { path =>
Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath)
Expand Down Expand Up @@ -824,7 +824,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
}
}

ignore("SPARK-19993 subquery with cached underlying relation") {
test("SPARK-19993 subquery with cached underlying relation") {
withTempView("t1") {
Seq(1).toDF("c1").createOrReplaceTempView("t1")
spark.catalog.cacheTable("t1")
Expand Down Expand Up @@ -1029,7 +1029,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
SHUFFLE_HASH)
}

ignore("analyzes column statistics in cached query") {
test("analyzes column statistics in cached query") {
def query(): DataFrame = {
spark.range(100)
.selectExpr("id % 3 AS c0", "id % 5 AS c1", "2 AS c2")
Expand Down
Loading

0 comments on commit 1126320

Please sign in to comment.