Skip to content

Commit a393d6c

Browse files
committed
[SPARK-48370][CONNECT] Checkpoint and localCheckpoint in Scala Spark Connect client
### What changes were proposed in this pull request? This PR adds `Dataset.checkpoint` and `Dataset.localCheckpoint` into Scala Spark Connect client. Python API was implemented at #46570 ### Why are the changes needed? For API parity. ### Does this PR introduce _any_ user-facing change? Yes, it adds `Dataset.checkpoint` and `Dataset.localCheckpoint` into Scala Spark Connect client. ### How was this patch tested? Unittests added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46683 from HyukjinKwon/SPARK-48370. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 5df9a08 commit a393d6c

File tree

6 files changed

+379
-13
lines changed

6 files changed

+379
-13
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 96 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3402,20 +3402,105 @@ class Dataset[T] private[sql] (
34023402
df
34033403
}
34043404

3405-
def checkpoint(): Dataset[T] = {
3406-
throw new UnsupportedOperationException("checkpoint is not implemented.")
3407-
}
3405+
/**
3406+
* Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to
3407+
* truncate the logical plan of this Dataset, which is especially useful in iterative algorithms
3408+
* where the plan may grow exponentially. It will be saved to files inside the checkpoint
3409+
* directory set with `SparkContext#setCheckpointDir`.
3410+
*
3411+
* @group basic
3412+
* @since 4.0.0
3413+
*/
3414+
def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true)
34083415

3409-
def checkpoint(eager: Boolean): Dataset[T] = {
3410-
throw new UnsupportedOperationException("checkpoint is not implemented.")
3411-
}
3416+
/**
3417+
* Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the
3418+
* logical plan of this Dataset, which is especially useful in iterative algorithms where the
3419+
* plan may grow exponentially. It will be saved to files inside the checkpoint directory set
3420+
* with `SparkContext#setCheckpointDir`.
3421+
*
3422+
* @param eager
3423+
* Whether to checkpoint this dataframe immediately
3424+
*
3425+
* @note
3426+
* When checkpoint is used with eager = false, the final data that is checkpointed after the
3427+
* first action may be different from the data that was used during the job due to
3428+
* non-determinism of the underlying operation and retries. If checkpoint is used to achieve
3429+
* saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is
3430+
* only deterministic after the first execution, after the checkpoint was finalized.
3431+
*
3432+
* @group basic
3433+
* @since 4.0.0
3434+
*/
3435+
def checkpoint(eager: Boolean): Dataset[T] =
3436+
checkpoint(eager = eager, reliableCheckpoint = true)
34123437

3413-
def localCheckpoint(): Dataset[T] = {
3414-
throw new UnsupportedOperationException("localCheckpoint is not implemented.")
3415-
}
3438+
/**
3439+
* Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used
3440+
* to truncate the logical plan of this Dataset, which is especially useful in iterative
3441+
* algorithms where the plan may grow exponentially. Local checkpoints are written to executor
3442+
* storage and despite potentially faster they are unreliable and may compromise job completion.
3443+
*
3444+
* @group basic
3445+
* @since 4.0.0
3446+
*/
3447+
def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false)
3448+
3449+
/**
3450+
* Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to
3451+
* truncate the logical plan of this Dataset, which is especially useful in iterative algorithms
3452+
* where the plan may grow exponentially. Local checkpoints are written to executor storage and
3453+
* despite potentially faster they are unreliable and may compromise job completion.
3454+
*
3455+
* @param eager
3456+
* Whether to checkpoint this dataframe immediately
3457+
*
3458+
* @note
3459+
* When checkpoint is used with eager = false, the final data that is checkpointed after the
3460+
* first action may be different from the data that was used during the job due to
3461+
* non-determinism of the underlying operation and retries. If checkpoint is used to achieve
3462+
* saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is
3463+
* only deterministic after the first execution, after the checkpoint was finalized.
3464+
*
3465+
* @group basic
3466+
* @since 4.0.0
3467+
*/
3468+
def localCheckpoint(eager: Boolean): Dataset[T] =
3469+
checkpoint(eager = eager, reliableCheckpoint = false)
34163470

