Skip to content

Commit e73ccbf

Browse files
committed
reduce influence
1 parent e692b9f commit e73ccbf

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,8 @@ private[hive] case class HiveSimpleUDF(
5858
lazy val function = funcWrapper.createFunction[UDF]()
5959

6060
@transient
61-
private lazy val method = {
62-
// the simple UDF method must be 'evaluate'
63-
val methods = function.getClass.getMethods.filter(_.getName == "evaluate")
64-
val passedMethod = methods.filter(_.getGenericParameterTypes.length == children.length)
65-
66-
// no matching parameter num for evaluate method
67-
if (passedMethod.isEmpty) {
68-
throw new NoMatchingMethodException(function.getClass,
69-
children.map(_.dataType.toTypeInfo).asJava, methods.toSeq.asJava)
70-
}
71-
// if there exists many method, we choose the first
72-
methods.head
73-
}
61+
private lazy val method =
62+
function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava)
7463

7564
@transient
7665
private lazy val arguments = children.map(toInspector).toArray
@@ -82,7 +71,20 @@ private[hive] case class HiveSimpleUDF(
8271
}
8372

8473
override def inputTypes: Seq[AbstractDataType] = {
85-
method.getGenericParameterTypes.map(javaTypeToDataType)
74+
val inTypes = children.map(_.dataType)
75+
if (!inTypes.exists(_.existsRecursively(_.isInstanceOf[DecimalType]))) {
76+
inTypes
77+
} else {
78+
val expectTypes = method.getGenericParameterTypes.map(javaTypeToDataType)
79+
// check decimal
80+
inTypes.zip(expectTypes).map { case (in, expect) =>
81+
if (in.existsRecursively(_.isInstanceOf[DecimalType])) {
82+
expect
83+
} else {
84+
in
85+
}
86+
}
87+
}
8688
}
8789

8890
override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable)

0 commit comments

Comments
 (0)