Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce stack depth in StateT #1466

Merged
merged 4 commits into from
Dec 31, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 73 additions & 27 deletions core/src/main/scala/cats/data/StateT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,46 @@ import cats.syntax.either._
*/
final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable {

def flatMap[B](fas: A => StateT[F, S, B])(implicit F: Monad[F]): StateT[F, S, B] =
StateT(s =>
F.flatMap(runF) { fsf =>
F.flatMap(fsf(s)) { case (s, a) =>
def flatMap[B](fas: A => StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, B] =
StateT.applyF(F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) =>
fas(a).run(s)
}
})
}
})

def flatMapF[B](faf: A => F[B])(implicit F: Monad[F]): StateT[F, S, B] =
StateT(s =>
F.flatMap(runF) { fsf =>
F.flatMap(fsf(s)) { case (s, a) =>
F.map(faf(a))((s, _))
}
def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): StateT[F, S, B] =
StateT.applyF(F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) }
}
)
})

def map[B](f: A => B)(implicit F: Monad[F]): StateT[F, S, B] =
def map[B](f: A => B)(implicit F: Functor[F]): StateT[F, S, B] =
transform { case (s, a) => (s, f(a)) }

def map2[B, Z](sb: StateT[F, S, B])(fn: (A, B) => Z)(implicit F: FlatMap[F]): StateT[F, S, Z] =
StateT.applyF(F.map2(runF, sb.runF) { (ssa, ssb) =>
ssa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) =>
F.map(ssb(s)) { case (s, b) => (s, fn(a, b)) }
}
}
})

def map2Eval[B, Z](sb: Eval[StateT[F, S, B]])(fn: (A, B) => Z)(implicit F: FlatMap[F]): Eval[StateT[F, S, Z]] =
F.map2Eval(runF, sb.map(_.runF)) { (ssa, ssb) =>
ssa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) =>
F.map(ssb(s)) { case (s, b) => (s, fn(a, b)) }
}
}
}.map(StateT.applyF)

def product[B](sb: StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, (A, B)] =
map2(sb)((_, _))

/**
* Run with the provided initial state value
*/
Expand Down Expand Up @@ -69,10 +89,13 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable
/**
* Like [[map]], but also allows the state (`S`) value to be modified.
*/
def transform[B](f: (S, A) => (S, B))(implicit F: Monad[F]): StateT[F, S, B] =
transformF { fsa =>
F.map(fsa){ case (s, a) => f(s, a) }
}
def transform[B](f: (S, A) => (S, B))(implicit F: Functor[F]): StateT[F, S, B] =
StateT.applyF(
F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
F.map(fsa) { case (s, a) => f(s, a) }
}
})

/**
* Like [[transform]], but allows the context to change from `F` to `G`.
Expand All @@ -98,31 +121,31 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable
* res1: Option[(GlobalEnv, Double)] = Some(((6,hello),5.0))
* }}}
*/
def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Monad[F]): StateT[F, R, A] =
StateT { r =>
F.flatMap(runF) { ff =>
def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Functor[F]): StateT[F, R, A] =
StateT.applyF(F.map(runF) { sfsa =>
{ r: R =>
val s = f(r)
val nextState = ff(s)
F.map(nextState) { case (s, a) => (g(r, s), a) }
val fsa = sfsa(s)
F.map(fsa) { case (s, a) => (g(r, s), a) }
}
}
})

/**
* Modify the state (`S`) component.
*/
def modify(f: S => S)(implicit F: Monad[F]): StateT[F, S, A] =
def modify(f: S => S)(implicit F: Functor[F]): StateT[F, S, A] =
transform((s, a) => (f(s), a))

/**
* Inspect a value from the input state, without modifying the state.
*/
def inspect[B](f: S => B)(implicit F: Monad[F]): StateT[F, S, B] =
def inspect[B](f: S => B)(implicit F: Functor[F]): StateT[F, S, B] =
transform((s, _) => (s, f(s)))

/**
* Get the input state, without modifying the state.
*/
def get(implicit F: Monad[F]): StateT[F, S, S] =
def get(implicit F: Functor[F]): StateT[F, S, S] =
inspect(identity)
}

