From 855411a519e341342c18ff1e05fb1880f1be29bd Mon Sep 17 00:00:00 2001 From: Kalra Date: Sat, 29 Feb 2020 10:38:35 +0800 Subject: [PATCH] backported #3041, added tailrec instance for StacksafeMonad and Defer --- .../scala/cats/bench/TrampolineBench.scala | 19 ++++++------ core/src/main/scala/cats/Eval.scala | 1 + core/src/main/scala/cats/instances/all.scala | 1 + .../main/scala/cats/instances/package.scala | 1 + .../main/scala/cats/instances/tailrec.scala | 26 +++++++++++++++++ .../test/scala/cats/tests/TailRecSuite.scala | 29 +++++++++++++++++++ 6 files changed, 67 insertions(+), 10 deletions(-) create mode 100644 core/src/main/scala/cats/instances/tailrec.scala create mode 100644 tests/src/test/scala/cats/tests/TailRecSuite.scala diff --git a/bench/src/main/scala/cats/bench/TrampolineBench.scala b/bench/src/main/scala/cats/bench/TrampolineBench.scala index e7f547f252..d79e6c81e2 100644 --- a/bench/src/main/scala/cats/bench/TrampolineBench.scala +++ b/bench/src/main/scala/cats/bench/TrampolineBench.scala @@ -1,11 +1,12 @@ package cats.bench import org.openjdk.jmh.annotations.{Benchmark, Scope, State} - import cats._ import cats.implicits._ import cats.free.Trampoline +import scala.util.control.TailCalls + @State(Scope.Benchmark) class TrampolineBench { @@ -30,14 +31,12 @@ class TrampolineBench { y <- Trampoline.defer(trampolineFib(n - 2)) } yield x + y - // TailRec[A] only has .flatMap in 2.11. + @Benchmark + def stdlib(): Int = stdlibFib(N).result - // @Benchmark - // def stdlib(): Int = stdlibFib(N).result - // - // def stdlibFib(n: Int): TailCalls.TailRec[Int] = - // if (n < 2) TailCalls.done(n) else for { - // x <- TailCalls.tailcall(stdlibFib(n - 1)) - // y <- TailCalls.tailcall(stdlibFib(n - 2)) - // } yield x + y + def stdlibFib(n: Int): TailCalls.TailRec[Int] = + if (n < 2) TailCalls.done(n) else for { + x <- TailCalls.tailcall(stdlibFib(n - 1)) + y <- TailCalls.tailcall(stdlibFib(n - 2)) + } yield x + y } diff --git a/core/src/main/scala/cats/Eval.scala b/core/src/main/scala/cats/Eval.scala index 53676f0e04..f477492785 100644 --- a/core/src/main/scala/cats/Eval.scala +++ b/core/src/main/scala/cats/Eval.scala @@ -378,6 +378,7 @@ sealed abstract private[cats] class EvalInstances extends EvalInstances0 { def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f) def extract[A](la: Eval[A]): A = la.value def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa)) + override def unit: Eval[Unit] = Eval.Unit } implicit val catsDeferForEval: Defer[Eval] = diff --git a/core/src/main/scala/cats/instances/all.scala b/core/src/main/scala/cats/instances/all.scala index 5dcc1a3e02..f823626109 100644 --- a/core/src/main/scala/cats/instances/all.scala +++ b/core/src/main/scala/cats/instances/all.scala @@ -69,3 +69,4 @@ trait AllInstancesBinCompat7 with VectorInstancesBinCompat1 with EitherInstancesBinCompat0 with StreamInstancesBinCompat1 + with TailRecInstances diff --git a/core/src/main/scala/cats/instances/package.scala b/core/src/main/scala/cats/instances/package.scala index 097bc4434b..8af3346ab2 100644 --- a/core/src/main/scala/cats/instances/package.scala +++ b/core/src/main/scala/cats/instances/package.scala @@ -39,6 +39,7 @@ package object instances { object sortedSet extends SortedSetInstances with SortedSetInstancesBinCompat0 with SortedSetInstancesBinCompat1 object stream extends StreamInstances with StreamInstancesBinCompat0 with StreamInstancesBinCompat1 object string extends StringInstances + object tailRec extends TailRecInstances object try_ extends TryInstances object tuple extends TupleInstances with Tuple2InstancesBinCompat0 object unit extends UnitInstances diff --git a/core/src/main/scala/cats/instances/tailrec.scala b/core/src/main/scala/cats/instances/tailrec.scala new file mode 100644 index 0000000000..d0e7bee15e --- /dev/null +++ b/core/src/main/scala/cats/instances/tailrec.scala @@ -0,0 +1,26 @@ +package cats.instances + +import cats.{Defer, StackSafeMonad} +import scala.util.control.TailCalls.{done, tailcall, TailRec} + +trait TailRecInstances { + implicit def catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] = + TailRecInstances.catsInstancesForTailRec +} + +private object TailRecInstances { + val catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] = + new StackSafeMonad[TailRec] with Defer[TailRec] { + def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa) + + def pure[A](a: A): TailRec[A] = done(a) + + override def map[A, B](fa: TailRec[A])(f: A => B): TailRec[B] = + fa.map(f) + + def flatMap[A, B](fa: TailRec[A])(f: A => TailRec[B]): TailRec[B] = + fa.flatMap(f) + + override val unit: TailRec[Unit] = done(()) + } +} diff --git a/tests/src/test/scala/cats/tests/TailRecSuite.scala b/tests/src/test/scala/cats/tests/TailRecSuite.scala new file mode 100644 index 0000000000..5bb31169c7 --- /dev/null +++ b/tests/src/test/scala/cats/tests/TailRecSuite.scala @@ -0,0 +1,29 @@ +package cats.tests + +import cats.{Defer, Eq, Monad} +import cats.laws.discipline.{DeferTests, MonadTests, SerializableTests} +import org.scalacheck.Arbitrary.arbitrary +import org.scalacheck.{Arbitrary, Cogen, Gen} + +import scala.util.control.TailCalls.{done, tailcall, TailRec} + +class TailRecSuite extends CatsSuite { + + implicit def tailRecArb[A: Arbitrary: Cogen]: Arbitrary[TailRec[A]] = + Arbitrary( + Gen.frequency( + (3, arbitrary[A].map(done)), + (1, Gen.lzy(arbitrary[(A, A => TailRec[A])].map { case (a, fn) => tailcall(fn(a)) })), + (1, Gen.lzy(arbitrary[(TailRec[A], A => TailRec[A])].map { case (a, fn) => a.flatMap(fn) })) + ) + ) + + implicit def eqTailRec[A: Eq]: Eq[TailRec[A]] = + Eq.by[TailRec[A], A](_.result) + + checkAll("TailRec[Int]", MonadTests[TailRec].monad[Int, Int, Int]) + checkAll("Monad[TailRec]", SerializableTests.serializable(Monad[TailRec])) + + checkAll("TailRec[Int]", DeferTests[TailRec].defer[Int]) + checkAll("Defer[TailRec]", SerializableTests.serializable(Defer[TailRec])) +}