diff --git a/core/src/main/scala/cats/Traverse.scala b/core/src/main/scala/cats/Traverse.scala index 4b6f86f534..547a01079a 100644 --- a/core/src/main/scala/cats/Traverse.scala +++ b/core/src/main/scala/cats/Traverse.scala @@ -130,4 +130,38 @@ import simulacrum.typeclass */ def zipWithIndex[A](fa: F[A]): F[(A, Int)] = mapWithIndex(fa)((a, i) => (a, i)) + + /** + * Statefully traverse through the structure F while carrying and updating a + * value of type S. + * + * Value `s0` is used as the initial state and function `fs` is used + * to update the state after each application of `f` through the structure. + * + * Function `fs` can update the state using the current traversal value, the + * previous state, and the output of the current application of `f`. + * + * Both the final state value and the traversed result are returned. + * + * The parameters are grouped in this order for better type inferencing. + * + * @see [[traverseWithStateMA]] if the final state value isn't needed. + */ + def traverseWithStateM[G[_], A, S, B](fa: F[A])(s0: S)(f: (A, S) => G[B])(fs: (A, S, B) => S)(implicit G: Monad[G]): G[(S, F[B])] = + internalTraverseWithStateM(fa)(s0)(f)(fs).run(s0) + + /** + * Statefully traverse through the structure F while carrying and updating a + * value of type S. + * + * This is identical to [[traverseWithStateM]] except it does not + * include the final state value in the result. + */ + def traverseWithStateMA[G[_], A, S, B](fa: F[A])(s0: S)(f: (A, S) => G[B])(fs: (A, S, B) => S)(implicit G: Monad[G]): G[F[B]] = + internalTraverseWithStateM(fa)(s0)(f)(fs).runA(s0) + + private[this] def internalTraverseWithStateM[G[_], A, S, B]( + fa: F[A])(s0: S)(f: (A, S) => G[B])(fs: (A, S, B) => S)(implicit G: Monad[G] + ): StateT[G, S, F[B]] = + traverse(fa)(a => StateT[G, S, B](s => G.map(f(a, s))(b => (fs(a, s, b), b)))) } diff --git a/tests/src/test/scala/cats/tests/TraverseTests.scala b/tests/src/test/scala/cats/tests/TraverseTests.scala index de48c339a1..c64f7d6961 100644 --- a/tests/src/test/scala/cats/tests/TraverseTests.scala +++ b/tests/src/test/scala/cats/tests/TraverseTests.scala @@ -28,6 +28,25 @@ abstract class TraverseCheck[F[_]: Traverse](name: String)(implicit ArbFInt: Arb } } + test(s"Traverse[$name].traverseWithStateM") { + forAll { (fa: F[Int]) => + val left = fa.traverseWithStateM(Set.empty[Int])( + (a, s) => if (s.contains(a)) Eval.now("duplicate") else Eval.later(a.toString))( + (a, s, _) => s + a) + val fal = fa.toList + left.value.map(_.toList.filterNot(_ == "duplicate")) should === (fal.toSet -> fal.distinct.map(_.toString)) + } + } + + test(s"Traverse[$name].traverseWithStateMA") { + forAll { (fa: F[Int]) => + val left = fa.traverseWithStateMA(Set.empty[Int])( + (a, s) => if (s.contains(a)) Eval.now("duplicate") else Eval.later(a.toString))( + (a, s, _) => s + a) + left.value.toList.filterNot(_ == "duplicate") should === (fa.toList.distinct.map(_.toString)) + } + } + } object TraverseCheck {