3417-
def localCheckpoint(eager: Boolean): Dataset[T] = {
3418-
throw new UnsupportedOperationException("localCheckpoint is not implemented.")
3471+
/**
3472+
* Returns a checkpointed version of this Dataset.
3473+
*
3474+
* @param eager
3475+
* Whether to checkpoint this dataframe immediately
3476+
* @param reliableCheckpoint
3477+
* Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If
3478+
* false creates a local checkpoint using the caching subsystem
3479+
*/
3480+
private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = {
3481+
sparkSession.newDataset(agnosticEncoder) { builder =>
3482+
val command = sparkSession.newCommand { builder =>
3483+
builder.getCheckpointCommandBuilder
3484+
.setLocal(reliableCheckpoint)
3485+
.setEager(eager)
3486+
.setRelation(this.plan.getRoot)
3487+
}
3488+
val responseIter = sparkSession.execute(command)
3489+
try {
3490+
val response = responseIter
3491+
.find(_.hasCheckpointCommandResult)
3492+
.getOrElse(throw new RuntimeException("CheckpointCommandResult must be present"))
3493+
3494+
val cachedRemoteRelation = response.getCheckpointCommandResult.getRelation
3495+
sparkSession.cleaner.registerCachedRemoteRelationForCleanup(cachedRemoteRelation)
3496+
3497+
// Update the builder with the values from the result.
3498+
builder.setCachedRemoteRelation(cachedRemoteRelation)
3499+
} finally {
3500+
// consume the rest of the iterator
3501+
responseIter.foreach(_ => ())
3502+
}
3503+
}
34193504
}
34203505

34213506
/**

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar
4141
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
4242
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
4343
import org.apache.spark.sql.functions.lit
44-
import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf}
44+
import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf}
4545
import org.apache.spark.sql.streaming.DataStreamReader
4646
import org.apache.spark.sql.streaming.StreamingQueryManager
4747
import org.apache.spark.sql.types.StructType
@@ -73,6 +73,11 @@ class SparkSession private[sql] (
7373
with Logging {
7474

7575
private[this] val allocator = new RootAllocator()
76+
private var shouldStopCleaner = false
77+
private[sql] lazy val cleaner = {
78+
shouldStopCleaner = true
79+
new SessionCleaner(this)
80+
}
7681

7782
// a unique session ID for this session from client.
7883
private[sql] def sessionId: String = client.sessionId
@@ -714,6 +719,9 @@ class SparkSession private[sql] (
714719
if (releaseSessionOnClose) {
715720
client.releaseSession()
716721
}
722+
if (shouldStopCleaner) {
723+
cleaner.stop()
724+
}
717725
client.shutdown()
718726
allocator.close()
719727
SparkSession.onSessionClose(this)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal
19+
20+
import java.lang.ref.{ReferenceQueue, WeakReference}
21+
import java.util.Collections
22+
import java.util.concurrent.ConcurrentHashMap
23+
24+
import org.apache.spark.connect.proto
25+
import org.apache.spark.internal.Logging
26+
import org.apache.spark.sql.SparkSession
27+
28+
/**
29+
* Classes that represent cleaning tasks.
30+
*/
31+
private sealed trait CleanupTask
32+
private case class CleanupCachedRemoteRelation(dfID: String) extends CleanupTask
33+
34+
/**
35+
* A WeakReference associated with a CleanupTask.
36+
*
37+
* When the referent object becomes only weakly reachable, the corresponding
38+
* CleanupTaskWeakReference is automatically added to the given reference queue.
39+
*/
40+
private class CleanupTaskWeakReference(
41+
val task: CleanupTask,
42+
referent: AnyRef,
43+
referenceQueue: ReferenceQueue[AnyRef])
44+
extends WeakReference(referent, referenceQueue)
45+
46+
/**
47+
* An asynchronous cleaner for objects.
48+
*
49+
* This maintains a weak reference for each CashRemoteRelation, etc. of interest, to be processed
50+
* when the associated object goes out of scope of the application. Actual cleanup is performed in
51+
* a separate daemon thread.
52+
*/
53+
private[sql] class SessionCleaner(session: SparkSession) extends Logging {
54+
55+
/**
56+
* How often (seconds) to trigger a garbage collection in this JVM. This context cleaner
57+
* triggers cleanups only when weak references are garbage collected. In long-running
58+
* applications with large driver JVMs, where there is little memory pressure on the driver,
59+
* this may happen very occasionally or not at all. Not cleaning at all may lead to executors
60+
* running out of disk space after a while.
61+
*/
62+
private val refQueuePollTimeout: Long = 100
63+
64+
/**
65+
* A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they
66+
* have not been handled by the reference queue.
67+
*/
68+
private val referenceBuffer =
69+
Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap)
70+
71+
private val referenceQueue = new ReferenceQueue[AnyRef]
72+
73+
private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() }
74+
75+
@volatile private var started = false
76+
@volatile private var stopped = false
77+
78+
/** Start the cleaner. */
79+
def start(): Unit = {
80+
cleaningThread.setDaemon(true)
81+
cleaningThread.setName("Spark Connect Context Cleaner")
82+
cleaningThread.start()
83+
}
84+
85+
/**
86+
* Stop the cleaning thread and wait until the thread has finished running its current task.
87+
*/
88+
def stop(): Unit = {
89+
stopped = true
90+
// Interrupt the cleaning thread, but wait until the current task has finished before
91+
// doing so. This guards against the race condition where a cleaning thread may
92+
// potentially clean similarly named variables created by a different SparkSession.
93+
synchronized {
94+
cleaningThread.interrupt()
95+
}
96+
cleaningThread.join()
97+
}
98+
99+
/** Register a CachedRemoteRelation for cleanup when it is garbage collected. */
100+
def registerCachedRemoteRelationForCleanup(relation: proto.CachedRemoteRelation): Unit = {
101+
registerForCleanup(relation, CleanupCachedRemoteRelation(relation.getRelationId))
102+
}
103+
104+
/** Register an object for cleanup. */
105+
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
106+
if (!started) {
107+
// Lazily starts when the first cleanup is registered.
108+
start()
109+
started = true
110+
}
111+
referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue))
112+
}
113+
114+
/** Keep cleaning objects. */
115+
private def keepCleaning(): Unit = {
116+
while (!stopped && !session.client.channel.isShutdown) {
117+
try {
118+
val reference = Option(referenceQueue.remove(refQueuePollTimeout))
119+
.map(_.asInstanceOf[CleanupTaskWeakReference])
120+
// Synchronize here to avoid being interrupted on stop()
121+
synchronized {
122+
reference.foreach { ref =>
123+
logDebug("Got cleaning task " + ref.task)
124+
referenceBuffer.remove(ref)
125+
ref.task match {
126+
case CleanupCachedRemoteRelation(dfID) =>
127+
doCleanupCachedRemoteRelation(dfID)
128+
}
129+
}
130+
}
131+
} catch {
132+
case e: Throwable => logError("Error in cleaning thread", e)
133+
}
134+
}
135+
}
136+
137+
/** Perform CleanupCachedRemoteRelation cleanup. */
138+
private[spark] def doCleanupCachedRemoteRelation(dfID: String): Unit = {
139+
session.execute {
140+
session.newCommand { builder =>
141+
builder.getRemoveCachedRemoteRelationCommandBuilder
142+
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build())
143+
}
144+
}
145+
}
146+
}

0 commit comments

Comments
 (0)