Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #2186: make IndexedStateT stack safe #2187

Merged
merged 10 commits into from
Mar 14, 2018
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions core/src/main/scala/cats/data/AndThen.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package cats.data

import java.io.Serializable

/**
* Internal API (Cats) — A type-aligned seq for representing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually seems like something to move to kernel since it is not higher kinded very useful.

I see it is private, but lots of use cases could exist for an optimized AndThen.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not very clear to me what the boundaries of Cats's sub-projects and packages are.

I can move it to cats-kernel if needed, however I only see a bunch of type classes in cats-kernel, whereas I am seeing data types that aren't type classes in cats.data.

* function composition in constant stack space with amortized
* linear time application (in the number of constituent functions).
*
* A variation of this implementation was first introduced in the
* `cats-effect` project. Implementation is enormously uglier than
* it should be since `@tailrec` doesn't work properly on functions
* with existential types.
*
* Example:
*
* {{{
* val seed = AndThen((x: Int) => x + 1))
* val f = (0 until 10000).foldLeft(seed)((acc, _) => acc.andThen(_ + 1))
* // This should not trigger stack overflow ;-)
* f(0)
* }}}
*/
private[cats] sealed abstract class AndThen[-T, +R]
extends (T => R) with Product with Serializable {

import AndThen._

final def apply(a: T): R =
runLoop(a)

override def andThen[A](g: R => A): T => A = {
// Fusing calls up to a certain threshold, using the fusion
// technique implemented for `cats.effect.IO#map`
this match {
case Single(f, index) if index != 127 =>
Single(f.andThen(g), index + 1)
case _ =>
andThenF(AndThen(g))
}
}

override def compose[A](g: A => T): A => R = {
// Fusing calls up to a certain threshold, using the fusion
// technique implemented for `cats.effect.IO#map`
this match {
case Single(f, index) if index != 127 =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, can we make this 127 (and at line 32) a constant with a sensible name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I've updated the code with a documented constant.

Single(f.compose(g), index + 1)
case _ =>
composeF(AndThen(g))
}
}

private def runLoop(start: T): R = {
var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]]
var current: Any = start.asInstanceOf[Any]
var continue = true

while (continue) {
self match {
case Single(f, _) =>
current = f(current)
continue = false

case Concat(Single(f, _), right) =>
current = f(current)
self = right.asInstanceOf[AndThen[Any, Any]]

case Concat(left @ Concat(_, _), right) =>
self = left.rotateAccum(right)
}
}
current.asInstanceOf[R]
}

private final def andThenF[X](right: AndThen[R, X]): AndThen[T, X] =
Concat(this, right)
private final def composeF[X](right: AndThen[X, T]): AndThen[X, R] =
Concat(right, this)

// converts left-leaning to right-leaning
protected final def rotateAccum[E](_right: AndThen[R, E]): AndThen[T, E] = {
var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]]
var right: AndThen[Any, Any] = _right.asInstanceOf[AndThen[Any, Any]]
var continue = true
while (continue) {
self match {
case Concat(left, inner) =>
self = left.asInstanceOf[AndThen[Any, Any]]
right = inner.andThenF(right)

case _ => // Single
self = self.andThenF(right)
continue = false
}
}
self.asInstanceOf[AndThen[T, E]]
}

override def toString: String =
"AndThen$" + System.identityHashCode(this)
}

private[cats] object AndThen {
/** Builds simple [[AndThen]] reference by wrapping a function. */
def apply[A, B](f: A => B): AndThen[A, B] =
f match {
case ref: AndThen[A, B] @unchecked => ref
case _ => Single(f, 0)
}

/** Alias for `apply` that returns a `Function1` type. */
def of[A, B](f: A => B): (A => B) =
apply(f)

private final case class Single[-A, +B](f: A => B, index: Int)
extends AndThen[A, B]
private final case class Concat[-A, E, +B](left: AndThen[A, E], right: AndThen[E, B])
extends AndThen[A, B]
}
4 changes: 2 additions & 2 deletions core/src/main/scala/cats/data/IndexedStateT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extend

