Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ import javax.ws.rs.core.UriBuilder
import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.{JobArtifactSet, SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL
import org.apache.spark.sql.connect.service.SessionHolder
Expand All @@ -39,45 +42,29 @@ import org.apache.spark.util.Utils
* The Artifact Manager for the [[SparkConnectService]].
*
* This class handles the storage of artifacts as well as preparing the artifacts for use.
* Currently, jars and classfile artifacts undergo additional processing:
* - Jars and pyfiles are automatically added to the underlying [[SparkContext]] and are
* accessible by all users of the cluster.
* - Class files are moved into a common directory that is shared among all users of the
* cluster. Note: Under a multi-user setup, class file conflicts may occur between user
* classes as the class file directory is shared.
*
* Artifacts belonging to different [[SparkSession]]s are segregated and isolated from each other
* with the help of the `sessionUUID`.
*
* Jars and classfile artifacts are stored under "jars" and "classes" sub-directories respectively
* while other types of artifacts are stored under the root directory for that particular
* [[SparkSession]].
*
* @param sessionHolder
* The object used to hold the Spark Connect session state.
*/
class SparkConnectArtifactManager private[connect] {
class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging {
import SparkConnectArtifactManager._

// The base directory where all artifacts are stored.
// Note: If a REPL is attached to the cluster, class file artifacts are stored in the
// REPL's output directory.
private[connect] lazy val artifactRootPath = SparkContext.getActive match {
case Some(sc) =>
sc.sparkConnectArtifactDirectory.toPath
case None =>
throw new RuntimeException("SparkContext is uninitialized!")
}
private[connect] lazy val artifactRootURI = {
val fileServer = SparkEnv.get.rpcEnv.fileServer
fileServer.addDirectory("artifacts", artifactRootPath.toFile)
}

// The base directory where all class files are stored.
// Note: If a REPL is attached to the cluster, we piggyback on the existing REPL output
// directory to store class file artifacts.
private[connect] lazy val classArtifactDir = SparkEnv.get.conf
.getOption("spark.repl.class.outputDir")
.map(p => Paths.get(p))
.getOrElse(ArtifactUtils.concatenatePaths(artifactRootPath, "classes"))

private[connect] lazy val classArtifactUri: String =
SparkEnv.get.conf.getOption("spark.repl.class.uri") match {
case Some(uri) => uri
case None =>
throw new RuntimeException("Class artifact URI had not been initialised in SparkContext!")
}
private val sessionUUID = sessionHolder.session.sessionUUID
// The base directory/URI where all artifacts are stored for this `sessionUUID`.
val (artifactPath, artifactURI): (Path, String) =
getArtifactDirectoryAndUriForSession(sessionHolder)
// The base directory/URI where all class file artifacts are stored for this `sessionUUID`.
val (classDir, classURI): (Path, String) = getClassfileDirectoryAndUriForSession(sessionHolder)

private val jarsList = new CopyOnWriteArrayList[Path]
private val jarsURI = new CopyOnWriteArrayList[String]
private val pythonIncludeList = new CopyOnWriteArrayList[String]

/**
Expand All @@ -98,13 +85,11 @@ class SparkConnectArtifactManager private[connect] {
* Add and prepare a staged artifact (i.e an artifact that has been rebuilt locally from bytes
* over the wire) for use.
*
* @param session
* @param remoteRelativePath
* @param serverLocalStagingPath
* @param fragment
*/
private[connect] def addArtifact(
sessionHolder: SessionHolder,
remoteRelativePath: Path,
serverLocalStagingPath: Path,
fragment: Option[String]): Unit = {
Expand All @@ -127,27 +112,28 @@ class SparkConnectArtifactManager private[connect] {
updater.save()
}(catchBlock = { tmpFile.delete() })
} else if (remoteRelativePath.startsWith(s"classes${File.separator}")) {
// Move class files to common location (shared among all users)
// Move class files to the right directory.
val target = ArtifactUtils.concatenatePaths(
classArtifactDir,
classDir,
remoteRelativePath.toString.stripPrefix(s"classes${File.separator}"))
Files.createDirectories(target.getParent)
// Allow overwriting class files to capture updates to classes.
// This is required because the client currently sends all the class files in each class file
// transfer.
Files.move(serverLocalStagingPath, target, StandardCopyOption.REPLACE_EXISTING)
} else {
val target = ArtifactUtils.concatenatePaths(artifactRootPath, remoteRelativePath)
val target = ArtifactUtils.concatenatePaths(artifactPath, remoteRelativePath)
Files.createDirectories(target.getParent)
// Disallow overwriting jars because spark doesn't support removing jars that were
// previously added,
// Disallow overwriting non-classfile artifacts
if (Files.exists(target)) {
throw new RuntimeException(
s"Duplicate file: $remoteRelativePath. Files cannot be overwritten.")
s"Duplicate Artifact: $remoteRelativePath. " +
"Artifacts cannot be overwritten.")
}
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
// Adding Jars to the underlying spark context (visible to all users)
sessionHolder.session.sessionState.resourceLoader.addJar(target.toString)
jarsList.add(target)
jarsURI.add(artifactURI + "/" + target.toString)
} else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
sessionHolder.session.sparkContext.addFile(target.toString)
val stringRemotePath = remoteRelativePath.toString
Expand All @@ -165,8 +151,47 @@ class SparkConnectArtifactManager private[connect] {
}
}

/**
* Returns a [[JobArtifactSet]] pointing towards the session-specific jars and class files.
*/
def jobArtifactSet: JobArtifactSet = {
val builder = Map.newBuilder[String, Long]
jarsURI.forEach { jar =>
builder += jar -> 0
}

new JobArtifactSet(
uuid = Option(sessionUUID),
replClassDirUri = Option(classURI),
jars = builder.result(),
files = Map.empty,
archives = Map.empty)
}

/**
* Returns a [[ClassLoader]] for session-specific jar/class file resources.
*/
def classloader: ClassLoader = {
val urls = jarsList.asScala.map(_.toUri.toURL) :+ classDir.toUri.toURL
new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
}

/**
* Cleans up all resources specific to this `sessionHolder`.
*/
private[connect] def cleanUpResources(): Unit = {
logDebug(
s"Cleaning up resources for session with userId: ${sessionHolder.userId} and " +
s"sessionId: ${sessionHolder.sessionId}")
// Clean up cached relations
val blockManager = sessionHolder.session.sparkContext.env.blockManager
blockManager.removeCache(sessionHolder.userId, sessionHolder.sessionId)

// Clean up artifacts folder
FileUtils.deleteDirectory(artifactRootPath.toFile)
}

private[connect] def uploadArtifactToFs(
sessionHolder: SessionHolder,
remoteRelativePath: Path,
serverLocalStagingPath: Path): Unit = {
val hadoopConf = sessionHolder.session.sparkContext.hadoopConfiguration
Expand Down Expand Up @@ -200,48 +225,80 @@ class SparkConnectArtifactManager private[connect] {
}
}

object SparkConnectArtifactManager {
object SparkConnectArtifactManager extends Logging {

val forwardToFSPrefix = "forward_to_fs"

private var _activeArtifactManager: SparkConnectArtifactManager = _
private var currentArtifactRootUri: String = _
private var lastKnownSparkContextInstance: SparkContext = _

/**
* Obtain the active artifact manager or create a new artifact manager.
*
* @return
*/
def getOrCreateArtifactManager: SparkConnectArtifactManager = {
if (_activeArtifactManager == null) {
_activeArtifactManager = new SparkConnectArtifactManager
}
_activeArtifactManager
private val ARTIFACT_DIRECTORY_PREFIX = "artifacts"

// The base directory where all artifacts are stored.
private[spark] lazy val artifactRootPath = {
Utils.createTempDir(ARTIFACT_DIRECTORY_PREFIX).toPath
}

private lazy val artifactManager = getOrCreateArtifactManager
private[spark] def getArtifactDirectoryAndUriForSession(session: SparkSession): (Path, String) =
(
ArtifactUtils.concatenatePaths(artifactRootPath, session.sessionUUID),
s"$artifactRootURI/${session.sessionUUID}")

private[spark] def getArtifactDirectoryAndUriForSession(
sessionHolder: SessionHolder): (Path, String) =
getArtifactDirectoryAndUriForSession(sessionHolder.session)

private[spark] def getClassfileDirectoryAndUriForSession(
session: SparkSession): (Path, String) = {
val (artDir, artUri) = getArtifactDirectoryAndUriForSession(session)
(ArtifactUtils.concatenatePaths(artDir, "classes"), s"$artUri/classes/")
}

private[spark] def getClassfileDirectoryAndUriForSession(
sessionHolder: SessionHolder): (Path, String) =
getClassfileDirectoryAndUriForSession(sessionHolder.session)

/**
* Obtain a classloader that contains jar and classfile artifacts on the classpath.
* Updates the URI for the artifact directory.
*
* @return
* This is required if the SparkContext is restarted.
*
* Note: This logic is solely to handle testing where a [[SparkContext]] may be restarted
* several times in a single JVM lifetime. In a general Spark cluster, the [[SparkContext]] is
* not expected to be restarted at any point in time.
*/
def classLoaderWithArtifacts: ClassLoader = {
val urls = artifactManager.getSparkConnectAddedJars :+
artifactManager.classArtifactDir.toUri.toURL
new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
private def refreshArtifactUri(sc: SparkContext): Unit = synchronized {
// If a competing thread had updated the URI, we do not need to refresh the URI again.
if (sc eq lastKnownSparkContextInstance) {
return
}
val oldArtifactUri = currentArtifactRootUri
currentArtifactRootUri = SparkEnv.get.rpcEnv.fileServer
.addDirectoryIfAbsent(ARTIFACT_DIRECTORY_PREFIX, artifactRootPath.toFile)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use addDirectory instead? The if-absent bit if pretty well protected by this object.

lastKnownSparkContextInstance = sc
logDebug(s"Artifact URI updated from $oldArtifactUri to $currentArtifactRootUri")
}

/**
* Run a segment of code utilising a classloader that contains jar and classfile artifacts on
* the classpath.
* Checks if the URI for the artifact directory needs to be updated. This is required in cases
* where SparkContext is restarted as the old URI would no longer be valid.
*
* @param thunk
* @tparam T
* @return
* Note: This logic is solely to handle testing where a [[SparkContext]] may be restarted
* several times in a single JVM lifetime. In a general Spark cluster, the [[SparkContext]] is
* not expected to be restarted at any point in time.
*/
def withArtifactClassLoader[T](thunk: => T): T = {
Utils.withContextClassLoader(classLoaderWithArtifacts) {
thunk
private def updateUriIfRequired(): Unit = {
SparkContext.getActive.foreach { sc =>
if (lastKnownSparkContextInstance == null || (sc ne lastKnownSparkContextInstance)) {
logDebug("Refreshing artifact URI due to SparkContext (re)initialisation!")
refreshArtifactUri(sc)
}
}
}

private[connect] def artifactRootURI: String = {
updateUriIfRequired()
require(currentArtifactRootUri != null)
currentArtifactRootUri
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import com.google.protobuf.{Any => ProtoAny, ByteString}
import io.grpc.{Context, Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.exception.ExceptionUtils
import org.json4s._
import org.json4s.jackson.JsonMethods.parse

import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
Expand All @@ -50,7 +52,6 @@ import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
Expand Down Expand Up @@ -88,6 +89,15 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
private lazy val pythonExec =
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))

// SparkConnectPlanner is used per request.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could put this in the session holder right?

private lazy val pythonIncludes = {
implicit val formats = DefaultFormats
parse(session.conf.get("spark.connect.pythonUDF.includes", "[]"))
.extract[Array[String]]
.toList
.asJava
}

// The root of the query plan is a relation and we apply the transformations to it.
def transformRelation(rel: proto.Relation): LogicalPlan = {
val plan = rel.getRelTypeCase match {
Expand Down Expand Up @@ -1418,13 +1428,13 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = {
Utils.deserialize[UdfPacket](
fun.getScalarScalaUdf.getPayload.toByteArray,
SparkConnectArtifactManager.classLoaderWithArtifacts)
Utils.getContextOrSparkClassLoader)
}

private def unpackForeachWriter(fun: proto.ScalarScalaUDF): ForeachWriterPacket = {
Utils.deserialize[ForeachWriterPacket](
fun.getPayload.toByteArray,
SparkConnectArtifactManager.classLoaderWithArtifacts)
Utils.getContextOrSparkClassLoader)
}

/**
Expand Down Expand Up @@ -1481,8 +1491,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
command = fun.getCommand.toByteArray,
// Empty environment variables
envVars = Maps.newHashMap(),
pythonIncludes =
SparkConnectArtifactManager.getOrCreateArtifactManager.getSparkConnectPythonIncludes.asJava,
pythonIncludes = pythonIncludes,
pythonExec = pythonExec,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
Expand Down
Loading