diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f15ae3255ca98..cb3aafa7da04c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} -import org.apache.spark.SparkException +import org.apache.spark.{SparkContext, SparkException, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{DataTypes, _} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -122,6 +122,9 @@ object Cast { """) case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant { + private val fractionalToTimestampCastingErrorMessage = "Can not cast NaN or infinite" + + s" fractional value to ${TimestampType.simpleString}." + override def toString: String = s"cast($child as ${dataType.simpleString})" override def checkInputDataTypes(): TypeCheckResult = { @@ -161,7 +164,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } else if (StringUtils.isFalseString(s)) { false } else { - null + if (failOnCastErrorEnabled) { + throw new RuntimeException(s"Can not cast '$s' to ${BooleanType.simpleString}.") + } else { + null + } } }) case TimestampType => @@ -215,8 +222,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def decimalToTimestamp(d: Decimal): Long = { (d.toBigDecimal * 1000000L).longValue() } + private[this] def doubleToTimestamp(d: Double): Any = { - if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong + if (d.isNaN || d.isInfinite) { + if (failOnCastErrorEnabled) { + throw new TypeCastException(DoubleType, TimestampType, d) + } else { + null + } + } + else (d * 1000000L).toLong } // converting seconds to us @@ -231,7 +246,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s).orNull) + buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s).getOrElse(() => { + if (failOnCastErrorEnabled) { + throw new TypeCastException(StringType, DateType, s) + } else { + null + } + })) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. @@ -248,7 +269,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try s.toLong catch { - case _: NumberFormatException => null + case _: NumberFormatException => if (failOnCastErrorEnabled) { + throw new TypeCastException(StringType, LongType, s) + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) @@ -264,7 +289,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try s.toInt catch { - case _: NumberFormatException => null + case _: NumberFormatException => if (failOnCastErrorEnabled) { + throw new TypeCastException(StringType, IntegerType, s) + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) @@ -280,7 +309,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try s.toShort catch { - case _: NumberFormatException => null + case _: NumberFormatException => if (failOnCastErrorEnabled) { + throw new TypeCastException(StringType, ShortType, s) + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -296,7 +329,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try s.toByte catch { - case _: NumberFormatException => null + case _: NumberFormatException => if (failOnCastErrorEnabled) { + throw new TypeCastException(StringType, ByteType, s) + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) @@ -315,7 +352,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w * NOTE: this modifies `value` in-place, so don't call it on external data. */ private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { - if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null + if (value.changePrecision(decimalType.precision, decimalType.scale)) { + value + } else { + if (failOnCastErrorEnabled) { + throw new TypeCastException(DecimalType(value.precision, value.scale), decimalType, value) + } else { + null + } + } } private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { @@ -323,7 +368,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[UTF8String](_, s => try { changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) } catch { - case _: NumberFormatException => null + case _: NumberFormatException => if (failOnCastErrorEnabled) { + throw new TypeCastException(StringType, target, s) + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) @@ -340,7 +389,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w b => try { changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target) } catch { - case _: NumberFormatException => null + case _: NumberFormatException => if (failOnCastErrorEnabled) { + throw new TypeCastException(StringType, target, b) + } else { + null + } } } @@ -348,7 +401,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try s.toString.toDouble catch { - case _: NumberFormatException => null + case _: NumberFormatException => if (failOnCastErrorEnabled) { + throw new TypeCastException(StringType, DoubleType, s) + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) @@ -364,7 +421,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try s.toString.toFloat catch { - case _: NumberFormatException => null + case _: NumberFormatException => if (failOnCastErrorEnabled) { + throw new TypeCastException(StringType, FloatType, s) + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1f else 0f) @@ -535,7 +596,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if ($intOpt.isDefined()) { $evPrim = ((Integer) $intOpt.get()).intValue(); } else { - $evNull = true; + ${ + if (failOnCastErrorEnabled) { + s"throw new org.apache.spark.sql.catalyst.expressions.TypeCastException(" + + s"${javaDataTypeName(StringType)}, ${javaDataTypeName(DateType)}," + + s" $c);" + } else { + s"$evNull = true;" + } + } } """ case TimestampType => @@ -551,7 +620,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { $evPrim = $d; } else { - $evNull = true; + ${ + if (failOnCastErrorEnabled) { + s"throw new org.apache.spark.sql.catalyst.expressions.TypeCastException(" + + s"$javaDataTypesClassName.createDecimalType($d.precision(), $d.scale())," + + s" ${javaDataTypeName(decimalType)}, $d);" + } else { + s"$evNull = true;" + } + } } """ @@ -563,14 +640,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w from match { case StringType => (c, evPrim, evNull) => - s""" - try { - Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); - ${changePrecision(tmp, target, evPrim, evNull)} - } catch (java.lang.NumberFormatException e) { - $evNull = true; - } - """ + castStringToNumberCode( + s"""Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); + ${changePrecision(tmp, target, evPrim, evNull)}""", evNull, c, target) case BooleanType => (c, evPrim, evNull) => s""" @@ -579,7 +651,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w """ case DateType => // date can't cast to decimal in Hive - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => + if (failOnCastErrorEnabled) { + s"throw new org.apache.spark.sql.catalyst.expressions.TypeCastException(" + + s"${javaDataTypeName(DateType)}, ${javaDataTypeName(target)}," + + s" $c);" + } else { + s"$evNull = true;" + } case TimestampType => // Note that we lose precision here. (c, evPrim, evNull) => @@ -608,7 +687,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); ${changePrecision(tmp, target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { - $evNull = true; + ${ + if (failOnCastErrorEnabled) { + s"throw new org.apache.spark.sql.catalyst.expressions.TypeCastException(" + + s"${javaDataTypeName(x)}, ${javaDataTypeName(target)}," + + s" $c);" + } else { + s"$evNull = true;" + } + } } """ } @@ -626,7 +713,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if ($longOpt.isDefined()) { $evPrim = ((Long) $longOpt.get()).longValue(); } else { - $evNull = true; + ${ + if (failOnCastErrorEnabled) { + s"throw new org.apache.spark.sql.catalyst.expressions.TypeCastException(" + + s"${javaDataTypeName(StringType)}, ${javaDataTypeName(TimestampType)}," + + s" $c);" + } else { + s"$evNull = true;" + } + } } """ case BooleanType => @@ -642,7 +737,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" if (Double.isNaN($c) || Double.isInfinite($c)) { - $evNull = true; + ${ + if (failOnCastErrorEnabled) { + s"throw new org.apache.spark.sql.catalyst.expressions.TypeCastException(" + + s"${javaDataTypeName(DoubleType)}, ${javaDataTypeName(TimestampType)}," + + s" $c);" + } else { + s"$evNull = true; " + } + } } else { $evPrim = (long)($c * 1000000L); } @@ -651,7 +754,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" if (Float.isNaN($c) || Float.isInfinite($c)) { - $evNull = true; + ${ + if (failOnCastErrorEnabled) { + s"throw new org.apache.spark.sql.catalyst.expressions.TypeCastException(" + + s"${javaDataTypeName(FloatType)}, ${javaDataTypeName(TimestampType)}," + + s" $c);" + } else { + s"$evNull = true;" + } + } } else { $evPrim = (long)($c * 1000000L); } @@ -686,7 +797,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } else if ($stringUtils.isFalseString($c)) { $evPrim = false; } else { - $evNull = true; + ${ + if (failOnCastErrorEnabled) { + s"throw new org.apache.spark.sql.catalyst.expressions.TypeCastException(" + + s"${javaDataTypeName(StringType)}, ${javaDataTypeName(BooleanType)}," + + s" $c);" + } else { + s"$evNull = true;" + } + } } """ case TimestampType => @@ -703,13 +822,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToByteCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" - try { - $evPrim = $c.toByte(); - } catch (java.lang.NumberFormatException e) { - $evNull = true; - } - """ + castStringToNumberCode(s"$evPrim = $c.toByte();", evNull, c, ByteType) case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => @@ -725,13 +838,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToShortCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" - try { - $evPrim = $c.toShort(); - } catch (java.lang.NumberFormatException e) { - $evNull = true; - } - """ + castStringToNumberCode(s"$evPrim = $c.toShort();", evNull, c, ShortType) case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => @@ -747,13 +854,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToIntCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" - try { - $evPrim = $c.toInt(); - } catch (java.lang.NumberFormatException e) { - $evNull = true; - } - """ + castStringToNumberCode(s"$evPrim = $c.toInt();", evNull, c, IntegerType) case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" case DateType => @@ -769,13 +870,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToLongCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" - try { - $evPrim = $c.toLong(); - } catch (java.lang.NumberFormatException e) { - $evNull = true; - } - """ + castStringToNumberCode(s" $evPrim = $c.toLong();", evNull, c, LongType) case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" case DateType => @@ -791,13 +886,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToFloatCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" - try { - $evPrim = Float.valueOf($c.toString()); - } catch (java.lang.NumberFormatException e) { - $evNull = true; - } - """ + castStringToNumberCode(s"$evPrim = Float.valueOf($c.toString());", evNull, c, FloatType) case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" case DateType => @@ -810,16 +899,24 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s"$evPrim = (float) $c;" } + private def failOnCastErrorEnabled: Boolean = { + val failOnCastError: String = if (TaskContext.get() != null) { + TaskContext.get().getLocalProperty("snappydata.failOnCastError") + } else if (SparkContext.activeContext.get() != null){ + SparkContext.activeContext.get().getLocalProperty("snappydata.failOnCastError") + } else { + "false" + } + Option(failOnCastError) match { + case Some(value) => value.toBoolean + case None => false + } + } + private[this] def castToDoubleCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" - try { - $evPrim = Double.valueOf($c.toString()); - } catch (java.lang.NumberFormatException e) { - $evNull = true; - } - """ + castStringToNumberCode(s"$evPrim = Double.valueOf($c.toString());", evNull, c, DoubleType) case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" case DateType => @@ -832,6 +929,36 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s"$evPrim = (double) $c;" } + private val javaDataTypesClassName = classOf[DataTypes].getCanonicalName + + private def javaDataTypeName(dataType: DataType) : String = { + dataType match { + case decimalType: DecimalType => + s"$javaDataTypesClassName.createDecimalType(${decimalType.precision}," + + s" ${decimalType.scale})" + case _ => s"$javaDataTypesClassName.$dataType" + } + } + + private[this] def castStringToNumberCode(code: String, evNull: String, c: String, + dataType: DataType): String = { + s""" + try { + $code + } catch (java.lang.NumberFormatException e) { + ${ + if (failOnCastErrorEnabled) { + s"throw new org.apache.spark.sql.catalyst.expressions.TypeCastException(" + + s"${javaDataTypeName(StringType)}, ${javaDataTypeName(dataType)}," + + s" $c);" + } else { + s"$evNull = true;" + } + } + } + """ + } + private[this] def castArrayCode( fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) @@ -955,3 +1082,10 @@ case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[Str extends UnaryExpression with Unevaluable { override lazy val resolved = false } + +class TypeCastException(sourceType: DataType, targetType: DataType, value: Any) + extends RuntimeException { + override def getMessage: String = { + s"Can not cast ${sourceType.simpleString} type value '$value' to ${targetType.simpleString}." + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9b53d21deed97..dd2fafd7d2077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.nio.charset.StandardCharsets import java.sql.Timestamp +import java.util.NoSuchElementException import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -89,7 +90,23 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + lazy val toRdd: RDD[InternalRow] = { + // setting snappydata.failOnCastError local property every time before + // executing the query to make the change to the property effective + + try { + sparkSession.sparkContext.setLocalProperty("snappydata.failOnCastError", + sparkSession.sessionState.conf.getConfString("snappydata.failOnCastError")) + } catch { + case ex: NoSuchElementException + if (ex.getMessage.equalsIgnoreCase("snappydata.failOnCastError")) => + // Only SnappySession config will have "snappydata.failOnCastError" set. + // While using spark session this config won't be there hence ignoring this + // failure. + } + + executedPlan.execute() + } /** * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal