Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IterableOps.groupFlatMap #136

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/main/scala/scala/collection/next/NextIterableOpsOps.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Scala (https://www.scala-lang.org)
*
* Copyright Skylight IPV Ltd.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/
package scala.collection
package next

private[next] final class NextIterableOpsOps[A, CC[_], C](
private val coll: IterableOps[A, CC, C]
) extends AnyVal {
/**
* Partitions this $coll into a map of ${coll}s according to a discriminator function `key`.
* Each element in a group is transformed into a collection of type `B` using the `value` function.
*
* It is equivalent to `groupBy(key).mapValues(_.flatMap(f))`, but more efficient.
*
* {{{
* case class User(name: String, age: Int, pets: Seq[String])
*
* def petsByAge(users: Seq[User]): Map[Int, Seq[String]] =
* users.groupFlatMap(_.age)(_.pets)
* }}}
*
* $willForceEvaluation
*
* @param key the discriminator function
* @param f the element transformation function
* @tparam K the type of keys returned by the discriminator function
* @tparam B the type of values returned by the transformation function
*/
def groupFlatMap[K, B](key: A => K)(f: A => IterableOnce[B]): immutable.Map[K, CC[B]] = {
val m = mutable.Map.empty[K, mutable.Builder[B, CC[B]]]
coll.foreach { elem =>
val k = key(elem)
val b = m.getOrElseUpdate(k, coll.iterableFactory.newBuilder[B])
b ++= f(elem)
}
var result = immutable.Map.empty[K, CC[B]]
m.foreach { case (k, b) =>
result = result + ((k, b.result()))
}
result
}
}
5 changes: 5 additions & 0 deletions src/main/scala/scala/collection/next/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,9 @@ package object next {
col: IterableOnceOps[A, CC, C]
): NextIterableOnceOpsExtensions[A, CC, C] =
new NextIterableOnceOpsExtensions(col)

implicit final def scalaNextSyntaxForIterableOps[A, CC[_], C](
coll: IterableOps[A, CC, C]
): NextIterableOpsOps[A, CC, C] =
new NextIterableOpsOps(coll)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Scala (https://www.scala-lang.org)
*
* Copyright Skylight IPV Ltd.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/
package scala.collection.next

import org.junit.Assert._
import org.junit.Test
import scala.collection.IterableOps
import scala.collection.generic.IsIterable

final class TestIterableOpsExtensions {
import TestIterableOpsExtensions.LowerCaseString

@Test
def iterableOpsGroupFlatMap(): Unit = {
def groupedStrings[A, CC[_], C](coll: IterableOps[A, CC, C]): Map[A, CC[A]] =
coll.groupFlatMap(identity)(Seq(_))

val xs = Seq('a', 'b', 'c', 'b', 'c', 'c')
val expected = Map('a' -> Seq('a'), 'b' -> Seq('b', 'b'), 'c' -> Seq('c', 'c', 'c'))
assertEquals(expected, groupedStrings(xs))
}

@Test
def anyLikeIterableGroupFlatMap(): Unit = {
def groupedStrings[Repr](coll: Repr)(implicit it: IsIterable[Repr]): Map[it.A, Iterable[it.A]] =
it(coll).groupFlatMap(identity)(Seq(_))

val xs = "abbcaaab"
val expected = Map('a' -> Seq('a', 'a', 'a', 'a'), 'b' -> Seq('b', 'b', 'b'), 'c' -> Seq('c'))
assertEquals(expected, groupedStrings(xs))
}

@Test
def customIterableOnceOpsGroupMapReduce(): Unit = {
def groupedStrings(coll: LowerCaseString): Map[Char, Iterable[Char]] =
coll.groupFlatMap(identity)(Seq(_))

val xs = LowerCaseString("abBcAaAb")
val expected = Map('a' -> Seq('a', 'a', 'a', 'a'), 'b' -> Seq('b', 'b', 'b'), 'c' -> Seq('c'))
assertEquals(expected, groupedStrings(xs))
}
}

object TestIterableOpsExtensions {
final case class LowerCaseString(source: String) extends Iterable[Char] {
override def iterator: Iterator[Char] = source.iterator.map(_.toLower)
}
}