diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala index e7f0e571804d..c71b7ff90632 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala @@ -19,15 +19,15 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastToBoolean +import org.apache.spark.sql.catalyst.expressions.postgreSQL.{PostgreCastToBoolean, PostgreCastToDecimal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, StringType} +import org.apache.spark.sql.types.{BooleanType, DecimalType} object PostgreSQLDialect { val postgreSQLDialectRules: List[Rule[LogicalPlan]] = - CastToBoolean :: + CastToBoolean :: CastToDecimal :: Nil object CastToBoolean extends Rule[LogicalPlan] with Logging { @@ -46,4 +46,24 @@ object PostgreSQLDialect { } } } + + object CastToDecimal extends Rule[LogicalPlan] with Logging { + override def apply(plan: LogicalPlan): LogicalPlan = { + // The SQL configuration `spark.sql.dialect` can be changed in runtime. + // To make sure the configuration is effective, we have to check it during rule execution. + val conf = SQLConf.get + if (conf.usePostgreSQLDialect) { + plan.transformExpressions { + case Cast(child, dataType, timeZoneId) + if child.dataType != DecimalType => + dataType match { + case _: DecimalType => PostgreCastToDecimal(child, timeZoneId) + } + } + } else { + plan + } + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f3b58fa3137b..40f875dc17de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -596,7 +596,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit * * NOTE: this modifies `value` in-place, so don't call it on external data. */ - private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { + protected[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { if (value.changePrecision(decimalType.precision, decimalType.scale)) { value } else { @@ -619,7 +619,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled) - private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { + protected[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) @@ -1073,7 +1073,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, + protected[this] def changePrecision(d: ExprValue, decimalType: DecimalType, evPrim: ExprValue, evNull: ExprValue, canNullSafeCast: Boolean): Block = { if (canNullSafeCast) { code""" @@ -1099,7 +1099,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - private[this] def castToDecimalCode( + protected[this] def castToDecimalCode( from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToDecimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToDecimal.scala new file mode 100644 index 000000000000..b583d4d60efe --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToDecimal.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.postgreSQL + +import java.math.{BigDecimal => JavaBigDecimal} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +case class PostgreCastToDecimal(child: Expression, timeZoneId: Option[String]) + extends CastBase { + + override def dataType: DataType = DecimalType.defaultConcreteType + + override def toString: String = s"PostgreCastToDecimal($child as ${dataType.simpleString})" + + override def nullable: Boolean = child.nullable + + override protected def ansiEnabled = + throw new UnsupportedOperationException("PostgreSQL dialect doesn't support ansi mode") + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case StringType | LongType | IntegerType | NullType | FloatType | ShortType | + DoubleType | ByteType => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"cannot cast type ${child.dataType} to decimal") + } + + override def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, s => try { + changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) + } catch { + case _: NumberFormatException => + throw new AnalysisException(s"invalid input syntax for type numeric: $s") + }) + case t: IntegralType => + super.castToDecimal(from, target) + + case x: FractionalType => + b => try { + changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target) + } catch { + case _: NumberFormatException => + throw new AnalysisException(s"invalid input syntax for type numeric: $x") + } + } + + override def castToDecimalCode(from: DataType, target: DecimalType, + ctx: CodegenContext): CastFunction = { + val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) + val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target) + from match { + case StringType => + (c, evPrim, evNull) => + code""" + try { + Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} + } catch (java.lang.NumberFormatException e) { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } + """ + case t: IntegralType => + super.castToDecimalCode(from, target, ctx) + + case x: FractionalType => + // All other numeric types can be represented precisely as Doubles + (c, evPrim, evNull) => + code""" + try { + Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} + } catch (java.lang.NumberFormatException e) { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } + """ + } + } + + override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})" +}