Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes for AnnotatedTypes mapping #19957

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package dotty.tools.benchmarks

import org.openjdk.jmh.annotations.{Benchmark, BenchmarkMode, Fork, Level, Measurement, Mode as JMHMode, Param, Scope, Setup, State, Warmup}
import java.util.concurrent.TimeUnit.SECONDS

import dotty.tools.dotc.{Driver, Run, Compiler}
import dotty.tools.dotc.ast.{tpd, TreeTypeMap}, tpd.{Apply, Block, Tree, TreeAccumulator, TypeApply}
import dotty.tools.dotc.core.Annotations.{Annotation, ConcreteAnnotation, EmptyAnnotation}
import dotty.tools.dotc.core.Contexts.{ContextBase, Context, ctx, withMode}
import dotty.tools.dotc.core.Mode
import dotty.tools.dotc.core.Phases.Phase
import dotty.tools.dotc.core.Symbols.{defn, mapSymbols, Symbol}
import dotty.tools.dotc.core.Types.{AnnotatedType, NoType, SkolemType, TermRef, Type, TypeMap}
import dotty.tools.dotc.parsing.Parser
import dotty.tools.dotc.typer.TyperPhase

/** Measures the performance of mapping over annotated types.
*
* Run with: scala3-bench-micro / Jmh / run AnnotationsMappingBenchmark
*/
@Fork(value = 4)
@Warmup(iterations = 4, time = 1, timeUnit = SECONDS)
@Measurement(iterations = 4, time = 1, timeUnit = SECONDS)
@BenchmarkMode(Array(JMHMode.Throughput))
@State(Scope.Thread)
class AnnotationsMappingBenchmark:
var tp: Type = null
var specialIntTp: Type = null
var context: Context = null
var typeFunction: Context ?=> Type => Type = null
var typeMap: TypeMap = null

@Param(Array("v1", "v2", "v3", "v4"))
var valName: String = null

@Param(Array("id", "mapInts"))
var typeFunctionName: String = null

@Setup(Level.Iteration)
def setup(): Unit =
val testPhase =
new Phase:
final override def phaseName = "testPhase"
final override def run(using ctx: Context): Unit =
val pkg = ctx.compilationUnit.tpdTree.symbol
tp = pkg.requiredClass("Test").requiredValueRef(valName).underlying
specialIntTp = pkg.requiredClass("Test").requiredType("SpecialInt").typeRef
context = ctx

val compiler =
new Compiler:
private final val baseCompiler = new Compiler()
final override def phases = List(List(Parser()), List(TyperPhase()), List(testPhase))

val driver =
new Driver:
final override def newCompiler(using Context): Compiler = compiler

driver.process(Array("-classpath", System.getProperty("BENCH_CLASS_PATH"), "tests/someAnnotatedTypes.scala"))

typeFunction =
typeFunctionName match
case "id" => tp => tp
case "mapInts" => tp => (if tp frozen_=:= defn.IntType then specialIntTp else tp)
case _ => throw new IllegalArgumentException(s"Unknown type function: $typeFunctionName")

typeMap =
new TypeMap(using context):
final override def apply(tp: Type): Type = typeFunction(mapOver(tp))

