From 67b99fc8a695db3d2a019ad8adbdd3eba9572e5f Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Wed, 22 Feb 2023 10:09:27 -0800 Subject: [PATCH] [SPARK-42468][CONNECT][FOLLOW-UP] Add agg in Dataset. --- .../scala/org/apache/spark/sql/Dataset.scala | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 3c876c054320c..de7cd93d684ab 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1058,6 +1058,61 @@ class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan: new RelationalGroupedDataset(toDF(), cols.map(_.expr)) } + /** + * (Scala-specific) Aggregates on the entire Dataset without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg("age" -> "max", "salary" -> "avg") + * ds.groupBy().agg("age" -> "max", "salary" -> "avg") + * }}} + * + * @group untypedrel + * @since 3.4.0 + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + groupBy().agg(aggExpr, aggExprs: _*) + } + + /** + * (Scala-specific) Aggregates on the entire Dataset without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(Map("age" -> "max", "salary" -> "avg")) + * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }}} + * + * @group untypedrel + * @since 3.4.0 + */ + def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * (Java-specific) Aggregates on the entire Dataset without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(Map("age" -> "max", "salary" -> "avg")) + * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }}} + * + * @group untypedrel + * @since 3.4.0 + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * Aggregates on the entire Dataset without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(max($"age"), avg($"salary")) + * ds.groupBy().agg(max($"age"), avg($"salary")) + * }}} + * + * @group untypedrel + * @since 3.4.0 + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs: _*) + /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,