From df3137fcb852d272ca9c1e65aede0bfb1ad93c1e Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Tue, 15 Nov 2016 12:22:26 -0800 Subject: [PATCH 1/4] Reduce stack depth in StateT --- core/src/main/scala/cats/data/StateT.scala | 42 +++++++++++++++------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index 65150e3a2a..f96f69479d 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -12,25 +12,32 @@ import cats.syntax.either._ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable { def flatMap[B](fas: A => StateT[F, S, B])(implicit F: Monad[F]): StateT[F, S, B] = - StateT(s => - F.flatMap(runF) { fsf => - F.flatMap(fsf(s)) { case (s, a) => + StateT.applyF(F.map(runF) { sfsa => + sfsa.andThen { fsa => + F.flatMap(fsa) { case (s, a) => fas(a).run(s) } - }) + } + }) def flatMapF[B](faf: A => F[B])(implicit F: Monad[F]): StateT[F, S, B] = - StateT(s => - F.flatMap(runF) { fsf => - F.flatMap(fsf(s)) { case (s, a) => - F.map(faf(a))((s, _)) - } + StateT.applyF(F.map(runF) { sfsa => + sfsa.andThen { fsa => + F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) } } - ) + }) def map[B](f: A => B)(implicit F: Monad[F]): StateT[F, S, B] = transform { case (s, a) => (s, f(a)) } + def product[B](sb: StateT[F, S, B])(implicit F: Monad[F]): StateT[F, S, (A, B)] = + StateT.applyF(F.map2(runF, sb.runF) { (ssa, ssb) => + ssa.andThen { fsa => + F.flatMap(fsa) { case (s, a) => + F.map(ssb(s)) { case (s, b) => (s, (a, b)) } + } + } + }) /** * Run with the provided initial state value */ @@ -70,9 +77,12 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable * Like [[map]], but also allows the state (`S`) value to be modified. */ def transform[B](f: (S, A) => (S, B))(implicit F: Monad[F]): StateT[F, S, B] = - transformF { fsa => - F.map(fsa){ case (s, a) => f(s, a) } - } + StateT.applyF( + F.map(runF) { sfsa => + sfsa.andThen { fsa => + F.map(fsa) { case (s, a) => f(s, a) } + } + }) /** * Like [[transform]], but allows the context to change from `F` to `G`. @@ -231,6 +241,12 @@ private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] { override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f) + override def ap[A, B](ff: StateT[F, S, A => B])(fa: StateT[F, S, A]): StateT[F, S, B] = + map2(ff, fa) { case (f, a) => f(a) } + + override def product[A, B](fa: StateT[F, S, A], fb: StateT[F, S, B]): StateT[F, S, (A, B)] = + fa.product(fb) + def tailRecM[A, B](a: A)(f: A => StateT[F, S, Either[A, B]]): StateT[F, S, B] = StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) { case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) } From 9bbbf5da82e0d4ff55e8f7edcef71915e1458f5c Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Tue, 15 Nov 2016 18:18:08 -0800 Subject: [PATCH 2/4] only use FlatMap for product --- core/src/main/scala/cats/data/StateT.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index f96f69479d..2e4fc875a7 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -30,7 +30,7 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable def map[B](f: A => B)(implicit F: Monad[F]): StateT[F, S, B] = transform { case (s, a) => (s, f(a)) } - def product[B](sb: StateT[F, S, B])(implicit F: Monad[F]): StateT[F, S, (A, B)] = + def product[B](sb: StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, (A, B)] = StateT.applyF(F.map2(runF, sb.runF) { (ssa, ssb) => ssa.andThen { fsa => F.flatMap(fsa) { case (s, a) => From 595be3140a28f479d4a2c5afdf81ffe1abfc0970 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Wed, 16 Nov 2016 11:36:54 -0800 Subject: [PATCH 3/4] Lower required typeclasses --- core/src/main/scala/cats/data/StateT.scala | 39 ++++++++++++------- .../test/scala/cats/tests/StateTTests.scala | 24 ++++++++++-- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index 2e4fc875a7..593a1530f4 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -11,7 +11,7 @@ import cats.syntax.either._ */ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable { - def flatMap[B](fas: A => StateT[F, S, B])(implicit F: Monad[F]): StateT[F, S, B] = + def flatMap[B](fas: A => StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, B] = StateT.applyF(F.map(runF) { sfsa => sfsa.andThen { fsa => F.flatMap(fsa) { case (s, a) => @@ -20,14 +20,14 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable } }) - def flatMapF[B](faf: A => F[B])(implicit F: Monad[F]): StateT[F, S, B] = + def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): StateT[F, S, B] = StateT.applyF(F.map(runF) { sfsa => sfsa.andThen { fsa => F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) } } }) - def map[B](f: A => B)(implicit F: Monad[F]): StateT[F, S, B] = + def map[B](f: A => B)(implicit F: Functor[F]): StateT[F, S, B] = transform { case (s, a) => (s, f(a)) } def product[B](sb: StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, (A, B)] = @@ -76,7 +76,7 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable /** * Like [[map]], but also allows the state (`S`) value to be modified. */ - def transform[B](f: (S, A) => (S, B))(implicit F: Monad[F]): StateT[F, S, B] = + def transform[B](f: (S, A) => (S, B))(implicit F: Functor[F]): StateT[F, S, B] = StateT.applyF( F.map(runF) { sfsa => sfsa.andThen { fsa => @@ -108,31 +108,31 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable * res1: Option[(GlobalEnv, Double)] = Some(((6,hello),5.0)) * }}} */ - def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Monad[F]): StateT[F, R, A] = - StateT { r => - F.flatMap(runF) { ff => + def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Functor[F]): StateT[F, R, A] = + StateT.applyF(F.map(runF) { sfsa => + { r: R => val s = f(r) - val nextState = ff(s) - F.map(nextState) { case (s, a) => (g(r, s), a) } + val fsa = sfsa(s) + F.map(fsa) { case (s, a) => (g(r, s), a) } } - } + }) /** * Modify the state (`S`) component. */ - def modify(f: S => S)(implicit F: Monad[F]): StateT[F, S, A] = + def modify(f: S => S)(implicit F: Functor[F]): StateT[F, S, A] = transform((s, a) => (f(s), a)) /** * Inspect a value from the input state, without modifying the state. */ - def inspect[B](f: S => B)(implicit F: Monad[F]): StateT[F, S, B] = + def inspect[B](f: S => B)(implicit F: Functor[F]): StateT[F, S, B] = transform((s, _) => (s, f(s))) /** * Get the input state, without modifying the state. */ - def get(implicit F: Monad[F]): StateT[F, S, S] = + def get(implicit F: Functor[F]): StateT[F, S, S] = inspect(identity) } @@ -192,11 +192,16 @@ private[data] sealed trait StateTInstances2 extends StateTInstances3 { new StateTSemigroupK[F, S] { implicit def F = F0; implicit def G = G0 } } -private[data] sealed trait StateTInstances3 { +private[data] sealed trait StateTInstances3 extends StateTInstances4 { implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] = new StateTMonad[F, S] { implicit def F = F0 } } +private[data] sealed trait StateTInstances4 { + implicit def catsDataFunctorForStateT[F[_], S](implicit F0: Functor[F]): Functor[StateT[F, S, ?]] = + new StateTFunctor[F, S] { implicit def F = F0 } +} + // To workaround SI-7139 `object State` needs to be defined inside the package object // together with the type alias. private[data] abstract class StateFunctions { @@ -230,6 +235,12 @@ private[data] abstract class StateFunctions { def set[S](s: S): State[S, Unit] = State(_ => (s, ())) } +private[data] sealed trait StateTFunctor[F[_], S] extends Functor[StateT[F, S, ?]] { + implicit def F: Functor[F] + + def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f) +} + private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] { implicit def F: Monad[F] diff --git a/tests/src/test/scala/cats/tests/StateTTests.scala b/tests/src/test/scala/cats/tests/StateTTests.scala index 8d1e876a92..28bb8f2403 100644 --- a/tests/src/test/scala/cats/tests/StateTTests.scala +++ b/tests/src/test/scala/cats/tests/StateTTests.scala @@ -30,7 +30,7 @@ class StateTTests extends CatsSuite { } test("State.get and StateT.get are consistent") { - forAll{ (s: String) => + forAll{ (s: String) => val state: State[String, String] = State.get val stateT: State[String, String] = StateT.get state.run(s) should === (stateT.run(s)) @@ -195,7 +195,25 @@ class StateTTests extends CatsSuite { } - implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataMonadForStateT(ListWrapper.monad)) + implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataFunctorForStateT(ListWrapper.monad)) + + { + // F has a Functor + implicit val F: Functor[ListWrapper] = ListWrapper.monad + // We only need a Functor on F to find a Functor on StateT + Functor[StateT[ListWrapper, Int, ?]] + } + + { + // F needs a Monad to do Eq on StateT + implicit val F: Monad[ListWrapper] = ListWrapper.monad + implicit val FS: Functor[StateT[ListWrapper, Int, ?]] = StateT.catsDataFunctorForStateT + + checkAll("StateT[ListWrapper, Int, Int]", FunctorTests[StateT[ListWrapper, Int, ?]].functor[Int, Int, Int]) + checkAll("Functor[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(Functor[StateT[ListWrapper, Int, ?]])) + + Functor[StateT[ListWrapper, Int, ?]] + } { // F has a Monad @@ -265,7 +283,7 @@ class StateTTests extends CatsSuite { // F has a MonadError implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]] implicit val eqEitherTFA: Eq[EitherT[StateT[Option, Int , ?], Unit, Int]] = EitherT.catsDataEqForEitherT[StateT[Option, Int , ?], Unit, Int] - + checkAll("StateT[Option, Int, Int]", MonadErrorTests[StateT[Option, Int , ?], Unit].monadError[Int, Int, Int]) checkAll("MonadError[StateT[Option, Int , ?], Unit]", SerializableTests.serializable(MonadError[StateT[Option, Int , ?], Unit])) } From d5e371f36fb659e2c392890d8e8dc5d9d69ce30b Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Tue, 22 Nov 2016 10:32:39 -1000 Subject: [PATCH 4/4] Add more methods --- core/src/main/scala/cats/data/StateT.scala | 27 ++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index 593a1530f4..4f8b6f8b0c 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -30,14 +30,27 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable def map[B](f: A => B)(implicit F: Functor[F]): StateT[F, S, B] = transform { case (s, a) => (s, f(a)) } - def product[B](sb: StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, (A, B)] = + def map2[B, Z](sb: StateT[F, S, B])(fn: (A, B) => Z)(implicit F: FlatMap[F]): StateT[F, S, Z] = StateT.applyF(F.map2(runF, sb.runF) { (ssa, ssb) => ssa.andThen { fsa => F.flatMap(fsa) { case (s, a) => - F.map(ssb(s)) { case (s, b) => (s, (a, b)) } + F.map(ssb(s)) { case (s, b) => (s, fn(a, b)) } } } }) + + def map2Eval[B, Z](sb: Eval[StateT[F, S, B]])(fn: (A, B) => Z)(implicit F: FlatMap[F]): Eval[StateT[F, S, Z]] = + F.map2Eval(runF, sb.map(_.runF)) { (ssa, ssb) => + ssa.andThen { fsa => + F.flatMap(fsa) { case (s, a) => + F.map(ssb(s)) { case (s, b) => (s, fn(a, b)) } + } + } + }.map(StateT.applyF) + + def product[B](sb: StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, (A, B)] = + map2(sb)((_, _)) + /** * Run with the provided initial state value */ @@ -250,10 +263,16 @@ private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] { def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] = fa.flatMap(f) + override def ap[A, B](ff: StateT[F, S, A => B])(fa: StateT[F, S, A]): StateT[F, S, B] = + ff.map2(fa) { case (f, a) => f(a) } + override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f) - override def ap[A, B](ff: StateT[F, S, A => B])(fa: StateT[F, S, A]): StateT[F, S, B] = - map2(ff, fa) { case (f, a) => f(a) } + override def map2[A, B, Z](fa: StateT[F, S, A], fb: StateT[F, S, B])(fn: (A, B) => Z): StateT[F, S, Z] = + fa.map2(fb)(fn) + + override def map2Eval[A, B, Z](fa: StateT[F, S, A], fb: Eval[StateT[F, S, B]])(fn: (A, B) => Z): Eval[StateT[F, S, Z]] = + fa.map2Eval(fb)(fn) override def product[A, B](fa: StateT[F, S, A], fb: StateT[F, S, B]): StateT[F, S, (A, B)] = fa.product(fb)