diff --git a/core/src/main/scala/cats/syntax/seq.scala b/core/src/main/scala/cats/syntax/seq.scala index fbcc04f34f6..75a111b94c3 100644 --- a/core/src/main/scala/cats/syntax/seq.scala +++ b/core/src/main/scala/cats/syntax/seq.scala @@ -1,7 +1,10 @@ package cats.syntax +import cats.Order import cats.data.NonEmptySeq + import scala.collection.immutable.Seq +import scala.collection.immutable.SortedMap trait SeqSyntax { implicit final def catsSyntaxSeqs[A](va: Seq[A]): SeqOps[A] = new SeqOps(va) @@ -43,4 +46,27 @@ final class SeqOps[A](private val va: Seq[A]) extends AnyVal { * }}} */ def concatNeSeq[AA >: A](neseq: NonEmptySeq[AA]): NonEmptySeq[AA] = neseq.prependSeq(va) + + /** + * Groups elements inside this `Seq` according to the `Order` of the keys + * produced by the given mapping function. + * + * {{{ + * scala> import cats.data.NonEmptySeq + * scala> import cats.syntax.all._ + * scala> import scala.collection.immutable.Seq + * scala> import scala.collection.immutable.SortedMap + * + * scala> val seq = Seq(12, -2, 3, -5) + * + * scala> val expectedResult = SortedMap(false -> NonEmptySeq.of(-2, -5), true -> NonEmptySeq.of(12, 3)) + * + * scala> seq.groupByNeSeq(_ >= 0) === expectedResult + * res0: Boolean = true + * }}} + */ + def groupByNeSeq[B](f: A => B)(implicit B: Order[B]): SortedMap[B, NonEmptySeq[A]] = { + implicit val ordering: Ordering[B] = B.toOrdering + toNeSeq.fold(SortedMap.empty[B, NonEmptySeq[A]])(_.groupBy(f)) + } } diff --git a/tests/src/test/scala/cats/tests/SeqSuite.scala b/tests/src/test/scala/cats/tests/SeqSuite.scala index 6f6cffb82c2..1f491254fbf 100644 --- a/tests/src/test/scala/cats/tests/SeqSuite.scala +++ b/tests/src/test/scala/cats/tests/SeqSuite.scala @@ -15,10 +15,11 @@ import cats.laws.discipline.{ TraverseTests } import cats.laws.discipline.arbitrary._ -import cats.syntax.show._ -import cats.syntax.seq._ import cats.syntax.eq._ +import cats.syntax.seq._ +import cats.syntax.show._ import org.scalacheck.Prop._ + import scala.collection.immutable.Seq class SeqSuite extends CatsSuite { @@ -71,6 +72,12 @@ class SeqSuite extends CatsSuite { } } + test("groupByNeSeq should be consistent with groupBy")( + forAll { (fa: Seq[Int], f: Int => Int) => + assert((fa.groupByNeSeq(f).map { case (k, v) => (k, v.toSeq) }: Map[Int, Seq[Int]]) === fa.groupBy(f)) + } + ) + test("traverse is stack-safe") { val seq = (0 until 100000).toSeq val sumAll = Traverse[Seq]