Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.classic.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode}
import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode, CostMode, ExtendedMode, FormattedMode, SimpleMode}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -61,21 +64,23 @@ private[connect] class SparkConnectAnalyzeHandler(

def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true)

def getDataFrameWithoutExecuting(rel: LogicalPlan): DataFrame = {
val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP)
new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema))
}

request.getAnalyzeCase match {
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
val schema = Dataset
.ofRows(session, transformRelation(request.getSchema.getPlan.getRoot))
.schema
val rel = transformRelation(request.getSchema.getPlan.getRoot)
val schema = getDataFrameWithoutExecuting(rel).schema
builder.setSchema(
proto.AnalyzePlanResponse.Schema
.newBuilder()
.setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
val queryExecution = Dataset
.ofRows(session, transformRelation(request.getExplain.getPlan.getRoot))
.queryExecution
val rel = transformRelation(request.getExplain.getPlan.getRoot)
val queryExecution = getDataFrameWithoutExecuting(rel).queryExecution
val explainString = request.getExplain.getExplainMode match {
case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
queryExecution.explainString(SimpleMode)
Expand All @@ -96,9 +101,8 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
val schema = Dataset
.ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot))
.schema
val rel = transformRelation(request.getTreeString.getPlan.getRoot)
val schema = getDataFrameWithoutExecuting(rel).schema
val treeString = if (request.getTreeString.hasLevel) {
schema.treeString(request.getTreeString.getLevel)
} else {
Expand All @@ -111,29 +115,26 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
val isLocal = Dataset
.ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot))
.isLocal
val rel = transformRelation(request.getIsLocal.getPlan.getRoot)
val isLocal = getDataFrameWithoutExecuting(rel).isLocal
builder.setIsLocal(
proto.AnalyzePlanResponse.IsLocal
.newBuilder()
.setIsLocal(isLocal)
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
val isStreaming = Dataset
.ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot))
.isStreaming
val rel = transformRelation(request.getIsStreaming.getPlan.getRoot)
val isStreaming = getDataFrameWithoutExecuting(rel).isStreaming
builder.setIsStreaming(
proto.AnalyzePlanResponse.IsStreaming
.newBuilder()
.setIsStreaming(isStreaming)
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
val inputFiles = Dataset
.ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot))
.inputFiles
val rel = transformRelation(request.getInputFiles.getPlan.getRoot)
val inputFiles = getDataFrameWithoutExecuting(rel).inputFiles
builder.setInputFiles(
proto.AnalyzePlanResponse.InputFiles
.newBuilder()
Expand All @@ -156,29 +157,27 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
val target = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
val other = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
val targetRel = transformRelation(request.getSameSemantics.getTargetPlan.getRoot)
val otherRel = transformRelation(request.getSameSemantics.getOtherPlan.getRoot)
val target = getDataFrameWithoutExecuting(targetRel)
val other = getDataFrameWithoutExecuting(otherRel)
builder.setSameSemantics(
proto.AnalyzePlanResponse.SameSemantics
.newBuilder()
.setResult(target.sameSemantics(other)))

case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
val semanticHash = Dataset
.ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot))
val rel = transformRelation(request.getSemanticHash.getPlan.getRoot)
val semanticHash = getDataFrameWithoutExecuting(rel)
.semanticHash()
builder.setSemanticHash(
proto.AnalyzePlanResponse.SemanticHash
.newBuilder()
.setResult(semanticHash))

case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getPersist.getRelation))
val rel = transformRelation(request.getPersist.getRelation)
val target = getDataFrameWithoutExecuting(rel)
if (request.getPersist.hasStorageLevel) {
target.persist(
StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel))
Expand All @@ -188,8 +187,8 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getUnpersist.getRelation))
val rel = transformRelation(request.getUnpersist.getRelation)
val target = getDataFrameWithoutExecuting(rel)
if (request.getUnpersist.hasBlocking) {
target.unpersist(request.getUnpersist.getBlocking)
} else {
Expand All @@ -198,8 +197,8 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
val target = Dataset
.ofRows(session, transformRelation(request.getGetStorageLevel.getRelation))
val rel = transformRelation(request.getGetStorageLevel.getRelation)
val target = getDataFrameWithoutExecuting(rel)
val storageLevel = target.storageLevel
builder.setGetStorageLevel(
proto.AnalyzePlanResponse.GetStorageLevel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,42 @@ class SparkConnectServiceSuite
}
}

test("SPARK-51818: AnalyzePlanRequest does not execute the command") {
withTable("test") {
spark.sql("""
| CREATE TABLE test (col1 INT, col2 STRING)
|""".stripMargin)
val sqlString = "DROP TABLE test"
val plan = proto.Plan
.newBuilder()
.setRoot(
proto.Relation
.newBuilder()
.setCommon(proto.RelationCommon.newBuilder().setPlanId(1))
.setSql(proto.SQL.newBuilder().setQuery(sqlString).build())
.build())
.build()

val handler = new SparkConnectAnalyzeHandler(null)

val request = proto.AnalyzePlanRequest
.newBuilder()
.setExplain(
proto.AnalyzePlanRequest.Explain
.newBuilder()
.setPlan(plan)
.setExplainMode(proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_EXTENDED)
.build())
.build()

handler.process(request, sparkSessionHolder)

// assert that table was not dropped
val tableExists = spark.catalog.tableExists("test")
assert(tableExists, "Table test should still exist after analyze request of DROP TABLE")
}
}

test("Test explain mode in analyze response") {
withTable("test") {
spark.sql("""
Expand Down