Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sql-migration-guide-upgrade.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ license: |
{:toc}

## Upgrading From Spark SQL 2.4 to 3.0
- Since Spark 3.0, trim the string when casting from string to boolean, date, timestamp or numeric types, whitespace is trimmed from the ends of the value first.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:
Since Spark 3.0, when a string is cast to boolean/date/timestamp/numeric types, it is trimmed before it is parsed.


- Since Spark 3.0, PySpark requires a Pandas version of 0.23.2 or higher to use Pandas related functionality, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc.

- Since Spark 3.0, PySpark requires a PyArrow version of 0.12.1 or higher to use PyArrow related functionality, such as `pandas_udf`, `toPandas` and `createDataFrame` with "spark.sql.execution.arrow.enabled=true", etc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
val result = new LongWrapper()
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
buildCast[UTF8String](_, s => if (s.trim.toLong(result)) result.value else null)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be a correct fix. Do we have a possibility of performance regression?

case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L)
case DateType =>
Expand All @@ -448,7 +448,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
buildCast[UTF8String](_, s => if (s.trim.toInt(result)) result.value else null)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
Expand All @@ -463,7 +463,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toShort(result)) {
buildCast[UTF8String](_, s => if (s.trim.toShort(result)) {
result.value.toShort
} else {
null
Expand All @@ -482,7 +482,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toByte(result)) {
buildCast[UTF8String](_, s => if (s.trim.toByte(result)) {
result.value.toByte
} else {
null
Expand Down Expand Up @@ -518,7 +518,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try {
changePrecision(Decimal(new JavaBigDecimal(s.toString)), target)
changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target)
} catch {
case _: NumberFormatException => null
})
Expand All @@ -544,7 +544,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toString.toDouble catch {
buildCast[UTF8String](_, s => try s.toString.trim.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -560,7 +560,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toString.toFloat catch {
buildCast[UTF8String](_, s => try s.toString.trim.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand Down Expand Up @@ -983,7 +983,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) =>
code"""
try {
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString()));
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: $c.trim().toString() may be more efficient?

${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
Expand Down Expand Up @@ -1136,7 +1136,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) =>
code"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toByte($wrapper)) {
if ($c.trim().toByte($wrapper)) {
$evPrim = (byte) $wrapper.value;
} else {
$evNull = true;
Expand All @@ -1163,7 +1163,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) =>
code"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toShort($wrapper)) {
if ($c.trim().toShort($wrapper)) {
$evPrim = (short) $wrapper.value;
} else {
$evNull = true;
Expand All @@ -1188,7 +1188,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) =>
code"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toInt($wrapper)) {
if ($c.trim().toInt($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
Expand All @@ -1214,7 +1214,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) =>
code"""
UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
if ($c.toLong($wrapper)) {
if ($c.trim().toLong($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
Expand All @@ -1238,7 +1238,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) =>
code"""
try {
$evPrim = Float.valueOf($c.toString());
$evPrim = Float.valueOf($c.toString().trim());
} catch (java.lang.NumberFormatException e) {
$evNull = true;
}
Expand All @@ -1260,7 +1260,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) =>
code"""
try {
$evPrim = Double.valueOf($c.toString());
$evPrim = Double.valueOf($c.toString().trim());
} catch (java.lang.NumberFormatException e) {
$evNull = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ object StringUtils extends Logging {
private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)

// scalastyle:off caselocale
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.trim.toLowerCase)
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.trim.toLowerCase)
// scalastyle:on caselocale

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.math.BigDecimal
import java.sql.{Date, Timestamp}
import java.util.{Calendar, TimeZone}
import java.util.concurrent.TimeUnit._
Expand Down Expand Up @@ -1018,4 +1019,27 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ret, InternalRow(null))
}
}

test("Trim the string when cast string type to Boolean/Numeric types") {
Seq(" true ", " true", "true ").foreach { str =>
checkEvaluation(Cast(Literal(str), BooleanType), true)
}
Seq(" false ", " false", "false ").foreach { str =>
checkEvaluation(Cast(Literal(str), BooleanType), false)
}

Seq(" 1 ", " 1", "1 ").foreach { str =>
checkEvaluation(Cast(Literal(str), ByteType), 1.toByte)
checkEvaluation(Cast(Literal(str), ShortType), 1.toShort)
checkEvaluation(Cast(Literal(str), IntegerType), 1)
checkEvaluation(Cast(Literal(str), LongType), 1L)
checkEvaluation(Cast(Literal(str), DecimalType.IntDecimal), BigDecimal.ONE)
}

Seq(" 1.23 ", " 1.23", "1.23 ").foreach { str =>
checkEvaluation(Cast(Literal(str), FloatType), 1.23F)
checkEvaluation(Cast(Literal(str), DoubleType), 1.23D)
checkEvaluation(Cast(Literal(str), DecimalType.FloatDecimal), BigDecimal.valueOf(1.2300000))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ SELECT boolean(' f ') AS false
-- !query 4 schema
struct<false:boolean>
-- !query 4 output
NULL
false


-- !query 5
Expand Down Expand Up @@ -296,7 +296,7 @@ SELECT boolean(string(' true ')) AS true,
-- !query 36 schema
struct<true:boolean,false:boolean>
-- !query 36 output
NULL NULL
true false


-- !query 37
Expand Down