diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3754a1a0374a..9b64144af64b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1077,6 +1077,7 @@ public UTF8String translate(Map dict) { */ public static class LongWrapper implements Serializable { public transient long value = 0; + public transient boolean formatInvalid = false; } /** @@ -1088,6 +1089,7 @@ public static class LongWrapper implements Serializable { */ public static class IntWrapper implements Serializable { public transient int value = 0; + public transient boolean formatInvalid = false; } /** @@ -1140,6 +1142,7 @@ public boolean toLong(LongWrapper toLongResult) { if (b >= '0' && b <= '9') { digit = b - '0'; } else { + toLongResult.formatInvalid = false; return false; } @@ -1233,6 +1236,7 @@ public boolean toInt(IntWrapper intWrapper) { if (b >= '0' && b <= '9') { digit = b - '0'; } else { + intWrapper.formatInvalid = true; return false; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index fa27a48419db..a61863c9a129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -480,11 +480,28 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s)) } + private[this] def onStringToIntegerFailed( + str: UTF8String, + formatInvalid: Boolean, + typeName: String): Any = { + if (ansiEnabled) { + if (formatInvalid) { + throw new ArithmeticException(s"Invalid input syntax for type integer: $str") + } else { + throw new ArithmeticException(s"Casting $str to $typeName causes overflow") + } + } else { + null + } + } + // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => val result = new LongWrapper() - buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) + buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else { + onStringToIntegerFailed(s, result.formatInvalid, "Long") + }) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => @@ -501,7 +518,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) + buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else { + onStringToIntegerFailed(s, result.formatInvalid, "Int") + }) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -523,7 +542,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => if (s.toShort(result)) { result.value.toShort } else { - null + onStringToIntegerFailed(s, result.formatInvalid, "Short") }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -564,7 +583,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => if (s.toByte(result)) { result.value.toByte } else { - null + onStringToIntegerFailed(s, result.formatInvalid, "Byte") }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cf24372e0e0b..59676eb97325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3384,6 +3384,36 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { assert(exp.getMessage.contains("Resources not found")) } } + + test("SPARK-30472: ANSI SQL: Throw exception on format invalid and overflow when casting " + + "String to Integer type.") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('abc' as STRING) AS INTEGER)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('abc' as STRING) AS BYTE)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('abc' as STRING) AS SHORT)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('abc' as STRING) AS LONG)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('2147483648' as STRING) AS INTEGER)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('128' as STRING) AS BYTE)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('32768' as STRING) AS SHORT)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('9223372036854775808' as STRING) AS LONG)").collect() + ) + } + } } case class Foo(bar: Option[String])