diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d8c8698e31d3..fd5d70259ae5 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -995,6 +995,15 @@ The following options can be used to configure the version of Hive that is used

+ + spark.sql.broadcastTimeout + 300 + +

+ Timeout in seconds for the broadcast wait time in broadcast joins +

+ + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0ca715f42472..c06a66fa89b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -93,20 +93,57 @@ case class CreateMap(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.") } else if (keys.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + - "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + if (keys.map(_.dataType).forall(_.isInstanceOf[DecimalType])) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + + "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } } else if (values.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " + - "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + if (values.map(_.dataType).forall(_.isInstanceOf[DecimalType])) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("The given values of function map should all be the " + + "same type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } } else { TypeCheckResult.TypeCheckSuccess } } + private def isDecimalTypeTighterThan(src: DecimalType, other: DataType): Boolean = other match { + case dt: DecimalType => + (src.precision - src.scale) <= (dt.precision - dt.scale) && src.scale <= dt.scale + case _ => false + } + + /** + * only check decimal type contains by the coltypes + * @param colType + * @return + */ + private def checkDecimalType(colType: Seq[Expression]): DataType = { + val elementType = colType.headOption.map(_.dataType).getOrElse(NullType) + + elementType match { + case _ if elementType.isInstanceOf[DecimalType] => + var tighter: DataType = elementType + colType.foreach { child => + if (isDecimalTypeTighterThan(tighter.asInstanceOf[DecimalType], child.dataType)) { + tighter = child.dataType + } + } + + tighter + case _ => + elementType + } + } + override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(NullType), - valueType = values.headOption.map(_.dataType).getOrElse(NullType), + keyType = checkDecimalType(keys), + valueType = checkDecimalType(values), valueContainsNull = values.exists(_.nullable)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c307b2b8576..293f3aa8d1fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -134,16 +134,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } - test("CreateMap") { - def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = { - keys.zip(values).flatMap { case (k, v) => Seq(k, v) } - } + private def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = { + keys.zip(values).flatMap { case (k, v) => Seq(k, v) } + } - def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { - // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. - scala.collection.immutable.ListMap(keys.zip(values): _*) - } + private def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. + scala.collection.immutable.ListMap(keys.zip(values): _*) + } + test("SPARK-16735: CreateMap with Decimals") { + val keys = Seq(0.02, 0.004) + val values = Seq(0.001, 0.5) + val keys1 = Seq(0.020, 0.004) + val values1 = Seq(0.001, 0.500) + val map1 = CreateMap(interlace(keys.map(Literal(_)), values.map(Literal(_)))) + + assert(map1.checkInputDataTypes() == TypeCheckResult.TypeCheckSuccess) + + checkEvaluation(map1, createMap(keys1, values1)) + checkEvaluation(map1, createMap(keys, values)) + } + + test("CreateMap") { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) val strSeq = intSeq.map(_.toString)