diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 33800c40badf..9e1ccc861aed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -145,6 +145,34 @@ trait InvokeLike extends Expression with NonSQLExpression { } } } + + final def findMethod(cls: Class[_], functionName: String, argClasses: Seq[Class[_]]): Method = { + // Looking with function name + argument classes first. + try { + cls.getMethod(functionName, argClasses: _*) + } catch { + case _: NoSuchMethodException => + // For some cases, e.g. arg class is Object, `getMethod` cannot find the method. + // We look at function name + argument length + val m = cls.getMethods.filter { m => + m.getName == functionName && m.getParameterCount == arguments.length + } + if (m.isEmpty) { + sys.error(s"Couldn't find $functionName on $cls") + } else if (m.length > 1) { + // More than one matched method signature. Exclude synthetic one, e.g. generic one. + val realMethods = m.filter(!_.isSynthetic) + if (realMethods.length > 1) { + // Ambiguous case, we don't know which method to choose, just fail it. + sys.error(s"Found ${realMethods.length} $functionName on $cls") + } else { + realMethods.head + } + } else { + m.head + } + } + } } /** @@ -236,7 +264,7 @@ case class StaticInvoke( override def children: Seq[Expression] = arguments lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) - @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*) + @transient lazy val method = findMethod(cls, functionName, argClasses) override def eval(input: InternalRow): Any = { invoke(null, method, arguments, input, dataType) @@ -326,31 +354,7 @@ case class Invoke( @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - // Looking with function name + argument classes first. - try { - Some(cls.getMethod(encodedFunctionName, argClasses: _*)) - } catch { - case _: NoSuchMethodException => - // For some cases, e.g. arg class is Object, `getMethod` cannot find the method. - // We look at function name + argument length - val m = cls.getMethods.filter { m => - m.getName == encodedFunctionName && m.getParameterCount == arguments.length - } - if (m.isEmpty) { - sys.error(s"Couldn't find $encodedFunctionName on $cls") - } else if (m.length > 1) { - // More than one matched method signature. Exclude synthetic one, e.g. generic one. - val realMethods = m.filter(!_.isSynthetic) - if (realMethods.length > 1) { - // Ambiguous case, we don't know which method to choose, just fail it. - sys.error(s"Found ${realMethods.length} $encodedFunctionName on $cls") - } else { - Some(realMethods.head) - } - } else { - Some(m.head) - } - } + Some(findMethod(cls, encodedFunctionName, argClasses)) case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 14b72f5132bf..7bcb2cd75c85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -638,8 +638,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val clsType = ObjectType(classOf[ConcreteClass]) val obj = new ConcreteClass + val input = (1, 2) checkObjectExprEvaluation( - Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 0) + Invoke(Literal(obj, clsType), "testFunc", IntegerType, + Seq(Literal(input, ObjectType(input.getClass)))), 2) + } + + test("SPARK-35288: static invoke should find method without exact param type match") { + val input = (1, 2) + + checkObjectExprEvaluation( + StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func", + Seq(Literal(input, ObjectType(input.getClass)))), 3) + + checkObjectExprEvaluation( + StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func", + Seq(Literal(1, IntegerType))), -1) } } @@ -652,10 +666,22 @@ class TestBean extends Serializable { assert(i != null, "this setter should not be called with null.") } +object TestStaticInvoke { + def func(param: Any): Int = param match { + case pair: Tuple2[_, _] => + pair.asInstanceOf[Tuple2[Int, Int]]._1 + pair.asInstanceOf[Tuple2[Int, Int]]._2 + case _ => -1 + } +} + abstract class BaseClass[T] { - def testFunc(param: T): T + def testFunc(param: T): Int } -class ConcreteClass extends BaseClass[Int] with Serializable { - override def testFunc(param: Int): Int = param - 1 +class ConcreteClass extends BaseClass[Product] with Serializable { + override def testFunc(param: Product): Int = param match { + case _: Tuple2[_, _] => 2 + case _: Tuple3[_, _, _] => 3 + case _ => 4 + } }