Skip to content

Commit

Permalink
Merge pull request #4234 from armanbilge/topic/tailrec-replicateA
Browse files Browse the repository at this point in the history
Tail-recursive `replicateA`
  • Loading branch information
danicheg authored Jun 13, 2022
2 parents 7233e8a + fd13841 commit 77fbe02
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions core/src/main/scala/cats/Applicative.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ package cats
import cats.arrow.Arrow
import cats.data.Chain

import scala.annotation.tailrec

/**
* Applicative functor.
*
Expand Down Expand Up @@ -92,18 +94,18 @@ trait Applicative[F[_]] extends Apply[F] with InvariantMonoidal[F] { self =>
val one = map(fa)(Chain.one(_))

// invariant: n >= 1
def loop(n: Int): F[Chain[A]] =
if (n == 1) one
else {
@tailrec def loop(fa: F[Chain[A]], n: Int, acc: F[Chain[A]]): F[Chain[A]] =
if (n == 1) map2(fa, acc)(_.concat(_))
else
// n >= 2
// so (n >> 1) >= 1 and we are allowed to call loop
val half = loop(n >> 1)
val both = map2(half, half)(_.concat(_))
if ((n & 1) == 1) map2(one, both)(_.concat(_))
else both
}
loop(
map2(fa, fa)(_.concat(_)),
n >> 1,
if ((n & 1) == 1) map2(acc, fa)(_.concat(_)) else acc
)

map(loop(n))(_.toList)
map(loop(one, n - 1, one))(_.toList)
}

/**
Expand All @@ -123,19 +125,23 @@ trait Applicative[F[_]] extends Apply[F] with InvariantMonoidal[F] { self =>
*/
def replicateA_[A](n: Int, fa: F[A]): F[Unit] =
if (n <= 0) unit
else if (n == 1) void(fa)
else {
val fvoid = void(fa)
// invariant n >= 1
def loop(n: Int): F[Unit] =
if (n == 1) fvoid
else {
// since n >= 2, then (n >> 1) >= 1 so we can call loop
val half = loop(n >> 1)
val both = productR(half)(half)
if ((n & 1) == 1) productR(both)(fvoid)
else both
}
loop(n)

// invariant: n >= 1
@tailrec def loop(fa: F[Unit], n: Int, acc: F[Unit]): F[Unit] =
if (n == 1) productR(fa)(acc)
else
// n >= 2
// so (n >> 1) >= 1 and we are allowed to call loop
loop(
productR(fa)(fa),
n >> 1,
if ((n & 1) == 1) productR(acc)(fa) else acc
)

loop(fvoid, n - 1, fvoid)
}

/**
Expand Down

0 comments on commit 77fbe02

Please sign in to comment.