diff --git a/bench/src/main/scala-2.12/cats/bench/ChainBench.scala b/bench/src/main/scala-2.12/cats/bench/ChainBench.scala index d0019dd16a..81c77916e9 100644 --- a/bench/src/main/scala-2.12/cats/bench/ChainBench.scala +++ b/bench/src/main/scala-2.12/cats/bench/ChainBench.scala @@ -88,4 +88,10 @@ class ChainBench { @Benchmark def createChainSeqOption: Chain[Int] = Chain.fromSeq(intOption.toSeq) @Benchmark def createChainOption: Chain[Int] = Chain.fromOption(intOption) + + @Benchmark def reverseLargeList: List[Int] = largeList.reverse + @Benchmark def reverseLargeChain: Chain[Int] = largeChain.reverse + + @Benchmark def lengthLargeList: Int = largeList.length + @Benchmark def lengthLargeChain: Long = largeChain.length } diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index 669364cf4e..9453f7b729 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -503,8 +503,38 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { /** * Reverses this `Chain` */ - def reverse: Chain[A] = - fromSeq(reverseIterator.toVector) + def reverse: Chain[A] = { + @annotation.tailrec + def loop[B <: A](h: Chain.NonEmpty[B], tail: List[Chain.NonEmpty[B]], acc: Chain[A]): Chain[A] = + h match { + case Append(l, r) => loop(l, r :: tail, acc) + case sing @ Singleton(_) => + val nextAcc = sing.concat(acc) + tail match { + case h1 :: t1 => + loop(h1, t1, nextAcc) + case _ => + nextAcc + } + case Wrap(seq) => + val nextAcc = Wrap(seq.reverse).concat(acc) + tail match { + case h1 :: t1 => + loop(h1, t1, nextAcc) + case _ => + nextAcc + } + } + + this match { + case Append(l, r) => + loop(l, r :: Nil, Empty) + case Wrap(seq) => Wrap(seq.reverse) + case _ => + // Empty | Singleton(_) + this + } + } /** * Yields to Some(a, Chain[A]) with `a` removed where `f` holds for the first time, @@ -587,19 +617,35 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { * Returns the number of elements in this structure */ final def length: Long = { + // This is an optimized (unboxed) implementation + // of the same code as foldLeft @annotation.tailrec - def loop(chains: List[Chain[A]], acc: Long): Long = - chains match { - case Nil => acc - case h :: tail => - h match { - case Empty => loop(tail, acc) - case Wrap(seq) => loop(tail, acc + seq.length) - case Singleton(a) => loop(tail, acc + 1) - case Append(l, r) => loop(l :: r :: tail, acc) + def loop(head: Chain.NonEmpty[A], tail: List[Chain.NonEmpty[A]], acc: Long): Long = + head match { + case Append(l, r) => loop(l, r :: tail, acc) + case Singleton(_) => + val nextAcc = acc + 1L + tail match { + case h1 :: t1 => + loop(h1, t1, nextAcc) + case _ => + nextAcc + } + case Wrap(seq) => + val nextAcc = acc + seq.length.toLong + tail match { + case h1 :: t1 => + loop(h1, t1, nextAcc) + case _ => + nextAcc } } - loop(this :: Nil, 0L) + + this match { + case ne: Chain.NonEmpty[A] => + loop(ne, Nil, 0L) + case _ => 0L + } } /** @@ -632,28 +678,37 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { * }}} */ final def lengthCompare(len: Long): Int = { - import java.lang.Long - this match { - // `isEmpty` check should be faster than `== Chain.Empty`, - // but the compiler fails to prove that the match is still exhaustive. - case _ if isEmpty => Long.compare(0L, len) - case Chain.Singleton(_) => Long.compare(1L, len) - case _ if len < 2 => 1 // the following cases should always have `length >= 2` - case Chain.Wrap(seq) => - if (len > Int.MaxValue) -1 // `Seq#length` has `Int` type so cannot be `> Int.MaxValue` - else - seq.lengthCompare(len.toInt) - case _ => // should always be `Chain.Append` (i.e. `NonEmpty` with 2+ elements) - var sz = 2L - val it = new ChainIterator(this) - it.next() - it.next() - while (it.hasNext) { - if (sz == len) return 1 - it.next() - sz += 1L + // This is an optimized (unboxed) implementation + // of the same code as foldLeft + @annotation.tailrec + def loop(head: Chain.NonEmpty[A], tail: List[Chain.NonEmpty[A]], len: Long): Int = + if (len <= 0L) 1 // head is nonempty + else + head match { + case Append(l, r) => loop(l, r :: tail, len) + case Singleton(_) => + tail match { + case h1 :: t1 => + loop(h1, t1, len - 1L) + case _ => + java.lang.Long.compare(1L, len) + } + case Wrap(seq) => + val c = + if (len <= Int.MaxValue) seq.lengthCompare(len.toInt) + else -1 + tail match { + case h1 :: t1 => + if (c >= 0) 1 // there is definitely more in tail + else loop(h1, t1, len - seq.length) + case _ => c + } } - Long.compare(sz, len) + + this match { + case ne: Chain.NonEmpty[A] => + loop(ne, Nil, len) + case _ => java.lang.Long.compare(0L, len) } } @@ -767,18 +822,20 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { final def sortBy[B](f: A => B)(implicit B: Order[B]): Chain[A] = this match { - case Singleton(_) => this case Append(_, _) => Wrap(toVector.sortBy(f)(B.toOrdering)) case Wrap(seq) => Wrap(seq.sortBy(f)(B.toOrdering)) - case _ => this + case _ => + // Empty | Singleton(_) + this } final def sorted[AA >: A](implicit AA: Order[AA]): Chain[AA] = this match { - case Singleton(_) => this case Append(_, _) => Wrap(toVector.sorted(AA.toOrdering)) case Wrap(seq) => Wrap(seq.sorted(AA.toOrdering)) - case _ => this + case _ => + // Empty | Singleton(_) + this } } @@ -1115,6 +1172,9 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 { Eval.defer(loop(fa)) } + override def foldMap[A, B](fa: Chain[A])(f: A => B)(implicit B: Monoid[B]): B = + B.combineAll(fa.iterator.map(f)) + override def map[A, B](fa: Chain[A])(f: A => B): Chain[B] = fa.map(f) override def toList[A](fa: Chain[A]): List[A] = fa.toList override def isEmpty[A](fa: Chain[A]): Boolean = fa.isEmpty diff --git a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala index 099f2bd246..f568639e48 100644 --- a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala @@ -434,4 +434,10 @@ class ChainSuite extends CatsSuite { assert(sumAll == chain.iterator.sum) } + + test("foldRight(b)(fn) == toList.foldRight(b)(fn)") { + forAll { (chain: Chain[Int], init: Long, fn: (Int, Long) => Long) => + assert(chain.foldRight(init)(fn) == chain.toList.foldRight(init)(fn)) + } + } }