From 2f08a567ae03f0f489a3f36875785d870806a63b Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 14 Oct 2022 14:00:12 +0200 Subject: [PATCH 1/5] enforce scalafmt --- .../sql/connect/SparkConnectPlugin.scala | 3 +- .../command/SparkConnectCommandPlanner.scala | 9 ++- .../spark/sql/connect/dsl/package.scala | 67 +++++++++++-------- .../connect/planner/SparkConnectPlanner.scala | 33 ++++----- .../connect/service/SparkConnectService.scala | 5 +- .../service/SparkConnectStreamHandler.scala | 1 - .../planner/SparkConnectPlannerSuite.scala | 14 ++-- .../planner/SparkConnectProtoSuite.scala | 18 ++--- dev/lint-scala | 10 +++ dev/scalafmt | 2 +- pom.xml | 3 +- 11 files changed, 96 insertions(+), 69 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala index 7ac33fa9324ac..4ecbfd123f0d2 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala @@ -39,7 +39,8 @@ class SparkConnectPlugin extends SparkPlugin { /** * Return the plugin's driver-side component. * - * @return The driver-side component. + * @return + * The driver-side component. */ override def driverPlugin(): DriverPlugin = new DriverPlugin { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index ebc5cfe5b55b7..36f256120c64b 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.types.StringType - @Unstable @Since("3.4.0") class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) { @@ -47,10 +46,10 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) /** * This is a helper function that registers a new Python function in the SparkSession. * - * Right now this function is very rudimentary and bare-bones just to showcase how it - * is possible to remotely serialize a Python function and execute it on the Spark cluster. - * If the Python version on the client and server diverge, the execution of the function that - * is serialized will most likely fail. + * Right now this function is very rudimentary and bare-bones just to showcase how it is + * possible to remotely serialize a Python function and execute it on the Spark cluster. If the + * Python version on the client and server diverge, the execution of the function that is + * serialized will most likely fail. * * @param cf */ diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 0db8ab9661074..29e3530922099 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -34,59 +34,70 @@ package object dsl { val identifier = CatalystSqlParser.parseMultipartIdentifier(s) def protoAttr: proto.Expression = - proto.Expression.newBuilder() + proto.Expression + .newBuilder() .setUnresolvedAttribute( - proto.Expression.UnresolvedAttribute.newBuilder() + proto.Expression.UnresolvedAttribute + .newBuilder() .addAllParts(identifier.asJava) .build()) .build() } implicit class DslExpression(val expr: proto.Expression) { - def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias( - proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build() + def as(alias: String): proto.Expression = proto.Expression + .newBuilder() + .setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)) + .build() - def < (other: proto.Expression): proto.Expression = - proto.Expression.newBuilder().setUnresolvedFunction( - proto.Expression.UnresolvedFunction.newBuilder() - .addParts("<") - .addArguments(expr) - .addArguments(other) - ).build() + def <(other: proto.Expression): proto.Expression = + proto.Expression + .newBuilder() + .setUnresolvedFunction( + proto.Expression.UnresolvedFunction + .newBuilder() + .addParts("<") + .addArguments(expr) + .addArguments(other)) + .build() } implicit def intToLiteral(i: Int): proto.Expression = - proto.Expression.newBuilder().setLiteral( - proto.Expression.Literal.newBuilder().setI32(i) - ).build() + proto.Expression + .newBuilder() + .setLiteral(proto.Expression.Literal.newBuilder().setI32(i)) + .build() } object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: proto.Relation) { def select(exprs: proto.Expression*): proto.Relation = { - proto.Relation.newBuilder().setProject( - proto.Project.newBuilder() - .setInput(logicalPlan) - .addAllExpressions(exprs.toIterable.asJava) - .build() - ).build() + proto.Relation + .newBuilder() + .setProject( + proto.Project + .newBuilder() + .setInput(logicalPlan) + .addAllExpressions(exprs.toIterable.asJava) + .build()) + .build() } def where(condition: proto.Expression): proto.Relation = { - proto.Relation.newBuilder() - .setFilter( - proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition) - ).build() + proto.Relation + .newBuilder() + .setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)) + .build() } - def join( otherPlan: proto.Relation, joinType: JoinType = JoinType.JOIN_TYPE_INNER, condition: Option[proto.Expression] = None): proto.Relation = { val relation = proto.Relation.newBuilder() val join = proto.Join.newBuilder() - join.setLeft(logicalPlan) + join + .setLeft(logicalPlan) .setRight(otherPlan) .setJoinType(joinType) if (condition.isDefined) { @@ -95,8 +106,8 @@ package object dsl { relation.setJoin(join).build() } - def groupBy( - groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = { + def groupBy(groupingExprs: proto.Expression*)( + aggregateExprs: proto.Expression*): proto.Relation = { val agg = proto.Aggregate.newBuilder() agg.setInput(logicalPlan) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5ad95a6b516ab..46072ec089e03 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -60,7 +60,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) - case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation) + case proto.Relation.RelTypeCase.LOCAL_RELATION => + transformLocalRelation(rel.getLocalRelation) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -109,10 +110,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // TODO: support the target field for *. val projection = if (rel.getExpressionsCount == 1 && rel.getExpressions(0).hasUnresolvedStar) { - Seq(UnresolvedStar(Option.empty)) - } else { - rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) - } + Seq(UnresolvedStar(Option.empty)) + } else { + rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) + } val project = logical.Project(projectList = projection.toSeq, child = baseRel) if (common.nonEmpty && common.get.getAlias.nonEmpty) { logical.SubqueryAlias(identifier = common.get.getAlias, child = project) @@ -141,7 +142,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * Transforms the protocol buffers literals into the appropriate Catalyst literal expression. * * TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp, - * Duration, Period. + * Duration, Period. * @param lit * @return * Expression @@ -167,9 +168,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // Days since UNIX epoch. case proto.Expression.Literal.LiteralTypeCase.DATE => expressions.Literal(lit.getDate, DateType) - case _ => throw InvalidPlanInput( - s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + - s"(${lit.getLiteralTypeCase.name})") + case _ => + throw InvalidPlanInput( + s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + + s"(${lit.getLiteralTypeCase.name})") } } @@ -188,7 +190,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * * TODO(SPARK-40546) We need to homogenize the function names for binary operators. * - * @param fun Proto representation of the function call. + * @param fun + * Proto representation of the function call. * @return */ private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = { @@ -278,11 +281,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { val groupingExprs = rel.getGroupingExpressionsList.asScala - .map(transformExpression) - .map { - case x @ UnresolvedAttribute(_) => x - case x => UnresolvedAlias(x) - } + .map(transformExpression) + .map { + case x @ UnresolvedAttribute(_) => x + case x => UnresolvedAlias(x) + } logical.Aggregate( child = transformRelation(rel.getInput), diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index b62917d94727e..7c494e39a69a0 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -49,7 +49,8 @@ import org.apache.spark.sql.execution.ExtendedMode @Unstable @Since("3.4.0") class SparkConnectService(debug: Boolean) - extends SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging { + extends SparkConnectServiceGrpc.SparkConnectServiceImplBase + with Logging { /** * This is the main entry method for Spark Connect and all calls to execute a plan. @@ -183,7 +184,6 @@ object SparkConnectService { /** * Starts the GRPC Serivce. - * */ def startGRPCService(): Unit = { val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true) @@ -212,4 +212,3 @@ object SparkConnectService { } } } - diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 52b807f63bb03..84a6efb2baabd 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.ArrowUtils - @Unstable @Since("3.4.0") class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 10e17f121f0e5..855450f3c582c 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -88,16 +88,20 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("Simple Project") { - val readWithTable = proto.Read.newBuilder() + val readWithTable = proto.Read + .newBuilder() .setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build()) .build() val project = - proto.Project.newBuilder() + proto.Project + .newBuilder() .setInput(proto.Relation.newBuilder().setRead(readWithTable).build()) .addExpressions( - proto.Expression.newBuilder() - .setUnresolvedStar(UnresolvedStar.newBuilder().build()).build() - ).build() + proto.Expression + .newBuilder() + .setUnresolvedStar(UnresolvedStar.newBuilder().build()) + .build()) + .build() val res = transform(proto.Relation.newBuilder.setProject(project).build()) assert(res !== null) assert(res.nodeName == "Project") diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 351cc70852a18..cfa7189660dd7 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -77,12 +77,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = None) comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) for ((t, y) <- Seq( - (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), - (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), - (JoinType.JOIN_TYPE_FULL_OUTER, FullOuter), - (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), - (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), - (JoinType.JOIN_TYPE_INNER, Inner))) { + (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), + (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), + (JoinType.JOIN_TYPE_FULL_OUTER, FullOuter), + (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), + (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), + (JoinType.JOIN_TYPE_INNER, Inner))) { val connectPlan3 = { import org.apache.spark.sql.connect.dsl.plans._ transform(connectTestRelation.join(connectTestRelation2, t)) @@ -115,10 +115,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val localRelationBuilder = proto.LocalRelation.newBuilder() for (attr <- attrs) { localRelationBuilder.addAttributes( - proto.Expression.QualifiedAttribute.newBuilder() + proto.Expression.QualifiedAttribute + .newBuilder() .setName(attr.name) - .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType)) - ) + .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))) } proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() } diff --git a/dev/lint-scala b/dev/lint-scala index 9c701ab463fe5..085eb3a76e116 100755 --- a/dev/lint-scala +++ b/dev/lint-scala @@ -21,3 +21,13 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" "$SCRIPT_DIR/scalastyle" "$1" + +# For Spark Connect, we actively enforce scalafmt and check that the produced diff is empty. +./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=true -pl connector/connect +if [[ $? -ne 0 ]]; then + echo "The scalafmt check failed on connector/connect." + echo "Before submitting your change, please make sure to format your code using the following command:" + echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=fase -Dscalafmt.validateOnly=false -pl connector/connect" + exit 1 +fi + diff --git a/dev/scalafmt b/dev/scalafmt index 56ff75fe7d383..3971f7a69e724 100755 --- a/dev/scalafmt +++ b/dev/scalafmt @@ -18,5 +18,5 @@ # VERSION="${@:-2.12}" -./build/mvn -Pscala-$VERSION scalafmt:format -Dscalafmt.skip=false +./build/mvn -Pscala-$VERSION scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false diff --git a/pom.xml b/pom.xml index ee28cbdb014b9..8fd7d4c6719bb 100644 --- a/pom.xml +++ b/pom.xml @@ -172,6 +172,7 @@ 4.7.1 true + true 1.9.13 2.13.4 2.13.4.1 @@ -3412,7 +3413,7 @@ mvn-scalafmt_${scala.binary.version} 1.1.1640084764.9f463a9 - ${scalafmt.skip} + ${scalafmt.validateOnly} ${scalafmt.skip} ${scalafmt.skip} dev/.scalafmt.conf From c9e725c060264e035c978d611cecb59a0628239a Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 17 Oct 2022 10:12:11 +0200 Subject: [PATCH 2/5] fix --- .../sql/connect/SparkConnectPlugin.scala | 3 +- .../command/SparkConnectCommandPlanner.scala | 11 +++---- .../spark/sql/connect/dsl/package.scala | 23 ------------- .../connect/planner/SparkConnectPlanner.scala | 33 +++++++++---------- .../connect/service/SparkConnectService.scala | 5 +-- .../service/SparkConnectStreamHandler.scala | 1 + .../planner/SparkConnectPlannerSuite.scala | 14 +++----- .../planner/SparkConnectProtoSuite.scala | 11 ------- 8 files changed, 29 insertions(+), 72 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala index 4ecbfd123f0d2..7ac33fa9324ac 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala @@ -39,8 +39,7 @@ class SparkConnectPlugin extends SparkPlugin { /** * Return the plugin's driver-side component. * - * @return - * The driver-side component. + * @return The driver-side component. */ override def driverPlugin(): DriverPlugin = new DriverPlugin { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index 016c9af0d3dfb..ae606a6a72edd 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -30,14 +30,11 @@ import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnec import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.types.StringType -<<<<<<< HEAD -======= final case class InvalidCommandInput( private val message: String = "", private val cause: Throwable = null) extends Exception(message, cause) ->>>>>>> origin/master @Unstable @Since("3.4.0") class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) { @@ -58,10 +55,10 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) /** * This is a helper function that registers a new Python function in the SparkSession. * - * Right now this function is very rudimentary and bare-bones just to showcase how it is - * possible to remotely serialize a Python function and execute it on the Spark cluster. If the - * Python version on the client and server diverge, the execution of the function that is - * serialized will most likely fail. + * Right now this function is very rudimentary and bare-bones just to showcase how it + * is possible to remotely serialize a Python function and execute it on the Spark cluster. + * If the Python version on the client and server diverge, the execution of the function that + * is serialized will most likely fail. * * @param cf */ diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 259e71c6fd907..f6553f7e90b64 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -69,8 +69,6 @@ package object dsl { .newBuilder() .setLiteral(proto.Expression.Literal.newBuilder().setI32(i)) .build() -<<<<<<< HEAD -======= } object commands { // scalastyle:ignore @@ -110,22 +108,13 @@ package object dsl { proto.Command.newBuilder().setWriteOperation(writeOp.build()).build() } } ->>>>>>> origin/master } object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: proto.Relation) { def select(exprs: proto.Expression*): proto.Relation = { -<<<<<<< HEAD - proto.Relation - .newBuilder() - .setProject( - proto.Project - .newBuilder() -======= proto.Relation.newBuilder().setProject( proto.Project.newBuilder() ->>>>>>> origin/master .setInput(logicalPlan) .addAllExpressions(exprs.toIterable.asJava) .build()) @@ -133,17 +122,10 @@ package object dsl { } def where(condition: proto.Expression): proto.Relation = { -<<<<<<< HEAD - proto.Relation - .newBuilder() - .setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)) - .build() -======= proto.Relation.newBuilder() .setFilter( proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition) ).build() ->>>>>>> origin/master } def join( @@ -162,10 +144,6 @@ package object dsl { relation.setJoin(join).build() } -<<<<<<< HEAD - def groupBy(groupingExprs: proto.Expression*)( - aggregateExprs: proto.Expression*): proto.Relation = { -======= def as(alias: String): proto.Relation = { proto.Relation.newBuilder(logicalPlan) .setCommon(proto.RelationCommon.newBuilder().setAlias(alias)) @@ -174,7 +152,6 @@ package object dsl { def groupBy( groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = { ->>>>>>> origin/master val agg = proto.Aggregate.newBuilder() agg.setInput(logicalPlan) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 46072ec089e03..5ad95a6b516ab 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -60,8 +60,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) - case proto.Relation.RelTypeCase.LOCAL_RELATION => - transformLocalRelation(rel.getLocalRelation) + case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -110,10 +109,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // TODO: support the target field for *. val projection = if (rel.getExpressionsCount == 1 && rel.getExpressions(0).hasUnresolvedStar) { - Seq(UnresolvedStar(Option.empty)) - } else { - rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) - } + Seq(UnresolvedStar(Option.empty)) + } else { + rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) + } val project = logical.Project(projectList = projection.toSeq, child = baseRel) if (common.nonEmpty && common.get.getAlias.nonEmpty) { logical.SubqueryAlias(identifier = common.get.getAlias, child = project) @@ -142,7 +141,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * Transforms the protocol buffers literals into the appropriate Catalyst literal expression. * * TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp, - * Duration, Period. + * Duration, Period. * @param lit * @return * Expression @@ -168,10 +167,9 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // Days since UNIX epoch. case proto.Expression.Literal.LiteralTypeCase.DATE => expressions.Literal(lit.getDate, DateType) - case _ => - throw InvalidPlanInput( - s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + - s"(${lit.getLiteralTypeCase.name})") + case _ => throw InvalidPlanInput( + s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + + s"(${lit.getLiteralTypeCase.name})") } } @@ -190,8 +188,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * * TODO(SPARK-40546) We need to homogenize the function names for binary operators. * - * @param fun - * Proto representation of the function call. + * @param fun Proto representation of the function call. * @return */ private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = { @@ -281,11 +278,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { val groupingExprs = rel.getGroupingExpressionsList.asScala - .map(transformExpression) - .map { - case x @ UnresolvedAttribute(_) => x - case x => UnresolvedAlias(x) - } + .map(transformExpression) + .map { + case x @ UnresolvedAttribute(_) => x + case x => UnresolvedAlias(x) + } logical.Aggregate( child = transformRelation(rel.getInput), diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 7c494e39a69a0..b62917d94727e 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -49,8 +49,7 @@ import org.apache.spark.sql.execution.ExtendedMode @Unstable @Since("3.4.0") class SparkConnectService(debug: Boolean) - extends SparkConnectServiceGrpc.SparkConnectServiceImplBase - with Logging { + extends SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging { /** * This is the main entry method for Spark Connect and all calls to execute a plan. @@ -184,6 +183,7 @@ object SparkConnectService { /** * Starts the GRPC Serivce. + * */ def startGRPCService(): Unit = { val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true) @@ -212,3 +212,4 @@ object SparkConnectService { } } } + diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 84a6efb2baabd..52b807f63bb03 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.ArrowUtils + @Unstable @Since("3.4.0") class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 67518f3bdb172..ba6995bfc5a82 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -108,20 +108,16 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("Simple Project") { - val readWithTable = proto.Read - .newBuilder() + val readWithTable = proto.Read.newBuilder() .setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build()) .build() val project = - proto.Project - .newBuilder() + proto.Project.newBuilder() .setInput(proto.Relation.newBuilder().setRead(readWithTable).build()) .addExpressions( - proto.Expression - .newBuilder() - .setUnresolvedStar(UnresolvedStar.newBuilder().build()) - .build()) - .build() + proto.Expression.newBuilder() + .setUnresolvedStar(UnresolvedStar.newBuilder().build()).build() + ).build() val res = transform(proto.Relation.newBuilder.setProject(project).build()) assert(res !== null) assert(res.nodeName == "Project") diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 48dc8c1e6a800..7395307903ebd 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -109,21 +109,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } -<<<<<<< HEAD - private def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { - val localRelationBuilder = proto.LocalRelation.newBuilder() - for (attr <- attrs) { - localRelationBuilder.addAttributes( - proto.Expression.QualifiedAttribute - .newBuilder() - .setName(attr.name) - .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))) -======= test("Test as(alias: String)") { val connectPlan = { import org.apache.spark.sql.connect.dsl.plans._ transform(connectTestRelation.as("target_table")) ->>>>>>> origin/master } val sparkPlan = sparkTestRelation.as("target_table") comparePlans(connectPlan.analyze, sparkPlan.analyze, false) From 50486ed68177c0a246d6e2fcd372cd2cac81e582 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 17 Oct 2022 11:48:19 +0200 Subject: [PATCH 3/5] triggering the check --- .../org/apache/spark/sql/connect/SparkConnectPlugin.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala index 7ac33fa9324ac..4ecbfd123f0d2 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala @@ -39,7 +39,8 @@ class SparkConnectPlugin extends SparkPlugin { /** * Return the plugin's driver-side component. * - * @return The driver-side component. + * @return + * The driver-side component. */ override def driverPlugin(): DriverPlugin = new DriverPlugin { From deb890b91b001442b8e66b432694412d6ec3bbd9 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 17 Oct 2022 14:02:24 +0200 Subject: [PATCH 4/5] fixing changed Only --- dev/lint-scala | 5 ++--- pom.xml | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/lint-scala b/dev/lint-scala index 085eb3a76e116..ad2be152cfad6 100755 --- a/dev/lint-scala +++ b/dev/lint-scala @@ -23,11 +23,10 @@ SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" "$SCRIPT_DIR/scalastyle" "$1" # For Spark Connect, we actively enforce scalafmt and check that the produced diff is empty. -./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=true -pl connector/connect +./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=true -Dscalafmt.changedOnly=false -pl connector/connect if [[ $? -ne 0 ]]; then echo "The scalafmt check failed on connector/connect." echo "Before submitting your change, please make sure to format your code using the following command:" - echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=fase -Dscalafmt.validateOnly=false -pl connector/connect" + echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=fase -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl connector/connect" exit 1 fi - diff --git a/pom.xml b/pom.xml index a542187b831fd..65dfcdb22340c 100644 --- a/pom.xml +++ b/pom.xml @@ -173,6 +173,7 @@ true true + true 1.9.13 2.13.4 2.13.4.1 @@ -3417,7 +3418,7 @@ ${scalafmt.skip} ${scalafmt.skip} dev/.scalafmt.conf - true + ${scalafmt.changedOnly} From 6bc068961bffd6cc98f133d94b29f430cbef6909 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 17 Oct 2022 15:00:48 +0200 Subject: [PATCH 5/5] format changes --- .../command/SparkConnectCommandPlanner.scala | 8 ++--- .../spark/sql/connect/dsl/package.scala | 22 ++++++++----- .../connect/planner/SparkConnectPlanner.scala | 33 ++++++++++--------- .../connect/service/SparkConnectService.scala | 5 ++- .../service/SparkConnectStreamHandler.scala | 1 - .../planner/SparkConnectPlannerSuite.scala | 14 +++++--- 6 files changed, 46 insertions(+), 37 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index ae606a6a72edd..47d421a0359bf 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -55,10 +55,10 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) /** * This is a helper function that registers a new Python function in the SparkSession. * - * Right now this function is very rudimentary and bare-bones just to showcase how it - * is possible to remotely serialize a Python function and execute it on the Spark cluster. - * If the Python version on the client and server diverge, the execution of the function that - * is serialized will most likely fail. + * Right now this function is very rudimentary and bare-bones just to showcase how it is + * possible to remotely serialize a Python function and execute it on the Spark cluster. If the + * Python version on the client and server diverge, the execution of the function that is + * serialized will most likely fail. * * @param cf */ diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index f6553f7e90b64..401624e9882af 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -113,8 +113,11 @@ package object dsl { object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: proto.Relation) { def select(exprs: proto.Expression*): proto.Relation = { - proto.Relation.newBuilder().setProject( - proto.Project.newBuilder() + proto.Relation + .newBuilder() + .setProject( + proto.Project + .newBuilder() .setInput(logicalPlan) .addAllExpressions(exprs.toIterable.asJava) .build()) @@ -122,10 +125,10 @@ package object dsl { } def where(condition: proto.Expression): proto.Relation = { - proto.Relation.newBuilder() - .setFilter( - proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition) - ).build() + proto.Relation + .newBuilder() + .setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)) + .build() } def join( @@ -145,13 +148,14 @@ package object dsl { } def as(alias: String): proto.Relation = { - proto.Relation.newBuilder(logicalPlan) + proto.Relation + .newBuilder(logicalPlan) .setCommon(proto.RelationCommon.newBuilder().setAlias(alias)) .build() } - def groupBy( - groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = { + def groupBy(groupingExprs: proto.Expression*)( + aggregateExprs: proto.Expression*): proto.Relation = { val agg = proto.Aggregate.newBuilder() agg.setInput(logicalPlan) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5ad95a6b516ab..46072ec089e03 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -60,7 +60,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) - case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation) + case proto.Relation.RelTypeCase.LOCAL_RELATION => + transformLocalRelation(rel.getLocalRelation) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -109,10 +110,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // TODO: support the target field for *. val projection = if (rel.getExpressionsCount == 1 && rel.getExpressions(0).hasUnresolvedStar) { - Seq(UnresolvedStar(Option.empty)) - } else { - rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) - } + Seq(UnresolvedStar(Option.empty)) + } else { + rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) + } val project = logical.Project(projectList = projection.toSeq, child = baseRel) if (common.nonEmpty && common.get.getAlias.nonEmpty) { logical.SubqueryAlias(identifier = common.get.getAlias, child = project) @@ -141,7 +142,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * Transforms the protocol buffers literals into the appropriate Catalyst literal expression. * * TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp, - * Duration, Period. + * Duration, Period. * @param lit * @return * Expression @@ -167,9 +168,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // Days since UNIX epoch. case proto.Expression.Literal.LiteralTypeCase.DATE => expressions.Literal(lit.getDate, DateType) - case _ => throw InvalidPlanInput( - s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + - s"(${lit.getLiteralTypeCase.name})") + case _ => + throw InvalidPlanInput( + s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + + s"(${lit.getLiteralTypeCase.name})") } } @@ -188,7 +190,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * * TODO(SPARK-40546) We need to homogenize the function names for binary operators. * - * @param fun Proto representation of the function call. + * @param fun + * Proto representation of the function call. * @return */ private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = { @@ -278,11 +281,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { val groupingExprs = rel.getGroupingExpressionsList.asScala - .map(transformExpression) - .map { - case x @ UnresolvedAttribute(_) => x - case x => UnresolvedAlias(x) - } + .map(transformExpression) + .map { + case x @ UnresolvedAttribute(_) => x + case x => UnresolvedAlias(x) + } logical.Aggregate( child = transformRelation(rel.getInput), diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index b62917d94727e..7c494e39a69a0 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -49,7 +49,8 @@ import org.apache.spark.sql.execution.ExtendedMode @Unstable @Since("3.4.0") class SparkConnectService(debug: Boolean) - extends SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging { + extends SparkConnectServiceGrpc.SparkConnectServiceImplBase + with Logging { /** * This is the main entry method for Spark Connect and all calls to execute a plan. @@ -183,7 +184,6 @@ object SparkConnectService { /** * Starts the GRPC Serivce. - * */ def startGRPCService(): Unit = { val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true) @@ -212,4 +212,3 @@ object SparkConnectService { } } } - diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 52b807f63bb03..84a6efb2baabd 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.ArrowUtils - @Unstable @Since("3.4.0") class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index ba6995bfc5a82..67518f3bdb172 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -108,16 +108,20 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("Simple Project") { - val readWithTable = proto.Read.newBuilder() + val readWithTable = proto.Read + .newBuilder() .setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build()) .build() val project = - proto.Project.newBuilder() + proto.Project + .newBuilder() .setInput(proto.Relation.newBuilder().setRead(readWithTable).build()) .addExpressions( - proto.Expression.newBuilder() - .setUnresolvedStar(UnresolvedStar.newBuilder().build()).build() - ).build() + proto.Expression + .newBuilder() + .setUnresolvedStar(UnresolvedStar.newBuilder().build()) + .build()) + .build() val res = transform(proto.Relation.newBuilder.setProject(project).build()) assert(res !== null) assert(res.nodeName == "Project")