From 2023db1e52d39c235e7205ed409e570edf505b10 Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Mon, 13 Jan 2020 11:58:12 -0700 Subject: [PATCH] Reimplemented MonadError[FreeT[...]] to be correct (#3241) * Reimplemented MonadError[FreeT[...]] to be correct * Completely forgot to apply scalafmt... * Binary compatibility silliness * Fixed typo in comment * Added ApplicativeError distributivity laws * Added Defer[FreeT[...]] and used it to ensure handleErrorWith is stack-safe (in and of itself) * Applied scalafmt --- free/src/main/scala/cats/free/FreeT.scala | 67 ++++++++++++++++++- .../src/test/scala/cats/free/FreeTSuite.scala | 11 +++ .../cats/laws/ApplicativeErrorLaws.scala | 15 +++++ .../discipline/ApplicativeErrorTests.scala | 8 ++- 4 files changed, 99 insertions(+), 2 deletions(-) diff --git a/free/src/main/scala/cats/free/FreeT.scala b/free/src/main/scala/cats/free/FreeT.scala index ffd06aa1c8..caa8b9e69d 100644 --- a/free/src/main/scala/cats/free/FreeT.scala +++ b/free/src/main/scala/cats/free/FreeT.scala @@ -220,7 +220,12 @@ object FreeT extends FreeTInstances { } sealed abstract private[free] class FreeTInstances extends FreeTInstances0 { - implicit def catsFreeMonadErrorForFreeT[S[_], M[_], E](implicit E: MonadError[M, E]): MonadError[FreeT[S, M, *], E] = + + // retained for binary compatibility. its results are incorrect though and it would fail the laws if we generated things of the form pure(()).flatMap(_ => fa) + @deprecated("does not handle errors beyond the head suspension; use catsFreeMonadErrorForFreeT2", "2.1.0") + def catsFreeMonadErrorForFreeT[S[_], M[_], E]( + implicit E: MonadError[M, E] + ): MonadError[FreeT[S, M, *], E] = new MonadError[FreeT[S, M, *], E] with FreeTMonad[S, M] { override def M = E override def handleErrorWith[A](fa: FreeT[S, M, A])(f: E => FreeT[S, M, A]) = @@ -228,6 +233,66 @@ sealed abstract private[free] class FreeTInstances extends FreeTInstances0 { override def raiseError[A](e: E) = FreeT.liftT(E.raiseError[A](e))(M) } + + // not to be confused with defer... which is something different... sigh... + implicit def catsDeferForFreeT[S[_], M[_]: Applicative]: Defer[FreeT[S, M, *]] = + new Defer[FreeT[S, M, *]] { + def defer[A](fa: => FreeT[S, M, A]): FreeT[S, M, A] = + FreeT.pure[S, M, Unit](()).flatMap(_ => fa) + } + + implicit def catsFreeMonadErrorForFreeT2[S[_], M[_], E](implicit E: MonadError[M, E], + S: Functor[S]): MonadError[FreeT[S, M, *], E] = + new MonadError[FreeT[S, M, *], E] with FreeTMonad[S, M] { + override def M = E + + private[this] val RealDefer = catsDeferForFreeT[S, M] + + /* + * Quick explanation... The previous version of this function (retained above for + * bincompat) was only able to look at the *top* level M[_] suspension in a Free + * program. Any suspensions below that in the compute tree were invisible. Thus, + * if there were errors in that top level suspension, then they would be handled + * by the delegate. Errors buried further in the tree were unhandled. This is most + * easily visualized by the following two expressions: + * + * - fa + * - pure(()).flatMap(_ => fa) + * + * By the monad laws, these *should* be identical in effect, but they do have + * different structural representations within FreeT. More specifically, the latter + * has a meaningless M[_] suspension which sits in front of the rest of the + * computation. The previous iteration of this function would be blind to fa in + * the latter encoding, while it would handle it correctly in the former. + * + * Historical sidebar: this became visible because of the "shift" mechanism in + * Kleisli. + */ + override def handleErrorWith[A](fa: FreeT[S, M, A])(f: E => FreeT[S, M, A]) = { + val ft = FreeT.liftT[S, M, FreeT[S, M, A]] { + val resultsM = E.map(fa.resume) { + case Left(se) => + // we defer here in order to ensure stack-safety in the results even when M[_] is not itself stack-safe + // there's some small performance loss as a consequence, but really, if you care that much about performance, why are you using FreeT? + RealDefer.defer(FreeT.liftF[S, M, FreeT[S, M, A]](S.map(se)(handleErrorWith(_)(f))).flatMap(identity)) + + case Right(a) => + pure(a) + } + + E.handleErrorWith(resultsM) { e => + E.map(f(e).resume) { eth => + FreeT.defer(E.pure(eth.swap)) // why on earth is defer inconsistent with resume?? + } + } + } + + ft.flatMap(identity) + } + + override def raiseError[A](e: E) = + FreeT.liftT(E.raiseError[A](e))(M) + } } sealed abstract private[free] class FreeTInstances0 extends FreeTInstances1 { diff --git a/free/src/test/scala/cats/free/FreeTSuite.scala b/free/src/test/scala/cats/free/FreeTSuite.scala index 5fcdd8f55a..2210a23816 100644 --- a/free/src/test/scala/cats/free/FreeTSuite.scala +++ b/free/src/test/scala/cats/free/FreeTSuite.scala @@ -44,6 +44,8 @@ class FreeTSuite extends CatsSuite { SerializableTests.serializable(MonadError[FreeTOption, Unit])) } + checkAll("FreeT[Option, Option, Int", DeferTests[FreeTOption].defer[Int]) + test("FlatMap stack safety tested with 50k flatMaps") { val expected = Applicative[FreeTOption].pure(()) val result = @@ -115,6 +117,15 @@ class FreeTSuite extends CatsSuite { } } + // NB: this does not analogously cause problems for the SemigroupK implementation as semigroup's effects associate while errors do not + test("handle errors in non-head suspensions") { + type F[A] = FreeT[Id, Option, A] + val F = MonadError[F, Unit] + + val eff = F.flatMap(F.pure(()))(_ => F.raiseError[String](())) + F.attempt(eff).runM(Some(_)) should ===(Some(Left(()))) + } + sealed trait Test1Algebra[A] case class Test1[A](value: Int, f: Int => A) extends Test1Algebra[A] diff --git a/laws/src/main/scala/cats/laws/ApplicativeErrorLaws.scala b/laws/src/main/scala/cats/laws/ApplicativeErrorLaws.scala index 2b6c079a82..dd061468ee 100644 --- a/laws/src/main/scala/cats/laws/ApplicativeErrorLaws.scala +++ b/laws/src/main/scala/cats/laws/ApplicativeErrorLaws.scala @@ -54,6 +54,21 @@ trait ApplicativeErrorLaws[F[_], E] extends ApplicativeLaws[F] { def redeemDerivedFromAttemptMap[A, B](fa: F[A], fe: E => B, fs: A => B): IsEq[F[B]] = F.redeem(fa)(fe, fs) <-> F.map(F.attempt(fa))(_.fold(fe, fs)) + + /* + * These laws, taken together with applicativeErrorHandle, show that errors dominate in + * ap, *and* show that handle has lexical semantics over ap. F.unit is used in both laws + * because we don't have another way of expressing "an F[_] which does *not* contain any + * errors". We could make these laws considerably stronger if such a thing were + * expressible. Specifically, what we're missing here is the ability to say that + * raiseError distributes over an *arbitrary* number of aps. + */ + + def raiseErrorDistributesOverApLeft[A](h: E => F[A], e: E) = + F.handleErrorWith(F.ap(F.raiseError[Unit => A](e))(F.unit))(h) <-> h(e) + + def raiseErrorDistributesOverApRight[A](h: E => F[A], e: E) = + F.handleErrorWith(F.ap(F.pure((a: A) => a))(F.raiseError[A](e)))(h) <-> h(e) } object ApplicativeErrorLaws { diff --git a/laws/src/main/scala/cats/laws/discipline/ApplicativeErrorTests.scala b/laws/src/main/scala/cats/laws/discipline/ApplicativeErrorTests.scala index 9e89895c9e..e53d3918fd 100644 --- a/laws/src/main/scala/cats/laws/discipline/ApplicativeErrorTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/ApplicativeErrorTests.scala @@ -56,7 +56,13 @@ trait ApplicativeErrorTests[F[_], E] extends ApplicativeTests[F] { "applicativeError onError raise" -> forAll(laws.onErrorRaise[A] _), "applicativeError adaptError pure" -> forAll(laws.adaptErrorPure[A] _), "applicativeError adaptError raise" -> forAll(laws.adaptErrorRaise[A] _), - "applicativeError redeem is derived from attempt and map" -> forAll(laws.redeemDerivedFromAttemptMap[A, B] _) + "applicativeError redeem is derived from attempt and map" -> forAll(laws.redeemDerivedFromAttemptMap[A, B] _), + "applicativeError handleError . raiseError left-distributes over ap" -> forAll( + laws.raiseErrorDistributesOverApLeft[A] _ + ), + "applicativeError handleError . raiseError right-distributes over ap" -> forAll( + laws.raiseErrorDistributesOverApRight[A] _ + ) ) } }