Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add iterateWhileM and iterateUntilM #1809

Merged
merged 4 commits into from
Aug 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions core/src/main/scala/cats/Monad.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cats

import simulacrum.typeclass
import syntax.either._

/**
* Monad.
*
Expand Down Expand Up @@ -30,7 +30,7 @@ import syntax.either._
Left(G.combineK(xs, G.pure(bv)))
}
},
ifFalse = pure(xs.asRight[G[A]])
ifFalse = pure(Right(xs))
))
}

Expand Down Expand Up @@ -76,28 +76,37 @@ import syntax.either._
* Execute an action repeatedly until its result fails to satisfy the given predicate
* and return that result, discarding all others.
*/
def iterateWhile[A](f: F[A])(p: A => Boolean): F[A] = {
def iterateWhile[A](f: F[A])(p: A => Boolean): F[A] =
flatMap(f) { i =>
tailRecM(i) { a =>
if (p(a))
map(f)(_.asLeft[A])
else pure(a.asRight[A])
}
iterateWhileM(i)(_ => f)(p)
}
}

/**
* Execute an action repeatedly until its result satisfies the given predicate
* and return that result, discarding all others.
*/
def iterateUntil[A](f: F[A])(p: A => Boolean): F[A] = {
def iterateUntil[A](f: F[A])(p: A => Boolean): F[A] =
flatMap(f) { i =>
tailRecM(i) { a =>
if (p(a))
pure(a.asRight[A])
else map(f)(_.asLeft[A])
}
iterateUntilM(i)(_ => f)(p)
}
}

/**
* Apply a monadic function iteratively until its result fails
* to satisfy the given predicate and return that result.
*/
def iterateWhileM[A](init: A)(f: A => F[A])(p: A => Boolean): F[A] =
tailRecM(init) { a =>
if (p(a))
map(f(a))(Left(_))
else
pure(Right(a))
}

/**
* Apply a monadic function iteratively until its result satisfies
* the given predicate and return that result.
*/
def iterateUntilM[A](init: A)(f: A => F[A])(p: A => Boolean): F[A] =
iterateWhileM(init)(f)(!p(_))

}
16 changes: 16 additions & 0 deletions core/src/main/scala/cats/syntax/monad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package syntax

trait MonadSyntax {
implicit final def catsSyntaxMonad[F[_], A](fa: F[A]): MonadOps[F, A] = new MonadOps(fa)

implicit final def catsSyntaxMonadIdOps[A](a: A): MonadIdOps[A] =
new MonadIdOps[A](a)
}

final class MonadOps[F[_], A](val fa: F[A]) extends AnyVal {
Expand All @@ -13,3 +16,16 @@ final class MonadOps[F[_], A](val fa: F[A]) extends AnyVal {
def iterateWhile(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhile(fa)(p)
def iterateUntil(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntil(fa)(p)
}

final class MonadIdOps[A](val a: A) extends AnyVal {

/**
* Iterative application of `f` while `p` holds.
*/
def iterateWhileM[F[_]](f: A => F[A])(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhileM(a)(f)(p)

/**
* Iterative application of `f` until `p` holds.
*/
def iterateUntilM[F[_]](f: A => F[A])(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntilM(a)(f)(p)
}
38 changes: 32 additions & 6 deletions tests/src/test/scala/cats/tests/MonadTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,35 @@ import org.scalacheck.Gen
class MonadTest extends CatsSuite {
implicit val testInstance: Monad[StateT[Id, Int, ?]] = StateT.catsDataMonadForStateT[Id, Int]

val smallPosInt = Gen.choose(1, 5000)

val increment: StateT[Id, Int, Unit] = StateT.modify(_ + 1)
val incrementAndGet: StateT[Id, Int, Int] = increment >> StateT.get

test("whileM_") {
forAll(Gen.posNum[Int]) { (max: Int) =>
forAll(smallPosInt) { (max: Int) =>
val (result, _) = increment.whileM_(StateT.inspect(i => !(i >= max))).run(0)
result should ===(Math.max(0, max))
}
}

test("whileM") {
forAll(Gen.posNum[Int]) { (max: Int) =>
forAll(smallPosInt) { (max: Int) =>
val (result, aggregation) = incrementAndGet.whileM[Vector](StateT.inspect(i => !(i >= max))).run(0)
result should ===(Math.max(0, max))
aggregation should === ( if(max > 0) (1 to max).toVector else Vector.empty )
}
}

test("untilM_") {
forAll(Gen.posNum[Int]) { (max: Int) =>
forAll(smallPosInt) { (max: Int) =>
val (result, _) = increment.untilM_(StateT.inspect(_ >= max)).run(-1)
result should ===(max)
}
}

test("untilM") {
forAll(Gen.posNum[Int]) { (max: Int) =>
forAll(smallPosInt) { (max: Int) =>
val (result, aggregation) = incrementAndGet.untilM[Vector](StateT.inspect(_ >= max)).run(-1)
result should ===(max)
aggregation should === ((0 to max).toVector)
Expand All @@ -51,7 +53,7 @@ class MonadTest extends CatsSuite {
}

test("iterateWhile") {
forAll(Gen.posNum[Int]) { (max: Int) =>
forAll(smallPosInt) { (max: Int) =>
val (result, _) = incrementAndGet.iterateWhile(_ < max).run(-1)
result should ===(Math.max(0, max))
}
Expand All @@ -63,7 +65,7 @@ class MonadTest extends CatsSuite {
}

test("iterateUntil") {
forAll(Gen.posNum[Int]) { (max: Int) =>
forAll(smallPosInt) { (max: Int) =>
val (result, _) = incrementAndGet.iterateUntil(_ == max).run(-1)
result should ===(Math.max(0, max))
}
Expand All @@ -74,4 +76,28 @@ class MonadTest extends CatsSuite {
result should ===(50000)
}

test("iterateWhileM") {
forAll(smallPosInt) { (max: Int) =>
val (n, sum) = 0.iterateWhileM(s => incrementAndGet map (_ + s))(_ < max).run(0)
sum should ===(n * (n + 1) / 2)
}
}

test("iterateWhileM is stack safe") {
val (n, sum) = 0.iterateWhileM(s => incrementAndGet map (_ + s))(_ < 50000000).run(0)
sum should ===(n * (n + 1) / 2)
}

test("iterateUntilM") {
forAll(smallPosInt) { (max: Int) =>
val (n, sum) = 0.iterateUntilM(s => incrementAndGet map (_ + s))(_ > max).run(0)
sum should ===(n * (n + 1) / 2)
}
}

test("iterateUntilM is stack safe") {
val (n, sum) = 0.iterateUntilM(s => incrementAndGet map (_ + s))(_ > 50000000).run(0)
sum should ===(n * (n + 1) / 2)
}

}