diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 0f9b665718ca6..a3bc08cd4f198 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -143,6 +143,11 @@ "Offset expression must be a literal." ] }, + "HASH_MAP_TYPE" : { + "message" : [ + "Input to the function cannot contain elements of the \"MAP\" type. In Spark, same maps may have different hashcode, thus hash expressions are prohibited on \"MAP\" elements. To restore previous behavior set \"spark.sql.legacy.allowHashOnMapType\" to \"true\"." + ] + }, "INVALID_JSON_MAP_KEY_TYPE" : { "message" : [ "Input schema can only contain STRING as a key type for a MAP." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 7ac486f05af1b..4f8ed1953f409 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,8 @@ import org.apache.commons.codec.digest.MessageDigestAlgorithms import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} @@ -268,15 +270,17 @@ abstract class HashExpression[E] extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.length < 1) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName requires at least one argument") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_PARAMS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "> 0", + "actualNum" -> children.length.toString)) } else if (children.exists(child => hasMapType(child.dataType)) && !SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE)) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName cannot contain elements of MapType. In Spark, same maps " + - "may have different hashcode, thus hash expressions are prohibited on MapType elements." + - s" To restore previous behavior set ${SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE.key} " + - "to true.") + DataTypeMismatch( + errorSubClass = "HASH_MAP_TYPE", + messageParameters = Map("functionName" -> toSQLId(prettyName))) } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 5643598b4bd56..28739fb47a2b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -22,7 +22,8 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils} @@ -1481,7 +1482,12 @@ abstract class RoundBase(child: Expression, scale: Expression, if (scale.foldable) { TypeCheckSuccess } else { - TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "scala", + "inputType" -> toSQLType(scale.dataType), + "inputExpr" -> toSQLExpr(scale))) } case f => f } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index b41f627bac94e..0d66ad4b06848 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -440,8 +441,31 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer ) assertError(Coalesce(Nil), "function coalesce requires at least one argument") - assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") - assertError(new XxHash64(Nil), "function xxhash64 requires at least one argument") + + val murmur3Hash = new Murmur3Hash(Nil) + checkError( + exception = intercept[AnalysisException] { + assertSuccess(murmur3Hash) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + parameters = Map( + "sqlExpr" -> "\"hash()\"", + "functionName" -> toSQLId(murmur3Hash.prettyName), + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + val xxHash64 = new XxHash64(Nil) + checkError( + exception = intercept[AnalysisException] { + assertSuccess(xxHash64) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + parameters = Map( + "sqlExpr" -> "\"xxhash64()\"", + "functionName" -> toSQLId(xxHash64.prettyName), + "expectedNum" -> "> 0", + "actualNum" -> "0")) + assertError(Explode($"intField"), "input to function explode should be array or map type") assertError(PosExplode($"intField"), @@ -478,8 +502,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer assertSuccess(Round(Literal(null), Literal(null))) assertSuccess(Round($"intField", Literal(1))) - assertError(Round($"intField", $"intField"), - "Only foldable Expression is allowed") + checkError( + exception = intercept[AnalysisException] { + assertSuccess(Round($"intField", $"intField")) + }, + errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "sqlExpr" -> "\"round(intField, intField)\"", + "inputName" -> "scala", + "inputType" -> "\"INT\"", + "inputExpr" -> "\"intField\"")) + checkError( exception = intercept[AnalysisException] { assertSuccess(Round($"intField", $"booleanField")) @@ -516,9 +549,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer assertSuccess(BRound(Literal(null), Literal(null))) assertSuccess(BRound($"intField", Literal(1))) - - assertError(BRound($"intField", $"intField"), - "Only foldable Expression is allowed") + checkError( + exception = intercept[AnalysisException] { + assertSuccess(BRound($"intField", $"intField")) + }, + errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "sqlExpr" -> "\"bround(intField, intField)\"", + "inputName" -> "scala", + "inputType" -> "\"INT\"", + "inputExpr" -> "\"intField\"")) checkError( exception = intercept[AnalysisException] { assertSuccess(BRound($"intField", $"booleanField")) @@ -602,4 +642,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer assert(Literal.create(Map(42L -> null), MapType(LongType, NullType)).sql == "MAP(42L, NULL)") } + + test("hash expressions are prohibited on MapType elements") { + val argument = Literal.create(Map(42L -> true), MapType(LongType, BooleanType)) + val murmur3Hash = new Murmur3Hash(Seq(argument)) + assert(murmur3Hash.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "HASH_MAP_TYPE", + messageParameters = Map("functionName" -> toSQLId(murmur3Hash.prettyName)) + ) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7dea7799b666d..c52cb85e119d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -4219,16 +4219,68 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val funcsMustHaveAtLeastOneArg = ("coalesce", (df: DataFrame) => df.select(coalesce())) :: - ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: - ("hash", (df: DataFrame) => df.select(hash())) :: - ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: - ("xxhash64", (df: DataFrame) => df.select(xxhash64())) :: - ("xxhash64", (df: DataFrame) => df.selectExpr("xxhash64()")) :: Nil + ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: Nil funcsMustHaveAtLeastOneArg.foreach { case (name, func) => val errMsg = intercept[AnalysisException] { func(df) }.getMessage assert(errMsg.contains(s"input to function $name requires at least one argument")) } + checkError( + exception = intercept[AnalysisException] { + df.select(hash()) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"hash()\"", + "functionName" -> "`hash`", + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("hash()") + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"hash()\"", + "functionName" -> "`hash`", + "expectedNum" -> "> 0", + "actualNum" -> "0"), + context = ExpectedContext( + fragment = "hash()", + start = 0, + stop = 5)) + + checkError( + exception = intercept[AnalysisException] { + df.select(xxhash64()) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"xxhash64()\"", + "functionName" -> "`xxhash64`", + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("xxhash64()") + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"xxhash64()\"", + "functionName" -> "`xxhash64`", + "expectedNum" -> "> 0", + "actualNum" -> "0"), + context = ExpectedContext( + fragment = "xxhash64()", + start = 0, + stop = 9)) + checkError( exception = intercept[AnalysisException] { df.select(greatest())