diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 87dce7146b305..0266120d0ed47 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1070,6 +1070,35 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) } + /** + * Groups the Dataset using the specified columns, so that we can run aggregation on them. See + * [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names (i.e. + * cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * ds.groupBy("department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * ds.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group untypedrel + * @since 3.4.0 + */ + @scala.annotation.varargs + def groupBy(col1: String, cols: String*): RelationalGroupedDataset = { + val colNames: Seq[String] = col1 +: cols + new RelationalGroupedDataset( + toDF(), + colNames.map(colName => Column(colName).expr), + proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + } + /** * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate @@ -1990,14 +2019,14 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val viewName: String, replace: Boolean, global: Boolean): Unit = { - val command = session.newCommand { builder => + val command = sparkSession.newCommand { builder => builder.getCreateDataframeViewBuilder .setInput(plan.getRoot) .setName(viewName) .setIsGlobal(global) .setReplace(replace) } - session.execute(command) + sparkSession.execute(command) } /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 48d5f0cb409d0..6a789b1494f21 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -1892,6 +1892,12 @@ class PlanGenerationTestSuite "a" -> "count") } + test("groupby agg string") { + simple + .groupBy("id", "b") + .agg("a" -> "max", "a" -> "count") + } + test("groupby agg columns") { simple .groupBy(Column("id")) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg_string.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg_string.explain new file mode 100644 index 0000000000000..1c2b2f68c64c6 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg_string.explain @@ -0,0 +1,2 @@ +Aggregate [id#0L, b#0], [id#0L, b#0, max(a#0) AS max(a)#0, count(a#0) AS count(a)#0L] ++- LocalRelation , [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json new file mode 100644 index 0000000000000..26320d404835f --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json @@ -0,0 +1,46 @@ +{ + "common": { + "planId": "1" + }, + "aggregate": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "groupType": "GROUP_TYPE_GROUPBY", + "groupingExpressions": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "id" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }], + "aggregateExpressions": [{ + "unresolvedFunction": { + "functionName": "max", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a", + "planId": "0" + } + }] + } + }, { + "unresolvedFunction": { + "functionName": "count", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a", + "planId": "0" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.proto.bin new file mode 100644 index 0000000000000..818146f7f6935 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.proto.bin differ