Expand Down Expand Up @@ -182,11 +205,16 @@ private[data] sealed trait StateTInstances2 extends StateTInstances3 {
new StateTSemigroupK[F, S] { implicit def F = F0; implicit def G = G0 }
}

private[data] sealed trait StateTInstances3 {
private[data] sealed trait StateTInstances3 extends StateTInstances4 {
implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] =
new StateTMonad[F, S] { implicit def F = F0 }
}

private[data] sealed trait StateTInstances4 {
implicit def catsDataFunctorForStateT[F[_], S](implicit F0: Functor[F]): Functor[StateT[F, S, ?]] =
new StateTFunctor[F, S] { implicit def F = F0 }
}

// To workaround SI-7139 `object State` needs to be defined inside the package object
// together with the type alias.
private[data] abstract class StateFunctions {
Expand Down Expand Up @@ -220,6 +248,12 @@ private[data] abstract class StateFunctions {
def set[S](s: S): State[S, Unit] = State(_ => (s, ()))
}

private[data] sealed trait StateTFunctor[F[_], S] extends Functor[StateT[F, S, ?]] {
implicit def F: Functor[F]

def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f)
}

private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] {
implicit def F: Monad[F]

Expand All @@ -229,8 +263,20 @@ private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] {
def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] =
fa.flatMap(f)

override def ap[A, B](ff: StateT[F, S, A => B])(fa: StateT[F, S, A]): StateT[F, S, B] =
ff.map2(fa) { case (f, a) => f(a) }

override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f)

override def map2[A, B, Z](fa: StateT[F, S, A], fb: StateT[F, S, B])(fn: (A, B) => Z): StateT[F, S, Z] =
fa.map2(fb)(fn)

override def map2Eval[A, B, Z](fa: StateT[F, S, A], fb: Eval[StateT[F, S, B]])(fn: (A, B) => Z): Eval[StateT[F, S, Z]] =
fa.map2Eval(fb)(fn)

override def product[A, B](fa: StateT[F, S, A], fb: StateT[F, S, B]): StateT[F, S, (A, B)] =
fa.product(fb)

def tailRecM[A, B](a: A)(f: A => StateT[F, S, Either[A, B]]): StateT[F, S, B] =
StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) {
case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) }
Expand Down
24 changes: 21 additions & 3 deletions tests/src/test/scala/cats/tests/StateTTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class StateTTests extends CatsSuite {
}

test("State.get and StateT.get are consistent") {
forAll{ (s: String) =>
forAll{ (s: String) =>
val state: State[String, String] = State.get
val stateT: State[String, String] = StateT.get
state.run(s) should === (stateT.run(s))
Expand Down Expand Up @@ -195,7 +195,25 @@ class StateTTests extends CatsSuite {
}


implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataMonadForStateT(ListWrapper.monad))
implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataFunctorForStateT(ListWrapper.monad))

{
// F has a Functor
implicit val F: Functor[ListWrapper] = ListWrapper.monad
// We only need a Functor on F to find a Functor on StateT
Functor[StateT[ListWrapper, Int, ?]]
}

{
// F needs a Monad to do Eq on StateT
implicit val F: Monad[ListWrapper] = ListWrapper.monad
implicit val FS: Functor[StateT[ListWrapper, Int, ?]] = StateT.catsDataFunctorForStateT

checkAll("StateT[ListWrapper, Int, Int]", FunctorTests[StateT[ListWrapper, Int, ?]].functor[Int, Int, Int])
checkAll("Functor[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(Functor[StateT[ListWrapper, Int, ?]]))

Functor[StateT[ListWrapper, Int, ?]]
}

{
// F has a Monad
Expand Down Expand Up @@ -265,7 +283,7 @@ class StateTTests extends CatsSuite {
// F has a MonadError
implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]]
implicit val eqEitherTFA: Eq[EitherT[StateT[Option, Int , ?], Unit, Int]] = EitherT.catsDataEqForEitherT[StateT[Option, Int , ?], Unit, Int]

checkAll("StateT[Option, Int, Int]", MonadErrorTests[StateT[Option, Int , ?], Unit].monadError[Int, Int, Int])
checkAll("MonadError[StateT[Option, Int , ?], Unit]", SerializableTests.serializable(MonadError[StateT[Option, Int , ?], Unit]))
}
Expand Down