Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class Analyzer(
Batch("Substitution", fixedPoint,
CTESubstitution,
WindowsSubstitution,
EliminateUnions),
EliminateUnions,
SubstituteFunctionAliases),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._

/**
* An analyzer rule that handles function aliases.
*/
object SubstituteFunctionAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
// SPARK-16730: The following functions are aliases for cast in Hive.
case u: UnresolvedFunction
if u.name.database.isEmpty && u.children.size == 1 && !u.isDistinct =>
u.name.funcName.toLowerCase match {
case "boolean" => Cast(u.children.head, BooleanType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use FunctionRegister to handle these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean putting in FunctionRegistry?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, but not sure if it can work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for this. I think you can make it work. Implement a method that creates the same as FunctionRegistry.expression[T <: Expression](name: String) method, e.g.:

def cast(name: String, dt: DataType): (String, (ExpressionInfo, FunctionBuilder)) = {
  val info = new ExpressionInfo(classOf[Cast].getName, name)
  name -> (info, Cast(_, dt))
}

and use that method to register these casts...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was in #14364

case "tinyint" => Cast(u.children.head, ByteType)
case "smallint" => Cast(u.children.head, ShortType)
case "int" => Cast(u.children.head, IntegerType)
case "bigint" => Cast(u.children.head, LongType)
case "float" => Cast(u.children.head, FloatType)
case "double" => Cast(u.children.head, DoubleType)
case "decimal" => Cast(u.children.head, DecimalType.USER_DEFAULT)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using whatever cast as decimal is using here, but I think it is a bug to by default cast to USER_DEFAULT, which has scale = 0.

case "date" => Cast(u.children.head, DateType)
case "timestamp" => Cast(u.children.head, TimestampType)
case "binary" => Cast(u.children.head, BinaryType)
case "string" => Cast(u.children.head, StringType)
case _ => u
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types._

/** Unit tests for [[SubstituteFunctionAliases]]. */
class SubstituteFunctionAliasesSuite extends PlanTest {

private def ruleTest(initial: Expression, transformed: Expression): Unit = {
ruleTest(SubstituteFunctionAliases, initial, transformed)
}

private def func(name: String, arg: Any, isDistinct: Boolean = false): UnresolvedFunction = {
UnresolvedFunction(name, Literal(arg) :: Nil, isDistinct)
}

test("boolean") {
ruleTest(func("boolean", 10), Cast(Literal(10), BooleanType))
}

test("tinyint") {
ruleTest(func("tinyint", 10), Cast(Literal(10), ByteType))
}

test("smallint") {
ruleTest(func("smallint", 10), Cast(Literal(10), ShortType))
}

test("int") {
ruleTest(func("int", 10), Cast(Literal(10), IntegerType))
}

test("bigint") {
ruleTest(func("bigint", 10), Cast(Literal(10), LongType))
}

test("float") {
ruleTest(func("float", 10), Cast(Literal(10), FloatType))
}

test("double") {
ruleTest(func("double", 10), Cast(Literal(10), DoubleType))
}

test("decimal") {
ruleTest(func("decimal", 10), Cast(Literal(10), DecimalType.USER_DEFAULT))
}

test("binary") {
ruleTest(func("binary", 10), Cast(Literal(10), BinaryType))
}

test("string") {
ruleTest(func("string", 10), Cast(Literal(10), StringType))
}

test("function is not an alias for cast if it has a database defined") {
val f = UnresolvedFunction(
FunctionIdentifier("int", database = Option("db")), Literal(10) :: Nil, isDistinct = false)
ruleTest(f, f)
}

test("function is not an alias for cast if it has zero input arg") {
val f = UnresolvedFunction("int", Nil, isDistinct = false)
ruleTest(f, f)
}

test("function is not an alias for cast if it has more than one input args") {
val f = UnresolvedFunction("int", Literal(10) :: Literal(11) :: Nil, isDistinct = false)
ruleTest(f, f)
}

test("function is not an alias for cast if it is distinct (aggregate function)") {
val f = UnresolvedFunction("int", Literal(10) :: Nil, isDistinct = true)
ruleTest(f, f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,24 +200,6 @@ class TypeCoercionSuite extends PlanTest {
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}

private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
ruleTest(Seq(rule), initial, transformed)
}

private def ruleTest(
rules: Seq[Rule[LogicalPlan]],
initial: Expression,
transformed: Expression): Unit = {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
val analyzer = new RuleExecutor[LogicalPlan] {
override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*))
}

comparePlans(
analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}

test("cast NullType for expressions that implement ExpectsInputTypes") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this into PlanTest

import TypeCoercionSuite._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types.IntegerType

/**
* Provides helper methods for comparing plans.
Expand Down Expand Up @@ -83,4 +85,35 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
protected def compareExpressions(e1: Expression, e2: Expression): Unit = {
comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation))
}

/**
* Tests the behavior of the given rule on a specific expression.
* @param rule rule to run
* @param initial the initial expression
* @param transformed the expected expression after rule execution.
*/
protected def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
ruleTest(Seq(rule), initial, transformed)
}

/**
* Tests the behavior of the given batch of rules on a specific expression.
* @param rules a batch of rules to run
* @param initial the initial expression
* @param transformed the expected expression after rule execution.
*/
protected def ruleTest(
rules: Seq[Rule[LogicalPlan]],
initial: Expression,
transformed: Expression): Unit = {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
val analyzer = new RuleExecutor[LogicalPlan] {
override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*))
}

comparePlans(
analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@

package org.apache.spark.sql

import java.math.BigDecimal
import java.sql.Timestamp

import org.apache.spark.sql.test.SharedSQLContext

/**
* A test suite for functions added for compatibility with other databases such as Oracle, MSSQL.
*
* These functions are typically implemented using the trait
* [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]].
* [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]]
*
* or using analyzer rule [[org.apache.spark.sql.catalyst.analysis.SubstituteFunctionAliases]].
*/
class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext {

Expand Down Expand Up @@ -69,4 +75,22 @@ class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext {
sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"),
Row(2.1, 1.0))
}

test("SPARK-16730 cast alias functions for Hive compatibility") {
checkAnswer(
sql("SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1)"),
Row(true, 1.toByte, 1.toShort, 1, 1L))

checkAnswer(
sql("SELECT float(1), double(1), decimal(1)"),
Row(1.toFloat, 1.0, new BigDecimal(1)))

checkAnswer(
sql("SELECT date(\"2014-04-04\"), timestamp(date(\"2014-04-04\"))"),
Row(new java.util.Date(114, 3, 4), new Timestamp(114, 3, 4, 0, 0, 0, 0)))

checkAnswer(
sql("SELECT string(1)"),
Row("1"))
}
}