diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 4a9b28a455a4..72c6f53c69ee 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DecimalType, DoubleType} +import org.apache.spark.sql.types.{ArrayType, DecimalType, DoubleType} import org.apache.spark.util.Utils @@ -166,8 +166,13 @@ private[sql] class HiveSessionCatalog( } catch { case NonFatal(_) => // SPARK-16228 ExternalCatalog may recognize `double`-type only. + // SPARK-18527 Percentile needs explicit cast to array val newChildren = children.map { child => - if (child.dataType.isInstanceOf[DecimalType]) Cast(child, DoubleType) else child + child.dataType match { + case ArrayType(DecimalType(), nullable) => Cast(child, ArrayType(DoubleType, nullable)) + case DecimalType() => Cast(child, DoubleType) + case _ => child + } } lookupFunction0(name, newChildren) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 48adc833f4b2..5c8ee0bac1b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -149,6 +149,12 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { sql("select percentile_approx(value, 0.5) from values 1.0,2.0,3.0 T(value)") } + test("SPARK-18527 Percentile needs explicit cast to array") { + sql("select percentile(value, cast(array(0.1, 0.5) as array))" + + "from values 1,2,3 T(value)") + sql("select percentile(value, array(0.1, 0.5)) from values 1,2,3 T(value)") + } + test("Generic UDAF aggregates") { checkAnswer(sql(