diff --git a/core/src/main/scala/cats/Defer.scala b/core/src/main/scala/cats/Defer.scala index 2ab135c54d..c24b65b0ec 100644 --- a/core/src/main/scala/cats/Defer.scala +++ b/core/src/main/scala/cats/Defer.scala @@ -28,6 +28,27 @@ package cats */ trait Defer[F[_]] extends Serializable { def defer[A](fa: => F[A]): F[A] + + /** + * Defer instances, like functions, parsers, generators, IO, etc... + * often are used in recursive settings where this function is useful + * + * fix(fn) == fn(fix(fn)) + * + * example: + * + * val parser: P[Int] = + * Defer[P].fix[Int] { rec => + * CharsIn("0123456789") | P("(") ~ rec ~ P(")") + * } + * + * Note, fn may not yield a terminating value in which case both + * of the above F[A] run forever. + */ + def fix[A](fn: F[A] => F[A]): F[A] = { + lazy val res: F[A] = defer(fn(res)) + res + } } object Defer { diff --git a/tests/src/test/scala/cats/tests/FunctionSuite.scala b/tests/src/test/scala/cats/tests/FunctionSuite.scala index 4cc013752c..ffa5eb15bf 100644 --- a/tests/src/test/scala/cats/tests/FunctionSuite.scala +++ b/tests/src/test/scala/cats/tests/FunctionSuite.scala @@ -23,6 +23,7 @@ import cats.laws.discipline.eq._ import cats.laws.discipline.arbitrary._ import cats.kernel.{CommutativeGroup, CommutativeMonoid, CommutativeSemigroup} import cats.kernel.{Band, BoundedSemilattice, Semilattice} +import org.scalacheck.Gen class FunctionSuite extends CatsSuite { @@ -46,6 +47,22 @@ class FunctionSuite extends CatsSuite { // TODO: make an binary compatible way to do this // checkAll("Function1[Int => *]", DeferTests[Function1[Int, *]].defer[Int]) + test("Defer[Function1[Int, *]].fix computing sum") { + val sum2 = Defer[Function1[Int, *]].fix[Int] { + rec => + { n: Int => + if (n <= 0) 0 else n * n + rec(n - 1) + } + } + + forAll(Gen.choose(0, 1000)) { n => + // don't let n get too large because this consumes stack + assert(sum2(n) == (0 to n).map { n => + n * n + }.sum) + } + } + checkAll("Semigroupal[Function1[Int, *]]", SerializableTests.serializable(Semigroupal[Function1[Int, *]])) checkAll("Function1[MiniInt, Int]", MonadTests[MiniInt => *].monad[Int, Int, Int])