diff --git a/core/src/main/scala-2.13+/cats/data/NonEmptyLazyList.scala b/core/src/main/scala-2.13+/cats/data/NonEmptyLazyList.scala index d7a9b8d14d..606b1311cd 100644 --- a/core/src/main/scala-2.13+/cats/data/NonEmptyLazyList.scala +++ b/core/src/main/scala-2.13+/cats/data/NonEmptyLazyList.scala @@ -473,16 +473,14 @@ sealed abstract private[data] class NonEmptyLazyListInstances extends NonEmptyLa def extract[A](fa: NonEmptyLazyList[A]): A = fa.head - def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyLazyList[A])(f: A => G[B]): G[NonEmptyLazyList[B]] = - Foldable[LazyList] - .reduceRightToOption[A, G[LazyList[B]]](fa.tail)(a => Apply[G].map(f(a))(LazyList.apply(_))) { (a, lglb) => - Apply[G].map2Eval(f(a), lglb)(_ +: _) + def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyLazyList[A])(f: A => G[B]): G[NonEmptyLazyList[B]] = { + def loop(head: A, tail: LazyList[A]): Eval[G[NonEmptyLazyList[B]]] = + tail.headOption.fold(Eval.now(Apply[G].map(f(head))(NonEmptyLazyList(_)))) { h => + Apply[G].map2Eval(f(head), Eval.defer(loop(h, tail.tail)))((b, acc) => b +: acc) } - .map { - case None => Apply[G].map(f(fa.head))(h => create(LazyList(h))) - case Some(gtail) => Apply[G].map2(f(fa.head), gtail)((h, t) => create(LazyList(h) ++ t)) - } - .value + + loop(fa.head, fa.tail).value + } def reduceLeftTo[A, B](fa: NonEmptyLazyList[A])(f: A => B)(g: (B, A) => B): B = fa.reduceLeftTo(f)(g) diff --git a/core/src/main/scala/cats/data/NonEmptyChain.scala b/core/src/main/scala/cats/data/NonEmptyChain.scala index c78c9310ea..d0f577fa71 100644 --- a/core/src/main/scala/cats/data/NonEmptyChain.scala +++ b/core/src/main/scala/cats/data/NonEmptyChain.scala @@ -420,16 +420,14 @@ sealed abstract private[data] class NonEmptyChainInstances extends NonEmptyChain new AbstractNonEmptyInstances[Chain, NonEmptyChain] with Align[NonEmptyChain] { def extract[A](fa: NonEmptyChain[A]): A = fa.head - def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyChain[A])(f: A => G[B]): G[NonEmptyChain[B]] = - Foldable[Chain] - .reduceRightToOption[A, G[Chain[B]]](fa.tail)(a => Apply[G].map(f(a))(Chain.one)) { (a, lglb) => - Apply[G].map2Eval(f(a), lglb)(_ +: _) + def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyChain[A])(f: A => G[B]): G[NonEmptyChain[B]] = { + def loop(head: A, tail: Chain[A]): Eval[G[NonEmptyChain[B]]] = + tail.uncons.fold(Eval.now(Apply[G].map(f(head))(NonEmptyChain(_)))) { + case (h, t) => Apply[G].map2Eval(f(head), Eval.defer(loop(h, t)))((b, acc) => b +: acc) } - .map { - case None => Apply[G].map(f(fa.head))(NonEmptyChain.one) - case Some(gtail) => Apply[G].map2(f(fa.head), gtail)((h, t) => create(Chain.one(h) ++ t)) - } - .value + + loop(fa.head, fa.tail).value + } override def size[A](fa: NonEmptyChain[A]): Long = fa.length diff --git a/core/src/main/scala/cats/data/NonEmptyList.scala b/core/src/main/scala/cats/data/NonEmptyList.scala index 6ffd727fde..be0e1e53b2 100644 --- a/core/src/main/scala/cats/data/NonEmptyList.scala +++ b/core/src/main/scala/cats/data/NonEmptyList.scala @@ -562,16 +562,15 @@ sealed abstract private[data] class NonEmptyListInstances extends NonEmptyListIn def extract[A](fa: NonEmptyList[A]): A = fa.head - def nonEmptyTraverse[G[_], A, B](nel: NonEmptyList[A])(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyList[B]] = - Foldable[List] - .reduceRightToOption[A, G[List[B]]](nel.tail)(a => G.map(f(a))(_ :: Nil)) { (a, lglb) => - G.map2Eval(f(a), lglb)(_ :: _) + def nonEmptyTraverse[G[_], A, B](nel: NonEmptyList[A])(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyList[B]] = { + def loop(head: A, tail: List[A]): Eval[G[NonEmptyList[B]]] = + tail match { + case Nil => Eval.now(G.map(f(head))(NonEmptyList(_, Nil))) + case h :: t => G.map2Eval(f(head), Eval.defer(loop(h, t)))((b, acc) => NonEmptyList(b, acc.toList)) } - .map { - case None => G.map(f(nel.head))(NonEmptyList(_, Nil)) - case Some(gtail) => G.map2(f(nel.head), gtail)(NonEmptyList(_, _)) - } - .value + + loop(nel.head, nel.tail).value + } override def traverse[G[_], A, B]( fa: NonEmptyList[A] diff --git a/core/src/main/scala/cats/data/NonEmptyMapImpl.scala b/core/src/main/scala/cats/data/NonEmptyMapImpl.scala index 5f9a34ca88..751ecd1859 100644 --- a/core/src/main/scala/cats/data/NonEmptyMapImpl.scala +++ b/core/src/main/scala/cats/data/NonEmptyMapImpl.scala @@ -205,16 +205,15 @@ sealed class NonEmptyMapOps[K, A](val value: NonEmptyMap[K, A]) { * through the running of this function on all the values in this map, * returning an NonEmptyMap[K, B] in a G context. */ - def nonEmptyTraverse[G[_], B](f: A => G[B])(implicit G: Apply[G]): G[NonEmptyMap[K, B]] = - reduceRightToOptionWithKey[A, G[SortedMap[K, B]]](tail)({ - case (k, a) => - G.map(f(a))(b => SortedMap.empty[K, B] + ((k, b))) - }) { (t, lglb) => - G.map2Eval(f(t._2), lglb)((b, bs) => bs + ((t._1, b))) - }.map { - case None => G.map(f(head._2))(a => NonEmptyMapImpl.one(head._1, a)) - case Some(gtail) => G.map2(f(head._2), gtail)((a, bs) => NonEmptyMapImpl((head._1, a), bs)) - }.value + def nonEmptyTraverse[G[_], B](f: A => G[B])(implicit G: Apply[G]): G[NonEmptyMap[K, B]] = { + def loop(h: (K, A), t: SortedMap[K, A]): Eval[G[NonEmptyMap[K, B]]] = + if (t.isEmpty) + Eval.now(G.map(f(h._2))(b => NonEmptyMap(h._1 -> b, SortedMap.empty[K, B]))) + else + G.map2Eval(f(h._2), Eval.defer(loop(t.head, t.tail)))((b, acc) => NonEmptyMap(h._1 -> b, acc.toSortedMap)) + + loop(head, tail).value + } /** * Typesafe stringification method. diff --git a/core/src/main/scala/cats/data/NonEmptyVector.scala b/core/src/main/scala/cats/data/NonEmptyVector.scala index b4086cc8e1..7d29f0448c 100644 --- a/core/src/main/scala/cats/data/NonEmptyVector.scala +++ b/core/src/main/scala/cats/data/NonEmptyVector.scala @@ -371,17 +371,15 @@ sealed abstract private[data] class NonEmptyVectorInstances { def extract[A](fa: NonEmptyVector[A]): A = fa.head def nonEmptyTraverse[G[_], A, B]( - nel: NonEmptyVector[A] - )(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyVector[B]] = - Foldable[Vector] - .reduceRightToOption[A, G[Vector[B]]](nel.tail)(a => G.map(f(a))(_ +: Vector.empty)) { (a, lglb) => - G.map2Eval(f(a), lglb)(_ +: _) - } - .map { - case None => G.map(f(nel.head))(NonEmptyVector(_, Vector.empty)) - case Some(gtail) => G.map2(f(nel.head), gtail)(NonEmptyVector(_, _)) - } - .value + nev: NonEmptyVector[A] + )(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyVector[B]] = { + def loop(head: A, tail: Vector[A]): Eval[G[NonEmptyVector[B]]] = + tail.headOption.fold(Eval.now(G.map(f(head))(NonEmptyVector(_, Vector.empty[B]))))(h => + G.map2Eval(f(head), Eval.defer(loop(h, tail.tail)))((b, acc) => b +: acc) + ) + + loop(nev.head, nev.tail).value + } override def traverse[G[_], A, B]( fa: NonEmptyVector[A] diff --git a/core/src/main/scala/cats/data/OneAnd.scala b/core/src/main/scala/cats/data/OneAnd.scala index a294e5a117..e6f0762b43 100644 --- a/core/src/main/scala/cats/data/OneAnd.scala +++ b/core/src/main/scala/cats/data/OneAnd.scala @@ -262,9 +262,20 @@ sealed abstract private[data] class OneAndLowPriority0 extends OneAndLowPriority F2: Alternative[F] ): NonEmptyTraverse[OneAnd[F, *]] = new NonEmptyReducible[OneAnd[F, *], F] with NonEmptyTraverse[OneAnd[F, *]] { - def nonEmptyTraverse[G[_], A, B](fa: OneAnd[F, A])(f: (A) => G[B])(implicit G: Apply[G]): G[OneAnd[F, B]] = - fa.map(a => Apply[G].map(f(a))(OneAnd(_, F2.empty[B])))(F) - .reduceLeft(((acc, a) => G.map2(acc, a)((x: OneAnd[F, B], y: OneAnd[F, B]) => x.combine(y)))) + def nonEmptyTraverse[G[_], A, B](fa: OneAnd[F, A])(f: (A) => G[B])(implicit G: Apply[G]): G[OneAnd[F, B]] = { + import syntax.foldable._ + + def loop(head: A, tail: Iterator[A]): Eval[G[OneAnd[F, B]]] = + if (tail.hasNext) { + val h = tail.next() + val t = tail + G.map2Eval(f(head), Eval.defer(loop(h, t)))((b, acc) => OneAnd(b, acc.unwrap)) + } else { + Eval.now(G.map(f(head))(OneAnd(_, F2.empty[B]))) + } + + loop(fa.head, fa.tail.toIterable.iterator).value + } override def traverse[G[_], A, B](fa: OneAnd[F, A])(f: (A) => G[B])(implicit G: Applicative[G]): G[OneAnd[F, B]] = G.map2Eval(f(fa.head), Always(F.traverse(fa.tail)(f)))(OneAnd(_, _)).value diff --git a/laws/src/main/scala/cats/laws/ShortCircuitingLaws.scala b/laws/src/main/scala/cats/laws/ShortCircuitingLaws.scala index 6d4300cc55..ca8e0c4d22 100644 --- a/laws/src/main/scala/cats/laws/ShortCircuitingLaws.scala +++ b/laws/src/main/scala/cats/laws/ShortCircuitingLaws.scala @@ -5,8 +5,9 @@ import java.util.concurrent.atomic.AtomicLong import cats.instances.option._ import cats.syntax.foldable._ import cats.syntax.traverse._ +import cats.syntax.nonEmptyTraverse._ import cats.syntax.traverseFilter._ -import cats.{Applicative, Foldable, MonoidK, Traverse, TraverseFilter} +import cats.{Applicative, Foldable, MonoidK, NonEmptyTraverse, Traverse, TraverseFilter} trait ShortCircuitingLaws[F[_]] { @@ -46,6 +47,24 @@ trait ShortCircuitingLaws[F[_]] { f.invocations.get <-> size } + def nonEmptyTraverseShortCircuits[A](fa: F[A])(implicit F: NonEmptyTraverse[F]): IsEq[Long] = { + val size = fa.size + val maxInvocationsAllowed = size / 2 + val f = new RestrictedFunction((i: A) => Some(i), maxInvocationsAllowed, None) + + fa.nonEmptyTraverse(f) + f.invocations.get <-> (maxInvocationsAllowed + 1).min(size) + } + + def nonEmptyTraverseWontShortCircuit[A](fa: F[A])(implicit F: NonEmptyTraverse[F]): IsEq[Long] = { + val size = fa.size + val maxInvocationsAllowed = size / 2 + val f = new RestrictedFunction((i: A) => Some(i), maxInvocationsAllowed, None) + + fa.nonEmptyTraverse(f)(nonShortCircuitingApplicative) + f.invocations.get <-> size + } + def traverseFilterShortCircuits[A](fa: F[A])(implicit TF: TraverseFilter[F]): IsEq[Long] = { implicit val F: Traverse[F] = TF.traverse diff --git a/laws/src/main/scala/cats/laws/discipline/ShortCircuitingTests.scala b/laws/src/main/scala/cats/laws/discipline/ShortCircuitingTests.scala index 1f40605a13..91234710ea 100644 --- a/laws/src/main/scala/cats/laws/discipline/ShortCircuitingTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/ShortCircuitingTests.scala @@ -1,7 +1,7 @@ package cats.laws.discipline import cats.laws.ShortCircuitingLaws -import cats.{Eq, Foldable, Traverse, TraverseFilter} +import cats.{Eq, Foldable, NonEmptyTraverse, Traverse, TraverseFilter} import org.scalacheck.Arbitrary import org.scalacheck.Prop.forAll import org.typelevel.discipline.Laws @@ -25,11 +25,17 @@ trait ShortCircuitingTests[F[_]] extends Laws { "traverse won't short-circuit if Applicative[G].map2Eval won't" -> forAll(laws.traverseWontShortCircuit[A] _) ) - def traverseFilter[A: Arbitrary](implicit - TF: TraverseFilter[F], - ArbFA: Arbitrary[F[A]], - lEq: Eq[Long] - ): RuleSet = { + def nonEmptyTraverse[A: Arbitrary](implicit TF: NonEmptyTraverse[F], ArbFA: Arbitrary[F[A]], lEq: Eq[Long]): RuleSet = + new DefaultRuleSet( + name = "nonEmptyTraverseShortCircuiting", + parent = Some(traverse[A]), + "nonEmptyTraverse short-circuits if Applicative[G].map2Eval shorts" -> + forAll(laws.nonEmptyTraverseShortCircuits[A] _), + "nonEmptyTraverse short-circuits if Applicative[G].map2Eval won't" -> + forAll(laws.nonEmptyTraverseWontShortCircuit[A] _) + ) + + def traverseFilter[A: Arbitrary](implicit TF: TraverseFilter[F], ArbFA: Arbitrary[F[A]], lEq: Eq[Long]): RuleSet = { implicit val T: Traverse[F] = TF.traverse new DefaultRuleSet( name = "traverseFilterShortCircuiting", diff --git a/tests/src/test/scala-2.12/cats/tests/NonEmptyStreamSuite.scala b/tests/src/test/scala-2.12/cats/tests/NonEmptyStreamSuite.scala index 7ab74d33ec..45f804abbf 100644 --- a/tests/src/test/scala-2.12/cats/tests/NonEmptyStreamSuite.scala +++ b/tests/src/test/scala-2.12/cats/tests/NonEmptyStreamSuite.scala @@ -32,6 +32,7 @@ class NonEmptyStreamSuite extends CatsSuite { checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].foldable[Int]) checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].traverse[Int]) + checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].nonEmptyTraverse[Int]) { // Test functor and subclasses don't have implicit conflicts diff --git a/tests/src/test/scala-2.13+/cats/tests/NonEmptyLazyListSuite.scala b/tests/src/test/scala-2.13+/cats/tests/NonEmptyLazyListSuite.scala index 794a15f316..9260311c1e 100644 --- a/tests/src/test/scala-2.13+/cats/tests/NonEmptyLazyListSuite.scala +++ b/tests/src/test/scala-2.13+/cats/tests/NonEmptyLazyListSuite.scala @@ -46,6 +46,7 @@ class NonEmptyLazyListSuite extends NonEmptyCollectionSuite[LazyList, NonEmptyLa checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].foldable[Int]) checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].traverse[Int]) + checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].nonEmptyTraverse[Int]) test("show") { Show[NonEmptyLazyList[Int]].show(NonEmptyLazyList(1, 2, 3)) should ===("NonEmptyLazyList(1, ?)") diff --git a/tests/src/test/scala/cats/tests/NonEmptyChainSuite.scala b/tests/src/test/scala/cats/tests/NonEmptyChainSuite.scala index 9b7670a58d..2d6566ef85 100644 --- a/tests/src/test/scala/cats/tests/NonEmptyChainSuite.scala +++ b/tests/src/test/scala/cats/tests/NonEmptyChainSuite.scala @@ -43,6 +43,7 @@ class NonEmptyChainSuite extends NonEmptyCollectionSuite[Chain, NonEmptyChain, N checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].foldable[Int]) checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].traverse[Int]) + checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].nonEmptyTraverse[Int]) { implicit val partialOrder: PartialOrder[ListWrapper[Int]] = ListWrapper.partialOrder[Int] diff --git a/tests/src/test/scala/cats/tests/NonEmptyListSuite.scala b/tests/src/test/scala/cats/tests/NonEmptyListSuite.scala index fb2a906ac2..13b9412be6 100644 --- a/tests/src/test/scala/cats/tests/NonEmptyListSuite.scala +++ b/tests/src/test/scala/cats/tests/NonEmptyListSuite.scala @@ -59,6 +59,7 @@ class NonEmptyListSuite extends NonEmptyCollectionSuite[List, NonEmptyList, NonE checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].foldable[Int]) checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].traverse[Int]) + checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].nonEmptyTraverse[Int]) { implicit val A: PartialOrder[ListWrapper[Int]] = ListWrapper.partialOrder[Int] diff --git a/tests/src/test/scala/cats/tests/NonEmptyMapSuite.scala b/tests/src/test/scala/cats/tests/NonEmptyMapSuite.scala index bbdb89e73c..06454c9eed 100644 --- a/tests/src/test/scala/cats/tests/NonEmptyMapSuite.scala +++ b/tests/src/test/scala/cats/tests/NonEmptyMapSuite.scala @@ -25,6 +25,8 @@ class NonEmptyMapSuite extends CatsSuite { checkAll("NonEmptyMap[String, Int]", AlignTests[NonEmptyMap[String, *]].align[Int, Int, Int, Int]) checkAll("Align[NonEmptyMap]", SerializableTests.serializable(Align[NonEmptyMap[String, *]])) + checkAll("NonEmptyMap[Int, *]", ShortCircuitingTests[NonEmptyMap[Int, *]].nonEmptyTraverse[Int]) + test("Show is not empty and is formatted as expected") { forAll { (nem: NonEmptyMap[String, Int]) => nem.show.nonEmpty should ===(true) diff --git a/tests/src/test/scala/cats/tests/NonEmptyVectorSuite.scala b/tests/src/test/scala/cats/tests/NonEmptyVectorSuite.scala index 30df9d5363..048731e1d4 100644 --- a/tests/src/test/scala/cats/tests/NonEmptyVectorSuite.scala +++ b/tests/src/test/scala/cats/tests/NonEmptyVectorSuite.scala @@ -86,6 +86,7 @@ class NonEmptyVectorSuite extends NonEmptyCollectionSuite[Vector, NonEmptyVector checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].foldable[Int]) checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].traverse[Int]) + checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].nonEmptyTraverse[Int]) test("size is consistent with toList.size") { forAll { (nonEmptyVector: NonEmptyVector[Int]) =>