Skip to content

Commit

Permalink
Merge pull request #4694 from typelevel/oscar/20250102_drop_take_chain
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Jan 18, 2025
2 parents 8ce7326 + fdcff72 commit 32a50dc
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 11 deletions.
14 changes: 8 additions & 6 deletions core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,17 @@ private[data] trait ChainCompanionCompat {
}

private def fromImmutableSeq[A](s: immutable.Seq[A]): Chain[A] = {
if (s.isEmpty) nil
else if (s.lengthCompare(1) == 0) one(s.head)
else Wrap(s)
val lc = s.lengthCompare(1)
if (lc < 0) nil
else if (lc > 0) Wrap(s)
else one(s.head)
}

private def fromMutableSeq[A](s: Seq[A]): Chain[A] = {
if (s.isEmpty) nil
else if (s.lengthCompare(1) == 0) one(s.head)
else Wrap(s.toVector)
val lc = s.lengthCompare(1)
if (lc < 0) nil
else if (lc > 0) Wrap(s.toVector)
else one(s.head)
}

/**
Expand Down
10 changes: 6 additions & 4 deletions core/src/main/scala-2.13+/cats/data/ChainCompanionCompat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ private[data] trait ChainCompanionCompat {
/**
* Creates a Chain from the specified sequence.
*/
def fromSeq[A](s: Seq[A]): Chain[A] =
if (s.isEmpty) nil
else if (s.lengthCompare(1) == 0) one(s.head)
else Wrap(s)
def fromSeq[A](s: Seq[A]): Chain[A] = {
val lc = s.lengthCompare(1)
if (lc < 0) nil
else if (lc > 0) Wrap(s)
else one(s.head)
}

/**
* Creates a Chain from the specified IterableOnce.
Expand Down
195 changes: 194 additions & 1 deletion core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,99 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
result
}

/**
* take a certain amount of items from the front of the Chain
*/
final def take(count: Long): Chain[A] = {
// invariant count >= 1
@tailrec
def go(lhs: Chain[A], count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] =
arg match {
case Wrap(seq) =>
if (count == 1) {
lhs.append(seq.head)
} else {
// count > 1
val taken =
if (count < Int.MaxValue) seq.take(count.toInt)
else seq.take(Int.MaxValue)
// we may have not taken all of count
val newCount = count - taken.length
val wrapped = Wrap(taken)
// this is more efficient than using concat
val newLhs = if (lhs.isEmpty) wrapped else Append(lhs, wrapped)
rhs match {
case rhsNE: NonEmpty[A] if newCount > 0L =>
// we have to keep taking on the rhs
go(newLhs, newCount, rhsNE, Empty)
case _ =>
newLhs
}
}
case Append(l, r) =>
go(lhs, count, l, if (rhs.isEmpty) r else Append(r, rhs))
case s @ Singleton(_) =>
// due to the invariant count >= 1
val newLhs = if (lhs.isEmpty) s else Append(lhs, s)
rhs match {
case rhsNE: NonEmpty[A] if count > 1L =>
go(newLhs, count - 1L, rhsNE, Empty)
case _ => newLhs
}
}

this match {
case ne: NonEmpty[A] if count > 0L =>
go(Empty, count, ne, Empty)
case _ => Empty
}
}

/**
* take a certain amount of items from the back of the Chain
*/
final def takeRight(count: Long): Chain[A] = {
// invariant count >= 1
@tailrec
def go(lhs: Chain[A], count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] =
arg match {
case Wrap(seq) =>
if (count == 1L) {
seq.last +: rhs
} else {
// count > 1
val taken =
if (count < Int.MaxValue) seq.takeRight(count.toInt)
else seq.takeRight(Int.MaxValue)
// we may have not taken all of count
val newCount = count - taken.length
val wrapped = Wrap(taken)
val newRhs = if (rhs.isEmpty) wrapped else Append(wrapped, rhs)
lhs match {
case lhsNE: NonEmpty[A] if newCount > 0 =>
go(Empty, newCount, lhsNE, newRhs)
case _ => newRhs
}
}
case Append(l, r) =>
go(if (lhs.isEmpty) l else Append(lhs, l), count, r, rhs)
case s @ Singleton(_) =>
// due to the invariant count >= 1
val newRhs = if (rhs.isEmpty) s else Append(s, rhs)
lhs match {
case lhsNE: NonEmpty[A] if count > 1 =>
go(Empty, count - 1, lhsNE, newRhs)
case _ => newRhs
}
}

this match {
case ne: NonEmpty[A] if count > 0L =>
go(Empty, count, ne, Empty)
case _ => Empty
}
}

