From 736da602863823bb54058c4e3f164646b4916457 Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Fri, 15 Nov 2019 14:34:27 +0530 Subject: [PATCH] Initial commit --- .../catalyst/analysis/PostgreSQLDialect.scala | 23 +++++- .../spark/sql/catalyst/expressions/Cast.scala | 6 +- .../postgreSQL/PostgreCastToLong.scala | 76 +++++++++++++++++++ .../expressions/postgreSQL/CastSuite.scala | 4 + .../sql/PostgreSQLDialectQuerySuite.scala | 5 ++ 5 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToLong.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala index e7f0e571804d..6d19db38d1b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala @@ -19,15 +19,15 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastToBoolean +import org.apache.spark.sql.catalyst.expressions.postgreSQL.{PostgreCastToBoolean, PostgreCastToLong} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, StringType} +import org.apache.spark.sql.types.{BooleanType, LongType, StringType} object PostgreSQLDialect { val postgreSQLDialectRules: List[Rule[LogicalPlan]] = - CastToBoolean :: + CastToBoolean :: CastToLong :: Nil object CastToBoolean extends Rule[LogicalPlan] with Logging { @@ -46,4 +46,21 @@ object PostgreSQLDialect { } } } + + object CastToLong extends Rule[LogicalPlan] with Logging { + override def apply(plan: LogicalPlan): LogicalPlan = { + // The SQL configuration `spark.sql.dialect` can be changed in runtime. + // To make sure the configuration is effective, we have to check it during rule execution. + val conf = SQLConf.get + if (conf.usePostgreSQLDialect) { + plan.transformExpressions { + case Cast(child, dataType, timeZoneId) + if child.dataType != LongType && dataType == LongType => + PostgreCastToLong(child, timeZoneId) + } + } else { + plan + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f3b58fa3137b..d23c1f254ebc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -471,7 +471,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // LongConverter - private[this] def castToLong(from: DataType): Any => Any = from match { + protected[this] def castToLong(from: DataType): Any => Any = from match { case StringType => val result = new LongWrapper() buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) @@ -1422,7 +1422,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (int) $c;" } - private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + protected[this] def castToLongCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToLong.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToLong.scala new file mode 100644 index 000000000000..e2c5834be509 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToLong.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.postgreSQL + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{CastBase, Expression, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.UTF8String.LongWrapper + +case class PostgreCastToLong(child: Expression, timeZoneId: Option[String]) + extends CastBase { + override def dataType: DataType = LongType + + override protected def ansiEnabled: Boolean = + throw new AnalysisException("") + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case DateType | TimestampType | NullType => + TypeCheckResult.TypeCheckFailure(s"cannot cast type ${child.dataType} to long") + case _ => + TypeCheckResult.TypeCheckSuccess + } + /** Returns a copy of this expression with the specified timeZoneId. */ + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def castToLong(from: DataType): Any => Any = from match { + case StringType => + val result = new LongWrapper() + buildCast[UTF8String](_, s => if (s.toLong(result)) result.value + else throw new AnalysisException(s"invalid input syntax for type long: $s")) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1L else 0L) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + } + + override def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType => + val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) + (c, evPrim, _) => + code""" + UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); + if ($c.toLong($wrapper)) { + $evPrim = $wrapper.value; + } else { + throw new AnalysisException(s"invalid input syntax for type long: $c"); + } + $wrapper = null; + """ + case BooleanType => + (c, evPrim, _) => code"$evPrim = $c ? 1L : 0L;" + case DecimalType() => + (c, evPrim, _) => code"$evPrim = $c.to${"long".capitalize}();" + case NumericType() => + (c, evPrim, _) => code"$evPrim = (long) $c;" + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala index 6c5218b379f3..66f1182aec97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala @@ -70,4 +70,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(PostgreCastToBoolean(Literal(1.toDouble), None).checkInputDataTypes().isFailure) assert(PostgreCastToBoolean(Literal(1.toFloat), None).checkInputDataTypes().isFailure) } + + test("unsupported data types to cast to long") { + // TODO: Test cases to be added + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala index 7056f483609a..1ddf10553187 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.DateType class PostgreSQLDialectQuerySuite extends QueryTest with SharedSparkSession { @@ -39,4 +40,8 @@ class PostgreSQLDialectQuerySuite extends QueryTest with SharedSparkSession { intercept[IllegalArgumentException](sql(s"select cast('$input' as boolean)").collect()) } } + + test("cast to long") { + // TODO: Add test cases + } }