diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index f751546cf705..8fa003c11681 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -23,10 +23,13 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.internal.Logging -import org.apache.spark.sql.classic.Dataset +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter} import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode} +import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode, CostMode, ExtendedMode, FormattedMode, SimpleMode} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -61,21 +64,23 @@ private[connect] class SparkConnectAnalyzeHandler( def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true) + def getDataFrameWithoutExecuting(rel: LogicalPlan): DataFrame = { + val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP) + new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema)) + } + request.getAnalyzeCase match { case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA => - val schema = Dataset - .ofRows(session, transformRelation(request.getSchema.getPlan.getRoot)) - .schema + val rel = transformRelation(request.getSchema.getPlan.getRoot) + val schema = getDataFrameWithoutExecuting(rel).schema builder.setSchema( proto.AnalyzePlanResponse.Schema .newBuilder() .setSchema(DataTypeProtoConverter.toConnectProtoType(schema)) .build()) - case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN => - val queryExecution = Dataset - .ofRows(session, transformRelation(request.getExplain.getPlan.getRoot)) - .queryExecution + val rel = transformRelation(request.getExplain.getPlan.getRoot) + val queryExecution = getDataFrameWithoutExecuting(rel).queryExecution val explainString = request.getExplain.getExplainMode match { case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE => queryExecution.explainString(SimpleMode) @@ -96,9 +101,8 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING => - val schema = Dataset - .ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot)) - .schema + val rel = transformRelation(request.getTreeString.getPlan.getRoot) + val schema = getDataFrameWithoutExecuting(rel).schema val treeString = if (request.getTreeString.hasLevel) { schema.treeString(request.getTreeString.getLevel) } else { @@ -111,9 +115,8 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL => - val isLocal = Dataset - .ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot)) - .isLocal + val rel = transformRelation(request.getIsLocal.getPlan.getRoot) + val isLocal = getDataFrameWithoutExecuting(rel).isLocal builder.setIsLocal( proto.AnalyzePlanResponse.IsLocal .newBuilder() @@ -121,9 +124,8 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING => - val isStreaming = Dataset - .ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot)) - .isStreaming + val rel = transformRelation(request.getIsStreaming.getPlan.getRoot) + val isStreaming = getDataFrameWithoutExecuting(rel).isStreaming builder.setIsStreaming( proto.AnalyzePlanResponse.IsStreaming .newBuilder() @@ -131,9 +133,8 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES => - val inputFiles = Dataset - .ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot)) - .inputFiles + val rel = transformRelation(request.getInputFiles.getPlan.getRoot) + val inputFiles = getDataFrameWithoutExecuting(rel).inputFiles builder.setInputFiles( proto.AnalyzePlanResponse.InputFiles .newBuilder() @@ -156,20 +157,18 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS => - val target = Dataset.ofRows( - session, - transformRelation(request.getSameSemantics.getTargetPlan.getRoot)) - val other = Dataset.ofRows( - session, - transformRelation(request.getSameSemantics.getOtherPlan.getRoot)) + val targetRel = transformRelation(request.getSameSemantics.getTargetPlan.getRoot) + val otherRel = transformRelation(request.getSameSemantics.getOtherPlan.getRoot) + val target = getDataFrameWithoutExecuting(targetRel) + val other = getDataFrameWithoutExecuting(otherRel) builder.setSameSemantics( proto.AnalyzePlanResponse.SameSemantics .newBuilder() .setResult(target.sameSemantics(other))) case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH => - val semanticHash = Dataset - .ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot)) + val rel = transformRelation(request.getSemanticHash.getPlan.getRoot) + val semanticHash = getDataFrameWithoutExecuting(rel) .semanticHash() builder.setSemanticHash( proto.AnalyzePlanResponse.SemanticHash @@ -177,8 +176,8 @@ private[connect] class SparkConnectAnalyzeHandler( .setResult(semanticHash)) case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST => - val target = Dataset - .ofRows(session, transformRelation(request.getPersist.getRelation)) + val rel = transformRelation(request.getPersist.getRelation) + val target = getDataFrameWithoutExecuting(rel) if (request.getPersist.hasStorageLevel) { target.persist( StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel)) @@ -188,8 +187,8 @@ private[connect] class SparkConnectAnalyzeHandler( builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build()) case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST => - val target = Dataset - .ofRows(session, transformRelation(request.getUnpersist.getRelation)) + val rel = transformRelation(request.getUnpersist.getRelation) + val target = getDataFrameWithoutExecuting(rel) if (request.getUnpersist.hasBlocking) { target.unpersist(request.getUnpersist.getBlocking) } else { @@ -198,8 +197,8 @@ private[connect] class SparkConnectAnalyzeHandler( builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build()) case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL => - val target = Dataset - .ofRows(session, transformRelation(request.getGetStorageLevel.getRelation)) + val rel = transformRelation(request.getGetStorageLevel.getRelation) + val target = getDataFrameWithoutExecuting(rel) val storageLevel = target.storageLevel builder.setGetStorageLevel( proto.AnalyzePlanResponse.GetStorageLevel diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 5e8872569165..c036f162b7c0 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -688,6 +688,42 @@ class SparkConnectServiceSuite } } + test("SPARK-51818: AnalyzePlanRequest does not execute the command") { + withTable("test") { + spark.sql(""" + | CREATE TABLE test (col1 INT, col2 STRING) + |""".stripMargin) + val sqlString = "DROP TABLE test" + val plan = proto.Plan + .newBuilder() + .setRoot( + proto.Relation + .newBuilder() + .setCommon(proto.RelationCommon.newBuilder().setPlanId(1)) + .setSql(proto.SQL.newBuilder().setQuery(sqlString).build()) + .build()) + .build() + + val handler = new SparkConnectAnalyzeHandler(null) + + val request = proto.AnalyzePlanRequest + .newBuilder() + .setExplain( + proto.AnalyzePlanRequest.Explain + .newBuilder() + .setPlan(plan) + .setExplainMode(proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_EXTENDED) + .build()) + .build() + + handler.process(request, sparkSessionHolder) + + // assert that table was not dropped + val tableExists = spark.catalog.tableExists("test") + assert(tableExists, "Table test should still exist after analyze request of DROP TABLE") + } + } + test("Test explain mode in analyze response") { withTable("test") { spark.sql("""