Skip to content

Commit 011befd

Browse files
dongjoon-hyunrxin
authored andcommitted
[SPARK-16228][SQL] HiveSessionCatalog should return double-param functions for decimal param lookups
## What changes were proposed in this pull request? This PR supports a fallback lookup by casting `DecimalType` into `DoubleType` for the external functions with `double`-type parameter. **Reported Error Scenarios** ```scala scala> sql("select percentile(value, 0.5) from values 1,2,3 T(value)") org.apache.spark.sql.AnalysisException: ... No matching method for class org.apache.hadoop.hive.ql.udf.UDAFPercentile with (int, decimal(38,18)). Possible choices: _FUNC_(bigint, array<double>) _FUNC_(bigint, double) ; line 1 pos 7 scala> sql("select percentile_approx(value, 0.5) from values 1.0,2.0,3.0 T(value)") org.apache.spark.sql.AnalysisException: ... Only a float/double or float/double array argument is accepted as parameter 2, but decimal(38,18) was passed instead.; line 1 pos 7 ``` ## How was this patch tested? Pass the Jenkins tests (including a new testcase). Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13930 from dongjoon-hyun/SPARK-16228. (cherry picked from commit 2eaabfa) Signed-off-by: Reynold Xin <rxin@databricks.com>
1 parent c4cebd5 commit 011befd

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
3030
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
3131
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
3232
import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCatalog}
33-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
33+
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo}
3434
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
3535
import org.apache.spark.sql.catalyst.rules.Rule
3636
import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
3737
import org.apache.spark.sql.hive.client.HiveClient
3838
import org.apache.spark.sql.internal.SQLConf
39+
import org.apache.spark.sql.types.{DecimalType, DoubleType}
3940
import org.apache.spark.util.Utils
4041

4142

@@ -163,6 +164,19 @@ private[sql] class HiveSessionCatalog(
163164
}
164165

165166
override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
167+
try {
168+
lookupFunction0(name, children)
169+
} catch {
170+
case NonFatal(_) =>
171+
// SPARK-16228 ExternalCatalog may recognize `double`-type only.
172+
val newChildren = children.map { child =>
173+
if (child.dataType.isInstanceOf[DecimalType]) Cast(child, DoubleType) else child
174+
}
175+
lookupFunction0(name, newChildren)
176+
}
177+
}
178+
179+
private def lookupFunction0(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
166180
// TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to
167181
// if (super.functionExists(name)) {
168182
// super.lookupFunction(name, children)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
142142
sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq)
143143
}
144144

145+
test("SPARK-16228 Percentile needs explicit cast to double") {
146+
sql("select percentile(value, cast(0.5 as double)) from values 1,2,3 T(value)")
147+
sql("select percentile_approx(value, cast(0.5 as double)) from values 1.0,2.0,3.0 T(value)")
148+
sql("select percentile(value, 0.5) from values 1,2,3 T(value)")
149+
sql("select percentile_approx(value, 0.5) from values 1.0,2.0,3.0 T(value)")
150+
}
151+
145152
test("Generic UDAF aggregates") {
146153
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"),
147154
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)

0 commit comments

Comments
 (0)