Skip to content

Commit

Permalink
[TableUtils] Prevent hard failures on duplicate setups
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianfr committed Feb 10, 2024
1 parent ac5095b commit 2418803
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
31 changes: 28 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.storage.StorageLevel

import java.time.format.DateTimeFormatter
Expand Down Expand Up @@ -324,8 +324,33 @@ case class TableUtils(sparkSession: SparkSession) {
val partitionCount = sparkSession.sparkContext.getConf.getInt("spark.default.parallelism", 1000)
logger.info(
s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n")
val df = sparkSession.sql(query).coalesce(partitionCount)
df
try {
// Run the query
val df = sparkSession.sql(query).coalesce(partitionCount)
df
} catch {
case e: AnalysisException =>
// Check if the query contains function definitions
val containsFunctionDefinitions = query.contains("CREATE FUNCTION") || query.contains("CREATE TEMPORARY FUNCTION")

// If the query contains function definitions, then inspect existing functions
if (containsFunctionDefinitions) {
val functionIdentifiers = sparkSession.catalog.listFunctions().collect().map(_.name)
val queryFunctions = sparkSession.sessionState.sqlParser.parsePlan(query).flatMap {
case p: org.apache.spark.sql.catalyst.plans.logical.CreateFunctionStatement => p.functionName
}
val existingFunctions = functionIdentifiers.intersect(queryFunctions)

if (existingFunctions.nonEmpty) {
logger.warn(s"The following function(s) already exist(s): ${existingFunctions.mkString(", ")}. Query may result in function redefinition.")
return null
}
}
throw e
case e: Exception =>
logger.error("Error running query:", e)
throw e
}
}

def insertUnPartitioned(df: DataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,14 @@ class TableUtilsTest {
assertTrue(tableUtils.checkTablePermission(tableName))
}

@Test
def testDoubleUDFRegistration(): Unit = {
val resourceURL = getClass.getResource("/jars/brickhouse-0.6.0.jar")
tableUtils.sql(s"ADD JAR ${resourceURL.getPath}")
tableUtils.sql("CREATE TEMPORARY FUNCTION test AS 'brickhouse.udf.date.AddDaysUDF';")
tableUtils.sql("CREATE TEMPORARY FUNCTION test AS 'brickhouse.udf.date.AddDaysUDF';")
}

@Test
def testIfPartitionExistsInTable(): Unit = {
val tableName = "db.test_if_partition_exists"
Expand Down

0 comments on commit 2418803

Please sign in to comment.