Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,19 @@ object AppendColumns {
encoderFor[U].namedExpressions,
child)
}

def apply[T : Encoder, U : Encoder](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you use T : Encoder, i.e. with spaces before and after : while...

func: T => U,
inputAttributes: Seq[Attribute],
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer, inputAttributes),
encoderFor[U].namedExpressions,
child)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,48 @@ class KeyValueGroupedDataset[K, V] private[sql](
dataAttributes,
groupingAttributes)

/**
* Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
* to the data. The grouping key is unchanged by this.
*
* {{{
* // Create values grouped by key from a Dataset[(K, V)]
* ds.groupByKey(_._1).mapValues(_._2) // Scala
* }}}
*
* @since 2.1.0
*/
def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = {
val withNewData = AppendColumns(func, dataAttributes, logicalPlan)
val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData)
val executed = sparkSession.sessionState.executePlan(projected)

new KeyValueGroupedDataset(
encoderFor[K],
encoderFor[W],
executed,
withNewData.newColumns,
groupingAttributes)
}

/**
* Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
* to the data. The grouping key is unchanged by this.
*
* {{{
* // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
* Dataset<Tuple2<String, Integer>> ds = ...;
* KeyValueGroupedDataset<String, Integer> grouped =
* ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); // Java 8
* }}}
*
* @since 2.1.0
*/
def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
implicit val uEnc = encoder
mapValues { (v: V) => func.call(v) }
}

/**
* Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping
* over the Dataset to extract the keys and then running a distinct operation on those.
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
"a", "30", "b", "3", "c", "1")
}

test("groupBy function, mapValues, flatMap") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just .toDS? (no brackets)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems the other tests all use toDS() so i will stick to that convention

val keyValue = ds.groupByKey(_._1).mapValues(_._2)
val agged = keyValue.mapGroups { case (g, iter) => (g, iter.sum) }
checkDataset(agged, ("a", 30), ("b", 3), ("c", 1))

val keyValue1 = ds.groupByKey(t => (t._1, "key")).mapValues(t => (t._2, "value"))
val agged1 = keyValue1.mapGroups { case (g, iter) => (g._1, iter.map(_._1).sum) }
checkDataset(agged, ("a", 30), ("b", 3), ("c", 1))
}

test("groupBy function, reduce") {
val ds = Seq("abc", "xyz", "hello").toDS()
val agged = ds.groupByKey(_.length).reduceGroups(_ + _)
Expand Down