diff --git a/core/src/main/scala/cats/TraverseFilter.scala b/core/src/main/scala/cats/TraverseFilter.scala index 3d0dae6ae6..9f2d00ec72 100644 --- a/core/src/main/scala/cats/TraverseFilter.scala +++ b/core/src/main/scala/cats/TraverseFilter.scala @@ -56,6 +56,22 @@ trait TraverseFilter[F[_]] extends FunctorFilter[F] { */ def traverseFilter[G[_], A, B](fa: F[A])(f: A => G[Option[B]])(implicit G: Applicative[G]): G[F[B]] + /** + * A combined [[traverse]] and [[collect]]. + * + * scala> import cats.implicits._ + * scala> val m: Map[Int, String] = Map(1 -> "one", 2 -> "two") + * scala> val l: List[Int] = List(1, 2, 3, 4) + * scala> def asString: PartialFunction[Int, Eval[Option[String]]] = { case n if n % 2 == 0 => Now(m.get(n)) } + * scala> val result: Eval[List[Option[String]]] = l.traverseCollect(asString) + * scala> result.value + * res0: List[Option[String]] = List(Some(two), None) + */ + def traverseCollect[G[_], A, B](fa: F[A])(f: PartialFunction[A, G[B]])(implicit G: Applicative[G]): G[F[B]] = { + val optF = f.lift + traverseFilter(fa)(a => Traverse[Option].sequence(optF(a))) + } + /** * {{{ * scala> import cats.implicits._ diff --git a/core/src/main/scala/cats/syntax/traverseFilter.scala b/core/src/main/scala/cats/syntax/traverseFilter.scala index 20fc6a170c..fde0616d0c 100644 --- a/core/src/main/scala/cats/syntax/traverseFilter.scala +++ b/core/src/main/scala/cats/syntax/traverseFilter.scala @@ -27,6 +27,9 @@ trait TraverseFilterSyntax extends TraverseFilter.ToTraverseFilterOps private[syntax] trait TraverseFilterSyntaxBinCompat0 { implicit def toSequenceFilterOps[F[_], G[_], A](fgoa: F[G[Option[A]]]): SequenceFilterOps[F, G, A] = new SequenceFilterOps(fgoa) + + implicit def toTraverseFilterOps[F[_], G[_], A](fa: F[A]): TraverseFilterOps[F, G, A] = + new TraverseFilterOps(fa) } final class SequenceFilterOps[F[_], G[_], A](private val fgoa: F[G[Option[A]]]) extends AnyVal { @@ -41,3 +44,12 @@ final class SequenceFilterOps[F[_], G[_], A](private val fgoa: F[G[Option[A]]]) */ def sequenceFilter(implicit F: TraverseFilter[F], G: Applicative[G]): G[F[A]] = F.sequenceFilter(fgoa) } + +final class TraverseFilterOps[F[_], G[_], A](private val fa: F[A]) extends AnyVal { + + def traverseCollect[B](f: PartialFunction[A, G[B]])(implicit + F: TraverseFilter[F], + G: Applicative[G] + ): G[F[B]] = + F.traverseCollect(fa)(f) +} diff --git a/laws/src/main/scala/cats/laws/TraverseFilterLaws.scala b/laws/src/main/scala/cats/laws/TraverseFilterLaws.scala index b2d0df5bf0..f4dc4040e5 100644 --- a/laws/src/main/scala/cats/laws/TraverseFilterLaws.scala +++ b/laws/src/main/scala/cats/laws/TraverseFilterLaws.scala @@ -52,6 +52,14 @@ trait TraverseFilterLaws[F[_]] extends FunctorFilterLaws[F] { G: Monad[G] ): IsEq[G[F[B]]] = fa.traverseEither(a => f(a).map(_.toRight(e)))((_, _) => Applicative[G].unit) <-> fa.traverseFilter(f) + + def traverseCollectRef[G[_], A, B](fa: F[A], f: PartialFunction[A, G[B]])(implicit + G: Applicative[G] + ): IsEq[G[F[B]]] = { + val lhs = fa.traverseCollect(f) + val rhs = fa.traverseFilter(a => f.lift(a).sequence) + lhs <-> rhs + } } object TraverseFilterLaws { diff --git a/laws/src/main/scala/cats/laws/discipline/TraverseFilterTests.scala b/laws/src/main/scala/cats/laws/discipline/TraverseFilterTests.scala index 990092fc83..d24d3d81ba 100644 --- a/laws/src/main/scala/cats/laws/discipline/TraverseFilterTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/TraverseFilterTests.scala @@ -48,7 +48,17 @@ trait TraverseFilterTests[F[_]] extends FunctorFilterTests[F] { EqFC: Eq[F[C]], EqGFA: Eq[Option[F[A]]], EqMNFC: Eq[Nested[Option, Option, F[C]]] - ): RuleSet = + ): RuleSet = { + implicit val arbFAOB: Arbitrary[PartialFunction[A, Option[B]]] = + Arbitrary(ArbFABoo.arbitrary.map { pfab => + { + case a if pfab.isDefinedAt(a) => + val b = pfab(a) + if (((a.hashCode ^ b.hashCode) & 1) == 1) Some(b) + else None + } + }) + new DefaultRuleSet( name = "traverseFilter", parent = Some(functorFilter[A, B, C]), @@ -58,8 +68,10 @@ trait TraverseFilterTests[F[_]] extends FunctorFilterTests[F] { "filterA consistent with traverseFilter" -> forAll(laws.filterAConsistentWithTraverseFilter[Option, A] _), "traverseEither consistent with traverseFilter" -> forAll( laws.traverseEitherConsistentWithTraverseFilter[Option, F[A], A, B] _ - ) + ), + "traverseCollect reference" -> forAll(laws.traverseCollectRef[Option, A, B] _) ) + } } object TraverseFilterTests { diff --git a/tests/shared/src/test/scala/cats/tests/SyntaxSuite.scala b/tests/shared/src/test/scala/cats/tests/SyntaxSuite.scala index 432e4f5596..ac2b70ae74 100644 --- a/tests/shared/src/test/scala/cats/tests/SyntaxSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/SyntaxSuite.scala @@ -580,4 +580,11 @@ object SyntaxSuite { val result: Either[A, List[B]] = f.sequenceFilter } + + def testTraverseCollect[A, B]: Unit = { + val list = mock[List[A]] + val f = mock[PartialFunction[A, Option[B]]] + + val result: Option[List[B]] = list.traverseCollect(f) + } }