Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ message CreateScalarFunction {
repeated string parts = 1;
FunctionLanguage language = 2;
bool temporary = 3;
repeated Type argument_types = 4;
Type return_type = 5;
repeated DataType argument_types = 4;
DataType return_type = 5;

// How the function body is defined:
oneof function_definition {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ message Expression {
// Timestamp in units of microseconds since the UNIX epoch.
int64 timestamp_tz = 27;
bytes uuid = 28;
Type null = 29; // a typed null literal
DataType null = 29; // a typed null literal
List list = 30;
Type.List empty_list = 31;
Type.Map empty_map = 32;
DataType.List empty_list = 31;
DataType.Map empty_map = 32;
UserDefined user_defined = 33;
}

Expand Down Expand Up @@ -164,5 +164,6 @@ message Expression {
// by the analyzer.
message QualifiedAttribute {
string name = 1;
DataType type = 2;
}
}
18 changes: 2 additions & 16 deletions connector/connect/src/main/protobuf/spark/connect/relations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,8 @@ message Fetch {
// Relation of type [[Aggregate]].
message Aggregate {
Relation input = 1;

// Grouping sets are used in rollups
repeated GroupingSet grouping_sets = 2;

// Measures
repeated Measure measures = 3;

message GroupingSet {
repeated Expression aggregate_expressions = 1;
}

message Measure {
AggregateFunction function = 1;
// Conditional filter for SUM(x FILTER WHERE x < 10)
Expression filter = 2;
}
repeated Expression grouping_expressions = 2;
repeated AggregateFunction result_expressions = 3;
Copy link
Contributor Author

@amaliujia amaliujia Oct 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan I still keep this as AggregateFunction. proto.Expression is a too general type for now.

connect does not have a NamedExpression. I will follow up on this to improve.

This PR is to improve the grouping_expressions anyway

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

followup improvement SGTM. I don't think we even need AggregateFunction. The SQL parser usually just generate UnresolvedFunction, and the analyzer will look up the function and figure out if it's scalar/aggregate/window/table value function.


message AggregateFunction {
string name = 1;
Expand Down
12 changes: 6 additions & 6 deletions connector/connect/src/main/protobuf/spark/connect/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ package spark.connect;
option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";

// This message describes the logical [[Type]] of something. It does not carry the value
// This message describes the logical [[DataType]] of something. It does not carry the value
// itself but only describes it.
message Type {
message DataType {
oneof kind {
Boolean bool = 1;
I8 i8 = 2;
Expand Down Expand Up @@ -168,20 +168,20 @@ message Type {
}

message Struct {
repeated Type types = 1;
repeated DataType types = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}

message List {
Type type = 1;
DataType DataType = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}

message Map {
Type key = 1;
Type value = 2;
DataType key = 1;
DataType value = 2;
uint32 type_variation_reference = 3;
Nullability nullability = 4;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ package object dsl {
}
relation.setJoin(join).build()
}

def groupBy(
groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = {
val agg = proto.Aggregate.newBuilder()
agg.setInput(logicalPlan)

for (groupingExpr <- groupingExprs) {
agg.addGroupingExpressions(groupingExpr)
}
// TODO: support aggregateExprs, which is blocked by supporting any builtin function
// resolution only by name in the analyzer.
proto.Relation.newBuilder().setAggregate(agg.build()).build()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.planner

import org.apache.spark.connect.proto
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}

/**
* This object offers methods to convert to/from connect proto to catalyst types.
*/
object DataTypeProtoConverter {
def toCatalystType(t: proto.DataType): DataType = {
t.getKindCase match {
case proto.DataType.KindCase.I32 => IntegerType
case proto.DataType.KindCase.STRING => StringType
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.")
}
}

def toConnectProtoType(t: DataType): proto.DataType = {
t match {
case IntegerType =>
proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build()
case StringType =>
proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build()
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
}

private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = {
// TODO: use data type from the proto.
AttributeReference(exp.getName, IntegerType)()
AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))()
}

private def transformReadRel(
Expand Down Expand Up @@ -271,11 +270,9 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {

private def transformAggregate(rel: proto.Aggregate): LogicalPlan = {
assert(rel.hasInput)
assert(rel.getGroupingSetsCount == 1, "Only one grouping set is supported")

val groupingSet = rel.getGroupingSetsList.asScala.take(1)
val ge = groupingSet
.flatMap(f => f.getAggregateExpressionsList.asScala)
val groupingExprs =
rel.getGroupingExpressionsList.asScala
.map(transformExpression)
.map {
case x @ UnresolvedAttribute(_) => x
Expand All @@ -284,18 +281,18 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {

logical.Aggregate(
child = transformRelation(rel.getInput),
groupingExpressions = ge.toSeq,
groupingExpressions = groupingExprs.toSeq,
aggregateExpressions =
(rel.getMeasuresList.asScala.map(transformAggregateExpression) ++ ge).toSeq)
rel.getResultExpressionsList.asScala.map(transformAggregateExpression).toSeq)
}

private def transformAggregateExpression(
exp: proto.Aggregate.Measure): expressions.NamedExpression = {
val fun = exp.getFunction.getName
exp: proto.Aggregate.AggregateFunction): expressions.NamedExpression = {
val fun = exp.getName
UnresolvedAlias(
UnresolvedFunction(
name = fun,
arguments = exp.getFunction.getArgumentsList.asScala.map(transformExpression).toSeq,
arguments = exp.getArgumentsList.asScala.map(transformExpression).toSeq,
isDistinct = false))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ trait SparkConnectPlanTest {
.build()
}

trait SparkConnectSessionTest {
protected var spark: SparkSession

}

/**
* This is a rudimentary test class for SparkConnect. The main goal of these basic tests is to
* ensure that the transformation from Proto to LogicalPlan works and that the right nodes are
Expand Down Expand Up @@ -222,16 +217,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {

val agg = proto.Aggregate.newBuilder
.setInput(readRel)
.addAllMeasures(
Seq(
proto.Aggregate.Measure.newBuilder
.setFunction(proto.Aggregate.AggregateFunction.newBuilder
.setName("sum")
.addArguments(unresolvedAttribute))
.build()).asJava)
.addGroupingSets(proto.Aggregate.GroupingSet.newBuilder
.addAggregateExpressions(unresolvedAttribute)
.build())
.addResultExpressions(
proto.Aggregate.AggregateFunction.newBuilder
.setName("sum")
.addArguments(unresolvedAttribute))
.addGroupingExpressions(unresolvedAttribute)
.build()

val res = transform(proto.Relation.newBuilder.setAggregate(agg).build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
*/
class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {

lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int))
lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, $"name".string))

lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, $"value".int))

lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int)
lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, $"name".string)

lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, $"value".int)

Expand Down Expand Up @@ -81,12 +81,23 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}
}

test("Aggregate with more than 1 grouping expressions") {
val connectPlan = {
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
transform(connectTestRelation.groupBy("id".protoAttr, "name".protoAttr)())
}
val sparkPlan = sparkTestRelation.groupBy($"id", $"name")()
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
}

private def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
// TODO: set data types for each local relation attribute one proto supports data type.
for (attr <- attrs) {
localRelationBuilder.addAttributes(
proto.Expression.QualifiedAttribute.newBuilder().setName(attr.name).build()
proto.Expression.QualifiedAttribute.newBuilder()
.setName(attr.name)
.setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))
)
}
proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build()
Expand Down