Skip to content

Make sure symbols in annotation trees are fresh before pickling #22002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package dotty.tools
package dotc
package transform

import dotty.tools.dotc.ast.{Trees, tpd, untpd, desugar}
import dotty.tools.dotc.ast.{Trees, tpd, untpd, desugar, TreeTypeMap}
import scala.collection.mutable
import core.*
import dotty.tools.dotc.typer.Checking
Expand All @@ -16,7 +16,7 @@ import Symbols.*, NameOps.*
import ContextFunctionResults.annotateContextResults
import config.Printers.typr
import config.Feature
import util.SrcPos
import util.{SrcPos, Stats}
import reporting.*
import NameKinds.WildcardParamName
import cc.*
Expand Down Expand Up @@ -154,7 +154,21 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case _ =>
case _ =>

private def transformAnnot(annot: Tree)(using Context): Tree = {
/** Returns a copy of the given tree with all symbols fresh.
*
* Used to guarantee that no symbols are shared between trees in different
* annotations.
*/
private def copySymbols(tree: Tree)(using Context) =
Stats.trackTime("Annotations copySymbols"):
val ttm =
new TreeTypeMap:
override def withMappedSyms(syms: List[Symbol]) =
withMappedSyms(syms, mapSymbols(syms, this, true))
ttm(tree)

/** Transforms the given annotation tree. */
private def transformAnnotTree(annot: Tree)(using Context): Tree = {
val saved = inJavaAnnot
inJavaAnnot = annot.symbol.is(JavaDefined)
if (inJavaAnnot) checkValidJavaAnnotation(annot)
Expand All @@ -163,7 +177,19 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
}

private def transformAnnot(annot: Annotation)(using Context): Annotation =
annot.derivedAnnotation(transformAnnot(annot.tree))
val tree1 =
annot match
case _: BodyAnnotation => annot.tree
case _ => copySymbols(annot.tree)
annot.derivedAnnotation(transformAnnotTree(tree1))

/** Transforms all annotations in the given type. */
private def transformAnnotsIn(using Context) =
new TypeMap:
def apply(tp: Type) = tp match
case tp @ AnnotatedType(parent, annot) =>
tp.derivedAnnotatedType(mapOver(parent), transformAnnot(annot))
case _ => mapOver(tp)

private def processMemberDef(tree: Tree)(using Context): tree.type = {
val sym = tree.symbol
Expand Down Expand Up @@ -501,7 +527,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
Checking.checkRealizable(tree.tpt.tpe, tree.srcPos, "SAM type")
super.transform(tree)
case tree @ Annotated(annotated, annot) =>
cpy.Annotated(tree)(transform(annotated), transformAnnot(annot))
cpy.Annotated(tree)(transform(annotated), transformAnnotTree(annot))
case tree: AppliedTypeTree =>
if (tree.tpt.symbol == defn.andType)
Checking.checkNonCyclicInherited(tree.tpe, tree.args.tpes, EmptyScope, tree.srcPos)
Expand All @@ -524,11 +550,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
super.transform(tree)
case tree: TypeTree =>
val tpe = if tree.isInferred then CleanupRetains()(tree.tpe) else tree.tpe
tree.withType:
tpe match
case AnnotatedType(parent, annot) =>
AnnotatedType(parent, transformAnnot(annot)) // TODO: Also map annotations embedded in type?
case _ => tpe
tree.withType(transformAnnotsIn(tpe))
case Typed(Ident(nme.WILDCARD), _) =>
withMode(Mode.Pattern)(super.transform(tree))
// The added mode signals that bounds in a pattern need not
Expand Down
8 changes: 8 additions & 0 deletions tests/pos/annot-17939.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import scala.annotation.Annotation
class myRefined[T](f: T => Boolean) extends Annotation

class Box[T](val x: T)
class Box2(val x: Int)

class A(a: String @myRefined((x: Int) => Box(3).x == 3)) // crash
class A2(a2: String @myRefined((x: Int) => Box2(3).x == 3)) // works
9 changes: 9 additions & 0 deletions tests/pos/annot-19846.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package dependentAnnotation

class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation

def f(x: Int): Int @lambdaAnnot(() => x + 1) = x

@main def main =
val y: Int = 5
val z = f(y)
8 changes: 8 additions & 0 deletions tests/pos/annot-19846b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation

class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x))

@main def main =
val p = EqualPair(42, 42)
val y = p.y
println(42)
15 changes: 15 additions & 0 deletions tests/pos/annot-body.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// This test checks that symbols in `BodyAnnotation` are not copied in
// `transformAnnot` during `PostTyper`.

package json

trait Reads[A] {
def reads(a: Any): A
}

object JsMacroImpl {
inline def reads[A]: Reads[A] =
new Reads[A] { self =>
def reads(a: Any) = ???
}
}
20 changes: 20 additions & 0 deletions tests/pos/annot-i20272a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import language.experimental.captureChecking

trait Iterable[T] { self: Iterable[T]^ =>
def map[U](f: T => U): Iterable[U]^{this, f}
}

object Test {
def assertEquals[A, B](a: A, b: B): Boolean = ???

def foo[T](level: Int, lines: Iterable[T]) =
lines.map(x => x)

def bar(messages: Iterable[String]) =
foo(1, messages)

val it: Iterable[String] = ???
val msgs = bar(it)

assertEquals(msgs, msgs)
}
Loading