From 3829aaa5efdcf06e9709cec5ecb8d81a5da7a7ed Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Apr 2021 20:42:35 -0700 Subject: [PATCH 1/6] Invoke should find the method with correct number of parameters --- .../catalyst/expressions/objects/objects.scala | 4 +++- .../expressions/ObjectExpressionsSuite.scala | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 53be4bb651a4..9995d4a5574a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -325,7 +325,9 @@ case class Invoke( @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - val m = cls.getMethods.find(_.getName == encodedFunctionName) + val m = cls.getMethods.find { m => + m.getName == encodedFunctionName && m.getParameterCount == arguments.length + } if (m.isEmpty) { sys.error(s"Couldn't find $encodedFunctionName on $cls") } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index cbe37c4f9478..095968cdbc75 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -618,6 +618,21 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkExceptionInExpression[ArithmeticException]( StaticInvoke(mathCls, IntegerType, "addExact", Seq(Literal(Int.MaxValue), Literal(1))), "") } + + test("SPARK-35278: invoke should find method with correct number of parameters") { + val strClsType = ObjectType(classOf[String]) + checkExceptionInExpression[StringIndexOutOfBoundsException]( + Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(3))), "") + + checkObjectExprEvaluation( + Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0))), "a") + + checkExceptionInExpression[StringIndexOutOfBoundsException]( + Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(3))), "") + + checkObjectExprEvaluation( + Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(1))), "a") + } } class TestBean extends Serializable { From 9fc152c21d7908fcf6f527749d13ac2f7e40e5b3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Apr 2021 23:06:35 -0700 Subject: [PATCH 2/6] Use getMethod. --- .../sql/catalyst/expressions/objects/objects.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9995d4a5574a..00b82292c8a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -325,13 +325,11 @@ case class Invoke( @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - val m = cls.getMethods.find { m => - m.getName == encodedFunctionName && m.getParameterCount == arguments.length - } - if (m.isEmpty) { - sys.error(s"Couldn't find $encodedFunctionName on $cls") - } else { - m + try { + Some(cls.getMethod(encodedFunctionName, argClasses: _*)) + } catch { + case _: NoSuchMethodException => + sys.error(s"Couldn't find $encodedFunctionName on $cls") } case _ => None } From 8eed4d10eabbfdb46fd0ed9dc8a15ef38d7cd8de Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Apr 2021 01:22:47 -0700 Subject: [PATCH 3/6] Revert "Use getMethod." --- .../sql/catalyst/expressions/objects/objects.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 00b82292c8a4..9995d4a5574a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -325,11 +325,13 @@ case class Invoke( @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - try { - Some(cls.getMethod(encodedFunctionName, argClasses: _*)) - } catch { - case _: NoSuchMethodException => - sys.error(s"Couldn't find $encodedFunctionName on $cls") + val m = cls.getMethods.find { m => + m.getName == encodedFunctionName && m.getParameterCount == arguments.length + } + if (m.isEmpty) { + sys.error(s"Couldn't find $encodedFunctionName on $cls") + } else { + m } case _ => None } From d03318a00b2934228f2bf4375889806707aed22e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Apr 2021 12:01:10 -0700 Subject: [PATCH 4/6] Improve function lookup. --- .../expressions/objects/objects.scala | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9995d4a5574a..18e2c4b4d7e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -325,13 +325,24 @@ case class Invoke( @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - val m = cls.getMethods.find { m => - m.getName == encodedFunctionName && m.getParameterCount == arguments.length - } - if (m.isEmpty) { - sys.error(s"Couldn't find $encodedFunctionName on $cls") - } else { - m + // Looking with function name + argument classes first. + try { + Some(cls.getMethod(encodedFunctionName, argClasses: _*)) + } catch { + case _: NoSuchMethodException => + // For some cases, e.g. arg class is Object, `getMethod` cannot find the method. + // We look at function name + argument length. + val m = cls.getMethods.filter { m => + m.getName == encodedFunctionName && m.getParameterCount == arguments.length + } + if (m.isEmpty) { + sys.error(s"Couldn't find $encodedFunctionName on $cls") + } else if (m.length > 1) { + // Ambiguous case, we don't know which method to choose, just fail it. + sys.error(s"Found ${m.length} $encodedFunctionName on $cls") + } else { + Some(m.head) + } } case _ => None } From 176b103ffbbb1e9044e22233e6baaff60b8b6b9a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Apr 2021 16:42:24 -0700 Subject: [PATCH 5/6] Consider generic method. --- .../sql/catalyst/expressions/objects/objects.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 18e2c4b4d7e0..5abf972f79fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -331,15 +331,21 @@ case class Invoke( } catch { case _: NoSuchMethodException => // For some cases, e.g. arg class is Object, `getMethod` cannot find the method. - // We look at function name + argument length. + // We look at function name + argument length val m = cls.getMethods.filter { m => m.getName == encodedFunctionName && m.getParameterCount == arguments.length } if (m.isEmpty) { sys.error(s"Couldn't find $encodedFunctionName on $cls") } else if (m.length > 1) { - // Ambiguous case, we don't know which method to choose, just fail it. - sys.error(s"Found ${m.length} $encodedFunctionName on $cls") + // More than one matched method signature. Exclude synthetic one, e.g. generic one. + val realMethods = m.filter(!_.isSynthetic) + if (realMethods.length > 1) { + // Ambiguous case, we don't know which method to choose, just fail it. + sys.error(s"Found ${realMethods.length} $encodedFunctionName on $cls") + } else { + Some(realMethods.head) + } } else { Some(m.head) } From bfaaf32cbaf17a56443cf029f724d03b5640f605 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Apr 2021 20:24:32 -0700 Subject: [PATCH 6/6] Add test. --- .../expressions/ObjectExpressionsSuite.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 095968cdbc75..14b72f5132bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -633,6 +633,14 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkObjectExprEvaluation( Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(1))), "a") } + + test("SPARK-35278: invoke should correctly invoke override method") { + val clsType = ObjectType(classOf[ConcreteClass]) + val obj = new ConcreteClass + + checkObjectExprEvaluation( + Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 0) + } } class TestBean extends Serializable { @@ -643,3 +651,11 @@ class TestBean extends Serializable { def setNonPrimitive(i: AnyRef): Unit = assert(i != null, "this setter should not be called with null.") } + +abstract class BaseClass[T] { + def testFunc(param: T): T +} + +class ConcreteClass extends BaseClass[Int] with Serializable { + override def testFunc(param: Int): Int = param - 1 +}