Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.InternalRow

import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package org.apache.spark.sql.expressions

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn}

/**
* A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]]
Expand All @@ -32,55 +31,65 @@ import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
* case class Data(i: Int)
*
* val customSummer = new Aggregator[Data, Int, Int] {
* def zero = 0
* def reduce(b: Int, a: Data) = b + a.i
* def present(r: Int) = r
* def zero: Int = 0
* def reduce(b: Int, a: Data): Int = b + a.i
* def merge(b1: Int, b2: Int): Int = b1 + b2
* def present(r: Int): Int = r
* }.toColumn()
*
* val ds: Dataset[Data]
* val ds: Dataset[Data] = ...
* val aggregated = ds.select(customSummer)
* }}}
*
* Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
*
* @tparam A The input type for the aggregation.
* @tparam I The input type for the aggregation.
* @tparam B The type of the intermediate value of the reduction.
* @tparam C The type of the final result.
* @tparam O The type of the final output result.
*
* @since 1.6.0
*/
abstract class Aggregator[-A, B, C] extends Serializable {
abstract class Aggregator[-I, B, O] extends Serializable {

/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
/**
* A zero value for this aggregation. Should satisfy the property that any b + zero = b.
* @since 1.6.0
*/
def zero: B

/**
* Combine two values to produce a new value. For performance, the function may modify `b` and
* return it instead of constructing new object for b.
* @since 1.6.0
*/
def reduce(b: B, a: A): B
def reduce(b: B, a: I): B

/**
* Merge two intermediate values
* Merge two intermediate values.
* @since 1.6.0
*/
def merge(b1: B, b2: B): B

/**
* Transform the output of the reduction.
* @since 1.6.0
*/
def finish(reduction: B): C
def finish(reduction: B): O

/**
* Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]]
* operations.
* @since 1.6.0
*/
def toColumn(
implicit bEncoder: Encoder[B],
cEncoder: Encoder[C]): TypedColumn[A, C] = {
cEncoder: Encoder[O]): TypedColumn[I, O] = {
val expr =
new AggregateExpression(
TypedAggregateExpression(this),
Complete,
false)

new TypedColumn[A, C](expr, encoderFor[C])
new TypedColumn[I, O](expr, encoderFor[O])
}
}