diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index 05c0a5977226..0a91c6b95502 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -26,9 +26,12 @@ import javax.ws.rs.core.UriBuilder import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{JobArtifactSet, SparkContext, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.util.ArtifactUtils import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL import org.apache.spark.sql.connect.service.SessionHolder @@ -39,45 +42,29 @@ import org.apache.spark.util.Utils * The Artifact Manager for the [[SparkConnectService]]. * * This class handles the storage of artifacts as well as preparing the artifacts for use. - * Currently, jars and classfile artifacts undergo additional processing: - * - Jars and pyfiles are automatically added to the underlying [[SparkContext]] and are - * accessible by all users of the cluster. - * - Class files are moved into a common directory that is shared among all users of the - * cluster. Note: Under a multi-user setup, class file conflicts may occur between user - * classes as the class file directory is shared. + * + * Artifacts belonging to different [[SparkSession]]s are segregated and isolated from each other + * with the help of the `sessionUUID`. + * + * Jars and classfile artifacts are stored under "jars" and "classes" sub-directories respectively + * while other types of artifacts are stored under the root directory for that particular + * [[SparkSession]]. + * + * @param sessionHolder + * The object used to hold the Spark Connect session state. */ -class SparkConnectArtifactManager private[connect] { +class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging { + import SparkConnectArtifactManager._ - // The base directory where all artifacts are stored. - // Note: If a REPL is attached to the cluster, class file artifacts are stored in the - // REPL's output directory. - private[connect] lazy val artifactRootPath = SparkContext.getActive match { - case Some(sc) => - sc.sparkConnectArtifactDirectory.toPath - case None => - throw new RuntimeException("SparkContext is uninitialized!") - } - private[connect] lazy val artifactRootURI = { - val fileServer = SparkEnv.get.rpcEnv.fileServer - fileServer.addDirectory("artifacts", artifactRootPath.toFile) - } - - // The base directory where all class files are stored. - // Note: If a REPL is attached to the cluster, we piggyback on the existing REPL output - // directory to store class file artifacts. - private[connect] lazy val classArtifactDir = SparkEnv.get.conf - .getOption("spark.repl.class.outputDir") - .map(p => Paths.get(p)) - .getOrElse(ArtifactUtils.concatenatePaths(artifactRootPath, "classes")) - - private[connect] lazy val classArtifactUri: String = - SparkEnv.get.conf.getOption("spark.repl.class.uri") match { - case Some(uri) => uri - case None => - throw new RuntimeException("Class artifact URI had not been initialised in SparkContext!") - } + private val sessionUUID = sessionHolder.session.sessionUUID + // The base directory/URI where all artifacts are stored for this `sessionUUID`. + val (artifactPath, artifactURI): (Path, String) = + getArtifactDirectoryAndUriForSession(sessionHolder) + // The base directory/URI where all class file artifacts are stored for this `sessionUUID`. + val (classDir, classURI): (Path, String) = getClassfileDirectoryAndUriForSession(sessionHolder) private val jarsList = new CopyOnWriteArrayList[Path] + private val jarsURI = new CopyOnWriteArrayList[String] private val pythonIncludeList = new CopyOnWriteArrayList[String] /** @@ -98,13 +85,11 @@ class SparkConnectArtifactManager private[connect] { * Add and prepare a staged artifact (i.e an artifact that has been rebuilt locally from bytes * over the wire) for use. * - * @param session * @param remoteRelativePath * @param serverLocalStagingPath * @param fragment */ private[connect] def addArtifact( - sessionHolder: SessionHolder, remoteRelativePath: Path, serverLocalStagingPath: Path, fragment: Option[String]): Unit = { @@ -127,27 +112,28 @@ class SparkConnectArtifactManager private[connect] { updater.save() }(catchBlock = { tmpFile.delete() }) } else if (remoteRelativePath.startsWith(s"classes${File.separator}")) { - // Move class files to common location (shared among all users) + // Move class files to the right directory. val target = ArtifactUtils.concatenatePaths( - classArtifactDir, + classDir, remoteRelativePath.toString.stripPrefix(s"classes${File.separator}")) Files.createDirectories(target.getParent) // Allow overwriting class files to capture updates to classes. + // This is required because the client currently sends all the class files in each class file + // transfer. Files.move(serverLocalStagingPath, target, StandardCopyOption.REPLACE_EXISTING) } else { - val target = ArtifactUtils.concatenatePaths(artifactRootPath, remoteRelativePath) + val target = ArtifactUtils.concatenatePaths(artifactPath, remoteRelativePath) Files.createDirectories(target.getParent) - // Disallow overwriting jars because spark doesn't support removing jars that were - // previously added, + // Disallow overwriting non-classfile artifacts if (Files.exists(target)) { throw new RuntimeException( - s"Duplicate file: $remoteRelativePath. Files cannot be overwritten.") + s"Duplicate Artifact: $remoteRelativePath. " + + "Artifacts cannot be overwritten.") } Files.move(serverLocalStagingPath, target) if (remoteRelativePath.startsWith(s"jars${File.separator}")) { - // Adding Jars to the underlying spark context (visible to all users) - sessionHolder.session.sessionState.resourceLoader.addJar(target.toString) jarsList.add(target) + jarsURI.add(artifactURI + "/" + target.toString) } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) { sessionHolder.session.sparkContext.addFile(target.toString) val stringRemotePath = remoteRelativePath.toString @@ -165,8 +151,47 @@ class SparkConnectArtifactManager private[connect] { } } + /** + * Returns a [[JobArtifactSet]] pointing towards the session-specific jars and class files. + */ + def jobArtifactSet: JobArtifactSet = { + val builder = Map.newBuilder[String, Long] + jarsURI.forEach { jar => + builder += jar -> 0 + } + + new JobArtifactSet( + uuid = Option(sessionUUID), + replClassDirUri = Option(classURI), + jars = builder.result(), + files = Map.empty, + archives = Map.empty) + } + + /** + * Returns a [[ClassLoader]] for session-specific jar/class file resources. + */ + def classloader: ClassLoader = { + val urls = jarsList.asScala.map(_.toUri.toURL) :+ classDir.toUri.toURL + new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + } + + /** + * Cleans up all resources specific to this `sessionHolder`. + */ + private[connect] def cleanUpResources(): Unit = { + logDebug( + s"Cleaning up resources for session with userId: ${sessionHolder.userId} and " + + s"sessionId: ${sessionHolder.sessionId}") + // Clean up cached relations + val blockManager = sessionHolder.session.sparkContext.env.blockManager + blockManager.removeCache(sessionHolder.userId, sessionHolder.sessionId) + + // Clean up artifacts folder + FileUtils.deleteDirectory(artifactRootPath.toFile) + } + private[connect] def uploadArtifactToFs( - sessionHolder: SessionHolder, remoteRelativePath: Path, serverLocalStagingPath: Path): Unit = { val hadoopConf = sessionHolder.session.sparkContext.hadoopConfiguration @@ -200,48 +225,80 @@ class SparkConnectArtifactManager private[connect] { } } -object SparkConnectArtifactManager { +object SparkConnectArtifactManager extends Logging { val forwardToFSPrefix = "forward_to_fs" - private var _activeArtifactManager: SparkConnectArtifactManager = _ + private var currentArtifactRootUri: String = _ + private var lastKnownSparkContextInstance: SparkContext = _ - /** - * Obtain the active artifact manager or create a new artifact manager. - * - * @return - */ - def getOrCreateArtifactManager: SparkConnectArtifactManager = { - if (_activeArtifactManager == null) { - _activeArtifactManager = new SparkConnectArtifactManager - } - _activeArtifactManager + private val ARTIFACT_DIRECTORY_PREFIX = "artifacts" + + // The base directory where all artifacts are stored. + private[spark] lazy val artifactRootPath = { + Utils.createTempDir(ARTIFACT_DIRECTORY_PREFIX).toPath } - private lazy val artifactManager = getOrCreateArtifactManager + private[spark] def getArtifactDirectoryAndUriForSession(session: SparkSession): (Path, String) = + ( + ArtifactUtils.concatenatePaths(artifactRootPath, session.sessionUUID), + s"$artifactRootURI/${session.sessionUUID}") + + private[spark] def getArtifactDirectoryAndUriForSession( + sessionHolder: SessionHolder): (Path, String) = + getArtifactDirectoryAndUriForSession(sessionHolder.session) + + private[spark] def getClassfileDirectoryAndUriForSession( + session: SparkSession): (Path, String) = { + val (artDir, artUri) = getArtifactDirectoryAndUriForSession(session) + (ArtifactUtils.concatenatePaths(artDir, "classes"), s"$artUri/classes/") + } + + private[spark] def getClassfileDirectoryAndUriForSession( + sessionHolder: SessionHolder): (Path, String) = + getClassfileDirectoryAndUriForSession(sessionHolder.session) /** - * Obtain a classloader that contains jar and classfile artifacts on the classpath. + * Updates the URI for the artifact directory. * - * @return + * This is required if the SparkContext is restarted. + * + * Note: This logic is solely to handle testing where a [[SparkContext]] may be restarted + * several times in a single JVM lifetime. In a general Spark cluster, the [[SparkContext]] is + * not expected to be restarted at any point in time. */ - def classLoaderWithArtifacts: ClassLoader = { - val urls = artifactManager.getSparkConnectAddedJars :+ - artifactManager.classArtifactDir.toUri.toURL - new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + private def refreshArtifactUri(sc: SparkContext): Unit = synchronized { + // If a competing thread had updated the URI, we do not need to refresh the URI again. + if (sc eq lastKnownSparkContextInstance) { + return + } + val oldArtifactUri = currentArtifactRootUri + currentArtifactRootUri = SparkEnv.get.rpcEnv.fileServer + .addDirectoryIfAbsent(ARTIFACT_DIRECTORY_PREFIX, artifactRootPath.toFile) + lastKnownSparkContextInstance = sc + logDebug(s"Artifact URI updated from $oldArtifactUri to $currentArtifactRootUri") } /** - * Run a segment of code utilising a classloader that contains jar and classfile artifacts on - * the classpath. + * Checks if the URI for the artifact directory needs to be updated. This is required in cases + * where SparkContext is restarted as the old URI would no longer be valid. * - * @param thunk - * @tparam T - * @return + * Note: This logic is solely to handle testing where a [[SparkContext]] may be restarted + * several times in a single JVM lifetime. In a general Spark cluster, the [[SparkContext]] is + * not expected to be restarted at any point in time. */ - def withArtifactClassLoader[T](thunk: => T): T = { - Utils.withContextClassLoader(classLoaderWithArtifacts) { - thunk + private def updateUriIfRequired(): Unit = { + SparkContext.getActive.foreach { sc => + if (lastKnownSparkContextInstance == null || (sc ne lastKnownSparkContextInstance)) { + logDebug("Refreshing artifact URI due to SparkContext (re)initialisation!") + refreshArtifactUri(sc) + } } } + + private[connect] def artifactRootURI: String = { + updateUriIfRequired() + require(currentArtifactRootUri != null) + currentArtifactRootUri + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 856d0f06ba43..dfdba6b3c58e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -25,6 +25,8 @@ import com.google.protobuf.{Any => ProtoAny, ByteString} import io.grpc.{Context, Status, StatusRuntimeException} import io.grpc.stub.StreamObserver import org.apache.commons.lang3.exception.ExceptionUtils +import org.json4s._ +import org.json4s.jackson.JsonMethods.parse import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} @@ -50,7 +52,6 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} -import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry @@ -88,6 +89,15 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) + // SparkConnectPlanner is used per request. + private lazy val pythonIncludes = { + implicit val formats = DefaultFormats + parse(session.conf.get("spark.connect.pythonUDF.includes", "[]")) + .extract[Array[String]] + .toList + .asJava + } + // The root of the query plan is a relation and we apply the transformations to it. def transformRelation(rel: proto.Relation): LogicalPlan = { val plan = rel.getRelTypeCase match { @@ -1418,13 +1428,13 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = { Utils.deserialize[UdfPacket]( fun.getScalarScalaUdf.getPayload.toByteArray, - SparkConnectArtifactManager.classLoaderWithArtifacts) + Utils.getContextOrSparkClassLoader) } private def unpackForeachWriter(fun: proto.ScalarScalaUDF): ForeachWriterPacket = { Utils.deserialize[ForeachWriterPacket]( fun.getPayload.toByteArray, - SparkConnectArtifactManager.classLoaderWithArtifacts) + Utils.getContextOrSparkClassLoader) } /** @@ -1481,8 +1491,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { command = fun.getCommand.toByteArray, // Empty environment variables envVars = Maps.newHashMap(), - pythonIncludes = - SparkConnectArtifactManager.getOrCreateArtifactManager.getSparkConnectPythonIncludes.asJava, + pythonIncludes = pythonIncludes, pythonExec = pythonExec, pythonVer = fun.getPythonVer, // Empty broadcast variables diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index cc2327abb5cd..004322097790 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -17,15 +17,22 @@ package org.apache.spark.sql.connect.service +import java.nio.file.Path import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.JobArtifactSet import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager +import org.apache.spark.util.Utils /** * Object used to hold the Spark Connect session state. @@ -60,6 +67,102 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } } } + + private[connect] lazy val artifactManager = new SparkConnectArtifactManager(this) + + /** + * Add an artifact to this SparkConnect session. + * + * @param remoteRelativePath + * @param serverLocalStagingPath + * @param fragment + */ + private[connect] def addArtifact( + remoteRelativePath: Path, + serverLocalStagingPath: Path, + fragment: Option[String]): Unit = { + artifactManager.addArtifact(remoteRelativePath, serverLocalStagingPath, fragment) + } + + /** + * A [[JobArtifactSet]] for this SparkConnect session. + */ + def connectJobArtifactSet: JobArtifactSet = artifactManager.jobArtifactSet + + /** + * A [[ClassLoader]] for jar/class file resources specific to this SparkConnect session. + */ + def classloader: ClassLoader = artifactManager.classloader + + /** + * Expire this session and trigger state cleanup mechanisms. + */ + private[connect] def expireSession(): Unit = { + logDebug(s"Expiring session with userId: $userId and sessionId: $sessionId") + artifactManager.cleanUpResources() + } + + /** + * Execute a block of code using this session's classloader. + * @param f + * @tparam T + */ + def withContext[T](f: => T): T = { + // Needed for deserializing and evaluating the UDF on the driver + Utils.withContextClassLoader(classloader) { + // Needed for propagating the dependencies to the executors. + JobArtifactSet.withActive(connectJobArtifactSet) { + f + } + } + } + + /** + * Set the session-based Python paths to include in Python UDF. + * @param f + * @tparam T + */ + def withSessionBasedPythonPaths[T](f: => T): T = { + try { + session.conf.set( + "spark.connect.pythonUDF.includes", + compact(render(artifactManager.getSparkConnectPythonIncludes))) + f + } finally { + session.conf.unset("spark.connect.pythonUDF.includes") + } + } + + /** + * Execute a block of code with this session as the active SparkConnect session. + * @param f + * @tparam T + */ + def withSession[T](f: SparkSession => T): T = { + withSessionBasedPythonPaths { + withContext { + session.withActive { + f(session) + } + } + } + } + + /** + * Execute a block of code using the session from this [[SessionHolder]] as the active + * SparkConnect session. + * @param f + * @tparam T + */ + def withSessionHolder[T](f: SessionHolder => T): T = { + withSessionBasedPythonPaths { + withContext { + session.withActive { + f(this) + } + } + } + } } object SessionHolder { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala index 179ff1b3ec9c..e424331e7617 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala @@ -49,8 +49,6 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr // several [[AddArtifactsRequest]]s. private var chunkedArtifact: StagedChunkedArtifact = _ private var holder: SessionHolder = _ - private def artifactManager: SparkConnectArtifactManager = - SparkConnectArtifactManager.getOrCreateArtifactManager override def onNext(req: AddArtifactsRequest): Unit = { if (this.holder == null) { @@ -87,7 +85,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr } protected def addStagedArtifactToArtifactManager(artifact: StagedArtifact): Unit = { - artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath, artifact.fragment) + require(holder != null) + holder.addArtifact(artifact.path, artifact.stagedPath, artifact.fragment) } /** @@ -103,7 +102,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr if (artifact.getCrcStatus.contains(true)) { if (artifact.path.startsWith( SparkConnectArtifactManager.forwardToFSPrefix + File.separator)) { - artifactManager.uploadArtifactToFs(holder, artifact.path, artifact.stagedPath) + holder.artifactManager.uploadArtifactToFs(artifact.path, artifact.stagedPath) } else { addStagedArtifactToArtifactManager(artifact) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 947f6ebbebeb..5c069bfaf5d0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -24,7 +24,6 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.Dataset -import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter} import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode} @@ -33,16 +32,18 @@ private[connect] class SparkConnectAnalyzeHandler( responseObserver: StreamObserver[proto.AnalyzePlanResponse]) extends Logging { - def handle(request: proto.AnalyzePlanRequest): Unit = - SparkConnectArtifactManager.withArtifactClassLoader { - val sessionHolder = SparkConnectService - .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId) - sessionHolder.session.withActive { - val response = process(request, sessionHolder) - responseObserver.onNext(response) - responseObserver.onCompleted() - } + def handle(request: proto.AnalyzePlanRequest): Unit = { + val sessionHolder = SparkConnectService.getOrCreateIsolatedSession( + request.getUserContext.getUserId, + request.getSessionId) + // `withSession` ensures that session-specific artifacts (such as JARs and class files) are + // available during processing (such as deserialization). + sessionHolder.withSessionHolder { sessionHolder => + val response = process(request, sessionHolder) + responseObserver.onNext(response) + responseObserver.onCompleted() } + } def process( request: proto.AnalyzePlanRequest, diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index c1647fd85a05..0f90bccaac8f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -301,9 +301,7 @@ object SparkConnectService { private class RemoveSessionListener extends RemovalListener[SessionCacheKey, SessionHolder] { override def onRemoval( notification: RemovalNotification[SessionCacheKey, SessionHolder]): Unit = { - val SessionHolder(userId, sessionId, session) = notification.getValue - val blockManager = session.sparkContext.env.blockManager - blockManager.removeCache(userId, sessionId) + notification.getValue.expireSession() } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 70204f2913da..892ddaa9b44c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -30,7 +30,6 @@ import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoUtils} import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE @@ -45,12 +44,14 @@ import org.apache.spark.util.{ThreadUtils, Utils} class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse]) extends Logging { - def handle(v: ExecutePlanRequest): Unit = SparkConnectArtifactManager.withArtifactClassLoader { - val sessionHolder = SparkConnectService - .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId) - val session = sessionHolder.session - - session.withActive { + def handle(v: ExecutePlanRequest): Unit = { + val sessionHolder = + SparkConnectService + .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId) + // `withSession` ensures that session-specific artifacts (such as JARs and class files) are + // available during processing. + sessionHolder.withSession { session => + // Add debug information to the query execution so that the jobs are traceable. val debugString = try { Utils.redact( diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index b6da38fc5726..42ab8ca18f6e 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -21,7 +21,7 @@ import java.nio.file.{Files, Paths} import org.apache.commons.io.FileUtils -import org.apache.spark.SparkConf +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.sql.connect.ResourceHelper import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService} import org.apache.spark.sql.functions.col @@ -39,21 +39,30 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { } private val artifactPath = commonResourcePath.resolve("artifact-tests") - private lazy val artifactManager = SparkConnectArtifactManager.getOrCreateArtifactManager - private def sessionHolder(): SessionHolder = { SessionHolder("test", spark.sessionUUID, spark) } + private lazy val artifactManager = new SparkConnectArtifactManager(sessionHolder()) + + private def sessionUUID: String = spark.sessionUUID + + override def afterEach(): Unit = { + artifactManager.cleanUpResources() + super.afterEach() + } test("Jar artifacts are added to spark session") { val copyDir = Utils.createTempDir().toPath FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile) val stagingPath = copyDir.resolve("smallJar.jar") val remotePath = Paths.get("jars/smallJar.jar") - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) + artifactManager.addArtifact(remotePath, stagingPath, None) - val jarList = spark.sparkContext.listJars() - assert(jarList.exists(_.contains(remotePath.toString))) + val expectedPath = SparkConnectArtifactManager.artifactRootPath + .resolve(s"$sessionUUID/jars/smallJar.jar") + assert(expectedPath.toFile.exists()) + val jars = artifactManager.jobArtifactSet.jars + assert(jars.exists(_._1.contains(remotePath.toString))) } test("Class artifacts are added to the correct directory.") { @@ -62,10 +71,11 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("smallClassFile.class") val remotePath = Paths.get("classes/smallClassFile.class") assert(stagingPath.toFile.exists()) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) + artifactManager.addArtifact(remotePath, stagingPath, None) - val classFileDirectory = artifactManager.classArtifactDir - val movedClassFile = classFileDirectory.resolve("smallClassFile.class").toFile + val movedClassFile = SparkConnectArtifactManager.artifactRootPath + .resolve(s"$sessionUUID/classes/smallClassFile.class") + .toFile assert(movedClassFile.exists()) } @@ -75,13 +85,14 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) + artifactManager.addArtifact(remotePath, stagingPath, None) - val classFileDirectory = artifactManager.classArtifactDir - val movedClassFile = classFileDirectory.resolve("Hello.class").toFile + val movedClassFile = SparkConnectArtifactManager.artifactRootPath + .resolve(s"$sessionUUID/classes/Hello.class") + .toFile assert(movedClassFile.exists()) - val classLoader = SparkConnectArtifactManager.classLoaderWithArtifacts + val classLoader = artifactManager.classloader val instance = classLoader .loadClass("Hello") @@ -98,22 +109,26 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) - val classFileDirectory = artifactManager.classArtifactDir - val movedClassFile = classFileDirectory.resolve("Hello.class").toFile - assert(movedClassFile.exists()) + val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", "session") + sessionHolder.addArtifact(remotePath, stagingPath, None) - val classLoader = SparkConnectArtifactManager.classLoaderWithArtifacts + val movedClassFile = SparkConnectArtifactManager.artifactRootPath + .resolve(s"${sessionHolder.session.sessionUUID}/classes/Hello.class") + .toFile + assert(movedClassFile.exists()) + val classLoader = sessionHolder.classloader val instance = classLoader .loadClass("Hello") .getDeclaredConstructor(classOf[String]) .newInstance("Talon") .asInstanceOf[String => String] val udf = org.apache.spark.sql.functions.udf(instance) - val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session - session.range(10).select(udf(col("id").cast("string"))).collect() + + sessionHolder.withSession { session => + session.range(10).select(udf(col("id").cast("string"))).collect() + } } test("add a cache artifact to the Block Manager") { @@ -125,7 +140,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val blockManager = spark.sparkContext.env.blockManager val blockId = CacheId(session.userId, session.sessionId, "abc") try { - artifactManager.addArtifact(session, remotePath, stagingPath, None) + artifactManager.addArtifact(remotePath, stagingPath, None) val bytes = blockManager.getLocalBytes(blockId) assert(bytes.isDefined) val readback = new String(bytes.get.toByteBuffer().array(), StandardCharsets.UTF_8) @@ -141,9 +156,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { withTempPath { path => val stagingPath = path.toPath Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) - val session = sessionHolder() val remotePath = Paths.get("pyfiles/abc.zip") - artifactManager.addArtifact(session, remotePath, stagingPath, None) + artifactManager.addArtifact(remotePath, stagingPath, None) assert(artifactManager.getSparkConnectPythonIncludes == Seq("abc.zip")) } } @@ -155,10 +169,113 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("smallClassFile.class") val remotePath = Paths.get("forward_to_fs", destFSDir.toString, "smallClassFileCopied.class") assert(stagingPath.toFile.exists()) - artifactManager.uploadArtifactToFs(sessionHolder, remotePath, stagingPath) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) + artifactManager.uploadArtifactToFs(remotePath, stagingPath) + artifactManager.addArtifact(remotePath, stagingPath, None) val copiedClassFile = Paths.get(destFSDir.toString, "smallClassFileCopied.class").toFile assert(copiedClassFile.exists()) } + + test("Removal of resources") { + withTempPath { path => + // Setup cache + val stagingPath = path.toPath + Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) + val remotePath = Paths.get("cache/abc") + val session = sessionHolder() + val blockManager = spark.sparkContext.env.blockManager + val blockId = CacheId(session.userId, session.sessionId, "abc") + // Setup artifact dir + val copyDir = Utils.createTempDir().toPath + FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile) + try { + artifactManager.addArtifact(remotePath, stagingPath, None) + val stagingPathFile = copyDir.resolve("smallClassFile.class") + val remotePathFile = Paths.get("classes/smallClassFile.class") + artifactManager.addArtifact(remotePathFile, stagingPathFile, None) + + // Verify resources exist + val bytes = blockManager.getLocalBytes(blockId) + assert(bytes.isDefined) + blockManager.releaseLock(blockId) + val expectedPath = SparkConnectArtifactManager.artifactRootPath + .resolve(s"$sessionUUID/classes/smallClassFile.class") + assert(expectedPath.toFile.exists()) + + // Remove resources + artifactManager.cleanUpResources() + + assert(!blockManager.getLocalBytes(blockId).isDefined) + assert(!expectedPath.toFile.exists()) + } finally { + try { + blockManager.releaseLock(blockId) + } catch { + case _: SparkException => + case throwable: Throwable => throw throwable + } finally { + FileUtils.deleteDirectory(copyDir.toFile) + blockManager.removeCache(session.userId, session.sessionId) + } + } + } + } + + test("Classloaders for spark sessions are isolated") { + val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", "session1") + val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", "session2") + + def addHelloClass(holder: SessionHolder): Unit = { + val copyDir = Utils.createTempDir().toPath + FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile) + val stagingPath = copyDir.resolve("Hello.class") + val remotePath = Paths.get("classes/Hello.class") + assert(stagingPath.toFile.exists()) + holder.addArtifact(remotePath, stagingPath, None) + } + + // Add the classfile only for the first user + addHelloClass(holder1) + + val classLoader1 = holder1.classloader + val instance1 = classLoader1 + .loadClass("Hello") + .getDeclaredConstructor(classOf[String]) + .newInstance("Talon") + .asInstanceOf[String => String] + val udf1 = org.apache.spark.sql.functions.udf(instance1) + + holder1.withSession { session => + session.range(10).select(udf1(col("id").cast("string"))).collect() + } + + assertThrows[ClassNotFoundException] { + val classLoader2 = holder2.classloader + val instance2 = classLoader2 + .loadClass("Hello") + .getDeclaredConstructor(classOf[String]) + .newInstance("Talon") + .asInstanceOf[String => String] + } + } +} + +class ArtifactUriSuite extends SparkFunSuite with LocalSparkContext { + + private def createSparkContext(): Unit = { + resetSparkContext() + sc = new SparkContext("local[4]", "test", new SparkConf()) + + } + override def beforeEach(): Unit = { + super.beforeEach() + createSparkContext() + } + + test("Artifact URI is reset when SparkContext is restarted") { + val oldUri = SparkConnectArtifactManager.artifactRootURI + createSparkContext() + val newUri = SparkConnectArtifactManager.artifactRootURI + assert(newUri != oldUri) + } } diff --git a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala index d87c25c0b7c3..3e402b3b3302 100644 --- a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala +++ b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.Serializable +import java.util.Objects /** * Artifact set for a job. @@ -41,7 +42,7 @@ class JobArtifactSet( def withActive[T](f: => T): T = JobArtifactSet.withActive(this)(f) override def hashCode(): Int = { - Seq(uuid, replClassDirUri, jars.toSeq, files.toSeq, archives.toSeq).hashCode() + Objects.hash(uuid, replClassDirUri, jars.toSeq, files.toSeq, archives.toSeq) } override def equals(obj: Any): Boolean = { @@ -76,17 +77,17 @@ object JobArtifactSet { archives = sc.addedArchives.toMap) } + private lazy val emptyJobArtifactSet = new JobArtifactSet( + None, + None, + Map.empty, + Map.empty, + Map.empty) + /** * Empty artifact set for use in tests. */ - private[spark] def apply(): JobArtifactSet = { - new JobArtifactSet( - None, - None, - Map.empty, - Map.empty, - Map.empty) - } + private[spark] def apply(): JobArtifactSet = emptyJobArtifactSet /** * Used for testing. Returns artifacts from [[SparkContext]] if one exists or otherwise, an diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c32c674d64e0..51161a31e7d7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.io._ import java.net.URI -import java.nio.file.Files import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} @@ -42,7 +41,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.logging.log4j.Level -import org.apache.spark.annotation.{DeveloperApi, Experimental, Private} +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.executor.{Executor, ExecutorMetrics, ExecutorMetricsSource} @@ -387,13 +386,6 @@ class SparkContext(config: SparkConf) extends Logging { Utils.setLogLevel(Level.toLevel(upperCased)) } - /** - * :: Private :: - * Returns the directory that stores artifacts transferred through Spark Connect. - */ - @Private - private[spark] lazy val sparkConnectArtifactDirectory: File = Utils.createTempDir("artifacts") - try { _conf = config.clone() _conf.get(SPARK_LOG_LEVEL).foreach { level => @@ -479,18 +471,7 @@ class SparkContext(config: SparkConf) extends Logging { SparkEnv.set(_env) // If running the REPL, register the repl's output dir with the file server. - _conf.getOption("spark.repl.class.outputDir").orElse { - if (_conf.get(PLUGINS).contains("org.apache.spark.sql.connect.SparkConnectPlugin")) { - // For Spark Connect, we piggyback on the existing REPL integration to load class - // files on the executors. - // This is a temporary intermediate step due to unavailable classloader isolation. - val classDirectory = sparkConnectArtifactDirectory.toPath.resolve("classes") - Files.createDirectories(classDirectory) - Some(classDirectory.toString) - } else { - None - } - }.foreach { path => + _conf.getOption("spark.repl.class.outputDir").foreach { path => val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path)) _conf.set("spark.repl.class.uri", replUri) } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 82d3a28894b6..2fce2889c097 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -186,6 +186,17 @@ private[spark] trait RpcEnvFileServer { */ def addDirectory(baseUri: String, path: File): String + /** + * Adds a local directory to be served via this file server. + * If the directory is already registered with the file server, it will result in a no-op. + * + * @param baseUri Leading URI path (files can be retrieved by appending their relative + * path to this base URI). This cannot be "files" nor "jars". + * @param path Path to the local directory. + * @return URI for the root of the directory in the file server. + */ + def addDirectoryIfAbsent(baseUri: String, path: File): String + /** Validates and normalizes the base URI for directories. */ protected def validateDirectoryUri(baseUri: String): String = { val baseCanonicalUri = new URI(baseUri).normalize().getPath diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index 73eb9a34669c..57243133aba9 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -90,4 +90,9 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) s"${rpcEnv.address.toSparkURL}$fixedBaseUri" } + override def addDirectoryIfAbsent(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path.getCanonicalFile) + s"${rpcEnv.address.toSparkURL}$fixedBaseUri" + } }