Skip to content

Commit

Permalink
Can properly set expectation when a private overloaded method exists (#…
Browse files Browse the repository at this point in the history
…70)

* Extract findMethodsToOverride as an utility

So that when and mock can share it

* When should only consider some methods for override
  • Loading branch information
fmonniot authored Aug 11, 2024
1 parent e29cfb4 commit da6992e
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 33 deletions.
33 changes: 1 addition & 32 deletions core/src/main/scala/eu/monniot/scala3mock/macros/MockImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,37 +187,6 @@ private class MockImpl[T](ctx: Expr[MockContext], debug: Boolean)(using
.map(_.tree.asInstanceOf[DefDef])
.headOption

private def findMethodsToOverride(sym: Symbol): List[Symbol] =
val objectMembers = Symbol.requiredClass("java.lang.Object").methodMembers
val anyMembers = Symbol.requiredClass("scala.Any").methodMembers

// First we refine the methods by removing the methods inherited from Object and Any
val candidates = sym.methodMembers
.filter { m =>
!(objectMembers.contains(m) || anyMembers.contains(m))
}
.filterNot(_.flags.is(Flags.Private)) // Do not override private members

// We then generate a list of methods to ignore for default values. We do this because the
// compiler generate methods (following the `<methodName>$default$<parameterPosition>` naming
// scheme) to hold the default value of a parameter (and insert them automatically at call site).
// We do not want to override those.
val namesToIgnore = candidates
.flatMap { sym =>
sym.paramSymss
.filterNot(_.exists(_.isType))
.flatten
.zipWithIndex
.collect {
case (parameter, position)
if parameter.flags.is(Flags.HasDefault) =>
s"${sym.name}$$default$$${position + 1}"
}
}

candidates.filterNot(m => namesToIgnore.contains(m.name))
end findMethodsToOverride

/** Walk the given symbol hierarchy to find all trait which have parameters */
private def findParameterizedTraits(
lookedUpSymbol: Symbol
Expand Down Expand Up @@ -264,7 +233,7 @@ private class MockImpl[T](ctx: Expr[MockContext], debug: Boolean)(using
val objectMembers = Symbol.requiredClass("java.lang.Object").methodMembers
val anyMembers = Symbol.requiredClass("scala.Any").methodMembers

val methodsToOverride = findMethodsToOverride(classSymbol)
val methodsToOverride = utils.findMethodsToOverride(classSymbol)

// fields are values. TIL that scala has val without implementation :)
val fieldsToOverride =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ private[scala3mock] object WhenImpl:
name

case Some(signature) =>
classSymbol.map(_.methodMembers) match
classSymbol.map(utils.findMethodsToOverride) match
case None =>
report.errorAndAbort(
"The when parameter is composed of more than one type, which isn't supported at the moment."
Expand Down
35 changes: 35 additions & 0 deletions core/src/main/scala/eu/monniot/scala3mock/macros/utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,38 @@ private[macros] object utils:
val sig = sym.signature
sig.resultSig + sig.paramSigs.map(_.toString()).mkString
}

def findMethodsToOverride(using quotes: Quotes)(
sym: quotes.reflect.Symbol
): List[quotes.reflect.Symbol] =
import quotes.reflect.*

val objectMembers = Symbol.requiredClass("java.lang.Object").methodMembers
val anyMembers = Symbol.requiredClass("scala.Any").methodMembers

// First we refine the methods by removing the methods inherited from Object and Any
val candidates = sym.methodMembers
.filter { m =>
!(objectMembers.contains(m) || anyMembers.contains(m))
}
.filterNot(_.flags.is(Flags.Private)) // Do not override private members

// We then generate a list of methods to ignore for default values. We do this because the
// compiler generate methods (following the `<methodName>$default$<parameterPosition>` naming
// scheme) to hold the default value of a parameter (and insert them automatically at call site).
// We do not want to override those.
val namesToIgnore = candidates
.flatMap { sym =>
sym.paramSymss
.filterNot(_.exists(_.isType))
.flatten
.zipWithIndex
.collect {
case (parameter, position)
if parameter.flags.is(Flags.HasDefault) =>
s"${sym.name}$$default$$${position + 1}"
}
}

candidates.filterNot(m => namesToIgnore.contains(m.name))
end findMethodsToOverride
14 changes: 14 additions & 0 deletions core/src/test/scala/eu/monniot/scala3mock/mock/MockSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,18 @@ class MockSuite extends munit.FunSuite with ScalaMocks {
assertEquals(m.multiParamList(1)(1), "default")
}
}

test("overloaded private method - issue #68") {
class Service {
def method(value: String): String = method(2, true)
private def method(value: Int, value2: Boolean): String = "nok"
}

withExpectations() {
val service = mock[Service]

when(service.method).expects("").returns("ok")
assertEquals(service.method(""), "ok")
}
}
}

0 comments on commit da6992e

Please sign in to comment.