Skip to content

Commit a8ea206

Browse files
Merge pull request #7903 from dotty-staging/add-required-symbols-to-tasty-reflect
Add requiredXYZ symbols to TASTy reflect context
2 parents a49363b + 811804d commit a8ea206

File tree

8 files changed

+91
-42
lines changed

8 files changed

+91
-42
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
5959
def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type =
6060
self.gadt.approximation(sym, fromBelow)
6161

62+
def Context_requiredPackage(self: Context)(path: String): Symbol = self.requiredPackage(path)
63+
def Context_requiredClass(self: Context)(path: String): Symbol = self.requiredClass(path)
64+
def Context_requiredModule(self: Context)(path: String): Symbol = self.requiredModule(path)
65+
def Context_requiredMethod(self: Context)(path: String): Symbol = self.requiredMethod(path)
66+
6267
//
6368
// REPORTING
6469
//

library/src/scala/internal/quoted/Matcher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ private[quoted] object Matcher {
366366
*/
367367
private def patternsMatches(scrutinee: Tree, pattern: Tree)(given Context, Env): (Env, Matching) = (scrutinee, pattern) match {
368368
case (v1: Term, Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
369-
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
369+
if patternHole.symbol.owner == summon[Context].requiredModule("scala.runtime.quoted.Matcher") =>
370370
(summon[Env], matched(v1.seal))
371371

372372
case (Ident("_"), Ident("_")) =>

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,18 @@ trait CompilerInterface {
148148
def Context_GADT_addToConstraint(self: Context)(syms: List[Symbol]): Boolean
149149
def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type
150150

151+
/** Get package symbol if package is either defined in current compilation run or present on classpath. */
152+
def Context_requiredPackage(self: Context)(path: String): Symbol
153+
154+
/** Get class symbol if class is either defined in current compilation run or present on classpath. */
155+
def Context_requiredClass(self: Context)(path: String): Symbol
156+
157+
/** Get module symbol if module is either defined in current compilation run or present on classpath. */
158+
def Context_requiredModule(self: Context)(path: String): Symbol
159+
160+
/** Get method symbol if method is either defined in current compilation run or present on classpath. Throws if the method has an overload. */
161+
def Context_requiredMethod(self: Context)(path: String): Symbol
162+
151163
//
152164
// REPORTING
153165
//

library/src/scala/tasty/reflect/ContextOps.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ trait ContextOps extends Core {
1010
/** Returns the source file being compiled. The path is relative to the current working directory. */
1111
def source: java.nio.file.Path = internal.Context_source(self)
1212

13+
/** Get package symbol if package is either defined in current compilation run or present on classpath. */
14+
def requiredPackage(path: String): Symbol = internal.Context_requiredPackage(self)(path)
15+
16+
/** Get class symbol if class is either defined in current compilation run or present on classpath. */
17+
def requiredClass(path: String): Symbol = internal.Context_requiredClass(self)(path)
18+
19+
/** Get module symbol if module is either defined in current compilation run or present on classpath. */
20+
def requiredModule(path: String): Symbol = internal.Context_requiredModule(self)(path)
21+
22+
/** Get method symbol if method is either defined in current compilation run or present on classpath. Throws if the method has an overload. */
23+
def requiredMethod(path: String): Symbol = internal.Context_requiredMethod(self)(path)
24+
1325
}
1426

1527
/** Context of the macro expansion */

library/src/scala/tasty/reflect/SourceCodePrinter.scala

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
150150
}
151151

152152
val parents1 = parents.filter {
153-
case Apply(Select(New(tpt), _), _) => !Types.JavaLangObject.unapply(tpt.tpe)
153+
case Apply(Select(New(tpt), _), _) => tpt.tpe.typeSymbol != ctx.requiredClass("java.lang.Object")
154154
case TypeSelect(Select(Ident("_root_"), "scala"), "Product") => false
155155
case TypeSelect(Select(Ident("_root_"), "scala"), "Serializable") => false
156156
case _ => true
@@ -356,7 +356,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
356356
this += "throw "
357357
printTree(expr)
358358

359-
case Apply(fn, args) if fn.symbol.fullName == "scala.internal.Quoted$.exprQuote" =>
359+
case Apply(fn, args) if fn.symbol == ctx.requiredMethod("scala.internal.Quoted.exprQuote") =>
360360
args.head match {
361361
case Block(stats, expr) =>
362362
this += "'{"
@@ -371,12 +371,12 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
371371
this += "}"
372372
}
373373

374-
case TypeApply(fn, args) if fn.symbol.fullName == "scala.internal.Quoted$.typeQuote" =>
374+
case TypeApply(fn, args) if fn.symbol == ctx.requiredMethod("scala.internal.Quoted.typeQuote") =>
375375
this += "'["
376376
printTypeTree(args.head)
377377
this += "]"
378378

379-
case Apply(fn, arg :: Nil) if fn.symbol.fullName == "scala.internal.Quoted$.exprSplice" =>
379+
case Apply(fn, arg :: Nil) if fn.symbol == ctx.requiredMethod("scala.internal.Quoted.exprSplice") =>
380380
this += "${"
381381
printTree(arg)
382382
this += "}"
@@ -573,7 +573,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
573573
def printFlatBlock(stats: List[Statement], expr: Term)(given elideThis: Option[Symbol]): Buffer = {
574574
val (stats1, expr1) = flatBlock(stats, expr)
575575
val stats2 = stats1.filter {
576-
case tree: TypeDef => !tree.symbol.annots.exists(_.symbol.owner.fullName == "scala.internal.Quoted$.quoteTypeTag")
576+
case tree: TypeDef => !tree.symbol.annots.exists(_.symbol.owner == ctx.requiredClass("scala.internal.Quoted.quoteTypeTag"))
577577
case _ => true
578578
}
579579
if (stats2.isEmpty) {
@@ -971,7 +971,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
971971
printTypeAndAnnots(tp)
972972
this += " "
973973
printAnnotation(annot)
974-
case tpe: TypeRef if tpe.typeSymbol.fullName == "scala.runtime.Null$" || tpe.typeSymbol.fullName == "scala.runtime.Nothing$" =>
974+
case tpe: TypeRef if tpe.typeSymbol == ctx.requiredClass("scala.runtime.Null$") || tpe.typeSymbol == ctx.requiredClass("scala.runtime.Nothing$") =>
975975
// scala.runtime.Null$ and scala.runtime.Nothing$ are not modules, those are their actual names
976976
printType(tpe)
977977
case tpe: TermRef if tpe.termSymbol.isClassDef && tpe.termSymbol.name.endsWith("$") =>
@@ -1014,7 +1014,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
10141014
case Annotated(tpt, annot) =>
10151015
val Annotation(ref, args) = annot
10161016
ref.tpe match {
1017-
case Types.RepeatedAnnotation() =>
1017+
case tpe: TypeRef if tpe.typeSymbol == ctx.requiredClass("scala.annotation.internal.Repeated") =>
10181018
val Types.Sequence(tp) = tpt.tpe
10191019
printType(tp)
10201020
this += highlightTypeDef("*")
@@ -1115,7 +1115,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
11151115
tp match {
11161116
case tp: TypeLambda =>
11171117
printType(tpe.dealias)
1118-
case TypeRef(Types.ScalaPackage(), "<repeated>") =>
1118+
case tp: TypeRef if tp.typeSymbol == ctx.requiredClass("scala.<repeated>") =>
11191119
this += "_*"
11201120
case _ =>
11211121
printType(tp)
@@ -1228,7 +1228,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
12281228

12291229
def printAnnotation(annot: Term)(given elideThis: Option[Symbol]): Buffer = {
12301230
val Annotation(ref, args) = annot
1231-
if (annot.symbol.maybeOwner.fullName == "scala.internal.quoted.showName") this
1231+
if (annot.symbol.maybeOwner == ctx.requiredClass("scala.internal.quoted.showName")) this
12321232
else {
12331233
this += "@"
12341234
printTypeTree(ref)
@@ -1242,12 +1242,9 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
12421242
def printDefAnnotations(definition: Definition)(given elideThis: Option[Symbol]): Buffer = {
12431243
val annots = definition.symbol.annots.filter {
12441244
case Annotation(annot, _) =>
1245-
annot.tpe match {
1246-
case TypeRef(prefix: TermRef, _) if prefix.termSymbol.fullName == "scala.annotation.internal" => false
1247-
case TypeRef(prefix: TypeRef, _) if prefix.typeSymbol.fullName == "scala.annotation.internal" => false
1248-
case TypeRef(Types.ScalaPackage(), "forceInline") => false
1249-
case _ => true
1250-
}
1245+
val sym = annot.tpe.typeSymbol
1246+
sym != ctx.requiredClass("scala.forceInline") &&
1247+
sym.maybeOwner != ctx.requiredPackage("scala.annotation.internal")
12511248
case x => throw new MatchError(x.showExtractors)
12521249
}
12531250
printAnnotations(annots)
@@ -1404,7 +1401,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
14041401
}
14051402

14061403
private def splicedName(sym: Symbol)(given ctx: Context): Option[String] = {
1407-
sym.annots.find(_.symbol.owner.fullName == "scala.internal.quoted.showName").flatMap {
1404+
sym.annots.find(_.symbol.owner == ctx.requiredClass("scala.internal.quoted.showName")).flatMap {
14081405
case Apply(_, Literal(Constant(c: String)) :: Nil) => Some(c)
14091406
case Apply(_, Inlined(_, _, Literal(Constant(c: String))) :: Nil) => Some(c)
14101407
case annot => None
@@ -1435,43 +1432,20 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
14351432
// TODO Provide some of these in scala.tasty.Reflection.scala and implement them using checks on symbols for performance
14361433
private object Types {
14371434

1438-
object JavaLangObject {
1439-
def unapply(tpe: Type)(given ctx: Context): Boolean = tpe match {
1440-
case TypeRef(prefix: TermRef, "Object") => prefix.typeSymbol.fullName == "java.lang"
1441-
case _ => false
1442-
}
1443-
}
1444-
14451435
object Sequence {
14461436
def unapply(tpe: Type)(given ctx: Context): Option[Type] = tpe match {
1447-
case AppliedType(TypeRef(prefix: TermRef, "Seq"), (tp: Type) :: Nil) if prefix.termSymbol.fullName == "scala.collection" => Some(tp)
1448-
case AppliedType(TypeRef(prefix: TypeRef, "Seq"), (tp: Type) :: Nil) if prefix.typeSymbol.fullName == "scala.collection" => Some(tp)
1437+
case AppliedType(seq, (tp: Type) :: Nil) if seq.typeSymbol == ctx.requiredClass("scala.collection.Seq") => Some(tp)
14491438
case _ => None
14501439
}
14511440
}
14521441

1453-
object RepeatedAnnotation {
1454-
def unapply(tpe: Type)(given ctx: Context): Boolean = tpe match {
1455-
case TypeRef(prefix: TermRef, "Repeated") => prefix.termSymbol.fullName == "scala.annotation.internal"
1456-
case TypeRef(prefix: TypeRef, "Repeated") => prefix.typeSymbol.fullName == "scala.annotation.internal"
1457-
case _ => false
1458-
}
1459-
}
1460-
14611442
object Repeated {
14621443
def unapply(tpe: Type)(given ctx: Context): Option[Type] = tpe match {
1463-
case AppliedType(TypeRef(ScalaPackage(), "<repeated>"), (tp: Type) :: Nil) => Some(tp)
1444+
case AppliedType(rep, (tp: Type) :: Nil) if rep.typeSymbol == ctx.requiredClass("scala.<repeated>") => Some(tp)
14641445
case _ => None
14651446
}
14661447
}
14671448

1468-
object ScalaPackage {
1469-
def unapply(tpe: TypeOrBounds)(given ctx: Context): Boolean = tpe match {
1470-
case tpe: Type => tpe.termSymbol == defn.ScalaPackage
1471-
case _ => false
1472-
}
1473-
}
1474-
14751449
}
14761450

14771451
object PackageObject {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
java
2+
java.lang
3+
scala
4+
scala.collection
5+
java.lang.Object
6+
scala.Any
7+
scala.Any
8+
scala.AnyVal
9+
scala.Unit
10+
scala.Null
11+
scala.None
12+
scala.package$.Nil
13+
scala.collection.immutable.List$.empty
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import scala.quoted._
2+
3+
object Macro {
4+
inline def foo: String = ${ fooImpl }
5+
def fooImpl(given qctx: QuoteContext): Expr[String] = {
6+
import qctx.tasty.{given, _}
7+
val list = List(
8+
rootContext.requiredPackage("java"),
9+
rootContext.requiredPackage("java.lang"),
10+
rootContext.requiredPackage("scala"),
11+
rootContext.requiredPackage("scala.collection"),
12+
13+
rootContext.requiredClass("java.lang.Object"),
14+
rootContext.requiredClass("scala.Any"),
15+
rootContext.requiredClass("scala.AnyRef"),
16+
rootContext.requiredClass("scala.AnyVal"),
17+
rootContext.requiredClass("scala.Unit"),
18+
rootContext.requiredClass("scala.Null"),
19+
20+
rootContext.requiredModule("scala.None"),
21+
rootContext.requiredModule("scala.Nil"),
22+
23+
rootContext.requiredMethod("scala.List.empty"),
24+
)
25+
Expr(list.map(_.fullName).mkString("\n"))
26+
}
27+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
println(Macro.foo)
5+
}
6+
}

0 commit comments

Comments
 (0)