Skip to content

Commit

Permalink
List(...) optimization to avoid intermediate array (#17166)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasstucki authored Nov 28, 2023
2 parents 07ff4fb + bdb89d8 commit 841bbd4
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 28 deletions.
26 changes: 15 additions & 11 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -517,14 +517,16 @@ class Definitions {
methodNames.map(getWrapVarargsArrayModule.requiredMethod(_))
})

@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
def ConsType: TypeRef = ConsClass.typeRef
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")
@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List)
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
def ConsType: TypeRef = ConsClass.typeRef
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")

@tu lazy val SingletonClass: ClassSymbol =
// needed as a synthetic class because Scala 2.x refers to it in classfiles
Expand All @@ -534,16 +536,18 @@ class Definitions {
List(AnyType), EmptyScope)
@tu lazy val SingletonType: TypeRef = SingletonClass.typeRef

@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
@tu lazy val SeqModule_apply: Symbol = SeqModule.requiredMethod(nme.apply)
def SeqModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.Seq)
def SeqClass(using Context): ClassSymbol = SeqType.symbol.asClass
@tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply)
@tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head)
@tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop)
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")


@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
Expand Down
71 changes: 55 additions & 16 deletions compiler/src/dotty/tools/dotc/transform/ArrayApply.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package dotty.tools.dotc
package dotty.tools
package dotc
package transform

import core.*
import ast.tpd
import core.*, Contexts.*, Decorators.*, Symbols.*, Flags.*, StdNames.*
import reporting.trace
import util.Property
import MegaPhase.*
import Contexts.*
import Symbols.*
import Flags.*
import StdNames.*
import dotty.tools.dotc.ast.tpd



/** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode.
*
Expand All @@ -22,27 +19,69 @@ class ArrayApply extends MiniPhase {

override def description: String = ArrayApply.description

override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
private val TransformListApplyBudgetKey = new Property.Key[Int]
private def transformListApplyBudget(using Context) =
ctx.property(TransformListApplyBudgetKey).getOrElse(8) // default is 8, as originally implemented in nsc

override def prepareForApply(tree: Apply)(using Context): Context = tree match
case SeqApplyArgs(elems) =>
ctx.fresh.setProperty(TransformListApplyBudgetKey, transformListApplyBudget - elems.length)
case _ => ctx

override def transformApply(tree: Apply)(using Context): Tree =
if isArrayModuleApply(tree.symbol) then
tree.args match {
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
tree.args match
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: ct :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
seqLit

case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil
case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) =>
tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)

case _ =>
tree
}

else tree
else tree match
case SeqApplyArgs(elems) if transformListApplyBudget > 0 || elems.isEmpty =>
val consed = elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))
consed.cast(tree.tpe)
case _ => tree

private def isArrayModuleApply(sym: Symbol)(using Context): Boolean =
sym.name == nme.apply
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))

private def isListApply(tree: Tree)(using Context): Boolean =
(tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.ListModule
|| sym == defn.ListModuleAlias
case _ => false

private def isSeqApply(tree: Tree)(using Context): Boolean =
isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.SeqModule
|| sym == defn.SeqModuleAlias
|| sym == defn.CollectionSeqType.symbol.companionModule
case _ => false

private object SeqApplyArgs:
def unapply(tree: Apply)(using Context): Option[List[Tree]] =
if isSeqApply(tree) then
tree.args match
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) =>
Some(rest.elems)
case _ => None
else None


/** Only optimize when classtag if it is one of
* - `ClassTag.apply(classOf[XYZ])`
* - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ``
Expand Down
152 changes: 151 additions & 1 deletion compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package dotty.tools.backend.jvm
package dotty.tools
package backend.jvm

import org.junit.Test
import org.junit.Assert._
Expand Down Expand Up @@ -160,4 +161,153 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
}
}

@Test def testListApplyAvoidsIntermediateArray = {
checkApplyAvoidsIntermediateArray("List"):
"""import scala.collection.immutable.{ ::, Nil }
|class Foo {
| def meth1: List[String] = List("1", "2", "3")
| def meth2: List[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray = {
checkApplyAvoidsIntermediateArray("Seq"):
"""import scala.collection.immutable.{ ::, Nil }
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray2 = {
checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"):
"""import scala.collection.immutable.{ ::, Seq, Nil }
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray3 = {
checkApplyAvoidsIntermediateArray("scala.collection.Seq"):
"""import scala.collection.immutable.{ ::, Nil }, scala.collection.Seq
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max1 = {
checkApplyAvoidsIntermediateArray_examples("max1"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7")
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::("7", Nil)))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max2 = {
checkApplyAvoidsIntermediateArray_examples("max2"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]())
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(Nil, Nil)))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max3 = {
checkApplyAvoidsIntermediateArray_examples("max3"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6"))
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(new ::("6", Nil), Nil))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max4 = {
checkApplyAvoidsIntermediateArray_examples("max4"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", List[Object]("5", "6"))
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::(new ::("5", new ::("6", Nil)), Nil)))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over1 = {
checkApplyAvoidsIntermediateArray_examples("over1"):
""" def meth1: List[Object] = List("1", "2", "3", "4", "5", "6", "7", "8")
| def meth2: List[Object] = List(wrapRefArray(Array("1", "2", "3", "4", "5", "6", "7", "8"))*)
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over2 = {
checkApplyAvoidsIntermediateArray_examples("over2"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]())
| def meth2: List[Object] = List(wrapRefArray(Array[Object]("1", "2", "3", "4", "5", "6", "7", Nil))*)
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over3 = {
checkApplyAvoidsIntermediateArray_examples("over3"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7"))
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(List(wrapRefArray(Array[Object]("7"))*), Nil)))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over4 = {
checkApplyAvoidsIntermediateArray_examples("over4"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7"))
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(List(wrapRefArray(Array[Object]("6", "7"))*), Nil))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max5 = {
checkApplyAvoidsIntermediateArray_examples("max5"):
""" def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]())))))))
| def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(Nil, Nil), Nil), Nil), Nil), Nil), Nil), Nil)
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over5 = {
checkApplyAvoidsIntermediateArray_examples("over5"):
""" def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]()))))))))
| def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(List[Object](wrapRefArray(Array[Object](Nil))*), Nil), Nil), Nil), Nil), Nil), Nil), Nil)
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max6 = {
checkApplyAvoidsIntermediateArray_examples("max6"):
""" def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object](List[Object]())))
| def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::(Nil, Nil), Nil))), Nil)))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over6 = {
checkApplyAvoidsIntermediateArray_examples("over6"):
""" def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object]("5")))
| def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::("5", Nil), Nil))), Nil)))
""".stripMargin
}

def checkApplyAvoidsIntermediateArray_examples(name: String)(body: String): Unit = {
checkApplyAvoidsIntermediateArray(s"List_$name"):
s"""import scala.collection.immutable.{ ::, Nil }, scala.runtime.ScalaRunTime.wrapRefArray
|class Foo {
|$body
|}
""".stripMargin
}

def checkApplyAvoidsIntermediateArray(name: String)(source: String): Unit = {
checkBCode(source) { dir =>
val clsIn = dir.lookupName("Foo.class", directory = false).input
val clsNode = loadClassNode(clsIn)
val meth1 = getMethod(clsNode, "meth1")
val meth2 = getMethod(clsNode, "meth2")

val instructions1 = instructionsFromMethod(meth1).filter { case TypeOp(CHECKCAST, _) => false case _ => true }
val instructions2 = instructionsFromMethod(meth2).filter { case TypeOp(CHECKCAST, _) => false case _ => true }

assert(instructions1 == instructions2,
s"the $name.apply method\n" +
diffInstructions(instructions1, instructions2))
}
}

}
88 changes: 88 additions & 0 deletions tests/run/list-apply-eval.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
object Test:

var counter = 0

def next =
counter += 1
counter.toString

def main(args: Array[String]): Unit =
//List.apply is subject to an optimisation in cleanup
//ensure that the arguments are evaluated in the currect order
// Rewritten to:
// val myList: List = new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), scala.collection.immutable.Nil)));
val myList = List(next, next, next)
assert(myList == List("1", "2", "3"), myList)

val mySeq = Seq(next, next, next)
assert(mySeq == Seq("4", "5", "6"), mySeq)

val emptyList = List[Int]()
assert(emptyList == Nil)

// just assert it doesn't throw CCE to List
val queue = scala.collection.mutable.Queue[String]()

// test for the cast instruction described in checkApplyAvoidsIntermediateArray
def lub(b: Boolean): List[(String, String)] =
if b then List(("foo", "bar")) else Nil

// from minimising CI failure in oslib
// again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce)
def lub2(b: Boolean): Unit =
Seq(1) ++ (if (b) Seq(2) else Nil)

// Examples of arity and nesting arity
// to find the thresholds and reproduce the behaviour of nsc
def examples(): Unit =
val max1 = List[Object]("1", "2", "3", "4", "5", "6", "7") // 7 cons w/ 7 string heads + nil
val max2 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]()) // 7 cons w/ 6 string heads + 1 nil head + nil
val max3 = List[Object]("1", "2", "3", "4", "5", List[Object]("6"))
val max4 = List[Object]("1", "2", "3", "4", List[Object]("5", "6"))

val over1 = List[Object]("1", "2", "3", "4", "5", "6", "7", "8") // wrap 8-sized array
val over2 = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]()) // wrap 8-sized array
val over3 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7")) // wrap 1-sized array with 7
val over4 = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7")) // wrap 2

val max5 =
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
)))))))) // 7 cons + 1 nil

val over5 =
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object]( List[Object]()
)))))))) // 7 cons + 1-sized array wrapping nil

val max6 =
List[Object]( // ::(
"1", "2", List[Object]( // 1, ::(2, ::(::(
"3", "4", List[Object]( // 3, ::(4, ::(::(
List[Object]() // Nil, Nil
) // ), Nil))
) // ), Nil))
) // )
// 7 cons + 4 string heads + 4 nils for nested lists

val max7 =
List[Object]( // ::(
"1", "2", List[Object]( // 1, ::(2, ::(::(
"3", "4", List[Object]( // 3, ::(4, ::(::(
"5" // 5, Nil
) // ), Nil))
) // ), Nil))
) // )
// 7 cons + 5 string heads + 3 nils for nested lists

0 comments on commit 841bbd4

Please sign in to comment.