From 2846a0473c3f035158e0635b1f1f0ceec7376c8e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 03:31:03 +0800 Subject: [PATCH 1/7] fix 7952 --- .../catalyst/analysis/HiveTypeCoercion.scala | 58 +++++++++++++------ .../org/apache/spark/sql/SQLQuerySuite.scala | 22 +++++++ 2 files changed, 62 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 195418d6dfb1f..8a9b5494112e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -76,7 +76,7 @@ trait HiveTypeCoercion { WidenTypes :: PromoteStrings :: DecimalPrecision :: - BooleanComparisons :: + BooleanEqualization :: StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: @@ -482,30 +482,52 @@ trait HiveTypeCoercion { } /** - * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. + * Changes numeric values to booleans so that expressions like true = 1 can be Evaluated. */ - object BooleanComparisons extends Rule[LogicalPlan] { - val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_)) - val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_)) + object BooleanEqualization extends Rule[LogicalPlan] { + val trueValue = Literal(new java.math.BigDecimal(1)) + val falseValue = Literal(new java.math.BigDecimal(0)) + + def isNull(expr: Expression) = EqualNullSafe(expr, Literal.create(null, expr.dataType)) + + def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { + CaseKeyWhen(Cast(numericExpr, DecimalType.Unlimited), + Seq( + trueValue, booleanExpr, + falseValue, Not(booleanExpr), + Literal(false))) + } + + def transform(booleanExpr: Expression, numericExpr: Expression) = { + CaseWhen(Seq( + isNull(booleanExpr), Literal.create(null, BooleanType), + isNull(numericExpr), Literal.create(null, BooleanType), + buildCaseKeyWhen(booleanExpr, numericExpr) + )) + } + + def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { + CaseWhen(Seq( + And(isNull(booleanExpr), isNull(numericExpr)), Literal(true), + isNull(booleanExpr), Literal(false), + isNull(numericExpr), Literal(false), + buildCaseKeyWhen(booleanExpr, numericExpr) + )) + } def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e // Hive treats (true = 1) as true and (false = 0) as true. - case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l - case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r - case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) - case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) - - // No need to change other EqualTo operators as that actually makes sense for boolean types. - case e: EqualTo => e - // No need to change the EqualNullSafe operators, too - case e: EqualNullSafe => e - // Otherwise turn them to Byte types so that there exists and ordering. - case p: BinaryComparison if p.left.dataType == BooleanType && - p.right.dataType == BooleanType => - p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) + case EqualTo(l @ BooleanType(), r) if r.dataType.isInstanceOf[NumericType] => + transform(l, r) + case EqualTo(l, r @ BooleanType()) if l.dataType.isInstanceOf[NumericType] => + transform(r, l) + case EqualNullSafe(l @ BooleanType(), r) if r.dataType.isInstanceOf[NumericType] => + transformNullSafe(l, r) + case EqualNullSafe(l, r @ BooleanType()) if l.dataType.isInstanceOf[NumericType] => + transformNullSafe(r, l) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bf18bf854aa4a..42418bcf9b36b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1331,4 +1331,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("SPARK-7952: fix the equality check between boolean type and numeric typegd") { + val data = Seq( + (1, true), + (0, false), + (2, true), + (2, false), + (null, true), + (null, false), + (0, null), + (1, null), + (null, null) + ) + val rowRDD = sparkContext.makeRDD(data).map(r => Row(r._1, r._2)) + val schema = StructType(Seq(StructField("i", IntegerType), StructField("b", BooleanType))) + createDataFrame(rowRDD, schema).registerTempTable("t") + + checkAnswer(sql("select i = b from t"), + Seq(true, true, false, false, null, null, null, null, null).map(Row(_))) + checkAnswer(sql("select i <=> b from t"), + Seq(true, true, false, false, false, false, false, false, true).map(Row(_))) + } } From fc0d7410ce24bdfc693e55558ba0007dc4c84ed4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 11:34:45 +0800 Subject: [PATCH 2/7] fix style --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8a9b5494112e9..ce049275b1fba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -488,9 +488,9 @@ trait HiveTypeCoercion { val trueValue = Literal(new java.math.BigDecimal(1)) val falseValue = Literal(new java.math.BigDecimal(0)) - def isNull(expr: Expression) = EqualNullSafe(expr, Literal.create(null, expr.dataType)) + private def isNull(expr: Expression) = EqualNullSafe(expr, Literal.create(null, expr.dataType)) - def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { + private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { CaseKeyWhen(Cast(numericExpr, DecimalType.Unlimited), Seq( trueValue, booleanExpr, @@ -498,7 +498,7 @@ trait HiveTypeCoercion { Literal(false))) } - def transform(booleanExpr: Expression, numericExpr: Expression) = { + private def transform(booleanExpr: Expression, numericExpr: Expression) = { CaseWhen(Seq( isNull(booleanExpr), Literal.create(null, BooleanType), isNull(numericExpr), Literal.create(null, BooleanType), @@ -506,7 +506,7 @@ trait HiveTypeCoercion { )) } - def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { + private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { CaseWhen(Seq( And(isNull(booleanExpr), isNull(numericExpr)), Literal(true), isNull(booleanExpr), Literal(false), From 9ba2130873f6911253c0a81d73812e2f2c7ec27d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 22:31:00 +0800 Subject: [PATCH 3/7] address comments --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 5 +++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 15 +++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index ce049275b1fba..1db9524161ddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -482,7 +482,7 @@ trait HiveTypeCoercion { } /** - * Changes numeric values to booleans so that expressions like true = 1 can be Evaluated. + * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ object BooleanEqualization extends Rule[LogicalPlan] { val trueValue = Literal(new java.math.BigDecimal(1)) @@ -519,7 +519,8 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // Hive treats (true = 1) as true and (false = 0) as true. + // Hive treats (true = 1) as true and (false = 0) as true, + // all other cases are considered as false. case EqualTo(l @ BooleanType(), r) if r.dataType.isInstanceOf[NumericType] => transform(l, r) case EqualTo(l, r @ BooleanType()) if l.dataType.isInstanceOf[NumericType] => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 42418bcf9b36b..bece2e4afc401 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1332,8 +1332,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } - test("SPARK-7952: fix the equality check between boolean type and numeric typegd") { - val data = Seq( + test("SPARK-7952: fix the equality check between boolean and numeric types") { + val df = Seq( (1, true), (0, false), (2, true), @@ -1343,14 +1343,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { (0, null), (1, null), (null, null) - ) - val rowRDD = sparkContext.makeRDD(data).map(r => Row(r._1, r._2)) - val schema = StructType(Seq(StructField("i", IntegerType), StructField("b", BooleanType))) - createDataFrame(rowRDD, schema).registerTempTable("t") + ).map { case (i, b) => + (i.asInstanceOf[Integer], b.asInstanceOf[java.lang.Boolean]) + }.toDF("i", "b") - checkAnswer(sql("select i = b from t"), + checkAnswer(df.select('i === 'b), Seq(true, true, false, false, null, null, null, null, null).map(Row(_))) - checkAnswer(sql("select i <=> b from t"), + checkAnswer(df.select('i <=> 'b), Seq(true, true, false, false, false, false, false, false, true).map(Row(_))) } } From 625973c0d7c3f9a647dacbb5088f7ed15518cda9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 22:47:56 +0800 Subject: [PATCH 4/7] improve --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 1db9524161ddf..6eba3608649e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -488,8 +488,6 @@ trait HiveTypeCoercion { val trueValue = Literal(new java.math.BigDecimal(1)) val falseValue = Literal(new java.math.BigDecimal(0)) - private def isNull(expr: Expression) = EqualNullSafe(expr, Literal.create(null, expr.dataType)) - private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { CaseKeyWhen(Cast(numericExpr, DecimalType.Unlimited), Seq( @@ -500,17 +498,15 @@ trait HiveTypeCoercion { private def transform(booleanExpr: Expression, numericExpr: Expression) = { CaseWhen(Seq( - isNull(booleanExpr), Literal.create(null, BooleanType), - isNull(numericExpr), Literal.create(null, BooleanType), + Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal.create(null, BooleanType), buildCaseKeyWhen(booleanExpr, numericExpr) )) } private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { CaseWhen(Seq( - And(isNull(booleanExpr), isNull(numericExpr)), Literal(true), - isNull(booleanExpr), Literal(false), - isNull(numericExpr), Literal(false), + And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), + Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), buildCaseKeyWhen(booleanExpr, numericExpr) )) } From ebc8c613577c39eac01bcf504c890306b640e0e0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 23:01:55 +0800 Subject: [PATCH 5/7] use SQLTestUtils and If --- .../catalyst/analysis/HiveTypeCoercion.scala | 7 ++- .../org/apache/spark/sql/SQLQuerySuite.scala | 54 ++++++++++--------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6eba3608649e5..d54491482cbc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -497,10 +497,9 @@ trait HiveTypeCoercion { } private def transform(booleanExpr: Expression, numericExpr: Expression) = { - CaseWhen(Seq( - Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal.create(null, BooleanType), - buildCaseKeyWhen(booleanExpr, numericExpr) - )) + If(Or(IsNull(booleanExpr), IsNull(numericExpr)), + Literal.create(null, BooleanType), + buildCaseKeyWhen(booleanExpr, numericExpr)) } private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bece2e4afc401..312ea3355eda5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} import org.apache.spark.sql.types._ @@ -32,12 +32,12 @@ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { +class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - import org.apache.spark.sql.test.TestSQLContext.implicits._ - val sqlCtx = TestSQLContext + val sqlContext = TestSQLContext + import sqlContext.implicits._ test("SPARK-6743: no columns from cache") { Seq( @@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) + val df1 = createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) + val df2 = createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) + val df3 = createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1333,23 +1333,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-7952: fix the equality check between boolean and numeric types") { - val df = Seq( - (1, true), - (0, false), - (2, true), - (2, false), - (null, true), - (null, false), - (0, null), - (1, null), - (null, null) - ).map { case (i, b) => - (i.asInstanceOf[Integer], b.asInstanceOf[java.lang.Boolean]) - }.toDF("i", "b") - - checkAnswer(df.select('i === 'b), - Seq(true, true, false, false, null, null, null, null, null).map(Row(_))) - checkAnswer(df.select('i <=> 'b), - Seq(true, true, false, false, false, false, false, false, true).map(Row(_))) + withTempTable("t") { + Seq( + (1, true), + (0, false), + (2, true), + (2, false), + (null, true), + (null, false), + (0, null), + (1, null), + (null, null) + ).map { case (i, b) => + (i.asInstanceOf[Integer], b.asInstanceOf[java.lang.Boolean]) + }.toDF("i", "b").registerTempTable("t") + + checkAnswer(sql("select i = b from t"), + Seq(true, true, false, false, null, null, null, null, null).map(Row(_))) + checkAnswer(sql("select i <=> b from t"), + Seq(true, true, false, false, false, false, false, false, true).map(Row(_))) + } } } From b6401ba59cf98cd9218a6f69da9bf089f4ee5240 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 1 Jun 2015 01:01:39 +0800 Subject: [PATCH 6/7] add type coercion for CaseKeyWhen and address comments --- .../catalyst/analysis/HiveTypeCoercion.scala | 61 +++++++++++++++---- .../sql/catalyst/expressions/predicates.scala | 5 +- .../analysis/HiveTypeCoercionSuite.scala | 55 ++++++++++++++--- .../ExpressionEvaluationSuite.scala | 8 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 33 +++++----- 5 files changed, 116 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index d54491482cbc3..4e625e94fb791 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -485,15 +485,14 @@ trait HiveTypeCoercion { * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ object BooleanEqualization extends Rule[LogicalPlan] { - val trueValue = Literal(new java.math.BigDecimal(1)) - val falseValue = Literal(new java.math.BigDecimal(0)) + val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1)) + val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0)) private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { - CaseKeyWhen(Cast(numericExpr, DecimalType.Unlimited), - Seq( - trueValue, booleanExpr, - falseValue, Not(booleanExpr), - Literal(false))) + CaseKeyWhen(numericExpr, Seq( + Literal(trueValues.head), booleanExpr, + Literal(falseValues.head), Not(booleanExpr), + Literal(false))) } private def transform(booleanExpr: Expression, numericExpr: Expression) = { @@ -516,13 +515,32 @@ trait HiveTypeCoercion { // Hive treats (true = 1) as true and (false = 0) as true, // all other cases are considered as false. - case EqualTo(l @ BooleanType(), r) if r.dataType.isInstanceOf[NumericType] => - transform(l, r) - case EqualTo(l, r @ BooleanType()) if l.dataType.isInstanceOf[NumericType] => + + // We may simplify the expression if one side is literal numeric values + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => l + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(l) + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => r + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => Not(r) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(l), l) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Or(IsNull(l), l) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(r), r) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => Or(IsNull(r), r) + + case EqualTo(l @ BooleanType(), r @ NumericType()) => + transform(l , r) + case EqualTo(l @ NumericType(), r @ BooleanType()) => transform(r, l) - case EqualNullSafe(l @ BooleanType(), r) if r.dataType.isInstanceOf[NumericType] => + case EqualNullSafe(l @ BooleanType(), r @ NumericType()) => transformNullSafe(l, r) - case EqualNullSafe(l, r @ BooleanType()) if l.dataType.isInstanceOf[NumericType] => + case EqualNullSafe(l @ NumericType(), r @ BooleanType()) => transformNullSafe(r, l) } } @@ -624,7 +642,7 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => + case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual => logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") val commonType = cw.valueTypes.reduce { (v1, v2) => findTightestCommonType(v1, v2).getOrElse(sys.error( @@ -643,6 +661,23 @@ trait HiveTypeCoercion { case CaseKeyWhen(key, _) => CaseKeyWhen(key, transformedBranches) } + + case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved => + val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) => + findTightestCommonType(v1, v2).getOrElse(sys.error( + s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) + } + val transformedBranches = ckw.branches.sliding(2, 2).map { + case Seq(when, then) if when.dataType != commonType => + Seq(Cast(when, commonType), then) + case s => s + }.reduce(_ ++ _) + val transformedKey = if (ckw.key.dataType != commonType) { + Cast(ckw.key, commonType) + } else { + ckw.key + } + CaseKeyWhen(transformedKey, transformedBranches) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e2d1c8115e051..4f422d69c4382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -366,7 +366,7 @@ trait CaseWhenLike extends Expression { // both then and else val should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1 + def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 override def dataType: DataType = { if (!resolved) { @@ -442,7 +442,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual + childrenResolved && valueTypesEqual && + (key +: whenList).map(_.dataType).distinct.size == 1 /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f0101f4a88f86..0c561b2703ea1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { @@ -104,15 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(ArrayType(IntegerType), StructType(Seq()), None) } + private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + rule(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) + } + test("coalesce casts") { val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - fac(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - ruleTest( + ruleTest(fac, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -121,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest( + ruleTest(fac, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -131,4 +133,39 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) :: Nil)) } + + test("type coercion for CaseKeyWhen") { + val cwc = new HiveTypeCoercion {}.CaseWhenCoercion + ruleTest(cwc, + CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) + ) + // Will remove exception expectation in PR#6405 + intercept[RuntimeException] { + ruleTest(cwc, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) + } + } + + test("type coercion simplification for equal to") { + val be = new HiveTypeCoercion {}.BooleanEqualization + ruleTest(be, + EqualTo(Literal(true), Literal(1)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(true), Literal(0)), + Not(Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(1)), + And(IsNotNull(Literal(true)), Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(0)), + Or(IsNull(Literal(true)), Literal(true)) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 10181366c2fcd..56c027ef466e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -862,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val c5 = 'a.string.at(4) val c6 = 'a.string.at(5) - val literalNull = Literal.create(null, BooleanType) + val literalNull = Literal.create(null, IntegerType) val literalInt = Literal(1) val literalString = Literal("a") @@ -871,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) - checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row) + checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row) checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) - checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row) - checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row) + checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) } test("complex type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 312ea3355eda5..63f7d314fb699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1334,24 +1334,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-7952: fix the equality check between boolean and numeric types") { withTempTable("t") { - Seq( - (1, true), - (0, false), - (2, true), - (2, false), - (null, true), - (null, false), - (0, null), - (1, null), - (null, null) - ).map { case (i, b) => - (i.asInstanceOf[Integer], b.asInstanceOf[java.lang.Boolean]) - }.toDF("i", "b").registerTempTable("t") - - checkAnswer(sql("select i = b from t"), - Seq(true, true, false, false, null, null, null, null, null).map(Row(_))) - checkAnswer(sql("select i <=> b from t"), - Seq(true, true, false, false, false, false, false, false, true).map(Row(_))) + // numeric field i, boolean field j, result of i = j, result of i <=> j + Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( + (1, true, true, true), + (0, false, true, true), + (2, true, false, false), + (2, false, false, false), + (null, true, null, false), + (null, false, null, false), + (0, null, null, false), + (1, null, null, false), + (null, null, null, true) + ).toDF("i", "b", "r1", "r2").registerTempTable("t") + + checkAnswer(sql("select i = b from t"), sql("select r1 from t")) + checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) } } } From 77f0f39fdbe8a5858e38a79f72e131bed23e385e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 1 Jun 2015 03:28:12 +0800 Subject: [PATCH 7/7] minor fix --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 16 ++++++++-------- .../analysis/HiveTypeCoercionSuite.scala | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 4e625e94fb791..5d9911adbac05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -119,7 +119,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN") + private val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -349,17 +349,17 @@ trait HiveTypeCoercion { import scala.math.{max, min} // Conversion rules for integer types into fixed-precision decimals - val intTypeToFixed: Map[DataType, DecimalType] = Map( + private val intTypeToFixed: Map[DataType, DecimalType] = Map( ByteType -> DecimalType(3, 0), ShortType -> DecimalType(5, 0), IntegerType -> DecimalType(10, 0), LongType -> DecimalType(20, 0) ) - def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType // Conversion rules for float and double into fixed-precision decimals - val floatTypeToFixed: Map[DataType, DecimalType] = Map( + private val floatTypeToFixed: Map[DataType, DecimalType] = Map( FloatType -> DecimalType(7, 7), DoubleType -> DecimalType(15, 15) ) @@ -485,8 +485,8 @@ trait HiveTypeCoercion { * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ object BooleanEqualization extends Rule[LogicalPlan] { - val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1)) - val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0)) + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1)) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0)) private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { CaseKeyWhen(numericExpr, Seq( @@ -528,11 +528,11 @@ trait HiveTypeCoercion { case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) if trueValues.contains(value) => And(IsNotNull(l), l) case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) - if falseValues.contains(value) => Or(IsNull(l), l) + if falseValues.contains(value) => And(IsNotNull(l), Not(l)) case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) if trueValues.contains(value) => And(IsNotNull(r), r) case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) - if falseValues.contains(value) => Or(IsNull(r), r) + if falseValues.contains(value) => And(IsNotNull(r), Not(r)) case EqualTo(l @ BooleanType(), r @ NumericType()) => transform(l , r) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 0c561b2703ea1..a0798428db094 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -165,7 +165,7 @@ class HiveTypeCoercionSuite extends PlanTest { ) ruleTest(be, EqualNullSafe(Literal(true), Literal(0)), - Or(IsNull(Literal(true)), Literal(true)) + And(IsNotNull(Literal(true)), Not(Literal(true))) ) } }