From dff0550d699406163c3649c3f5f6fc79c3d3a4f2 Mon Sep 17 00:00:00 2001 From: Patrick Oscar Boykin Date: Mon, 9 Aug 2021 19:31:26 -1000 Subject: [PATCH 1/2] Make StackSafeMonad extend Defer --- core/src/main/scala/cats/Eval.scala | 1 + core/src/main/scala/cats/StackSafeMonad.scala | 28 ++++++++++++++++++- core/src/main/scala/cats/data/Kleisli.scala | 19 +++++-------- .../main/scala/cats/instances/tailrec.scala | 2 +- 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/cats/Eval.scala b/core/src/main/scala/cats/Eval.scala index 3b52cfcb22..0fe7093409 100644 --- a/core/src/main/scala/cats/Eval.scala +++ b/core/src/main/scala/cats/Eval.scala @@ -375,6 +375,7 @@ sealed abstract private[cats] class EvalInstances extends EvalInstances0 { def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa)) override def unit: Eval[Unit] = Eval.Unit override def void[A](a: Eval[A]): Eval[Unit] = Eval.Unit + override def defer[A](fa: => Eval[A]): Eval[A] = Eval.defer(fa) } implicit val catsDeferForEval: Defer[Eval] = diff --git a/core/src/main/scala/cats/StackSafeMonad.scala b/core/src/main/scala/cats/StackSafeMonad.scala index 8b815dc6f1..a98cf1aeda 100644 --- a/core/src/main/scala/cats/StackSafeMonad.scala +++ b/core/src/main/scala/cats/StackSafeMonad.scala @@ -9,12 +9,38 @@ import scala.util.{Either, Left, Right} * will inherit will not be sound, and will result in unexpected stack overflows. This * trait is only provided because a large number of monads ''do'' define a stack-safe * flatMap, and so this particular implementation was being repeated over and over again. + * + * Note, tailRecM being safe and pure implies that the function passed to flatMap + * is not called immediately on the current stack (since otherwise tailRecM would + * stack overflow for sufficiently deep recursions). This implies we can implement + * + * defer(fa) = unit.flatMap(_ => fa) */ -trait StackSafeMonad[F[_]] extends Monad[F] { +trait StackSafeMonad[F[_]] extends Monad[F] with Defer[F] { override def tailRecM[A, B](a: A)(f: A => F[Either[A, B]]): F[B] = flatMap(f(a)) { case Left(a) => tailRecM(a)(f) case Right(b) => pure(b) } + + /** + * This is always safe for a StackSafeMonad. + * proof: we know flatMap can't blow the stack + * because if it could, tailRecM would not be safe: + * if the function was called in the same stack then + * the depth would diverse on tailRecM(())(_ => pure(Left(()))) + * + * It may be better to override this for your particular Monad + */ + def defer[A](fa: => F[A]): F[A] = + flatMap(unit)(_ => fa) +} + +object StackSafeMonad { + def shiftFunctor[F[_], A, B](fn: A => F[B])(implicit F: Functor[F]): A => F[B] = + F match { + case ssm: StackSafeMonad[F] @unchecked => { a => ssm.defer(fn(a)) } + case _ => fn + } } diff --git a/core/src/main/scala/cats/data/Kleisli.scala b/core/src/main/scala/cats/data/Kleisli.scala index 2374161467..767c971a01 100644 --- a/core/src/main/scala/cats/data/Kleisli.scala +++ b/core/src/main/scala/cats/data/Kleisli.scala @@ -11,14 +11,14 @@ import cats.evidence.As final case class Kleisli[F[_], -A, B](run: A => F[B]) { self => private[data] def ap[C, AA <: A](f: Kleisli[F, AA, B => C])(implicit F: Apply[F]): Kleisli[F, AA, C] = - Kleisli(a => F.ap(f.run(a))(run(a))) + Kleisli(StackSafeMonad.shiftFunctor(a => F.ap(f.run(a))(run(a)))) def ap[C, D, AA <: A](f: Kleisli[F, AA, C])(implicit F: Apply[F], ev: B As (C => D)): Kleisli[F, AA, D] = { - Kleisli { a => + Kleisli(StackSafeMonad.shiftFunctor { a => val fb: F[C => D] = F.map(run(a))(ev.coerce) val fc: F[C] = f.run(a) F.ap(fb)(fc) - } + }) } /** @@ -97,7 +97,7 @@ final case class Kleisli[F[_], -A, B](run: A => F[B]) { self => * }}} */ def local[AA](f: AA => A): Kleisli[F, AA, B] = - Kleisli(aa => run(f(aa))) + Kleisli(AndThen(f).andThen(run)) @deprecated("Use mapK", "1.0.0-RC2") private[cats] def transform[G[_]](f: FunctionK[F, G]): Kleisli[G, A, B] = @@ -156,12 +156,7 @@ object Kleisli * in `flatMap`. */ private[data] def shift[F[_], A, B](run: A => F[B])(implicit F: FlatMap[F]): Kleisli[F, A, B] = - F match { - case ap: Applicative[F] @unchecked => - Kleisli(r => F.flatMap(ap.pure(r))(run)) - case _ => - Kleisli(run) - } + Kleisli(StackSafeMonad.shiftFunctor(run)) /** * Creates a `FunctionK` that transforms a `Kleisli[F, A, B]` into an `F[B]` by applying the value of type `a:A`. @@ -660,14 +655,14 @@ private[data] trait KleisliApply[F[_], A] extends Apply[Kleisli[F, A, *]] with K // We should only evaluate fb once val memoFb = fb.memoize - Eval.now(Kleisli { a => + Eval.now(Kleisli(StackSafeMonad.shiftFunctor { a => val fb = fa.run(a) val efc = memoFb.map(_.run(a)) val efz: Eval[F[Z]] = F.map2Eval(fb, efc)(f) // This is not safe and results in stack overflows: // see: https://github.com/typelevel/cats/issues/3947 efz.value - }) + })) } override def product[B, C](fb: Kleisli[F, A, B], fc: Kleisli[F, A, C]): Kleisli[F, A, (B, C)] = diff --git a/core/src/main/scala/cats/instances/tailrec.scala b/core/src/main/scala/cats/instances/tailrec.scala index fc125a2ae4..d21918bfd5 100644 --- a/core/src/main/scala/cats/instances/tailrec.scala +++ b/core/src/main/scala/cats/instances/tailrec.scala @@ -11,7 +11,7 @@ trait TailRecInstances { private object TailRecInstances { val catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] = new StackSafeMonad[TailRec] with Defer[TailRec] { - def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa) + override def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa) def pure[A](a: A): TailRec[A] = done(a) From 8e4fff6d2d93969473ce8c161c4548e667df15d6 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Sun, 12 Sep 2021 13:13:03 -1000 Subject: [PATCH 2/2] Apply suggestions from code review Accept @djspiewak's suggestions Co-authored-by: Daniel Spiewak --- core/src/main/scala/cats/StackSafeMonad.scala | 2 +- core/src/main/scala/cats/instances/tailrec.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/cats/StackSafeMonad.scala b/core/src/main/scala/cats/StackSafeMonad.scala index a98cf1aeda..9cb79e7737 100644 --- a/core/src/main/scala/cats/StackSafeMonad.scala +++ b/core/src/main/scala/cats/StackSafeMonad.scala @@ -24,7 +24,7 @@ trait StackSafeMonad[F[_]] extends Monad[F] with Defer[F] { case Right(b) => pure(b) } - /** + /* * This is always safe for a StackSafeMonad. * proof: we know flatMap can't blow the stack * because if it could, tailRecM would not be safe: diff --git a/core/src/main/scala/cats/instances/tailrec.scala b/core/src/main/scala/cats/instances/tailrec.scala index d21918bfd5..8ccfc1e3fe 100644 --- a/core/src/main/scala/cats/instances/tailrec.scala +++ b/core/src/main/scala/cats/instances/tailrec.scala @@ -9,8 +9,8 @@ trait TailRecInstances { } private object TailRecInstances { - val catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] = - new StackSafeMonad[TailRec] with Defer[TailRec] { + val catsInstancesForTailRec: StackSafeMonad[TailRec] = + new StackSafeMonad[TailRec] { override def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa) def pure[A](a: A): TailRec[A] = done(a)