Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
"sqlState" : "22023"
},
"INVALID_INPUT_SYNTAX_FOR_NUMERIC_TYPE" : {
"message" : [ "invalid input syntax for type numeric: %s. To return NULL instead, use 'try_cast'. If necessary set %s to false to bypass this error.%s" ],
"message" : [ "invalid input syntax for type %s: %s. To return NULL instead, use 'try_cast'. If necessary set %s to false to bypass this error.%s" ],
"sqlState" : "42000"
},
"INVALID_JSON_SCHEMA_MAPTYPE" : {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
})
case StringType if ansiEnabled =>
buildCast[UTF8String](_,
s => changePrecision(Decimal.fromStringANSI(s, origin.context), target))
s => changePrecision(Decimal.fromStringANSI(s, target.sql, origin.context), target))
case BooleanType =>
buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
case DateType =>
Expand Down Expand Up @@ -845,7 +845,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case _: NumberFormatException =>
val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
if(ansiEnabled && d == null) {
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(s, origin.context)
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(
s, DoubleType.sql, origin.context)
} else {
d
}
Expand All @@ -870,7 +871,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case _: NumberFormatException =>
val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
if (ansiEnabled && f == null) {
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(s, origin.context)
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(
s, FloatType.sql, origin.context)
} else {
f
}
Expand Down Expand Up @@ -1376,9 +1378,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
"""
case StringType if ansiEnabled =>
val errorContext = ctx.addReferenceObj("errCtx", origin.context)
val decimalType = "\"" + target.sql + "\""
(c, evPrim, evNull) =>
code"""
Decimal $tmp = Decimal.fromStringANSI($c, $errorContext);
Decimal $tmp = Decimal.fromStringANSI($c, $decimalType, $errorContext);
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)}
"""
case BooleanType =>
Expand Down Expand Up @@ -1896,10 +1899,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
from match {
case StringType =>
val floatStr = ctx.freshVariable("floatStr", StringType)
val targetType = "\"" + FloatType.sql + "\""
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
val errorContext = ctx.addReferenceObj("errCtx", origin.context)
s"throw QueryExecutionErrors.invalidInputSyntaxForNumericError($c, $errorContext);"
"throw QueryExecutionErrors.invalidInputSyntaxForNumericError(" +
s"$c, $targetType, $errorContext);"
} else {
s"$evNull = true;"
}
Expand Down Expand Up @@ -1936,7 +1941,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
val errorContext = ctx.addReferenceObj("errCtx", origin.context)
s"throw QueryExecutionErrors.invalidInputSyntaxForNumericError($c, $errorContext);"
val targetType = "\"" + DoubleType.sql + "\""
"throw QueryExecutionErrors.invalidInputSyntaxForNumericError(" +
s"$c, $targetType, $errorContext);"
} else {
s"$evNull = true;"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{ByteType, IntegerType, LongType, ShortType}
import org.apache.spark.unsafe.types.UTF8String

/**
Expand All @@ -26,23 +27,27 @@ import org.apache.spark.unsafe.types.UTF8String
object UTF8StringUtils {

def toLongExact(s: UTF8String, errorContext: String): Long =
withException(s.toLongExact, errorContext)
withException(s.toLongExact, s, LongType.sql, errorContext)

def toIntExact(s: UTF8String, errorContext: String): Int =
withException(s.toIntExact, errorContext)
withException(s.toIntExact, s, IntegerType.sql, errorContext)

def toShortExact(s: UTF8String, errorContext: String): Short =
withException(s.toShortExact, errorContext)
withException(s.toShortExact, s, ShortType.sql, errorContext)

def toByteExact(s: UTF8String, errorContext: String): Byte =
withException(s.toByteExact, errorContext)
withException(s.toByteExact, s, ByteType.sql, errorContext)

private def withException[A](f: => A, errorContext: String): A = {
private def withException[A](
f: => A,
s: UTF8String,
targetType: String,
errorContext: String): A = {
try {
f
} catch {
case e: NumberFormatException =>
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(e, errorContext)
case _: NumberFormatException =>
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(s, targetType, errorContext)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,13 @@ object QueryExecutionErrors extends QueryErrorsBase {
decimalPrecision.toString, decimalScale.toString, SQLConf.ANSI_ENABLED.key, context))
}

def invalidInputSyntaxForNumericError(
e: NumberFormatException,
errorContext: String): NumberFormatException = {
new NumberFormatException(s"${e.getMessage}. To return NULL instead, use 'try_cast'. " +
s"If necessary set ${SQLConf.ANSI_ENABLED.key} to false to bypass this error." + errorContext)
}

def invalidInputSyntaxForNumericError(
s: UTF8String,
targetType: String,
errorContext: String): NumberFormatException = {
new SparkNumberFormatException(errorClass = "INVALID_INPUT_SYNTAX_FOR_NUMERIC_TYPE",
messageParameters = Array(toSQLValue(s, StringType), SQLConf.ANSI_ENABLED.key, errorContext))
messageParameters =
Array(targetType, toSQLValue(s, StringType), SQLConf.ANSI_ENABLED.key, errorContext))
}

def cannotCastFromNullTypeError(to: DataType): Throwable = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,10 @@ object Decimal {
}
}

def fromStringANSI(str: UTF8String, errorContext: String = ""): Decimal = {
def fromStringANSI(
str: UTF8String,
decimalType: String = "DECIMAL",
errorContext: String = ""): Decimal = {
try {
val bigDecimal = stringToJavaBigDecimal(str)
// We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow.
Expand All @@ -626,7 +629,7 @@ object Decimal {
}
} catch {
case _: NumberFormatException =>
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(str, errorContext)
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(str, decimalType, errorContext)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp
import java.time.DateTimeException

import org.apache.spark.SparkArithmeticException
import org.apache.spark.{SparkArithmeticException, SparkNumberFormatException}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
Expand Down Expand Up @@ -174,29 +174,35 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
test("cast from invalid string to numeric should throw NumberFormatException") {
// cast to IntegerType
Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType =>
val typeName = dataType.sql
checkExceptionInExpression[SparkNumberFormatException](
cast("string", dataType), s"invalid input syntax for type $typeName: 'string'")
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: 'string'")
cast("123-string", dataType), s"invalid input syntax for type $typeName: '123-string'")
checkExceptionInExpression[NumberFormatException](
cast("123-string", dataType), "invalid input syntax for type numeric: '123-string'")
cast("2020-07-19", dataType), s"invalid input syntax for type $typeName: '2020-07-19'")
checkExceptionInExpression[NumberFormatException](
cast("2020-07-19", dataType), "invalid input syntax for type numeric: '2020-07-19'")
checkExceptionInExpression[NumberFormatException](
cast("1.23", dataType), "invalid input syntax for type numeric: '1.23'")
cast("1.23", dataType), s"invalid input syntax for type $typeName: '1.23'")
}

Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType =>
val typeName = dataType.sql
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: 'string'")
cast("string", dataType), s"invalid input syntax for type $typeName: 'string'")
checkExceptionInExpression[NumberFormatException](
cast("123.000.00", dataType), "invalid input syntax for type numeric: '123.000.00'")
cast("123.000.00", dataType), s"invalid input syntax for type $typeName: '123.000.00'")
checkExceptionInExpression[NumberFormatException](
cast("abc.com", dataType), "invalid input syntax for type numeric: 'abc.com'")
cast("abc.com", dataType), s"invalid input syntax for type $typeName: 'abc.com'")
}
}

protected def checkCastToNumericError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
val typeName = to match {
case a: ArrayType => a.elementType.sql
case _ => to.sql
}
checkExceptionInExpression[NumberFormatException](
cast(l, to), "invalid input syntax for type numeric: 'true'")
cast(l, to), s"invalid input syntax for type $typeName: 'true'")
}

test("cast from invalid string array to numeric array should throw NumberFormatException") {
Expand Down Expand Up @@ -243,7 +249,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {

checkExceptionInExpression[NumberFormatException](
cast("abcd", DecimalType(38, 1)),
"invalid input syntax for type numeric")
"invalid input syntax for type DECIMAL(38,1): 'abcd'")
}

protected def checkCastToBooleanError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
Expand Down Expand Up @@ -368,8 +374,8 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
assert(ret.resolved == !isTryCast)
if (!isTryCast) {
checkExceptionInExpression[NumberFormatException](
ret, "invalid input syntax for type numeric")
checkExceptionInExpression[SparkNumberFormatException](
ret, "invalid input syntax for type INT: 'a'")
}
}

Expand All @@ -387,7 +393,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
assert(ret.resolved == !isTryCast)
if (!isTryCast) {
checkExceptionInExpression[NumberFormatException](
ret, "invalid input syntax for type numeric")
ret, "invalid input syntax for type INT: 'a'")
}
}
}
Expand Down Expand Up @@ -511,8 +517,8 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {

assert(ret.resolved === !isTryCast)
if (!isTryCast) {
checkExceptionInExpression[NumberFormatException](
ret, "invalid input syntax for type numeric")
checkExceptionInExpression[SparkNumberFormatException](
ret, "invalid input syntax for type INT")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper

assert(Decimal.fromString(UTF8String.fromString("str")) === null)
val e = intercept[NumberFormatException](Decimal.fromStringANSI(UTF8String.fromString("str")))
assert(e.getMessage.contains("invalid input syntax for type numeric"))
assert(e.getMessage.contains("invalid input syntax for type DECIMAL"))
}

test("SPARK-35841: Casting string to decimal type doesn't work " +
Expand Down
10 changes: 8 additions & 2 deletions sql/core/src/test/resources/sql-tests/inputs/cast.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ SELECT CAST('1.23' AS long);
SELECT CAST('-4.56' AS int);
SELECT CAST('-4.56' AS long);

-- cast string which are not numbers to integral should return null
-- cast string which are not numbers to numeric types
SELECT CAST('abc' AS int);
SELECT CAST('abc' AS long);
SELECT CAST('abc' AS float);
SELECT CAST('abc' AS double);

-- cast string representing a very large number to integral should return null
SELECT CAST('1234567890123' AS int);
Expand All @@ -15,14 +17,18 @@ SELECT CAST('12345678901234567890123' AS long);
-- cast empty string to integral should return null
SELECT CAST('' AS int);
SELECT CAST('' AS long);
SELECT CAST('' AS float);
SELECT CAST('' AS double);

-- cast null to integral should return null
SELECT CAST(NULL AS int);
SELECT CAST(NULL AS long);

-- cast invalid decimal string to integral should return null
-- cast invalid decimal string to numeric types
SELECT CAST('123.a' AS int);
SELECT CAST('123.a' AS long);
SELECT CAST('123.a' AS float);
SELECT CAST('123.a' AS double);

-- '-2147483648' is the smallest int value
SELECT CAST('-2147483648' AS int);
Expand Down
Loading