Skip to content

Commit

Permalink
Optimise NonEmptyTraverse implementation (#3382)
Browse files Browse the repository at this point in the history
* added nonEmptyTraverse short-circuiting laws

* optimised nonEmptyTraverse implementation for chain, list, vector, map, lazylist and oneAnd

* updated instance tests to include nonEmptyTraverse short-circuiting behaviour

* fixed fmt error
  • Loading branch information
gagandeepkalra authored Jun 4, 2020
1 parent 1f8cf3c commit 0ade708
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 58 deletions.
16 changes: 7 additions & 9 deletions core/src/main/scala-2.13+/cats/data/NonEmptyLazyList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -473,16 +473,14 @@ sealed abstract private[data] class NonEmptyLazyListInstances extends NonEmptyLa

def extract[A](fa: NonEmptyLazyList[A]): A = fa.head

def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyLazyList[A])(f: A => G[B]): G[NonEmptyLazyList[B]] =
Foldable[LazyList]
.reduceRightToOption[A, G[LazyList[B]]](fa.tail)(a => Apply[G].map(f(a))(LazyList.apply(_))) { (a, lglb) =>
Apply[G].map2Eval(f(a), lglb)(_ +: _)
def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyLazyList[A])(f: A => G[B]): G[NonEmptyLazyList[B]] = {
def loop(head: A, tail: LazyList[A]): Eval[G[NonEmptyLazyList[B]]] =
tail.headOption.fold(Eval.now(Apply[G].map(f(head))(NonEmptyLazyList(_)))) { h =>
Apply[G].map2Eval(f(head), Eval.defer(loop(h, tail.tail)))((b, acc) => b +: acc)
}
.map {
case None => Apply[G].map(f(fa.head))(h => create(LazyList(h)))
case Some(gtail) => Apply[G].map2(f(fa.head), gtail)((h, t) => create(LazyList(h) ++ t))
}
.value

loop(fa.head, fa.tail).value
}

def reduceLeftTo[A, B](fa: NonEmptyLazyList[A])(f: A => B)(g: (B, A) => B): B = fa.reduceLeftTo(f)(g)

Expand Down
16 changes: 7 additions & 9 deletions core/src/main/scala/cats/data/NonEmptyChain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -420,16 +420,14 @@ sealed abstract private[data] class NonEmptyChainInstances extends NonEmptyChain
new AbstractNonEmptyInstances[Chain, NonEmptyChain] with Align[NonEmptyChain] {
def extract[A](fa: NonEmptyChain[A]): A = fa.head

def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyChain[A])(f: A => G[B]): G[NonEmptyChain[B]] =
Foldable[Chain]
.reduceRightToOption[A, G[Chain[B]]](fa.tail)(a => Apply[G].map(f(a))(Chain.one)) { (a, lglb) =>
Apply[G].map2Eval(f(a), lglb)(_ +: _)
def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyChain[A])(f: A => G[B]): G[NonEmptyChain[B]] = {
def loop(head: A, tail: Chain[A]): Eval[G[NonEmptyChain[B]]] =
tail.uncons.fold(Eval.now(Apply[G].map(f(head))(NonEmptyChain(_)))) {
case (h, t) => Apply[G].map2Eval(f(head), Eval.defer(loop(h, t)))((b, acc) => b +: acc)
}
.map {
case None => Apply[G].map(f(fa.head))(NonEmptyChain.one)
case Some(gtail) => Apply[G].map2(f(fa.head), gtail)((h, t) => create(Chain.one(h) ++ t))
}
.value

loop(fa.head, fa.tail).value
}

override def size[A](fa: NonEmptyChain[A]): Long = fa.length

Expand Down
17 changes: 8 additions & 9 deletions core/src/main/scala/cats/data/NonEmptyList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -562,16 +562,15 @@ sealed abstract private[data] class NonEmptyListInstances extends NonEmptyListIn

def extract[A](fa: NonEmptyList[A]): A = fa.head

