Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,17 @@ class EquivalentExpressions {
*/
def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
// the children of CodegenFallback will not be used to generate code (call eval() instead)
if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) {
// There are some special expressions that we should not recurse into children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
val shouldRecurse = root match {
// TODO: some expressions implements `CodegenFallback` but can still do codegen,
// e.g. `CaseWhen`, we should support them.
case _: CodegenFallback => false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few expressions implement CodegenFallback but only use it in some corner cases

case _: ReferenceToExpressions => false
case _ => true
}
if (!skip && !addExpr(root) && shouldRecurse) {
root.children.foreach(addExprTree(_, ignoreLeaf))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,17 @@ class CodegenContext {
}

def declareMutableStates(): String = {
mutableStates.map { case (javaType, variableName, _) =>
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid of adding the same mutable state twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, we can transform the left deserializer in TypedAggregateExpression and create new LambdaVariable with different names. But I'm afraid there will be more similar problems so I go with this approach.

mutableStates.distinct.map { case (javaType, variableName, _) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need comment explaining the distinct

s"private $javaType $variableName;"
}.mkString("\n")
}

def initMutableStates(): String = {
mutableStates.map(_._3).mkString("\n")
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
mutableStates.distinct.map(_._3).mkString("\n")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import scala.language.postfixOps

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -72,6 +73,16 @@ object NameAgg extends Aggregator[AggData, String, String] {
}


object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] {
def zero: Seq[Int] = Nil
def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b
def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2
def finish(r: Seq[Int]): Seq[Int] = r
override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
}


class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
extends Aggregator[IN, OUT, OUT] {

Expand Down Expand Up @@ -212,4 +223,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil)
}

test("SPARK-14675: ClassFormatError when use Seq as Aggregator buffer type") {
val ds = Seq(AggData(1, "a"), AggData(2, "a")).toDS()

checkDataset(
ds.groupByKey(_.b).agg(SeqAgg.toColumn),
"a" -> Seq(1, 2)
)
}
}