Skip to content

Commit da97deb

Browse files
committed
Restrict allowed trees in annotations
1 parent bed0e86 commit da97deb

22 files changed

+199
-66
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
144144
def allTermArguments(tree: Tree): List[Tree] = unsplice(tree) match {
145145
case Apply(fn, args) => allTermArguments(fn) ::: args
146146
case TypeApply(fn, args) => allTermArguments(fn)
147+
// TOOD(mbovel): is it really safe to skip all blocks here and in `allArguments`?
147148
case Block(_, expr) => allTermArguments(expr)
148149
case _ => Nil
149150
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,9 @@ class Definitions {
499499

500500
@tu lazy val DummyImplicitClass: ClassSymbol = requiredClass("scala.DummyImplicit")
501501

502+
@tu lazy val SymbolModule: Symbol = requiredModule("scala.Symbol")
503+
@tu lazy val JSSymbolModule: Symbol = requiredModule("scala.scalajs.js.Symbol")
504+
502505
@tu lazy val ScalaRuntimeModule: Symbol = requiredModule("scala.runtime.ScalaRunTime")
503506
def runtimeMethodRef(name: PreName): TermRef = ScalaRuntimeModule.requiredMethodRef(name)
504507
def ScalaRuntime_drop: Symbol = runtimeMethodRef(nme.drop).symbol

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,17 @@ object TreeChecker {
827827
|${mismatch.message}${mismatch.explanation}
828828
|tree = $tree ${tree.className}""".stripMargin
829829
})
830+
checkWellFormedType(tp1)
831+
checkWellFormedType(tp2)
832+
833+
/** Check that the type `tp` is well-formed. Currently this only means
834+
* checking that annotated types have valid annotation arguments.
835+
*/
836+
private def checkWellFormedType(tp: Type)(using Context): Unit =
837+
tp.foreachPart:
838+
case AnnotatedType(underlying, annot) => checkAnnot(annot.tree)
839+
case _ => ()
840+
830841
}
831842

832843
/** Tree checker that can be applied to a local tree. */

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,6 @@ object Checking {
916916
annot
917917
case _ => annot
918918
end checkNamedArgumentForJavaAnnotation
919-
920919
}
921920

922921
trait Checking {
@@ -1387,12 +1386,21 @@ trait Checking {
13871386
if !Inlines.inInlineMethod && !ctx.isInlineContext then
13881387
report.error(em"$what can only be used in an inline method", pos)
13891388

1389+
def checkAnnot(tree: Tree)(using Context): Tree =
1390+
tree match
1391+
case Ident(tpnme.BOUNDTYPE_ANNOT) =>
1392+
// `FirstTransform.toTypeTree` creates `Annotated` nodes whose `annot` are
1393+
// `Ident`s, not annotation instances. See `tests/pos/annot-boundtype.scala`.
1394+
tree
1395+
case _ =>
1396+
checkAnnotArgs(checkAnnotClass(tree))
1397+
13901398
/** Check that the class corresponding to this tree is either a Scala or Java annotation.
13911399
*
13921400
* @return The original tree or an error tree in case `tree` isn't a valid
13931401
* annotation or already an error tree.
13941402
*/
1395-
def checkAnnotClass(tree: Tree)(using Context): Tree =
1403+
private def checkAnnotClass(tree: Tree)(using Context): Tree =
13961404
if tree.tpe.isError then
13971405
return tree
13981406
val cls = Annotations.annotClass(tree)
@@ -1404,8 +1412,8 @@ trait Checking {
14041412
errorTree(tree, em"$cls is not a valid Scala annotation: it does not extend `scala.annotation.Annotation`")
14051413
else tree
14061414

1407-
/** Check arguments of compiler-defined annotations */
1408-
def checkAnnotArgs(tree: Tree)(using Context): tree.type =
1415+
/** Check arguments of annotations */
1416+
private def checkAnnotArgs(tree: Tree)(using Context): Tree =
14091417
val cls = Annotations.annotClass(tree)
14101418
tree match
14111419
case Apply(tycon, arg :: Nil) if cls == defn.TargetNameAnnot =>
@@ -1416,8 +1424,40 @@ trait Checking {
14161424
case _ =>
14171425
report.error(em"@${cls.name} needs a string literal as argument", arg.srcPos)
14181426
case _ =>
1427+
if cls.isRetainsLike then () // Do not check @retain annotations
1428+
else if cls == defn.ThrowsAnnot then
1429+
// Do not check @throws annotations.
1430+
// TODO(mbovel): in tests/run/t6380.scala, an annotation tree is
1431+
// `new throws[Exception](throws.<init>[Exception])`. What is this?
1432+
()
1433+
else
1434+
tpd.allTermArguments(tree).foreach(checkAnnotArg)
14191435
tree
14201436

1437+
private def checkAnnotArg(tree: Tree)(using Context): Unit =
1438+
def valid(t: Tree): Boolean =
1439+
t match
1440+
case _ if t.tpe.isEffectivelySingleton => true
1441+
case Literal(_) => true
1442+
// `_` is used as placeholder for unspecified arguments of Java
1443+
// annotations. Example: tests/run/java-ann-super-class
1444+
case Ident(nme.WILDCARD) => true
1445+
case Apply(fun, args) => valid(fun) && args.forall(valid)
1446+
case TypeApply(fun, args) => valid(fun)
1447+
case SeqLiteral(elems, _) => elems.forall(valid)
1448+
case Typed(expr, _) => valid(expr)
1449+
case NamedArg(_, arg) => valid(arg)
1450+
case Splice(_) => true
1451+
case Hole(_, _, _, _) => true
1452+
case _ => false
1453+
if !valid(tree) then
1454+
report.error(
1455+
i"""Implementation restriction: not a valid annotation argument.
1456+
|Argument: $tree
1457+
|Type: ${tree.tpe}""",
1458+
tree.srcPos
1459+
)
1460+
14211461
/** 1. Check that all case classes that extend `scala.reflect.Enum` are `enum` cases
14221462
* 2. Check that parameterised `enum` cases do not extend java.lang.Enum.
14231463
* 3. Check that only a static `enum` base class can extend java.lang.Enum.
@@ -1665,7 +1705,7 @@ trait NoChecking extends ReChecking {
16651705
override def checkImplicitConversionDefOK(sym: Symbol)(using Context): Unit = ()
16661706
override def checkImplicitConversionUseOK(tree: Tree, expected: Type)(using Context): Unit = ()
16671707
override def checkFeasibleParent(tp: Type, pos: SrcPos, where: => String = "")(using Context): Type = tp
1668-
override def checkAnnotArgs(tree: Tree)(using Context): tree.type = tree
1708+
override def checkAnnot(tree: Tree)(using Context): tree.type = tree
16691709
override def checkNoTargetNameConflict(stats: List[Tree])(using Context): Unit = ()
16701710
override def checkParentCall(call: Tree, caller: ClassSymbol)(using Context): Unit = ()
16711711
override def checkSimpleKinded(tpt: Tree)(using Context): Tree = tpt

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,7 +2780,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
27802780
}
27812781

27822782
def typedAnnotation(annot: untpd.Tree)(using Context): Tree =
2783-
checkAnnotClass(checkAnnotArgs(typed(annot)))
2783+
checkAnnot(typed(annot))
27842784

27852785
def registerNowarn(tree: Tree, mdef: untpd.Tree)(using Context): Unit =
27862786
val annot = Annotations.Annotation(tree)
@@ -3310,7 +3310,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
33103310
end typedPackageDef
33113311

33123312
def typedAnnotated(tree: untpd.Annotated, pt: Type)(using Context): Tree = {
3313-
val annot1 = checkAnnotClass(typedExpr(tree.annot))
3313+
val annot1 = checkAnnot(typedExpr(tree.annot))
33143314
val annotCls = Annotations.annotClass(annot1)
33153315
if annotCls == defn.NowarnAnnot then
33163316
registerNowarn(annot1, tree)

tests/bench/inductive-implicits.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ package shapeless {
6161
import shapeless.*
6262

6363
object Test extends App {
64+
import Selector.given
6465
val sel = Selector[L, Boolean]
6566

6667
type L =

tests/neg/annot-invalid.check

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
-- Error: tests/neg/annot-invalid.scala:4:21 ---------------------------------------------------------------------------
2+
4 | val x1: Int @annot(new Object {}) = 0 // error
3+
| ^^^^^^^^^^^^^
4+
| Implementation restriction: not a valid annotation argument.
5+
| Argument: {
6+
| final class $anon() extends Object() {}
7+
| new $anon():Object
8+
| }
9+
| Type: Object
10+
-- Error: tests/neg/annot-invalid.scala:5:21 ---------------------------------------------------------------------------
11+
5 | val x2: Int @annot({val x = 1}) = 0 // error
12+
| ^^^^^^^^^^^
13+
| Implementation restriction: not a valid annotation argument.
14+
| Argument: {
15+
| val x: Int = 1
16+
| ()
17+
| }
18+
| Type: Unit
19+
-- Error: tests/neg/annot-invalid.scala:6:21 ---------------------------------------------------------------------------
20+
6 | val x3: Int @annot((x: Int) => x) = 0 // error
21+
| ^^^^^^^^^^^^^
22+
| Implementation restriction: not a valid annotation argument.
23+
| Argument: (x: Int) => x
24+
| Type: Int => Int
25+
-- Error: tests/neg/annot-invalid.scala:8:9 ----------------------------------------------------------------------------
26+
8 | @annot(new Object {}) val y1: Int = 0 // error
27+
| ^^^^^^^^^^^^^
28+
| Implementation restriction: not a valid annotation argument.
29+
| Argument: {
30+
| final class $anon() extends Object() {}
31+
| new $anon():Object
32+
| }
33+
| Type: Object
34+
-- Error: tests/neg/annot-invalid.scala:9:9 ----------------------------------------------------------------------------
35+
9 | @annot({val x = 1}) val y2: Int = 0 // error
36+
| ^^^^^^^^^^^
37+
| Implementation restriction: not a valid annotation argument.
38+
| Argument: {
39+
| val x: Int = 1
40+
| ()
41+
| }
42+
| Type: Unit
43+
-- Error: tests/neg/annot-invalid.scala:10:9 ---------------------------------------------------------------------------
44+
10 | @annot((x: Int) => x) val y3: Int = 0 // error
45+
| ^^^^^^^^^^^^^
46+
| Implementation restriction: not a valid annotation argument.
47+
| Argument: (x: Int) => x
48+
| Type: Int => Int

tests/neg/annot-invalid.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
class annot[T](arg: T) extends scala.annotation.Annotation
2+
3+
def main =
4+
val x1: Int @annot(new Object {}) = 0 // error
5+
val x2: Int @annot({val x = 1}) = 0 // error
6+
val x3: Int @annot((x: Int) => x) = 0 // error
7+
8+
@annot(new Object {}) val y1: Int = 0 // error
9+
@annot({val x = 1}) val y2: Int = 0 // error
10+
@annot((x: Int) => x) val y3: Int = 0 // error
11+
12+
()

tests/neg/i15054.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import scala.annotation.Annotation
2+
3+
class AnAnnotation(function: Int => String) extends Annotation
4+
5+
@AnAnnotation(_.toString) // error: not a valid annotation
6+
val a = 1
7+
@AnAnnotation(_.toString.length.toString) // error: not a valid annotation
8+
val b = 2
9+
10+
def test =
11+
@AnAnnotation(_.toString) // error: not a valid annotation
12+
val a = 1
13+
@AnAnnotation(_.toString.length.toString) // error: not a valid annotation
14+
val b = 2
15+
a + b

tests/neg/i7740a.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class A(a: Any) extends annotation.StaticAnnotation
2+
@A({val x = 0}) trait B // error: not a valid annotation

0 commit comments

Comments
 (0)