Skip to content

Commit bc30a07

Browse files
Ngone51cloud-fan
authored andcommitted
[SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default
### What changes were proposed in this pull request? This PR proposes to throw exception by default when user use untyped UDF(a.k.a `org.apache.spark.sql.functions.udf(AnyRef, DataType)`). And user could still use it by setting `spark.sql.legacy.useUnTypedUdf.enabled` to `true`. ### Why are the changes needed? According to #23498, since Spark 3.0, the untyped UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return 0 in Spark 3.0 but null in Spark 2.4. And the behavior change is introduced due to Spark3.0 is built with Scala 2.12 by default. As a result, this might change data silently and may cause correctness issue if user still expect `null` in some cases. Thus, we'd better to encourage user to use typed UDF to avoid this problem. ### Does this PR introduce any user-facing change? Yeah. User will hit exception now when use untyped UDF. ### How was this patch tested? Added test and updated some tests. Closes #27488 from Ngone51/spark_26580_followup. Lead-authored-by: yi.wu <yi.wu@databricks.com> Co-authored-by: wuyi <yi.wu@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 82ce475) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent a415d07 commit bc30a07

File tree

10 files changed

+62
-31
lines changed

10 files changed

+62
-31
lines changed

docs/sql-migration-guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ license: |
6363

6464
- Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring.
6565

66-
- In Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(Any, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. Since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default.
66+
- Since Spark 3.0, using `org.apache.spark.sql.functions.udf(AnyRef, DataType)` is not allowed by default. Set `spark.sql.legacy.allowUntypedScalaUDF` to true to keep using it. But please note that, in Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(AnyRef, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. However, since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default.
6767

6868
- Since Spark 3.0, Proleptic Gregorian calendar is used in parsing, formatting, and converting dates and timestamps as well as in extracting sub-components like years, days and etc. Spark 3.0 uses Java 8 API classes from the java.time packages that based on ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html). In Spark version 2.4 and earlier, those operations are performed by using the hybrid calendar (Julian + Gregorian, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html). The changes impact on the results for dates before October 15, 1582 (Gregorian) and affect on the following Spark 3.0 API:
6969

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.ml
1919

2020
import scala.annotation.varargs
21+
import scala.reflect.runtime.universe.TypeTag
2122

2223
import org.apache.spark.annotation.{DeveloperApi, Since}
2324
import org.apache.spark.internal.Logging
@@ -79,7 +80,7 @@ abstract class Transformer extends PipelineStage {
7980
* result as a new column.
8081
*/
8182
@DeveloperApi
82-
abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
83+
abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
8384
extends Transformer with HasInputCol with HasOutputCol with Logging {
8485

8586
/** @group setParam */
@@ -118,7 +119,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
118119

119120
override def transform(dataset: Dataset[_]): DataFrame = {
120121
val outputSchema = transformSchema(dataset.schema, logging = true)
121-
val transformUDF = udf(this.createTransformFunc, outputDataType)
122+
val transformUDF = udf(this.createTransformFunc)
122123
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))),
123124
outputSchema($(outputCol)).metadata)
124125
}

mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
9898

9999
override def transform(dataset: Dataset[_]): DataFrame = {
100100
transformSchema(dataset.schema, logging = true)
101-
val transformUDF = udf(hashFunction(_: Vector), DataTypes.createArrayType(new VectorUDT))
101+
val transformUDF = udf(hashFunction(_: Vector))
102102
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
103103
}
104104

@@ -128,14 +128,13 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
128128
}
129129

130130
// In the origin dataset, find the hash value that hash the same bucket with the key
131-
val sameBucketWithKeyUDF = udf((x: Seq[Vector]) =>
132-
sameBucket(x, keyHash), DataTypes.BooleanType)
131+
val sameBucketWithKeyUDF = udf((x: Seq[Vector]) => sameBucket(x, keyHash))
133132

134133
modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol))))
135134
} else {
136135
// In the origin dataset, find the hash value that is closest to the key
137136
// Limit the use of hashDist since it's controversial
138-
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType)
137+
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash))
139138
val hashDistCol = hashDistUDF(col($(outputCol)))
140139
val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)
141140

@@ -172,7 +171,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
172171
}
173172

174173
// Get the top k nearest neighbor by their distance to the key
175-
val keyDistUDF = udf((x: Vector) => keyDistance(x, key), DataTypes.DoubleType)
174+
val keyDistUDF = udf((x: Vector) => keyDistance(x, key))
176175
val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol))))
177176
modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors)
178177
}
@@ -290,7 +289,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
290289
.drop(explodeCols: _*).distinct()
291290

292291
// Add a new column to store the distance of the two rows.
293-
val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType)
292+
val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y))
294293
val joinedDatasetWithDist = joinedDataset.select(col("*"),
295294
distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol)
296295
)

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ import org.apache.spark.ml.param._
2929
import org.apache.spark.ml.param.shared.HasPredictionCol
3030
import org.apache.spark.ml.util._
3131
import org.apache.spark.ml.util.Instrumentation.instrumented
32-
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
33-
FPGrowth => MLlibFPGrowth}
32+
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth}
3433
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
3534
import org.apache.spark.sql._
35+
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
3636
import org.apache.spark.sql.functions._
3737
import org.apache.spark.sql.types._
3838
import org.apache.spark.storage.StorageLevel
@@ -286,14 +286,17 @@ class FPGrowthModel private[ml] (
286286

287287
val dt = dataset.schema($(itemsCol)).dataType
288288
// For each rule, examine the input items and summarize the consequents
289-
val predictUDF = udf((items: Seq[Any]) => {
289+
val predictUDF = SparkUserDefinedFunction((items: Seq[Any]) => {
290290
if (items != null) {
291291
val itemset = items.toSet
292292
brRules.value.filter(_._1.forall(itemset.contains))
293293
.flatMap(_._2.filter(!itemset.contains(_))).distinct
294294
} else {
295295
Seq.empty
296-
}}, dt)
296+
}},
297+
dt,
298+
Nil
299+
)
297300
dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol))))
298301
}
299302

mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,8 @@ private[ml] object LSHTest {
7676

7777
// Perform a cross join and label each pair of same_bucket and distance
7878
val pairs = transformedData.as("a").crossJoin(transformedData.as("b"))
79-
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
80-
val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0,
81-
DataTypes.BooleanType)
79+
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y))
80+
val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0)
8281
val result = pairs
8382
.withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol")))
8483
.withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")))
@@ -110,7 +109,7 @@ private[ml] object LSHTest {
110109
val model = lsh.fit(dataset)
111110

112111
// Compute expected
113-
val distUDF = udf((x: Vector) => model.keyDistance(x, key), DataTypes.DoubleType)
112+
val distUDF = udf((x: Vector) => model.keyDistance(x, key))
114113
val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k)
115114

116115
// Compute actual
@@ -148,7 +147,7 @@ private[ml] object LSHTest {
148147
val inputCol = model.getInputCol
149148

150149
// Compute expected
151-
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
150+
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y))
152151
val expected = datasetA.as("a").crossJoin(datasetB.as("b"))
153152
.filter(distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")) < threshold)
154153

project/MimaExcludes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ object MimaExcludes {
7474
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.getRuns"),
7575
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.setRuns"),
7676

77+
// [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default
78+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.UnaryTransformer.this"),
79+
7780
// [SPARK-27090][CORE] Removing old LEGACY_DRIVER_IDENTIFIER ("<driver>")
7881
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.LEGACY_DRIVER_IDENTIFIER"),
7982
// [SPARK-25838] Remove formatVersion from Saveable

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,6 +2017,14 @@ object SQLConf {
20172017
.booleanConf
20182018
.createWithDefault(false)
20192019

2020+
val LEGACY_ALLOW_UNTYPED_SCALA_UDF =
2021+
buildConf("spark.sql.legacy.allowUntypedScalaUDF")
2022+
.internal()
2023+
.doc("When set to true, user is allowed to use org.apache.spark.sql.functions." +
2024+
"udf(f: AnyRef, dataType: DataType). Otherwise, exception will be throw.")
2025+
.booleanConf
2026+
.createWithDefault(false)
2027+
20202028
val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL =
20212029
buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled")
20222030
.internal()

sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ sealed abstract class UserDefinedFunction {
9090
def asNondeterministic(): UserDefinedFunction
9191
}
9292

93-
private[sql] case class SparkUserDefinedFunction(
93+
private[spark] case class SparkUserDefinedFunction(
9494
f: AnyRef,
9595
dataType: DataType,
9696
inputSchemas: Seq[Option[ScalaReflection.Schema]],

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4733,6 +4733,15 @@ object functions {
47334733
* @since 2.0.0
47344734
*/
47354735
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
4736+
if (!SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF)) {
4737+
val errorMsg = "You're using untyped Scala UDF, which does not have the input type " +
4738+
"information. Spark may blindly pass null to the Scala closure with primitive-type " +
4739+
"argument, and the closure will see the default value of the Java type for the null " +
4740+
"argument, e.g. `udf((x: Int) => x, IntegerType)`, the result is 0 for null input. " +
4741+
"You could use other typed Scala UDF APIs to avoid this problem, or set " +
4742+
s"${SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key} to true and use this API with caution."
4743+
throw new AnalysisException(errorMsg)
4744+
}
47364745
SparkUserDefinedFunction(f, dataType, inputSchemas = Nil)
47374746
}
47384747

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,12 @@ class UDFSuite extends QueryTest with SharedSparkSession {
134134
assert(df1.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
135135
assert(df1.head().getDouble(0) >= 0.0)
136136

137-
val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic()
138-
val df2 = testData.select(bar())
139-
assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
140-
assert(df2.head().getDouble(0) >= 0.0)
137+
withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") {
138+
val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic()
139+
val df2 = testData.select(bar())
140+
assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
141+
assert(df2.head().getDouble(0) >= 0.0)
142+
}
141143

142144
val javaUdf = udf(new UDF0[Double] {
143145
override def call(): Double = Math.random()
@@ -441,16 +443,23 @@ class UDFSuite extends QueryTest with SharedSparkSession {
441443
}
442444

443445
test("SPARK-25044 Verify null input handling for primitive types - with udf(Any, DataType)") {
444-
val f = udf((x: Int) => x, IntegerType)
445-
checkAnswer(
446-
Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")),
447-
Row(1) :: Row(0) :: Nil)
446+
withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") {
447+
val f = udf((x: Int) => x, IntegerType)
448+
checkAnswer(
449+
Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")),
450+
Row(1) :: Row(0) :: Nil)
451+
452+
val f2 = udf((x: Double) => x, DoubleType)
453+
checkAnswer(
454+
Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")),
455+
Row(1.1) :: Row(0.0) :: Nil)
456+
}
448457

449-
val f2 = udf((x: Double) => x, DoubleType)
450-
checkAnswer(
451-
Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")),
452-
Row(1.1) :: Row(0.0) :: Nil)
458+
}
453459

460+
test("use untyped Scala UDF should fail by default") {
461+
val e = intercept[AnalysisException](udf((x: Int) => x, IntegerType))
462+
assert(e.getMessage.contains("You're using untyped Scala UDF"))
454463
}
455464

456465
test("SPARK-26308: udf with decimal") {

0 commit comments

Comments
 (0)