def nonEmptyTraverse[G[_], A, B](nel: NonEmptyList[A])(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyList[B]] =
Foldable[List]
.reduceRightToOption[A, G[List[B]]](nel.tail)(a => G.map(f(a))(_ :: Nil)) { (a, lglb) =>
G.map2Eval(f(a), lglb)(_ :: _)
def nonEmptyTraverse[G[_], A, B](nel: NonEmptyList[A])(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyList[B]] = {
def loop(head: A, tail: List[A]): Eval[G[NonEmptyList[B]]] =
tail match {
case Nil => Eval.now(G.map(f(head))(NonEmptyList(_, Nil)))
case h :: t => G.map2Eval(f(head), Eval.defer(loop(h, t)))((b, acc) => NonEmptyList(b, acc.toList))
}
.map {
case None => G.map(f(nel.head))(NonEmptyList(_, Nil))
case Some(gtail) => G.map2(f(nel.head), gtail)(NonEmptyList(_, _))
}
.value

loop(nel.head, nel.tail).value
}

override def traverse[G[_], A, B](
fa: NonEmptyList[A]
Expand Down
19 changes: 9 additions & 10 deletions core/src/main/scala/cats/data/NonEmptyMapImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,15 @@ sealed class NonEmptyMapOps[K, A](val value: NonEmptyMap[K, A]) {
* through the running of this function on all the values in this map,
* returning an NonEmptyMap[K, B] in a G context.
*/
def nonEmptyTraverse[G[_], B](f: A => G[B])(implicit G: Apply[G]): G[NonEmptyMap[K, B]] =
reduceRightToOptionWithKey[A, G[SortedMap[K, B]]](tail)({
case (k, a) =>
G.map(f(a))(b => SortedMap.empty[K, B] + ((k, b)))
}) { (t, lglb) =>
G.map2Eval(f(t._2), lglb)((b, bs) => bs + ((t._1, b)))
}.map {
case None => G.map(f(head._2))(a => NonEmptyMapImpl.one(head._1, a))
case Some(gtail) => G.map2(f(head._2), gtail)((a, bs) => NonEmptyMapImpl((head._1, a), bs))
}.value
def nonEmptyTraverse[G[_], B](f: A => G[B])(implicit G: Apply[G]): G[NonEmptyMap[K, B]] = {
def loop(h: (K, A), t: SortedMap[K, A]): Eval[G[NonEmptyMap[K, B]]] =
if (t.isEmpty)
Eval.now(G.map(f(h._2))(b => NonEmptyMap(h._1 -> b, SortedMap.empty[K, B])))
else
G.map2Eval(f(h._2), Eval.defer(loop(t.head, t.tail)))((b, acc) => NonEmptyMap(h._1 -> b, acc.toSortedMap))

loop(head, tail).value
}

/**
* Typesafe stringification method.
Expand Down
20 changes: 9 additions & 11 deletions core/src/main/scala/cats/data/NonEmptyVector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -371,17 +371,15 @@ sealed abstract private[data] class NonEmptyVectorInstances {
def extract[A](fa: NonEmptyVector[A]): A = fa.head

def nonEmptyTraverse[G[_], A, B](
nel: NonEmptyVector[A]
)(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyVector[B]] =
Foldable[Vector]
.reduceRightToOption[A, G[Vector[B]]](nel.tail)(a => G.map(f(a))(_ +: Vector.empty)) { (a, lglb) =>
G.map2Eval(f(a), lglb)(_ +: _)
}
.map {
case None => G.map(f(nel.head))(NonEmptyVector(_, Vector.empty))
case Some(gtail) => G.map2(f(nel.head), gtail)(NonEmptyVector(_, _))
}
.value
nev: NonEmptyVector[A]
)(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyVector[B]] = {
def loop(head: A, tail: Vector[A]): Eval[G[NonEmptyVector[B]]] =
tail.headOption.fold(Eval.now(G.map(f(head))(NonEmptyVector(_, Vector.empty[B]))))(h =>
G.map2Eval(f(head), Eval.defer(loop(h, tail.tail)))((b, acc) => b +: acc)
)

loop(nev.head, nev.tail).value
}

override def traverse[G[_], A, B](
fa: NonEmptyVector[A]
Expand Down
17 changes: 14 additions & 3 deletions core/src/main/scala/cats/data/OneAnd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,20 @@ sealed abstract private[data] class OneAndLowPriority0 extends OneAndLowPriority
F2: Alternative[F]
): NonEmptyTraverse[OneAnd[F, *]] =
new NonEmptyReducible[OneAnd[F, *], F] with NonEmptyTraverse[OneAnd[F, *]] {
def nonEmptyTraverse[G[_], A, B](fa: OneAnd[F, A])(f: (A) => G[B])(implicit G: Apply[G]): G[OneAnd[F, B]] =
fa.map(a => Apply[G].map(f(a))(OneAnd(_, F2.empty[B])))(F)
.reduceLeft(((acc, a) => G.map2(acc, a)((x: OneAnd[F, B], y: OneAnd[F, B]) => x.combine(y))))
def nonEmptyTraverse[G[_], A, B](fa: OneAnd[F, A])(f: (A) => G[B])(implicit G: Apply[G]): G[OneAnd[F, B]] = {
import syntax.foldable._

def loop(head: A, tail: Iterator[A]): Eval[G[OneAnd[F, B]]] =
if (tail.hasNext) {
val h = tail.next()
val t = tail
G.map2Eval(f(head), Eval.defer(loop(h, t)))((b, acc) => OneAnd(b, acc.unwrap))
} else {
Eval.now(G.map(f(head))(OneAnd(_, F2.empty[B])))
}

loop(fa.head, fa.tail.toIterable.iterator).value
}

override def traverse[G[_], A, B](fa: OneAnd[F, A])(f: (A) => G[B])(implicit G: Applicative[G]): G[OneAnd[F, B]] =
G.map2Eval(f(fa.head), Always(F.traverse(fa.tail)(f)))(OneAnd(_, _)).value
Expand Down
21 changes: 20 additions & 1 deletion laws/src/main/scala/cats/laws/ShortCircuitingLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import java.util.concurrent.atomic.AtomicLong
import cats.instances.option._
import cats.syntax.foldable._
import cats.syntax.traverse._
import cats.syntax.nonEmptyTraverse._
import cats.syntax.traverseFilter._
import cats.{Applicative, Foldable, MonoidK, Traverse, TraverseFilter}
import cats.{Applicative, Foldable, MonoidK, NonEmptyTraverse, Traverse, TraverseFilter}

trait ShortCircuitingLaws[F[_]] {

Expand Down Expand Up @@ -46,6 +47,24 @@ trait ShortCircuitingLaws[F[_]] {
f.invocations.get <-> size
}

def nonEmptyTraverseShortCircuits[A](fa: F[A])(implicit F: NonEmptyTraverse[F]): IsEq[Long] = {
val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((i: A) => Some(i), maxInvocationsAllowed, None)

fa.nonEmptyTraverse(f)
f.invocations.get <-> (maxInvocationsAllowed + 1).min(size)
}

def nonEmptyTraverseWontShortCircuit[A](fa: F[A])(implicit F: NonEmptyTraverse[F]): IsEq[Long] = {
val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((i: A) => Some(i), maxInvocationsAllowed, None)

fa.nonEmptyTraverse(f)(nonShortCircuitingApplicative)
f.invocations.get <-> size
}

def traverseFilterShortCircuits[A](fa: F[A])(implicit TF: TraverseFilter[F]): IsEq[Long] = {
implicit val F: Traverse[F] = TF.traverse

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cats.laws.discipline

import cats.laws.ShortCircuitingLaws
import cats.{Eq, Foldable, Traverse, TraverseFilter}
import cats.{Eq, Foldable, NonEmptyTraverse, Traverse, TraverseFilter}
import org.scalacheck.Arbitrary
import org.scalacheck.Prop.forAll
import org.typelevel.discipline.Laws
Expand All @@ -25,11 +25,17 @@ trait ShortCircuitingTests[F[_]] extends Laws {
"traverse won't short-circuit if Applicative[G].map2Eval won't" -> forAll(laws.traverseWontShortCircuit[A] _)
)

def traverseFilter[A: Arbitrary](implicit
TF: TraverseFilter[F],
ArbFA: Arbitrary[F[A]],
lEq: Eq[Long]
): RuleSet = {
def nonEmptyTraverse[A: Arbitrary](implicit TF: NonEmptyTraverse[F], ArbFA: Arbitrary[F[A]], lEq: Eq[Long]): RuleSet =
new DefaultRuleSet(
name = "nonEmptyTraverseShortCircuiting",
parent = Some(traverse[A]),
"nonEmptyTraverse short-circuits if Applicative[G].map2Eval shorts" ->
forAll(laws.nonEmptyTraverseShortCircuits[A] _),
"nonEmptyTraverse short-circuits if Applicative[G].map2Eval won't" ->
forAll(laws.nonEmptyTraverseWontShortCircuit[A] _)
)

def traverseFilter[A: Arbitrary](implicit TF: TraverseFilter[F], ArbFA: Arbitrary[F[A]], lEq: Eq[Long]): RuleSet = {
implicit val T: Traverse[F] = TF.traverse
new DefaultRuleSet(
name = "traverseFilterShortCircuiting",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class NonEmptyStreamSuite extends CatsSuite {

checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].foldable[Int])
checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].traverse[Int])
checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].nonEmptyTraverse[Int])

{
// Test functor and subclasses don't have implicit conflicts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class NonEmptyLazyListSuite extends NonEmptyCollectionSuite[LazyList, NonEmptyLa

checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].foldable[Int])
checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].traverse[Int])
checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].nonEmptyTraverse[Int])

test("show") {
Show[NonEmptyLazyList[Int]].show(NonEmptyLazyList(1, 2, 3)) should ===("NonEmptyLazyList(1, ?)")
Expand Down
1 change: 1 addition & 0 deletions tests/src/test/scala/cats/tests/NonEmptyChainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class NonEmptyChainSuite extends NonEmptyCollectionSuite[Chain, NonEmptyChain, N

checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].foldable[Int])
checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].traverse[Int])
checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].nonEmptyTraverse[Int])

{
implicit val partialOrder: PartialOrder[ListWrapper[Int]] = ListWrapper.partialOrder[Int]
Expand Down
1 change: 1 addition & 0 deletions tests/src/test/scala/cats/tests/NonEmptyListSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class NonEmptyListSuite extends NonEmptyCollectionSuite[List, NonEmptyList, NonE

checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].foldable[Int])
checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].traverse[Int])
checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].nonEmptyTraverse[Int])

{
implicit val A: PartialOrder[ListWrapper[Int]] = ListWrapper.partialOrder[Int]
Expand Down
2 changes: 2 additions & 0 deletions tests/src/test/scala/cats/tests/NonEmptyMapSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class NonEmptyMapSuite extends CatsSuite {
checkAll("NonEmptyMap[String, Int]", AlignTests[NonEmptyMap[String, *]].align[Int, Int, Int, Int])
checkAll("Align[NonEmptyMap]", SerializableTests.serializable(Align[NonEmptyMap[String, *]]))

checkAll("NonEmptyMap[Int, *]", ShortCircuitingTests[NonEmptyMap[Int, *]].nonEmptyTraverse[Int])

test("Show is not empty and is formatted as expected") {
forAll { (nem: NonEmptyMap[String, Int]) =>
nem.show.nonEmpty should ===(true)
Expand Down
1 change: 1 addition & 0 deletions tests/src/test/scala/cats/tests/NonEmptyVectorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class NonEmptyVectorSuite extends NonEmptyCollectionSuite[Vector, NonEmptyVector

checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].foldable[Int])
checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].traverse[Int])
checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].nonEmptyTraverse[Int])

test("size is consistent with toList.size") {
forAll { (nonEmptyVector: NonEmptyVector[Int]) =>
Expand Down

0 comments on commit 0ade708

Please sign in to comment.