diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 15007dee39..323036b856 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession} import org.apache.spark.storage.StorageLevel import java.time.format.DateTimeFormatter @@ -324,8 +324,33 @@ case class TableUtils(sparkSession: SparkSession) { val partitionCount = sparkSession.sparkContext.getConf.getInt("spark.default.parallelism", 1000) logger.info( s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n") - val df = sparkSession.sql(query).coalesce(partitionCount) - df + try { + // Run the query + val df = sparkSession.sql(query).coalesce(partitionCount) + df + } catch { + case e: AnalysisException => + // Check if the query contains function definitions + val containsFunctionDefinitions = query.contains("CREATE FUNCTION") || query.contains("CREATE TEMPORARY FUNCTION") + + // If the query contains function definitions, then inspect existing functions + if (containsFunctionDefinitions) { + val functionIdentifiers = sparkSession.catalog.listFunctions().collect().map(_.name) + val queryFunctions = sparkSession.sessionState.sqlParser.parsePlan(query).flatMap { + case p: org.apache.spark.sql.catalyst.plans.logical.CreateFunctionStatement => p.functionName + } + val existingFunctions = functionIdentifiers.intersect(queryFunctions) + + if (existingFunctions.nonEmpty) { + logger.warn(s"The following function(s) already exist(s): ${existingFunctions.mkString(", ")}. Query may result in function redefinition.") + return null + } + } + throw e + case e: Exception => + logger.error("Error running query:", e) + throw e + } } def insertUnPartitioned(df: DataFrame, diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala index f151525338..64cdc143bd 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala @@ -409,6 +409,14 @@ class TableUtilsTest { assertTrue(tableUtils.checkTablePermission(tableName)) } + @Test + def testDoubleUDFRegistration(): Unit = { + val resourceURL = getClass.getResource("/jars/brickhouse-0.6.0.jar") + tableUtils.sql(s"ADD JAR ${resourceURL.getPath}") + tableUtils.sql("CREATE TEMPORARY FUNCTION test AS 'brickhouse.udf.date.AddDaysUDF';") + tableUtils.sql("CREATE TEMPORARY FUNCTION test AS 'brickhouse.udf.date.AddDaysUDF';") + } + @Test def testIfPartitionExistsInTable(): Unit = { val tableName = "db.test_if_partition_exists"