/**
* Drops longest prefix of elements that satisfy a predicate.
*
Expand All @@ -275,6 +368,105 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
go(this)
}

/**
* Drop a certain amount of items from the front of the Chain
*/
final def drop(count: Long): Chain[A] = {
// invariant count >= 1
@tailrec
def go(count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] =
arg match {
case Wrap(seq) =>
val dropped = if (count < Int.MaxValue) seq.drop(count.toInt) else seq.drop(Int.MaxValue)
val lc = dropped.lengthCompare(1)
if (lc < 0) {
// if dropped.length < 1, then it is zero
// we may have not dropped all of count
val newCount = count - seq.length
rhs match {
case rhsNE: NonEmpty[A] if newCount > 0 =>
// we have to keep dropping on the rhs
go(newCount, rhsNE, Empty)
case _ =>
// we know that count >= seq.length else we wouldn't be empty
// so in this case, it is exactly count == seq.length
rhs
}
} else {
// dropped is not empty
val wrapped = if (lc > 0) Wrap(dropped) else Singleton(dropped.head)
// we must be done
if (rhs.isEmpty) wrapped else Append(wrapped, rhs)
}
case Append(l, r) =>
go(count, l, if (rhs.isEmpty) r else Append(r, rhs))
case Singleton(_) =>
// due to the invariant count >= 1
rhs match {
case rhsNE: NonEmpty[A] if count > 1L =>
go(count - 1L, rhsNE, Empty)
case _ =>
rhs
}
}

this match {
case ne: NonEmpty[A] if count > 0L =>
go(count, ne, Empty)
case _ => this
}
}

/**
* Drop a certain amount of items from the back of the Chain
*/
final def dropRight(count: Long): Chain[A] = {
// invariant count >= 1
@tailrec
def go(lhs: Chain[A], count: Long, arg: NonEmpty[A]): Chain[A] =
arg match {
case Wrap(seq) =>
val dropped = if (count < Int.MaxValue) seq.dropRight(count.toInt) else seq.dropRight(Int.MaxValue)
val lc = dropped.lengthCompare(1)
if (lc < 0) {
// if dropped.length < 1, then it is zero
// we may have not dropped all of count
val newCount = count - seq.length
lhs match {
case lhsNE: NonEmpty[A] if newCount > 0L =>
// we have to keep dropping on the lhs
go(Empty, newCount, lhsNE)
case _ =>
// we know that count >= seq.length else we wouldn't be empty
// so in this case, it is exactly count == seq.length
lhs
}
} else {
// we must be done
// note: dropped.nonEmpty
val wrapped = if (lc > 0) Wrap(dropped) else Singleton(dropped.head)
if (lhs.isEmpty) wrapped else Append(lhs, wrapped)
}
case Append(l, r) =>
go(if (lhs.isEmpty) l else Append(lhs, l), count, r)
case Singleton(_) =>
// due to the invariant count >= 1
lhs match {
case lhsNE: NonEmpty[A] if count > 1L =>
go(Empty, count - 1L, lhsNE)
case _ =>
lhs
}
}

this match {
case ne: NonEmpty[A] if count > 0L =>
go(Empty, count, ne)
case _ =>
this
}
}

/**
* Folds over the elements from right to left using the supplied initial value and function.
*/
Expand Down Expand Up @@ -940,7 +1132,8 @@ object Chain extends ChainInstances with ChainCompanionCompat {
* if the length is one, fromSeq returns Singleton
*
* The only places we create Wrap is in fromSeq and in methods that preserve
* length: zipWithIndex, map, sort
* length: zipWithIndex, map, sort. Additionally, in drop/dropRight we carefully
* preserve this invariant.
*/
final private[data] case class Wrap[A](seq: immutable.Seq[A]) extends NonEmpty[A]

Expand Down
34 changes: 34 additions & 0 deletions tests/shared/src/test/scala/cats/tests/ChainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,38 @@ class ChainSuite extends CatsSuite {
assert(chain.foldRight(init)(fn) == chain.toList.foldRight(init)(fn))
}
}

private val genChainDropTakeArgs =
Arbitrary.arbitrary[Chain[Int]].flatMap { chain =>
// Bias to values close to the length
Gen
.oneOf(
Gen.choose(Int.MinValue, Int.MaxValue),
Gen.choose(-1, chain.length.toInt + 1)
)
.map((chain, _))
}

test("drop(cnt).toList == toList.drop(cnt)") {
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
assertEquals(chain.drop(count).toList, chain.toList.drop(count))
}
}

test("dropRight(cnt).toList == toList.dropRight(cnt)") {
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
assertEquals(chain.dropRight(count).toList, chain.toList.dropRight(count))
}
}
test("take(cnt).toList == toList.take(cnt)") {
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
assertEquals(chain.take(count).toList, chain.toList.take(count))
}
}

test("takeRight(cnt).toList == toList.takeRight(cnt)") {
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
assertEquals(chain.takeRight(count).toList, chain.toList.takeRight(count))
}
}
}

0 comments on commit 32a50dc

Please sign in to comment.