Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.util.Iterator;

/**
* Base interface for a map function used in GroupedDataset's map function.
* Base interface for a map function used in GroupedDataset's mapGroup function.
*/
public interface MapGroupFunction<K, V, R> extends Serializable {
R call(K key, Iterator<V> values) throws Exception;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import org.apache.spark.sql.types._
*
* Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
* and reuse internal buffers to improve performance.
*
* @since 1.6.0
*/
trait Encoder[T] extends Serializable {

Expand All @@ -42,6 +44,8 @@ trait Encoder[T] extends Serializable {

/**
* Methods for creating encoders.
*
* @since 1.6.0
*/
object Encoders {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._
/**
* Type-inference utilities for POJOs and Java collections.
*/
private [sql] object JavaTypeInference {
object JavaTypeInference {

private val iterableType = TypeToken.of(classOf[JIterable[_]])
private val mapType = TypeToken.of(classOf[JMap[_, _]])
Expand All @@ -53,7 +53,6 @@ private [sql] object JavaTypeInference {
* @return (SQL data type, nullable)
*/
private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
Expand Down
2 changes: 2 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ private[sql] object Column {
* @tparam T The input type expected for this expression. Can be `Any` if the expression is type
* checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
* @tparam U The output type of this column.
*
* @since 1.6.0
*/
class TypedColumn[-T, U](
expr: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ private[sql] object DataFrame {
* @groupname action Actions
* @since 1.3.0
*/
// TODO: Improve documentation.
@Experimental
class DataFrame private[sql](
@transient val sqlContext: SQLContext,
Expand Down
132 changes: 106 additions & 26 deletions sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Ou
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.Aggregator

/**
* :: Experimental ::
Expand All @@ -36,11 +37,13 @@ import org.apache.spark.sql.execution.QueryExecution
* making this change to the class hierarchy would break some function signatures. As such, this
* class should be considered a preview of the final API. Changes will be made to the interface
* after Spark 1.6.
*
* @since 1.6.0
*/
@Experimental
class GroupedDataset[K, T] private[sql](
class GroupedDataset[K, V] private[sql](
kEncoder: Encoder[K],
tEncoder: Encoder[T],
tEncoder: Encoder[V],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this variable be renamed to vEncoder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea that's a good catch. will fix it in my next pr.

val queryExecution: QueryExecution,
private val dataAttributes: Seq[Attribute],
private val groupingAttributes: Seq[Attribute]) extends Serializable {
Expand All @@ -67,8 +70,10 @@ class GroupedDataset[K, T] private[sql](
/**
* Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
* type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]].
*
* @since 1.6.0
*/
def asKey[L : Encoder]: GroupedDataset[L, T] =
def keyAs[L : Encoder]: GroupedDataset[L, V] =
new GroupedDataset(
encoderFor[L],
unresolvedTEncoder,
Expand All @@ -78,6 +83,8 @@ class GroupedDataset[K, T] private[sql](

/**
* Returns a [[Dataset]] that contains each unique key.
*
* @since 1.6.0
*/
def keys: Dataset[K] = {
new Dataset[K](
Expand All @@ -92,12 +99,18 @@ class GroupedDataset[K, T] private[sql](
* function can return an iterator containing elements of an arbitrary type which will be returned
* as a new [[Dataset]].
*
* This function does not support partial aggregation, and as a result requires shuffling all
* the data in the [[Dataset]]. If an application intends to perform an aggregation over each
* key, it is best to use the reduce function or an [[Aggregator]].
*
* Internally, the implementation will spill to disk if any given group is too large to fit into
* memory. However, users must take care to avoid materializing the whole iterator for a group
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
* constraints of their cluster.
*
* @since 1.6.0
*/
def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = {
def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
new Dataset[U](
sqlContext,
MapGroups(
Expand All @@ -108,41 +121,88 @@ class GroupedDataset[K, T] private[sql](
logicalPlan))
}

def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
/**
* Applies the given function to each group of data. For each unique group, the function will
* be passed the group key and an iterator that contains all of the elements in the group. The
* function can return an iterator containing elements of an arbitrary type which will be returned
* as a new [[Dataset]].
*
* This function does not support partial aggregation, and as a result requires shuffling all
* the data in the [[Dataset]]. If an application intends to perform an aggregation over each
* key, it is best to use the reduce function or an [[Aggregator]].
*
* Internally, the implementation will spill to disk if any given group is too large to fit into
* memory. However, users must take care to avoid materializing the whole iterator for a group
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
* constraints of their cluster.
*
* @since 1.6.0
*/
def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder)
}

/**
* Applies the given function to each group of data. For each unique group, the function will
* be passed the group key and an iterator that contains all of the elements in the group. The
* function can return an element of arbitrary type which will be returned as a new [[Dataset]].
*
* This function does not support partial aggregation, and as a result requires shuffling all
* the data in the [[Dataset]]. If an application intends to perform an aggregation over each
* key, it is best to use the reduce function or an [[Aggregator]].
*
* Internally, the implementation will spill to disk if any given group is too large to fit into
* memory. However, users must take care to avoid materializing the whole iterator for a group
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
* constraints of their cluster.
*
* @since 1.6.0
*/
def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
flatMap(func)
def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
val func = (key: K, it: Iterator[V]) => Iterator(f(key, it))
flatMapGroup(func)
}

def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
map((key, data) => f.call(key, data.asJava))(encoder)
/**
* Applies the given function to each group of data. For each unique group, the function will
* be passed the group key and an iterator that contains all of the elements in the group. The
* function can return an element of arbitrary type which will be returned as a new [[Dataset]].
*
* This function does not support partial aggregation, and as a result requires shuffling all
* the data in the [[Dataset]]. If an application intends to perform an aggregation over each
* key, it is best to use the reduce function or an [[Aggregator]].
*
* Internally, the implementation will spill to disk if any given group is too large to fit into
* memory. However, users must take care to avoid materializing the whole iterator for a group
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
* constraints of their cluster.
*
* @since 1.6.0
*/
def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
mapGroup((key, data) => f.call(key, data.asJava))(encoder)
}

/**
* Reduces the elements of each group of data using the specified binary function.
* The given function must be commutative and associative or the result may be non-deterministic.
*
* @since 1.6.0
*/
def reduce(f: (T, T) => T): Dataset[(K, T)] = {
val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f))
def reduce(f: (V, V) => V): Dataset[(K, V)] = {
val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))

implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder)
flatMap(func)
flatMapGroup(func)
}

def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = {
/**
* Reduces the elements of each group of data using the specified binary function.
* The given function must be commutative and associative or the result may be non-deterministic.
*
* @since 1.6.0
*/
def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = {
reduce(f.call _)
}

Expand Down Expand Up @@ -185,41 +245,51 @@ class GroupedDataset[K, T] private[sql](
/**
* Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
* and the result of computing this aggregation over all elements in the group.
*
* @since 1.6.0
*/
def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] =
def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]

/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*
* @since 1.6.0
*/
def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] =
def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] =
aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]

/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*
* @since 1.6.0
*/
def agg[U1, U2, U3](
col1: TypedColumn[T, U1],
col2: TypedColumn[T, U2],
col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] =
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]

/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*
* @since 1.6.0
*/
def agg[U1, U2, U3, U4](
col1: TypedColumn[T, U1],
col2: TypedColumn[T, U2],
col3: TypedColumn[T, U3],
col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] =
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
col3: TypedColumn[V, U3],
col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]

/**
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present
* for that key.
*
* @since 1.6.0
*/
def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]))

Expand All @@ -228,10 +298,12 @@ class GroupedDataset[K, T] private[sql](
* be passed the grouping key and 2 iterators containing all elements in the group from
* [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
* arbitrary type which will be returned as a new [[Dataset]].
*
* @since 1.6.0
*/
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
implicit def uEnc: Encoder[U] = other.unresolvedTEncoder
new Dataset[R](
sqlContext,
Expand All @@ -243,9 +315,17 @@ class GroupedDataset[K, T] private[sql](
other.logicalPlan))
}

/**
* Applies the given function to each cogrouped data. For each unique group, the function will
* be passed the grouping key and 2 iterators containing all elements in the group from
* [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
* arbitrary type which will be returned as a new [[Dataset]].
*
* @since 1.6.0
*/
def cogroup[U, R](
other: GroupedDataset[K, U],
f: CoGroupFunction[K, T, U, R],
f: CoGroupFunction[K, V, U, R],
encoder: Encoder[R]): Dataset[R] = {
cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ public Integer call(String v) throws Exception {
}
}, Encoders.INT());

Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() {
Dataset<String> mapped = grouped.mapGroup(new MapGroupFunction<Integer, String, String>() {
@Override
public String call(Integer key, Iterator<String> values) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
Expand All @@ -183,7 +183,7 @@ public String call(Integer key, Iterator<String> values) throws Exception {

Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());

Dataset<String> flatMapped = grouped.flatMap(
Dataset<String> flatMapped = grouped.flatMapGroup(
new FlatMapGroupFunction<Integer, String, String>() {
@Override
public Iterable<String> call(Integer key, Iterator<String> values) throws Exception {
Expand Down Expand Up @@ -247,9 +247,9 @@ public void testGroupByColumn() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
GroupedDataset<Integer, String> grouped =
ds.groupBy(length(col("value"))).asKey(Encoders.INT());
ds.groupBy(length(col("value"))).keyAs(Encoders.INT());

Dataset<String> mapped = grouped.map(
Dataset<String> mapped = grouped.mapGroup(
new MapGroupFunction<Integer, String, String>() {
@Override
public String call(Integer key, Iterator<String> data) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("groupBy function, map") {
val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
val grouped = ds.groupBy(_ % 2)
val agged = grouped.map { case (g, iter) =>
val agged = grouped.mapGroup { case (g, iter) =>
val name = if (g == 0) "even" else "odd"
(name, iter.size)
}
Expand All @@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("groupBy function, flatMap") {
val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
val grouped = ds.groupBy(_.length)
val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) }
val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) }

checkAnswer(
agged,
Expand Down
Loading