Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make StackSafeMonad extend Defer #3962

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/main/scala/cats/Eval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ sealed abstract private[cats] class EvalInstances extends EvalInstances0 {
def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa))
override def unit: Eval[Unit] = Eval.Unit
override def void[A](a: Eval[A]): Eval[Unit] = Eval.Unit
override def defer[A](fa: => Eval[A]): Eval[A] = Eval.defer(fa)
}

implicit val catsDeferForEval: Defer[Eval] =
Expand Down
28 changes: 27 additions & 1 deletion core/src/main/scala/cats/StackSafeMonad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,38 @@ import scala.util.{Either, Left, Right}
* will inherit will not be sound, and will result in unexpected stack overflows. This
* trait is only provided because a large number of monads ''do'' define a stack-safe
* flatMap, and so this particular implementation was being repeated over and over again.
*
* Note, tailRecM being safe and pure implies that the function passed to flatMap
* is not called immediately on the current stack (since otherwise tailRecM would
* stack overflow for sufficiently deep recursions). This implies we can implement
*
* defer(fa) = unit.flatMap(_ => fa)
*/
trait StackSafeMonad[F[_]] extends Monad[F] {
trait StackSafeMonad[F[_]] extends Monad[F] with Defer[F] {

override def tailRecM[A, B](a: A)(f: A => F[Either[A, B]]): F[B] =
flatMap(f(a)) {
case Left(a) => tailRecM(a)(f)
case Right(b) => pure(b)
}

/*
* This is always safe for a StackSafeMonad.
* proof: we know flatMap can't blow the stack
* because if it could, tailRecM would not be safe:
* if the function was called in the same stack then
* the depth would diverse on tailRecM(())(_ => pure(Left(())))
*
* It may be better to override this for your particular Monad
*/
def defer[A](fa: => F[A]): F[A] =
flatMap(unit)(_ => fa)
}

object StackSafeMonad {
def shiftFunctor[F[_], A, B](fn: A => F[B])(implicit F: Functor[F]): A => F[B] =
F match {
case ssm: StackSafeMonad[F] @unchecked => { a => ssm.defer(fn(a)) }
case _ => fn
}
Comment on lines +41 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's a way to make this more general. A => F[B] works extremely well for Kleisli and extremely not-well for almost anything else, since it will force boxing into a Function1 in cases where it just isn't needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it shouldn't be public at all. I just took what was in Kleisli and made it more particular. I think the current Kleisli hack is adding cost for any Monad that isn't lazy (e.g. Option, or Either), and this function will help there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When in doubt, leave it private. Can always publicize it later if someone finds a second use.

}
19 changes: 7 additions & 12 deletions core/src/main/scala/cats/data/Kleisli.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import cats.evidence.As
final case class Kleisli[F[_], -A, B](run: A => F[B]) { self =>

private[data] def ap[C, AA <: A](f: Kleisli[F, AA, B => C])(implicit F: Apply[F]): Kleisli[F, AA, C] =
Kleisli(a => F.ap(f.run(a))(run(a)))
Kleisli(StackSafeMonad.shiftFunctor(a => F.ap(f.run(a))(run(a))))

def ap[C, D, AA <: A](f: Kleisli[F, AA, C])(implicit F: Apply[F], ev: B As (C => D)): Kleisli[F, AA, D] = {
Kleisli { a =>
Kleisli(StackSafeMonad.shiftFunctor { a =>
val fb: F[C => D] = F.map(run(a))(ev.coerce)
val fc: F[C] = f.run(a)
F.ap(fb)(fc)
}
})
}

/**
Expand Down Expand Up @@ -97,7 +97,7 @@ final case class Kleisli[F[_], -A, B](run: A => F[B]) { self =>
* }}}
*/
def local[AA](f: AA => A): Kleisli[F, AA, B] =
Kleisli(aa => run(f(aa)))
Kleisli(AndThen(f).andThen(run))

@deprecated("Use mapK", "1.0.0-RC2")
private[cats] def transform[G[_]](f: FunctionK[F, G]): Kleisli[G, A, B] =
Expand Down Expand Up @@ -156,12 +156,7 @@ object Kleisli
* in `flatMap`.
*/
private[data] def shift[F[_], A, B](run: A => F[B])(implicit F: FlatMap[F]): Kleisli[F, A, B] =
F match {
case ap: Applicative[F] @unchecked =>
Kleisli(r => F.flatMap(ap.pure(r))(run))
case _ =>
Kleisli(run)
}
Kleisli(StackSafeMonad.shiftFunctor(run))

/**
* Creates a `FunctionK` that transforms a `Kleisli[F, A, B]` into an `F[B]` by applying the value of type `a:A`.
Expand Down Expand Up @@ -660,14 +655,14 @@ private[data] trait KleisliApply[F[_], A] extends Apply[Kleisli[F, A, *]] with K
// We should only evaluate fb once
val memoFb = fb.memoize

Eval.now(Kleisli { a =>
Eval.now(Kleisli(StackSafeMonad.shiftFunctor { a =>
val fb = fa.run(a)
val efc = memoFb.map(_.run(a))
val efz: Eval[F[Z]] = F.map2Eval(fb, efc)(f)
// This is not safe and results in stack overflows:
// see: https://github.com/typelevel/cats/issues/3947
efz.value
})
}))
}

override def product[B, C](fb: Kleisli[F, A, B], fc: Kleisli[F, A, C]): Kleisli[F, A, (B, C)] =
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/cats/instances/tailrec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ trait TailRecInstances {
}

private object TailRecInstances {
val catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] =
new StackSafeMonad[TailRec] with Defer[TailRec] {
def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa)
val catsInstancesForTailRec: StackSafeMonad[TailRec] =
new StackSafeMonad[TailRec] {
override def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa)

def pure[A](a: A): TailRec[A] = done(a)

Expand Down