Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ public UTF8String translate(Map<Character, Character> dict) {
*/
public static class LongWrapper implements Serializable {
public transient long value = 0;
public transient boolean formatInvalid = false;
}

/**
Expand All @@ -1088,6 +1089,7 @@ public static class LongWrapper implements Serializable {
*/
public static class IntWrapper implements Serializable {
public transient int value = 0;
public transient boolean formatInvalid = false;
}

/**
Expand Down Expand Up @@ -1140,6 +1142,7 @@ public boolean toLong(LongWrapper toLongResult) {
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
toLongResult.formatInvalid = false;
return false;
}

Expand Down Expand Up @@ -1233,6 +1236,7 @@ public boolean toInt(IntWrapper intWrapper) {
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
intWrapper.formatInvalid = true;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,28 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s))
}

private[this] def onStringToIntegerFailed(
str: UTF8String,
formatInvalid: Boolean,
typeName: String): Any = {
if (ansiEnabled) {
if (formatInvalid) {
throw new ArithmeticException(s"Invalid input syntax for type integer: $str")
} else {
throw new ArithmeticException(s"Casting $str to $typeName causes overflow")
}
} else {
null
}
}

// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
val result = new LongWrapper()
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else {
onStringToIntegerFailed(s, result.formatInvalid, "Long")
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L)
case DateType =>
Expand All @@ -501,7 +518,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else {
onStringToIntegerFailed(s, result.formatInvalid, "Int")
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
Expand All @@ -523,7 +542,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
buildCast[UTF8String](_, s => if (s.toShort(result)) {
result.value.toShort
} else {
null
onStringToIntegerFailed(s, result.formatInvalid, "Short")
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
Expand Down Expand Up @@ -564,7 +583,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
buildCast[UTF8String](_, s => if (s.toByte(result)) {
result.value.toByte
} else {
null
onStringToIntegerFailed(s, result.formatInvalid, "Byte")
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
Expand Down
30 changes: 30 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3384,6 +3384,36 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession {
assert(exp.getMessage.contains("Resources not found"))
}
}

test("SPARK-30472: ANSI SQL: Throw exception on format invalid and overflow when casting " +
"String to Integer type.") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
intercept[ArithmeticException](
sql("SELECT CAST(CAST('abc' as STRING) AS INTEGER)").collect()
)
intercept[ArithmeticException](
sql("SELECT CAST(CAST('abc' as STRING) AS BYTE)").collect()
)
intercept[ArithmeticException](
sql("SELECT CAST(CAST('abc' as STRING) AS SHORT)").collect()
)
intercept[ArithmeticException](
sql("SELECT CAST(CAST('abc' as STRING) AS LONG)").collect()
)
intercept[ArithmeticException](
sql("SELECT CAST(CAST('2147483648' as STRING) AS INTEGER)").collect()
)
intercept[ArithmeticException](
sql("SELECT CAST(CAST('128' as STRING) AS BYTE)").collect()
)
intercept[ArithmeticException](
sql("SELECT CAST(CAST('32768' as STRING) AS SHORT)").collect()
)
intercept[ArithmeticException](
sql("SELECT CAST(CAST('9223372036854775808' as STRING) AS LONG)").collect()
)
}
}
}

case class Foo(bar: Option[String])