Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A few more chain optimizations #4170

Merged
merged 11 commits into from
Apr 11, 2022
6 changes: 6 additions & 0 deletions bench/src/main/scala-2.12/cats/bench/ChainBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,10 @@ class ChainBench {

@Benchmark def createChainSeqOption: Chain[Int] = Chain.fromSeq(intOption.toSeq)
@Benchmark def createChainOption: Chain[Int] = Chain.fromOption(intOption)

@Benchmark def reverseLargeList: List[Int] = largeList.reverse
@Benchmark def reverseLargeChain: Chain[Int] = largeChain.reverse

@Benchmark def lengthLargeList: Int = largeList.length
@Benchmark def lengthLargeChain: Long = largeChain.length
}
134 changes: 97 additions & 37 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,38 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
/**
* Reverses this `Chain`
*/
def reverse: Chain[A] =
fromSeq(reverseIterator.toVector)
def reverse: Chain[A] = {
@annotation.tailrec
def loop[B <: A](h: Chain.NonEmpty[B], tail: List[Chain.NonEmpty[B]], acc: Chain[A]): Chain[A] =
h match {
case Append(l, r) => loop(l, r :: tail, acc)
case sing @ Singleton(_) =>
val nextAcc = sing.concat(acc)
tail match {
case h1 :: t1 =>
loop(h1, t1, nextAcc)
case _ =>
nextAcc
}
case Wrap(seq) =>
val nextAcc = Wrap(seq.reverse).concat(acc)
tail match {
case h1 :: t1 =>
loop(h1, t1, nextAcc)
case _ =>
nextAcc
}
}

this match {
case Append(l, r) =>
loop(l, r :: Nil, Empty)
case Wrap(seq) => Wrap(seq.reverse)
case _ =>
// Empty | Singleton(_)
this
}
}

/**
* Yields to Some(a, Chain[A]) with `a` removed where `f` holds for the first time,
Expand Down Expand Up @@ -587,19 +617,35 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
* Returns the number of elements in this structure
*/
final def length: Long = {
// This is an optimized (unboxed) implementation
// of the same code as foldLeft
@annotation.tailrec
def loop(chains: List[Chain[A]], acc: Long): Long =
chains match {
case Nil => acc
case h :: tail =>
h match {
case Empty => loop(tail, acc)
case Wrap(seq) => loop(tail, acc + seq.length)
case Singleton(a) => loop(tail, acc + 1)
case Append(l, r) => loop(l :: r :: tail, acc)
def loop(head: Chain.NonEmpty[A], tail: List[Chain.NonEmpty[A]], acc: Long): Long =
head match {
case Append(l, r) => loop(l, r :: tail, acc)
case Singleton(_) =>
val nextAcc = acc + 1L
tail match {
case h1 :: t1 =>
loop(h1, t1, nextAcc)
case _ =>
nextAcc
}
case Wrap(seq) =>
val nextAcc = acc + seq.length.toLong
tail match {
case h1 :: t1 =>
loop(h1, t1, nextAcc)
case _ =>
nextAcc
}
}
loop(this :: Nil, 0L)

this match {
case ne: Chain.NonEmpty[A] =>
loop(ne, Nil, 0L)
case _ => 0L
}
}

/**
Expand Down Expand Up @@ -632,28 +678,37 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
* }}}
*/
final def lengthCompare(len: Long): Int = {
import java.lang.Long
this match {
// `isEmpty` check should be faster than `== Chain.Empty`,
// but the compiler fails to prove that the match is still exhaustive.
case _ if isEmpty => Long.compare(0L, len)
case Chain.Singleton(_) => Long.compare(1L, len)
case _ if len < 2 => 1 // the following cases should always have `length >= 2`
case Chain.Wrap(seq) =>
if (len > Int.MaxValue) -1 // `Seq#length` has `Int` type so cannot be `> Int.MaxValue`
else
seq.lengthCompare(len.toInt)
case _ => // should always be `Chain.Append` (i.e. `NonEmpty` with 2+ elements)
var sz = 2L
val it = new ChainIterator(this)
it.next()
it.next()
while (it.hasNext) {
if (sz == len) return 1
it.next()
sz += 1L
// This is an optimized (unboxed) implementation
// of the same code as foldLeft
@annotation.tailrec
def loop(head: Chain.NonEmpty[A], tail: List[Chain.NonEmpty[A]], len: Long): Int =
if (len <= 0L) 1 // head is nonempty
else
head match {
case Append(l, r) => loop(l, r :: tail, len)
case Singleton(_) =>
tail match {
case h1 :: t1 =>
loop(h1, t1, len - 1L)
case _ =>
java.lang.Long.compare(1L, len)
}
case Wrap(seq) =>
val c =
if (len <= Int.MaxValue) seq.lengthCompare(len.toInt)
else -1
tail match {
case h1 :: t1 =>
if (c >= 0) 1 // there is definitely more in tail
else loop(h1, t1, len - seq.length)
case _ => c
}
}
Long.compare(sz, len)

this match {
case ne: Chain.NonEmpty[A] =>
loop(ne, Nil, len)
case _ => java.lang.Long.compare(0L, len)
}
}

Expand Down Expand Up @@ -767,18 +822,20 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {

final def sortBy[B](f: A => B)(implicit B: Order[B]): Chain[A] =
this match {
case Singleton(_) => this
case Append(_, _) => Wrap(toVector.sortBy(f)(B.toOrdering))
case Wrap(seq) => Wrap(seq.sortBy(f)(B.toOrdering))
case _ => this
case _ =>
// Empty | Singleton(_)
this
}

final def sorted[AA >: A](implicit AA: Order[AA]): Chain[AA] =
this match {
case Singleton(_) => this
case Append(_, _) => Wrap(toVector.sorted(AA.toOrdering))
case Wrap(seq) => Wrap(seq.sorted(AA.toOrdering))
case _ => this
case _ =>
// Empty | Singleton(_)
this
}
}

Expand Down Expand Up @@ -1115,6 +1172,9 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
Eval.defer(loop(fa))
}

override def foldMap[A, B](fa: Chain[A])(f: A => B)(implicit B: Monoid[B]): B =
B.combineAll(fa.iterator.map(f))

override def map[A, B](fa: Chain[A])(f: A => B): Chain[B] = fa.map(f)
override def toList[A](fa: Chain[A]): List[A] = fa.toList
override def isEmpty[A](fa: Chain[A]): Boolean = fa.isEmpty
Expand Down
6 changes: 6 additions & 0 deletions tests/shared/src/test/scala/cats/tests/ChainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,4 +434,10 @@ class ChainSuite extends CatsSuite {

assert(sumAll == chain.iterator.sum)
}

test("foldRight(b)(fn) == toList.foldRight(b)(fn)") {
forAll { (chain: Chain[Int], init: Long, fn: (Int, Long) => Long) =>
assert(chain.foldRight(init)(fn) == chain.toList.foldRight(init)(fn))
}
}
}