Skip to content

Commit

Permalink
Merge pull request #4277 from emilhotkowski/4276-added-traverse-collect
Browse files Browse the repository at this point in the history
Add `traverseCollect` to `TraverseFilter` typeclass
  • Loading branch information
armanbilge authored Aug 17, 2022
2 parents 06f798e + 8b7dcb6 commit 791c65d
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 2 deletions.
16 changes: 16 additions & 0 deletions core/src/main/scala/cats/TraverseFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ trait TraverseFilter[F[_]] extends FunctorFilter[F] {
*/
def traverseFilter[G[_], A, B](fa: F[A])(f: A => G[Option[B]])(implicit G: Applicative[G]): G[F[B]]

/**
* A combined [[traverse]] and [[collect]].
*
* scala> import cats.implicits._
* scala> val m: Map[Int, String] = Map(1 -> "one", 2 -> "two")
* scala> val l: List[Int] = List(1, 2, 3, 4)
* scala> def asString: PartialFunction[Int, Eval[Option[String]]] = { case n if n % 2 == 0 => Now(m.get(n)) }
* scala> val result: Eval[List[Option[String]]] = l.traverseCollect(asString)
* scala> result.value
* res0: List[Option[String]] = List(Some(two), None)
*/
def traverseCollect[G[_], A, B](fa: F[A])(f: PartialFunction[A, G[B]])(implicit G: Applicative[G]): G[F[B]] = {
val optF = f.lift
traverseFilter(fa)(a => Traverse[Option].sequence(optF(a)))
}

/**
* {{{
* scala> import cats.implicits._
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/cats/syntax/traverseFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ trait TraverseFilterSyntax extends TraverseFilter.ToTraverseFilterOps
private[syntax] trait TraverseFilterSyntaxBinCompat0 {
implicit def toSequenceFilterOps[F[_], G[_], A](fgoa: F[G[Option[A]]]): SequenceFilterOps[F, G, A] =
new SequenceFilterOps(fgoa)

implicit def toTraverseFilterOps[F[_], G[_], A](fa: F[A]): TraverseFilterOps[F, G, A] =
new TraverseFilterOps(fa)
}

final class SequenceFilterOps[F[_], G[_], A](private val fgoa: F[G[Option[A]]]) extends AnyVal {
Expand All @@ -41,3 +44,12 @@ final class SequenceFilterOps[F[_], G[_], A](private val fgoa: F[G[Option[A]]])
*/
def sequenceFilter(implicit F: TraverseFilter[F], G: Applicative[G]): G[F[A]] = F.sequenceFilter(fgoa)
}

final class TraverseFilterOps[F[_], G[_], A](private val fa: F[A]) extends AnyVal {

def traverseCollect[B](f: PartialFunction[A, G[B]])(implicit
F: TraverseFilter[F],
G: Applicative[G]
): G[F[B]] =
F.traverseCollect(fa)(f)
}
8 changes: 8 additions & 0 deletions laws/src/main/scala/cats/laws/TraverseFilterLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ trait TraverseFilterLaws[F[_]] extends FunctorFilterLaws[F] {
G: Monad[G]
): IsEq[G[F[B]]] =
fa.traverseEither(a => f(a).map(_.toRight(e)))((_, _) => Applicative[G].unit) <-> fa.traverseFilter(f)

def traverseCollectRef[G[_], A, B](fa: F[A], f: PartialFunction[A, G[B]])(implicit
G: Applicative[G]
): IsEq[G[F[B]]] = {
val lhs = fa.traverseCollect(f)
val rhs = fa.traverseFilter(a => f.lift(a).sequence)
lhs <-> rhs
}
}

object TraverseFilterLaws {
Expand Down
16 changes: 14 additions & 2 deletions laws/src/main/scala/cats/laws/discipline/TraverseFilterTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,17 @@ trait TraverseFilterTests[F[_]] extends FunctorFilterTests[F] {
EqFC: Eq[F[C]],
EqGFA: Eq[Option[F[A]]],
EqMNFC: Eq[Nested[Option, Option, F[C]]]
): RuleSet =
): RuleSet = {
implicit val arbFAOB: Arbitrary[PartialFunction[A, Option[B]]] =
Arbitrary(ArbFABoo.arbitrary.map { pfab =>
{
case a if pfab.isDefinedAt(a) =>
val b = pfab(a)
if (((a.hashCode ^ b.hashCode) & 1) == 1) Some(b)
else None
}
})

new DefaultRuleSet(
name = "traverseFilter",
parent = Some(functorFilter[A, B, C]),
Expand All @@ -58,8 +68,10 @@ trait TraverseFilterTests[F[_]] extends FunctorFilterTests[F] {
"filterA consistent with traverseFilter" -> forAll(laws.filterAConsistentWithTraverseFilter[Option, A] _),
"traverseEither consistent with traverseFilter" -> forAll(
laws.traverseEitherConsistentWithTraverseFilter[Option, F[A], A, B] _
)
),
"traverseCollect reference" -> forAll(laws.traverseCollectRef[Option, A, B] _)
)
}
}

object TraverseFilterTests {
Expand Down
7 changes: 7 additions & 0 deletions tests/shared/src/test/scala/cats/tests/SyntaxSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,11 @@ object SyntaxSuite {

val result: Either[A, List[B]] = f.sequenceFilter
}

def testTraverseCollect[A, B]: Unit = {
val list = mock[List[A]]
val f = mock[PartialFunction[A, Option[B]]]

val result: Option[List[B]] = list.traverseCollect(f)
}
}

0 comments on commit 791c65d

Please sign in to comment.