Skip to content

Commit

Permalink
Reimplemented MonadError[FreeT[...]] to be correct (#3241)
Browse files Browse the repository at this point in the history
* Reimplemented MonadError[FreeT[...]] to be correct

* Completely forgot to apply scalafmt...

* Binary compatibility silliness

* Fixed typo in comment

* Added ApplicativeError distributivity laws

* Added Defer[FreeT[...]] and used it to ensure handleErrorWith is stack-safe (in and of itself)

* Applied scalafmt
  • Loading branch information
djspiewak authored and LukaJCB committed Jan 13, 2020
1 parent bf14649 commit 2023db1
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 2 deletions.
67 changes: 66 additions & 1 deletion free/src/main/scala/cats/free/FreeT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,79 @@ object FreeT extends FreeTInstances {
}

sealed abstract private[free] class FreeTInstances extends FreeTInstances0 {
implicit def catsFreeMonadErrorForFreeT[S[_], M[_], E](implicit E: MonadError[M, E]): MonadError[FreeT[S, M, *], E] =

// retained for binary compatibility. its results are incorrect though and it would fail the laws if we generated things of the form pure(()).flatMap(_ => fa)
@deprecated("does not handle errors beyond the head suspension; use catsFreeMonadErrorForFreeT2", "2.1.0")
def catsFreeMonadErrorForFreeT[S[_], M[_], E](
implicit E: MonadError[M, E]
): MonadError[FreeT[S, M, *], E] =
new MonadError[FreeT[S, M, *], E] with FreeTMonad[S, M] {
override def M = E
override def handleErrorWith[A](fa: FreeT[S, M, A])(f: E => FreeT[S, M, A]) =
FreeT.liftT[S, M, FreeT[S, M, A]](E.handleErrorWith(fa.toM)(f.andThen(_.toM)))(M).flatMap(identity)
override def raiseError[A](e: E) =
FreeT.liftT(E.raiseError[A](e))(M)
}

// not to be confused with defer... which is something different... sigh...
implicit def catsDeferForFreeT[S[_], M[_]: Applicative]: Defer[FreeT[S, M, *]] =
new Defer[FreeT[S, M, *]] {
def defer[A](fa: => FreeT[S, M, A]): FreeT[S, M, A] =
FreeT.pure[S, M, Unit](()).flatMap(_ => fa)
}

implicit def catsFreeMonadErrorForFreeT2[S[_], M[_], E](implicit E: MonadError[M, E],
S: Functor[S]): MonadError[FreeT[S, M, *], E] =
new MonadError[FreeT[S, M, *], E] with FreeTMonad[S, M] {
override def M = E

private[this] val RealDefer = catsDeferForFreeT[S, M]

/*
* Quick explanation... The previous version of this function (retained above for
* bincompat) was only able to look at the *top* level M[_] suspension in a Free
* program. Any suspensions below that in the compute tree were invisible. Thus,
* if there were errors in that top level suspension, then they would be handled
* by the delegate. Errors buried further in the tree were unhandled. This is most
* easily visualized by the following two expressions:
*
* - fa
* - pure(()).flatMap(_ => fa)
*
* By the monad laws, these *should* be identical in effect, but they do have
* different structural representations within FreeT. More specifically, the latter
* has a meaningless M[_] suspension which sits in front of the rest of the
* computation. The previous iteration of this function would be blind to fa in
* the latter encoding, while it would handle it correctly in the former.
*
* Historical sidebar: this became visible because of the "shift" mechanism in
* Kleisli.
*/
override def handleErrorWith[A](fa: FreeT[S, M, A])(f: E => FreeT[S, M, A]) = {
val ft = FreeT.liftT[S, M, FreeT[S, M, A]] {
val resultsM = E.map(fa.resume) {
case Left(se) =>
// we defer here in order to ensure stack-safety in the results even when M[_] is not itself stack-safe
// there's some small performance loss as a consequence, but really, if you care that much about performance, why are you using FreeT?
RealDefer.defer(FreeT.liftF[S, M, FreeT[S, M, A]](S.map(se)(handleErrorWith(_)(f))).flatMap(identity))

case Right(a) =>
pure(a)
}

E.handleErrorWith(resultsM) { e =>
E.map(f(e).resume) { eth =>
FreeT.defer(E.pure(eth.swap)) // why on earth is defer inconsistent with resume??
}
}
}

ft.flatMap(identity)
}

override def raiseError[A](e: E) =
FreeT.liftT(E.raiseError[A](e))(M)
}
}

sealed abstract private[free] class FreeTInstances0 extends FreeTInstances1 {
Expand Down
11 changes: 11 additions & 0 deletions free/src/test/scala/cats/free/FreeTSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class FreeTSuite extends CatsSuite {
SerializableTests.serializable(MonadError[FreeTOption, Unit]))
}

checkAll("FreeT[Option, Option, Int", DeferTests[FreeTOption].defer[Int])

test("FlatMap stack safety tested with 50k flatMaps") {
val expected = Applicative[FreeTOption].pure(())
val result =
Expand Down Expand Up @@ -115,6 +117,15 @@ class FreeTSuite extends CatsSuite {
}
}

// NB: this does not analogously cause problems for the SemigroupK implementation as semigroup's effects associate while errors do not
test("handle errors in non-head suspensions") {
type F[A] = FreeT[Id, Option, A]
val F = MonadError[F, Unit]

val eff = F.flatMap(F.pure(()))(_ => F.raiseError[String](()))
F.attempt(eff).runM(Some(_)) should ===(Some(Left(())))
}

sealed trait Test1Algebra[A]

case class Test1[A](value: Int, f: Int => A) extends Test1Algebra[A]
Expand Down
15 changes: 15 additions & 0 deletions laws/src/main/scala/cats/laws/ApplicativeErrorLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ trait ApplicativeErrorLaws[F[_], E] extends ApplicativeLaws[F] {

def redeemDerivedFromAttemptMap[A, B](fa: F[A], fe: E => B, fs: A => B): IsEq[F[B]] =
F.redeem(fa)(fe, fs) <-> F.map(F.attempt(fa))(_.fold(fe, fs))

/*
* These laws, taken together with applicativeErrorHandle, show that errors dominate in
* ap, *and* show that handle has lexical semantics over ap. F.unit is used in both laws
* because we don't have another way of expressing "an F[_] which does *not* contain any
* errors". We could make these laws considerably stronger if such a thing were
* expressible. Specifically, what we're missing here is the ability to say that
* raiseError distributes over an *arbitrary* number of aps.
*/

def raiseErrorDistributesOverApLeft[A](h: E => F[A], e: E) =
F.handleErrorWith(F.ap(F.raiseError[Unit => A](e))(F.unit))(h) <-> h(e)

def raiseErrorDistributesOverApRight[A](h: E => F[A], e: E) =
F.handleErrorWith(F.ap(F.pure((a: A) => a))(F.raiseError[A](e)))(h) <-> h(e)
}

object ApplicativeErrorLaws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ trait ApplicativeErrorTests[F[_], E] extends ApplicativeTests[F] {
"applicativeError onError raise" -> forAll(laws.onErrorRaise[A] _),
"applicativeError adaptError pure" -> forAll(laws.adaptErrorPure[A] _),
"applicativeError adaptError raise" -> forAll(laws.adaptErrorRaise[A] _),
"applicativeError redeem is derived from attempt and map" -> forAll(laws.redeemDerivedFromAttemptMap[A, B] _)
"applicativeError redeem is derived from attempt and map" -> forAll(laws.redeemDerivedFromAttemptMap[A, B] _),
"applicativeError handleError . raiseError left-distributes over ap" -> forAll(
laws.raiseErrorDistributesOverApLeft[A] _
),
"applicativeError handleError . raiseError right-distributes over ap" -> forAll(
laws.raiseErrorDistributesOverApRight[A] _
)
)
}
}
Expand Down

0 comments on commit 2023db1

Please sign in to comment.