Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,30 @@ case class Invoke(

@transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
val m = cls.getMethods.find(_.getName == encodedFunctionName)
if (m.isEmpty) {
sys.error(s"Couldn't find $encodedFunctionName on $cls")
} else {
m
// Looking with function name + argument classes first.
try {
Some(cls.getMethod(encodedFunctionName, argClasses: _*))
} catch {
case _: NoSuchMethodException =>
// For some cases, e.g. arg class is Object, `getMethod` cannot find the method.
// We look at function name + argument length
val m = cls.getMethods.filter { m =>
m.getName == encodedFunctionName && m.getParameterCount == arguments.length
}
if (m.isEmpty) {
sys.error(s"Couldn't find $encodedFunctionName on $cls")
} else if (m.length > 1) {
// More than one matched method signature. Exclude synthetic one, e.g. generic one.
val realMethods = m.filter(!_.isSynthetic)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot filter out synthetic ones in L336?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it is possible that there is only one method and it is also a synthetic one. If we filter synthetic ones at L336, we may miss it?

if (realMethods.length > 1) {
// Ambiguous case, we don't know which method to choose, just fail it.
sys.error(s"Found ${realMethods.length} $encodedFunctionName on $cls")
} else {
Some(realMethods.head)
}
} else {
Some(m.head)
}
}
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,29 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkExceptionInExpression[ArithmeticException](
StaticInvoke(mathCls, IntegerType, "addExact", Seq(Literal(Int.MaxValue), Literal(1))), "")
}

test("SPARK-35278: invoke should find method with correct number of parameters") {
val strClsType = ObjectType(classOf[String])
checkExceptionInExpression[StringIndexOutOfBoundsException](
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(3))), "")

checkObjectExprEvaluation(
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0))), "a")

checkExceptionInExpression[StringIndexOutOfBoundsException](
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(3))), "")

checkObjectExprEvaluation(
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(1))), "a")
}

test("SPARK-35278: invoke should correctly invoke override method") {
val clsType = ObjectType(classOf[ConcreteClass])
val obj = new ConcreteClass

checkObjectExprEvaluation(
Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 0)
}
}

class TestBean extends Serializable {
Expand All @@ -628,3 +651,11 @@ class TestBean extends Serializable {
def setNonPrimitive(i: AnyRef): Unit =
assert(i != null, "this setter should not be called with null.")
}

abstract class BaseClass[T] {
def testFunc(param: T): T
}

class ConcreteClass extends BaseClass[Int] with Serializable {
override def testFunc(param: Int): Int = param - 1
}