Skip to content

Commit

Permalink
[Spark] Pass sparkSession to commitOwnerBuilder (#3112)
Browse files Browse the repository at this point in the history
<!--
Thanks for sending a pull request!  Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md
2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP]
Your PR title ...'.
  3. Be sure to keep the PR description updated to reflect all changes.
  4. Please write your PR title to summarize what this PR proposes.
5. If possible, provide a concise example to reproduce the issue for a
faster review.
6. If applicable, include the corresponding issue number in the PR title
and link it in the body.
-->

#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [X] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description

<!--
- Describe what this PR changes.
- Describe why we need the change.
 
If this PR resolves an issue be sure to include "Resolves #XXX" to
correctly link and close the issue upon merge.
-->
Updates CommitOwnerBuilder.build so that it can take in a sparkSession
object. This allows it to read CommitOwner-related dynamic confs from
the sparkSession while building it.


## Does this PR introduce _any_ user-facing changes?

<!--
If yes, please clarify the previous behavior and the change this PR
proposes - provide the console output, description and/or an example to
show the behavior difference if possible.
If possible, please also clarify if this is a user-facing change
compared to the released Delta Lake versions or within the unreleased
branches such as master.
If no, write 'No'.
-->
No
  • Loading branch information
dhruvarya-db authored May 17, 2024
1 parent eca5a7f commit 3af4335
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ trait OptimisticTransactionImpl extends TransactionalWrite
var newManagedCommitTableConf: Option[Map[String, String]] = None
if (finalMetadata.configuration != snapshot.metadata.configuration || snapshot.version == -1L) {
val newCommitOwnerClientOpt =
ManagedCommitUtils.getCommitOwnerClient(finalMetadata, finalProtocol)
ManagedCommitUtils.getCommitOwnerClient(spark, finalMetadata, finalProtocol)
(newCommitOwnerClientOpt, readSnapshotTableCommitOwnerClientOpt) match {
case (Some(newCommitOwnerClient), None) =>
// FS -> MC conversion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class Snapshot(
*/
val tableCommitOwnerClientOpt: Option[TableCommitOwnerClient] = initializeTableCommitOwner()
protected def initializeTableCommitOwner(): Option[TableCommitOwnerClient] = {
ManagedCommitUtils.getTableCommitOwner(this)
ManagedCommitUtils.getTableCommitOwner(spark, this)
}

/** Number of columns to collect stats on for data skipping */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import org.apache.spark.sql.delta.storage.LogStore
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.SparkSession

/** Representation of a commit file */
case class Commit(
private val version: Long,
Expand Down Expand Up @@ -199,7 +201,7 @@ trait CommitOwnerBuilder {
def getName: String

/** Returns a commit-owner client based on the given conf */
def build(conf: Map[String, String]): CommitOwnerClient
def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient
}

/** Factory to get the correct [[CommitOwnerClient]] for a table */
Expand All @@ -218,10 +220,12 @@ object CommitOwnerProvider {
}
}

/** Returns a [[CommitOwnerClient]] for the given `name` and `conf` */
/** Returns a [[CommitOwnerClient]] for the given `name`, `conf`, and `spark` */
def getCommitOwnerClient(
name: String, conf: Map[String, String]): CommitOwnerClient = synchronized {
nameToBuilderMapping.get(name).map(_.build(conf)).getOrElse {
name: String,
conf: Map[String, String],
spark: SparkSession): CommitOwnerClient = synchronized {
nameToBuilderMapping.get(name).map(_.build(spark, conf)).getOrElse {
throw new IllegalArgumentException(s"Unknown commit-owner: $name")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.sql.delta.storage.LogStore
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.SparkSession

class InMemoryCommitOwner(val batchSize: Long)
extends AbstractBatchBackfillingCommitOwnerClient {

Expand Down Expand Up @@ -206,7 +208,7 @@ case class InMemoryCommitOwnerBuilder(batchSize: Long) extends CommitOwnerBuilde
def getName: String = "in-memory"

/** Returns a commit-owner based on the given conf */
def build(conf: Map[String, String]): CommitOwnerClient = {
def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
inMemoryStore
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.apache.spark.sql.delta.util.FileNames.{DeltaFile, UnbackfilledDeltaFi
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.SparkSession

object ManagedCommitUtils extends DeltaLogging {

/**
Expand Down Expand Up @@ -111,16 +113,19 @@ object ManagedCommitUtils extends DeltaLogging {
*/
def getTablePath(logPath: Path): Path = logPath.getParent

def getCommitOwnerClient(metadata: Metadata, protocol: Protocol): Option[CommitOwnerClient] = {
def getCommitOwnerClient(
spark: SparkSession, metadata: Metadata, protocol: Protocol): Option[CommitOwnerClient] = {
metadata.managedCommitOwnerName.map { commitOwnerStr =>
assert(protocol.isFeatureSupported(ManagedCommitTableFeature))
CommitOwnerProvider.getCommitOwnerClient(commitOwnerStr, metadata.managedCommitOwnerConf)
CommitOwnerProvider.getCommitOwnerClient(
commitOwnerStr, metadata.managedCommitOwnerConf, spark)
}
}

def getTableCommitOwner(
spark: SparkSession,
snapshotDescriptor: SnapshotDescriptor): Option[TableCommitOwnerClient] = {
getCommitOwnerClient(snapshotDescriptor.metadata, snapshotDescriptor.protocol).map {
getCommitOwnerClient(spark, snapshotDescriptor.metadata, snapshotDescriptor.protocol).map {
commitOwner =>
TableCommitOwnerClient(
commitOwner,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.scalatest.Tag

import org.apache.spark.{DebugFilesystem, SparkException, TaskFailedReason}
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class DeltaLogSuite extends QueryTest
// For Managed Commit table with a commit that is not backfilled, we can't use
// 00000000002.json yet. Contact commit store to get uuid file path to malform json file.
val oc = CommitOwnerProvider.getCommitOwnerClient(
"tracking-in-memory", Map.empty[String, String])
"tracking-in-memory", Map.empty[String, String], spark)
val commitResponse = oc.getCommits(deltaLog.logPath, Map.empty, Some(2))
if (!commitResponse.getCommits.isEmpty) {
val path = commitResponse.getCommits.last.getFileStatus.getPath
Expand Down Expand Up @@ -602,7 +602,7 @@ class DeltaLogSuite extends QueryTest
// For Managed Commit table with a commit that is not backfilled, we can't use
// 00000000001.json yet. Contact commit store to get uuid file path to malform json file.
val oc = CommitOwnerProvider.getCommitOwnerClient(
"tracking-in-memory", Map.empty[String, String])
"tracking-in-memory", Map.empty[String, String], spark)
val commitResponse = oc.getCommits(log.logPath, Map.empty, Some(1))
if (!commitResponse.getCommits.isEmpty) {
commitFilePath = commitResponse.getCommits.head.getFileStatus.getPath
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ import org.apache.spark.sql.delta.util.{FileNames, JsonUtils}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.Row
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal}
import org.apache.spark.sql.functions.lit
Expand Down Expand Up @@ -520,7 +519,8 @@ class OptimisticTransactionSuite
}
}
}
override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def build(
spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
}

CommitOwnerProvider.registerBuilder(RetryableNonConflictCommitOwnerBuilder$)
Expand Down Expand Up @@ -569,7 +569,8 @@ class OptimisticTransactionSuite
}
}
}
override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def build(
spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
}

CommitOwnerProvider.registerBuilder(FileAlreadyExistsCommitOwnerBuilder)
Expand Down Expand Up @@ -878,7 +879,8 @@ class OptimisticTransactionSuite
object RetryableConflictCommitOwnerBuilder$ extends CommitOwnerBuilder {
lazy val commitOwnerClient = new RetryableConflictCommitOwnerClient()
override def getName: String = commitOwnerName
override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def build(
spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
}
CommitOwnerProvider.registerBuilder(RetryableConflictCommitOwnerBuilder$)
val conf = Map(DeltaConfigs.MANAGED_COMMIT_OWNER_NAME.key -> commitOwnerName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkConf
import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -587,7 +588,7 @@ object ConcurrentBackfillCommitOwnerBuilder extends CommitOwnerBuilder {
private lazy val concurrentBackfillCommitOwnerClient =
ConcurrentBackfillCommitOwnerClient(synchronousBackfillThreshold = 2, batchSize)
override def getName: String = "awaiting-commit-owner"
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
concurrentBackfillCommitOwnerClient
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.delta.test.DeltaSQLTestUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.{QueryTest, SparkSession}
import org.apache.spark.sql.test.SharedSparkSession

class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with SharedSparkSession
Expand Down Expand Up @@ -72,15 +72,15 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share

test("registering multiple commit-owner builders with same name") {
object Builder1 extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = null
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = null
override def getName: String = "builder-1"
}
object BuilderWithSameName extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = null
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = null
override def getName: String = "builder-1"
}
object Builder3 extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = null
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = null
override def getName: String = "builder-3"
}
CommitOwnerProvider.registerBuilder(Builder1)
Expand All @@ -94,7 +94,7 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share
object Builder1 extends CommitOwnerBuilder {
val cs1 = new TestCommitOwnerClient1()
val cs2 = new TestCommitOwnerClient2()
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
conf.getOrElse("url", "") match {
case "url1" => cs1
case "url2" => cs2
Expand All @@ -104,21 +104,22 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share
override def getName: String = "cs-x"
}
CommitOwnerProvider.registerBuilder(Builder1)
val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"))
val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"), spark)
assert(cs1.isInstanceOf[TestCommitOwnerClient1])
val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"))
val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"), spark)
assert(cs1 eq cs1_again)
val cs2 = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url2", "a" -> "b"))
val cs2 =
CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url2", "a" -> "b"), spark)
assert(cs2.isInstanceOf[TestCommitOwnerClient2])
// If builder receives a config which doesn't have expected params, then it can throw exception.
intercept[IllegalArgumentException] {
CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url3"))
CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url3"), spark)
}
}

test("getCommitOwnerClient - builder returns new object each time") {
object Builder1 extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
conf.getOrElse("url", "") match {
case "url1" => new TestCommitOwnerClient1()
case _ => throw new IllegalArgumentException("Invalid url")
Expand All @@ -127,9 +128,9 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share
override def getName: String = "cs-name"
}
CommitOwnerProvider.registerBuilder(Builder1)
val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"))
val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"), spark)
assert(cs1.isInstanceOf[TestCommitOwnerClient1])
val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"))
val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"), spark)
assert(cs1 ne cs1_again)
}

Expand Down Expand Up @@ -202,21 +203,21 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share
other.asInstanceOf[TestCommitOwnerClient].key == key
}
object Builder1 extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
new TestCommitOwnerClient(conf("key"))
}
override def getName: String = "cs-name"
}
CommitOwnerProvider.registerBuilder(Builder1)

// Different CommitOwner with same keys should be semantically equal.
val obj1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"))
val obj2 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"))
val obj1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"), spark)
val obj2 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"), spark)
assert(obj1 != obj2)
assert(obj1.semanticEquals(obj2))

// Different CommitOwner with different keys should be semantically unequal.
val obj3 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url2"))
val obj3 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url2"), spark)
assert(obj1 != obj3)
assert(!obj1.semanticEquals(obj3))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ abstract class InMemoryCommitOwnerSuite(batchSize: Int) extends CommitOwnerClien

override protected def createTableCommitOwnerClient(
deltaLog: DeltaLog): TableCommitOwnerClient = {
val cs = InMemoryCommitOwnerBuilder(batchSize).build(Map.empty)
val cs = InMemoryCommitOwnerBuilder(batchSize).build(spark, Map.empty)
TableCommitOwnerClient(cs, deltaLog, Map.empty[String, String])
}

Expand Down Expand Up @@ -65,22 +65,22 @@ abstract class InMemoryCommitOwnerSuite(batchSize: Int) extends CommitOwnerClien

test("InMemoryCommitOwnerBuilder works as expected") {
val builder1 = InMemoryCommitOwnerBuilder(5)
val cs1 = builder1.build(Map.empty)
val cs1 = builder1.build(spark, Map.empty)
assert(cs1.isInstanceOf[InMemoryCommitOwner])
assert(cs1.asInstanceOf[InMemoryCommitOwner].batchSize == 5)

val cs1_again = builder1.build(Map.empty)
val cs1_again = builder1.build(spark, Map.empty)
assert(cs1_again.isInstanceOf[InMemoryCommitOwner])
assert(cs1 == cs1_again)

val builder2 = InMemoryCommitOwnerBuilder(10)
val cs2 = builder2.build(Map.empty)
val cs2 = builder2.build(spark, Map.empty)
assert(cs2.isInstanceOf[InMemoryCommitOwner])
assert(cs2.asInstanceOf[InMemoryCommitOwner].batchSize == 10)
assert(cs2 ne cs1)

val builder3 = InMemoryCommitOwnerBuilder(10)
val cs3 = builder3.build(Map.empty)
val cs3 = builder3.build(spark, Map.empty)
assert(cs3.isInstanceOf[InMemoryCommitOwner])
assert(cs3.asInstanceOf[InMemoryCommitOwner].batchSize == 10)
assert(cs3 ne cs2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.SparkConf
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{QueryTest, Row, SparkSession}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ManualClock

Expand Down Expand Up @@ -71,7 +71,7 @@ class ManagedCommitSuite

override def getName: String = commitOwnerName

override def build(conf: Map[String, String]): CommitOwnerClient =
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient =
new InMemoryCommitOwner(batchSize = 5) {
override def commit(
logStore: LogStore,
Expand Down Expand Up @@ -125,7 +125,7 @@ class ManagedCommitSuite

test("cold snapshot initialization") {
val builder = TrackingInMemoryCommitOwnerBuilder(batchSize = 10)
val commitOwnerClient = builder.build(Map.empty).asInstanceOf[TrackingCommitOwnerClient]
val commitOwnerClient = builder.build(spark, Map.empty).asInstanceOf[TrackingCommitOwnerClient]
CommitOwnerProvider.registerBuilder(builder)
withTempDir { tempDir =>
val tablePath = tempDir.getAbsolutePath
Expand Down Expand Up @@ -221,7 +221,7 @@ class ManagedCommitSuite
name: String,
commitOwnerClient: CommitOwnerClient) extends CommitOwnerBuilder {
var numBuildCalled = 0
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
numBuildCalled += 1
commitOwnerClient
}
Expand Down Expand Up @@ -361,7 +361,8 @@ class ManagedCommitSuite
case class TrackingInMemoryCommitOwnerClientBuilder(
name: String,
commitOwnerClient: CommitOwnerClient) extends CommitOwnerBuilder {
override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def build(
spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient
override def getName: String = name
}
val builder1 = TrackingInMemoryCommitOwnerClientBuilder(name = "in-memory-1", cs1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.test.SharedSparkSession

trait ManagedCommitTestUtils
Expand Down Expand Up @@ -116,7 +117,7 @@ case class TrackingInMemoryCommitOwnerBuilder(
}

override def getName: String = "tracking-in-memory"
override def build(conf: Map[String, String]): CommitOwnerClient = {
override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = {
trackingInMemoryCommitOwnerClient
}
}
Expand Down

0 comments on commit 3af4335

Please sign in to comment.