diff --git a/src/main/scala/scala/collection/next/NextIterableOpsOps.scala b/src/main/scala/scala/collection/next/NextIterableOpsOps.scala new file mode 100644 index 0000000..e516588 --- /dev/null +++ b/src/main/scala/scala/collection/next/NextIterableOpsOps.scala @@ -0,0 +1,51 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright Skylight IPV Ltd. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ +package scala.collection +package next + +private[next] final class NextIterableOpsOps[A, CC[_], C]( + private val coll: IterableOps[A, CC, C] +) extends AnyVal { + /** + * Partitions this $coll into a map of ${coll}s according to a discriminator function `key`. + * Each element in a group is transformed into a collection of type `B` using the `value` function. + * + * It is equivalent to `groupBy(key).mapValues(_.flatMap(f))`, but more efficient. + * + * {{{ + * case class User(name: String, age: Int, pets: Seq[String]) + * + * def petsByAge(users: Seq[User]): Map[Int, Seq[String]] = + * users.groupFlatMap(_.age)(_.pets) + * }}} + * + * $willForceEvaluation + * + * @param key the discriminator function + * @param f the element transformation function + * @tparam K the type of keys returned by the discriminator function + * @tparam B the type of values returned by the transformation function + */ + def groupFlatMap[K, B](key: A => K)(f: A => IterableOnce[B]): immutable.Map[K, CC[B]] = { + val m = mutable.Map.empty[K, mutable.Builder[B, CC[B]]] + coll.foreach { elem => + val k = key(elem) + val b = m.getOrElseUpdate(k, coll.iterableFactory.newBuilder[B]) + b ++= f(elem) + } + var result = immutable.Map.empty[K, CC[B]] + m.foreach { case (k, b) => + result = result + ((k, b.result())) + } + result + } +} diff --git a/src/main/scala/scala/collection/next/package.scala b/src/main/scala/scala/collection/next/package.scala index db9e1a6..cd16757 100644 --- a/src/main/scala/scala/collection/next/package.scala +++ b/src/main/scala/scala/collection/next/package.scala @@ -19,4 +19,9 @@ package object next { col: IterableOnceOps[A, CC, C] ): NextIterableOnceOpsExtensions[A, CC, C] = new NextIterableOnceOpsExtensions(col) + + implicit final def scalaNextSyntaxForIterableOps[A, CC[_], C]( + coll: IterableOps[A, CC, C] + ): NextIterableOpsOps[A, CC, C] = + new NextIterableOpsOps(coll) } diff --git a/src/test/scala/scala/collection/next/TestIterableOpsExtensions.scala b/src/test/scala/scala/collection/next/TestIterableOpsExtensions.scala new file mode 100644 index 0000000..2b14c6e --- /dev/null +++ b/src/test/scala/scala/collection/next/TestIterableOpsExtensions.scala @@ -0,0 +1,57 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright Skylight IPV Ltd. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ +package scala.collection.next + +import org.junit.Assert._ +import org.junit.Test +import scala.collection.IterableOps +import scala.collection.generic.IsIterable + +final class TestIterableOpsExtensions { + import TestIterableOpsExtensions.LowerCaseString + + @Test + def iterableOpsGroupFlatMap(): Unit = { + def groupedStrings[A, CC[_], C](coll: IterableOps[A, CC, C]): Map[A, CC[A]] = + coll.groupFlatMap(identity)(Seq(_)) + + val xs = Seq('a', 'b', 'c', 'b', 'c', 'c') + val expected = Map('a' -> Seq('a'), 'b' -> Seq('b', 'b'), 'c' -> Seq('c', 'c', 'c')) + assertEquals(expected, groupedStrings(xs)) + } + + @Test + def anyLikeIterableGroupFlatMap(): Unit = { + def groupedStrings[Repr](coll: Repr)(implicit it: IsIterable[Repr]): Map[it.A, Iterable[it.A]] = + it(coll).groupFlatMap(identity)(Seq(_)) + + val xs = "abbcaaab" + val expected = Map('a' -> Seq('a', 'a', 'a', 'a'), 'b' -> Seq('b', 'b', 'b'), 'c' -> Seq('c')) + assertEquals(expected, groupedStrings(xs)) + } + + @Test + def customIterableOnceOpsGroupMapReduce(): Unit = { + def groupedStrings(coll: LowerCaseString): Map[Char, Iterable[Char]] = + coll.groupFlatMap(identity)(Seq(_)) + + val xs = LowerCaseString("abBcAaAb") + val expected = Map('a' -> Seq('a', 'a', 'a', 'a'), 'b' -> Seq('b', 'b', 'b'), 'c' -> Seq('c')) + assertEquals(expected, groupedStrings(xs)) + } +} + +object TestIterableOpsExtensions { + final case class LowerCaseString(source: String) extends Iterable[Char] { + override def iterator: Iterator[Char] = source.iterator.map(_.toLower) + } +}