@Benchmark def applyTypeMap() = typeMap.apply(tp)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -16,12 +17,12 @@ import java.util.concurrent.{Executors, ExecutorService}
class ContendedInitialization {

@Param(Array("2000000", "5000000"))
var size: Int = _
var size: Int = uninitialized

@Param(Array("2", "4", "8"))
var nThreads: Int = _
var nThreads: Int = uninitialized

var executor: ExecutorService = _
var executor: ExecutorService = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccess {

var holder: LazyHolder = _
var holder: LazyHolder = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyAnyHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessAny {

var holder: LazyAnyHolder = _
var holder: LazyAnyHolder = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyGenericHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessGeneric {

var holder: LazyGenericHolder[String] = _
var holder: LazyGenericHolder[String] = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations.*
import org.openjdk.jmh.infra.Blackhole
import LazyVals.LazyIntHolder
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessInt {

var holder: LazyIntHolder = _
var holder: LazyIntHolder = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessMultiple {

var holders: Array[LazyHolder] = _
var holders: Array[LazyHolder] = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyStringHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessString {

var holder: LazyStringHolder = _
var holder: LazyStringHolder = uninitialized

@Setup
def prepare: Unit = {
Expand Down
28 changes: 28 additions & 0 deletions bench-micro/tests/someAnnotatedTypes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
class Test:
class FlagAnnot extends annotation.StaticAnnotation
class StringAnnot(val s: String) extends annotation.StaticAnnotation
class LambdaAnnot(val f: Int => Boolean) extends annotation.StaticAnnotation

type SpecialInt <: Int

val v1: Int @FlagAnnot = 42

val v2: Int @StringAnnot("hello") = 42

val v3: Int @LambdaAnnot(it => it == 42) = 42

val v4: Int @LambdaAnnot(it => {
def g(x: Int, y: Int) = x - y + 5
g(it, 7) * 2 == 80
}) = 42

/*val v5: Int @LambdaAnnot(it => {
class Foo(x: Int):
def xPlus10 = x + 10
def xPlus20 = x + 20
def xPlus(y: Int) = x + y
val foo = Foo(it)
foo.xPlus10 - foo.xPlus20 + foo.xPlus(30) == 62
}) = 42*/

def main(args: Array[String]): Unit = ???
10 changes: 9 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
loop(tree, Nil)

/** All term arguments of an application in a single flattened list */
def allTermArguments(tree: Tree): List[Tree] = unsplice(tree) match {
case Apply(fn, args) => allArguments(fn) ::: args
case TypeApply(fn, args) => allArguments(fn)
case Block(_, expr) => allArguments(expr)
case _ => Nil
}

/** All type and term arguments of an application in a single flattened list */
mbovel marked this conversation as resolved.
Show resolved Hide resolved
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
case Apply(fn, args) => allArguments(fn) ::: args
case TypeApply(fn, _) => allArguments(fn)
case TypeApply(fn, args) => allArguments(fn) ::: args
case Block(_, expr) => allArguments(expr)
case _ => Nil
}
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ object Annotations {
def derivedAnnotation(tree: Tree)(using Context): Annotation =
if (tree eq this.tree) this else Annotation(tree)

/** All arguments to this annotation in a single flat list */
def arguments(using Context): List[Tree] = tpd.allArguments(tree)
/** All term arguments of this annotation in a single flat list */
def arguments(using Context): List[Tree] = tpd.allTermArguments(tree)

def argument(i: Int)(using Context): Option[Tree] = {
val args = arguments
Expand All @@ -54,15 +54,18 @@ object Annotations {
* type, since ranges cannot be types of trees.
*/
def mapWith(tm: TypeMap)(using Context) =
val args = arguments
val args = tpd.allArguments(tree)
if args.isEmpty then this
else
// Checks if `tm` would result in any change by applying it to types
// inside the annotations' arguments and checking if the resulting types
// are different.
val findDiff = new TreeAccumulator[Type]:
def apply(x: Type, tree: Tree)(using Context): Type =
if tm.isRange(x) then x
else
val tp1 = tm(tree.tpe)
foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree)
foldOver(if !tp1.exists || (tp1 frozen_=:= tree.tpe) then x else tp1, tree)
val diff = findDiff(NoType, args)
if tm.isRange(diff) then EmptyAnnotation
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object PositionPickler:
pickler: TastyPickler,
addrOfTree: TreeToAddr,
treeAnnots: untpd.MemberDef => List[tpd.Tree],
typeAnnots: List[tpd.Tree],
relativePathReference: String,
source: SourceFile,
roots: List[Tree],
Expand Down Expand Up @@ -136,6 +137,9 @@ object PositionPickler:
}
for (root <- roots)
traverse(root, NoSource)

for annotTree <- typeAnnots do
traverse(annotTree, NoSource)
end picklePositions
end PositionPickler

7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
*/
private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]()

/** A set of annotation trees appearing in annotated types.
*/
private val annotatedTypeTrees = mutable.ListBuffer[Tree]()

/** A map from member definitions to their doc comments, so that later
* parallel comment pickling does not need to access symbols of trees (which
* would involve accessing symbols of named types and possibly changing phases
Expand All @@ -57,6 +61,8 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
val ts = annotTrees.lookup(tree)
if ts == null then Nil else ts.toList

def typeAnnots: List[Tree] = annotatedTypeTrees.toList

def docString(tree: untpd.MemberDef): Option[Comment] =
Option(docStrings.lookup(tree))

Expand Down Expand Up @@ -278,6 +284,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
case tpe: AnnotatedType =>
writeByte(ANNOTATEDtype)
withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) }
annotatedTypeTrees += tpe.annot.tree
case tpe: AndType =>
writeByte(ANDtype)
withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) }
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ object PickledQuotes {
if tree.span.exists then
val positionWarnings = new mutable.ListBuffer[Message]()
val reference = ctx.settings.sourceroot.value
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
ctx.compilationUnit.source, tree :: Nil, positionWarnings)
positionWarnings.foreach(report.warning(_))

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/Pickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class Pickler extends Phase {
if tree.span.exists then
val reference = ctx.settings.sourceroot.value
PositionPickler.picklePositions(
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
unit.source, tree :: Nil, positionWarnings,
scratch.positionBuffer, scratch.pickledIndices)

Expand Down
6 changes: 4 additions & 2 deletions compiler/test/dotty/tools/dotc/printing/PrintingTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import scala.language.unsafeNulls

import vulpix.FileDiff
import vulpix.TestConfiguration
import vulpix.TestConfiguration
import vulpix.ParallelTesting
import reporting.TestReporter

import java.io._
Expand All @@ -25,7 +25,9 @@ import java.io.File
class PrintingTest {

def options(phase: String, flags: List[String]) =
List(s"-Xprint:$phase", "-color:never", "-nowarn", "-classpath", TestConfiguration.basicClasspath) ::: flags
val outDir = ParallelTesting.defaultOutputDir + "printing" + File.pathSeparator
File(outDir).mkdirs()
List(s"-Xprint:$phase", "-color:never", "-nowarn", "-d", outDir, "-classpath", TestConfiguration.basicClasspath) ::: flags

private def compileFile(path: JPath, phase: String): Boolean = {
val baseFilePath = path.toString.stripSuffix(".scala")
Expand Down
10 changes: 10 additions & 0 deletions tests/pos/annot-17939b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import scala.annotation.Annotation
class myRefined(f: ? => Boolean) extends Annotation

def test(axes: Int) = true

trait Tensor:
def mean(axes: Int): Int @myRefined(_ => test(axes))

class TensorImpl() extends Tensor:
def mean(axes: Int) = ???
9 changes: 9 additions & 0 deletions tests/pos/annot-18064.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//> using options "-Xprint:typer"

class myAnnot[T]() extends annotation.Annotation

trait Tensor[T]:
def add: Tensor[T] @myAnnot[T]()

class TensorImpl[A]() extends Tensor[A]:
def add /* : Tensor[A] @myAnnot[A] */ = this
10 changes: 10 additions & 0 deletions tests/pos/annot-5789.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class Annot[T] extends scala.annotation.Annotation

class D[T](val f: Int@Annot[T])

object A{
def main(a:Array[String]) = {
val c = new D[Int](1)
c.f
}
}
Loading
Loading