Skip to content

Commit

Permalink
Merge pull request #57 from fmonniot/default-values-take-2
Browse files Browse the repository at this point in the history
Fix default parameters (take 2)
  • Loading branch information
fmonniot authored Jun 25, 2024
2 parents 4d10b53 + 24dbd9a commit 83ca10e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
36 changes: 17 additions & 19 deletions core/src/main/scala/eu/monniot/scala3mock/macros/MockImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,27 +198,25 @@ private class MockImpl[T](ctx: Expr[MockContext], debug: Boolean)(using
}
.filterNot(_.flags.is(Flags.Private)) // Do not override private members

// Then we find the methods which have default values, and how many default values.
// This will be used to filter out the default value methods that the compiler generate.
// I don't know of a good way to find those via the Quotes API, so instead I use the fact
// that the compiler always use the same naming scheme to filter them out.
val methodsWithDefault = candidates
.map { sym =>
sym.name -> sym.paramSymss
.flatMap(_.map(s => s.flags.is(Flags.HasDefault)))
.count(identity)
}
.filter(_._2 > 0)

if (methodsWithDefault.isEmpty) candidates
else
val names = methodsWithDefault.flatMap { case (name, count) =>
(0 to count).map(i => s"$name$$default$$$i")
// We then generate a list of methods to ignore for default values. We do this because the
// compiler generate methods (following the `<methodName>$default$<parameterPosition>` naming
// scheme) to hold the default value of a parameter (and insert them automatically at call site).
// We do not want to override those.
val namesToIgnore = candidates
.flatMap { sym =>
sym.paramSymss
.filterNot(_.exists(_.isType))
.flatten
.zipWithIndex
.collect {
case (parameter, position)
if parameter.flags.is(Flags.HasDefault) =>
s"${sym.name}$$default$$${position + 1}"
}
}

candidates.filterNot { m =>
names.contains(m.name)
}
candidates.filterNot(m => namesToIgnore.contains(m.name))
end findMethodsToOverride

/** Walk the given symbol hierarchy to find all trait which have parameters */
private def findParameterizedTraits(
Expand Down
10 changes: 10 additions & 0 deletions core/src/test/scala/eu/monniot/scala3mock/mock/MockSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,18 @@ class MockSuite extends munit.FunSuite with ScalaMocks {
when(m.foo).expects(9).returns("ok")
when(m.foo).expects(42).returns("ok2")

when(m.multiParamList(_: Int, _: Int)(_: Long, _: Long))
.expects(1, 1, 1, 1)
.returns("one")
when(m.multiParamList(_: Int, _: Int)(_: Long, _: Long))
.expects(1, 0, 1, 0)
.returns("default")

assertEquals(m.foo(), "ok")
assertEquals(m.foo(42), "ok2")

assertEquals(m.multiParamList(1, 1)(1, 1), "one")
assertEquals(m.multiParamList(1)(1), "default")
}
}
}
5 changes: 5 additions & 0 deletions core/src/test/scala/fixtures/TestDefaultParameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,9 @@ package fixtures
trait TestDefaultParameters {
def foo(bar: Int = 9): String
def foo2(b: Int = 1, c: Long): String

def defaultAfterRegular(a: Int, b: Int = 0): String
def multiParamList(a: Int, b: Int = 0)(c: Long, d: Long = 0): String

def withTypeParam[A, B](a: A, b: B)(c: Long, d: Long = 0): String
}

0 comments on commit 83ca10e

Please sign in to comment.