def flatMap[B, SC](fas: A => IndexedStateT[F, SB, SC, B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SC, B] =
IndexedStateT.applyF(F.map(runF) { safsba =>
safsba.andThen { fsba =>
AndThen(safsba).andThen { fsba =>
F.flatMap(fsba) { case (sb, a) =>
fas(a).run(sb)
}
Expand All @@ -31,7 +31,7 @@ final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extend

def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SB, B] =
IndexedStateT.applyF(F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
AndThen(sfsa).andThen { fsa =>
F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) }
}
})
Expand Down
53 changes: 53 additions & 0 deletions tests/src/test/scala/cats/tests/AndThenSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package cats.tests

import catalysts.Platform
import cats.data._

class AndThenSuite extends CatsSuite {
test("compose a chain of functions with andThen") {
check { (i: Int, fs: List[Int => Int]) =>
val result = fs.map(AndThen.of(_)).reduceOption(_.andThen(_)).map(_(i))
val expect = fs.reduceOption(_.andThen(_)).map(_(i))

result == expect
}
}

test("compose a chain of functions with compose") {
check { (i: Int, fs: List[Int => Int]) =>
val result = fs.map(AndThen.of(_)).reduceOption(_.compose(_)).map(_(i))
val expect = fs.reduceOption(_.compose(_)).map(_(i))

result == expect
}
}

test("andThen is stack safe") {
val count = if (Platform.isJvm) 500000 else 1000
val fs = (0 until count).map(_ => { i: Int => i + 1 })
val result = fs.foldLeft(AndThen.of((x: Int) => x))(_.andThen(_))(42)

result shouldEqual (count + 42)
}

test("compose is stack safe") {
val count = if (Platform.isJvm) 500000 else 1000
val fs = (0 until count).map(_ => { i: Int => i + 1 })
val result = fs.foldLeft(AndThen.of((x: Int) => x))(_.compose(_))(42)

result shouldEqual (count + 42)
}

test("Function1 andThen is stack safe") {
val count = if (Platform.isJvm) 50000 else 1000
val start: (Int => Int) = AndThen((x: Int) => x)
val fs = (0 until count).foldLeft(start) { (acc, _) =>
acc.andThen(_ + 1)
}
fs(0) shouldEqual count
}

test("toString") {
AndThen((x: Int) => x).toString should startWith("AndThen$")
}
}
19 changes: 18 additions & 1 deletion tests/src/test/scala/cats/tests/IndexedStateTSuite.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package cats
package tests

import catalysts.Platform
import cats.arrow.{Profunctor, Strong}
import cats.data.{EitherT, IndexedStateT, State, StateT}

import cats.arrow.Profunctor
import cats.kernel.instances.tuple._
import cats.laws.discipline._
Expand Down Expand Up @@ -251,6 +251,23 @@ class IndexedStateTSuite extends CatsSuite {
got should === (expected)
}

test("flatMap is stack safe on repeated left binds when F is") {
val unit = StateT.pure[Eval, Unit, Unit](())
val count = if (Platform.isJvm) 100000 else 100
val result = (0 until count).foldLeft(unit) { (acc, _) =>
acc.flatMap(_ => unit)
}
result.run(()).value should === (((), ()))
}

test("flatMap is stack safe on repeated right binds when F is") {
val unit = StateT.pure[Eval, Unit, Unit](())
val count = if (Platform.isJvm) 100000 else 100
val result = (0 until count).foldLeft(unit) { (acc, _) =>
unit.flatMap(_ => acc)
}
result.run(()).value should === (((), ()))
}

implicit val iso = SemigroupalTests.Isomorphisms.invariant[IndexedStateT[ListWrapper, String, Int, ?]](IndexedStateT.catsDataFunctorForIndexedStateT(ListWrapper.monad))

Expand Down