Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ trait HiveTypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

// Skip if the type is boolean type already. Note that this extra cast should be removed
// by optimizer.SimplifyCasts.
case Cast(e, BooleanType) if e.dataType == BooleanType => e
case Cast(e, BooleanType) => Not(Equals(e, Literal(0)))
case Cast(e, dataType) if e.dataType == BooleanType =>
Cast(If(e, Literal(1), Literal(0)), dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, Equals}
import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.hive.test.TestHive

/**
* A set of tests that validate type promotion rules.
* A set of tests that validate type promotion and coercion rules.
*/
class HiveTypeCoercionSuite extends HiveComparisonTest {
val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")
Expand All @@ -28,4 +32,23 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
}
}

test("[SPARK-2210] boolean cast on boolean value should be removed") {
val q = "select cast(cast(key=0 as boolean) as boolean) from src"
val project = TestHive.hql(q).queryExecution.executedPlan.collect { case e: Project => e }.head

// No cast expression introduced
project.transformAllExpressions { case c: Cast =>
assert(false, "unexpected cast " + c)
c
}

// Only one Equals
var numEquals = 0
project.transformAllExpressions { case e: Equals =>
numEquals += 1
e
}
assert(numEquals === 1)
}
}