diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0ca715f42472..edef40388e7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ @@ -33,13 +33,24 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") + override def checkInputDataTypes(): TypeCheckResult = { + if (children.map(_.dataType).forall(_.isInstanceOf[DecimalType])) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") + } + } override def dataType: DataType = { - ArrayType( - children.headOption.map(_.dataType).getOrElse(NullType), - containsNull = children.exists(_.nullable)) + var elementType: DataType = children.headOption.map(_.dataType).getOrElse(NullType) + if (elementType.isInstanceOf[DecimalType]) { + children.foreach { child => + if (elementType.asInstanceOf[DecimalType].isTighterThan(child.dataType)) { + elementType = child.dataType + } + } + } + ArrayType(elementType, containsNull = children.exists(_.nullable)) } override def nullable: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c307b2b8576..d74d3426616d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -134,6 +134,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } + test("SPARK-16714: CreateArray with Decimals") { + val array1 = CreateArray(Seq(Literal(Decimal(0.001)), Literal(Decimal(0.02)))) + val array2 = CreateArray(Seq(Literal(Decimal(0.02)), Literal(Decimal(0.001)))) + + assert(array1.checkInputDataTypes() == TypeCheckResult.TypeCheckSuccess) + assert(array2.checkInputDataTypes() == TypeCheckResult.TypeCheckSuccess) + assert(array1.dataType == array2.dataType) + + checkEvaluation(array1, Seq(Decimal(0.001), Decimal(0.02))) + checkEvaluation(array2, Seq(Decimal(0.02), Decimal(0.001))) + } + test("CreateMap") { def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = { keys.zip(values).flatMap { case (k, v) => Seq(k, v) }