diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala index a2e69dc8d1eb..27a79b055071 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala @@ -59,6 +59,11 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type = self.gadt.approximation(sym, fromBelow) + def Context_requiredPackage(self: Context)(path: String): Symbol = self.requiredPackage(path) + def Context_requiredClass(self: Context)(path: String): Symbol = self.requiredClass(path) + def Context_requiredModule(self: Context)(path: String): Symbol = self.requiredModule(path) + def Context_requiredMethod(self: Context)(path: String): Symbol = self.requiredMethod(path) + // // REPORTING // diff --git a/library/src/scala/internal/quoted/Matcher.scala b/library/src/scala/internal/quoted/Matcher.scala index b4ddfc856cea..5b41e3fd8ca6 100644 --- a/library/src/scala/internal/quoted/Matcher.scala +++ b/library/src/scala/internal/quoted/Matcher.scala @@ -366,7 +366,7 @@ private[quoted] object Matcher { */ private def patternsMatches(scrutinee: Tree, pattern: Tree)(given Context, Env): (Env, Matching) = (scrutinee, pattern) match { case (v1: Term, Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil)) - if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" => + if patternHole.symbol.owner == summon[Context].requiredModule("scala.runtime.quoted.Matcher") => (summon[Env], matched(v1.seal)) case (Ident("_"), Ident("_")) => diff --git a/library/src/scala/tasty/reflect/CompilerInterface.scala b/library/src/scala/tasty/reflect/CompilerInterface.scala index 8c07834c7618..452792c77f99 100644 --- a/library/src/scala/tasty/reflect/CompilerInterface.scala +++ b/library/src/scala/tasty/reflect/CompilerInterface.scala @@ -148,6 +148,18 @@ trait CompilerInterface { def Context_GADT_addToConstraint(self: Context)(syms: List[Symbol]): Boolean def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type + /** Get package symbol if package is either defined in current compilation run or present on classpath. */ + def Context_requiredPackage(self: Context)(path: String): Symbol + + /** Get class symbol if class is either defined in current compilation run or present on classpath. */ + def Context_requiredClass(self: Context)(path: String): Symbol + + /** Get module symbol if module is either defined in current compilation run or present on classpath. */ + def Context_requiredModule(self: Context)(path: String): Symbol + + /** Get method symbol if method is either defined in current compilation run or present on classpath. Throws if the method has an overload. */ + def Context_requiredMethod(self: Context)(path: String): Symbol + // // REPORTING // diff --git a/library/src/scala/tasty/reflect/ContextOps.scala b/library/src/scala/tasty/reflect/ContextOps.scala index f09fcf353dc7..1d182bffe23f 100644 --- a/library/src/scala/tasty/reflect/ContextOps.scala +++ b/library/src/scala/tasty/reflect/ContextOps.scala @@ -10,6 +10,18 @@ trait ContextOps extends Core { /** Returns the source file being compiled. The path is relative to the current working directory. */ def source: java.nio.file.Path = internal.Context_source(self) + /** Get package symbol if package is either defined in current compilation run or present on classpath. */ + def requiredPackage(path: String): Symbol = internal.Context_requiredPackage(self)(path) + + /** Get class symbol if class is either defined in current compilation run or present on classpath. */ + def requiredClass(path: String): Symbol = internal.Context_requiredClass(self)(path) + + /** Get module symbol if module is either defined in current compilation run or present on classpath. */ + def requiredModule(path: String): Symbol = internal.Context_requiredModule(self)(path) + + /** Get method symbol if method is either defined in current compilation run or present on classpath. Throws if the method has an overload. */ + def requiredMethod(path: String): Symbol = internal.Context_requiredMethod(self)(path) + } /** Context of the macro expansion */ diff --git a/library/src/scala/tasty/reflect/SourceCodePrinter.scala b/library/src/scala/tasty/reflect/SourceCodePrinter.scala index b7ef83616a84..fc7a76f30bd6 100644 --- a/library/src/scala/tasty/reflect/SourceCodePrinter.scala +++ b/library/src/scala/tasty/reflect/SourceCodePrinter.scala @@ -150,7 +150,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig } val parents1 = parents.filter { - case Apply(Select(New(tpt), _), _) => !Types.JavaLangObject.unapply(tpt.tpe) + case Apply(Select(New(tpt), _), _) => tpt.tpe.typeSymbol != ctx.requiredClass("java.lang.Object") case TypeSelect(Select(Ident("_root_"), "scala"), "Product") => false case TypeSelect(Select(Ident("_root_"), "scala"), "Serializable") => false case _ => true @@ -356,7 +356,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig this += "throw " printTree(expr) - case Apply(fn, args) if fn.symbol.fullName == "scala.internal.Quoted$.exprQuote" => + case Apply(fn, args) if fn.symbol == ctx.requiredMethod("scala.internal.Quoted.exprQuote") => args.head match { case Block(stats, expr) => this += "'{" @@ -371,12 +371,12 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig this += "}" } - case TypeApply(fn, args) if fn.symbol.fullName == "scala.internal.Quoted$.typeQuote" => + case TypeApply(fn, args) if fn.symbol == ctx.requiredMethod("scala.internal.Quoted.typeQuote") => this += "'[" printTypeTree(args.head) this += "]" - case Apply(fn, arg :: Nil) if fn.symbol.fullName == "scala.internal.Quoted$.exprSplice" => + case Apply(fn, arg :: Nil) if fn.symbol == ctx.requiredMethod("scala.internal.Quoted.exprSplice") => this += "${" printTree(arg) this += "}" @@ -573,7 +573,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig def printFlatBlock(stats: List[Statement], expr: Term)(given elideThis: Option[Symbol]): Buffer = { val (stats1, expr1) = flatBlock(stats, expr) val stats2 = stats1.filter { - case tree: TypeDef => !tree.symbol.annots.exists(_.symbol.owner.fullName == "scala.internal.Quoted$.quoteTypeTag") + case tree: TypeDef => !tree.symbol.annots.exists(_.symbol.owner == ctx.requiredClass("scala.internal.Quoted.quoteTypeTag")) case _ => true } if (stats2.isEmpty) { @@ -971,7 +971,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig printTypeAndAnnots(tp) this += " " printAnnotation(annot) - case tpe: TypeRef if tpe.typeSymbol.fullName == "scala.runtime.Null$" || tpe.typeSymbol.fullName == "scala.runtime.Nothing$" => + case tpe: TypeRef if tpe.typeSymbol == ctx.requiredClass("scala.runtime.Null$") || tpe.typeSymbol == ctx.requiredClass("scala.runtime.Nothing$") => // scala.runtime.Null$ and scala.runtime.Nothing$ are not modules, those are their actual names printType(tpe) case tpe: TermRef if tpe.termSymbol.isClassDef && tpe.termSymbol.name.endsWith("$") => @@ -1014,7 +1014,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig case Annotated(tpt, annot) => val Annotation(ref, args) = annot ref.tpe match { - case Types.RepeatedAnnotation() => + case tpe: TypeRef if tpe.typeSymbol == ctx.requiredClass("scala.annotation.internal.Repeated") => val Types.Sequence(tp) = tpt.tpe printType(tp) this += highlightTypeDef("*") @@ -1115,7 +1115,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig tp match { case tp: TypeLambda => printType(tpe.dealias) - case TypeRef(Types.ScalaPackage(), "") => + case tp: TypeRef if tp.typeSymbol == ctx.requiredClass("scala.") => this += "_*" case _ => printType(tp) @@ -1228,7 +1228,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig def printAnnotation(annot: Term)(given elideThis: Option[Symbol]): Buffer = { val Annotation(ref, args) = annot - if (annot.symbol.maybeOwner.fullName == "scala.internal.quoted.showName") this + if (annot.symbol.maybeOwner == ctx.requiredClass("scala.internal.quoted.showName")) this else { this += "@" printTypeTree(ref) @@ -1242,12 +1242,9 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig def printDefAnnotations(definition: Definition)(given elideThis: Option[Symbol]): Buffer = { val annots = definition.symbol.annots.filter { case Annotation(annot, _) => - annot.tpe match { - case TypeRef(prefix: TermRef, _) if prefix.termSymbol.fullName == "scala.annotation.internal" => false - case TypeRef(prefix: TypeRef, _) if prefix.typeSymbol.fullName == "scala.annotation.internal" => false - case TypeRef(Types.ScalaPackage(), "forceInline") => false - case _ => true - } + val sym = annot.tpe.typeSymbol + sym != ctx.requiredClass("scala.forceInline") && + sym.maybeOwner != ctx.requiredPackage("scala.annotation.internal") case x => throw new MatchError(x.showExtractors) } printAnnotations(annots) @@ -1404,7 +1401,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig } private def splicedName(sym: Symbol)(given ctx: Context): Option[String] = { - sym.annots.find(_.symbol.owner.fullName == "scala.internal.quoted.showName").flatMap { + sym.annots.find(_.symbol.owner == ctx.requiredClass("scala.internal.quoted.showName")).flatMap { case Apply(_, Literal(Constant(c: String)) :: Nil) => Some(c) case Apply(_, Inlined(_, _, Literal(Constant(c: String))) :: Nil) => Some(c) case annot => None @@ -1435,43 +1432,20 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig // TODO Provide some of these in scala.tasty.Reflection.scala and implement them using checks on symbols for performance private object Types { - object JavaLangObject { - def unapply(tpe: Type)(given ctx: Context): Boolean = tpe match { - case TypeRef(prefix: TermRef, "Object") => prefix.typeSymbol.fullName == "java.lang" - case _ => false - } - } - object Sequence { def unapply(tpe: Type)(given ctx: Context): Option[Type] = tpe match { - case AppliedType(TypeRef(prefix: TermRef, "Seq"), (tp: Type) :: Nil) if prefix.termSymbol.fullName == "scala.collection" => Some(tp) - case AppliedType(TypeRef(prefix: TypeRef, "Seq"), (tp: Type) :: Nil) if prefix.typeSymbol.fullName == "scala.collection" => Some(tp) + case AppliedType(seq, (tp: Type) :: Nil) if seq.typeSymbol == ctx.requiredClass("scala.collection.Seq") => Some(tp) case _ => None } } - object RepeatedAnnotation { - def unapply(tpe: Type)(given ctx: Context): Boolean = tpe match { - case TypeRef(prefix: TermRef, "Repeated") => prefix.termSymbol.fullName == "scala.annotation.internal" - case TypeRef(prefix: TypeRef, "Repeated") => prefix.typeSymbol.fullName == "scala.annotation.internal" - case _ => false - } - } - object Repeated { def unapply(tpe: Type)(given ctx: Context): Option[Type] = tpe match { - case AppliedType(TypeRef(ScalaPackage(), ""), (tp: Type) :: Nil) => Some(tp) + case AppliedType(rep, (tp: Type) :: Nil) if rep.typeSymbol == ctx.requiredClass("scala.") => Some(tp) case _ => None } } - object ScalaPackage { - def unapply(tpe: TypeOrBounds)(given ctx: Context): Boolean = tpe match { - case tpe: Type => tpe.termSymbol == defn.ScalaPackage - case _ => false - } - } - } object PackageObject { diff --git a/tests/run-macros/requiredSymbols.check b/tests/run-macros/requiredSymbols.check new file mode 100644 index 000000000000..34699a22dc6f --- /dev/null +++ b/tests/run-macros/requiredSymbols.check @@ -0,0 +1,13 @@ +java +java.lang +scala +scala.collection +java.lang.Object +scala.Any +scala.Any +scala.AnyVal +scala.Unit +scala.Null +scala.None +scala.package$.Nil +scala.collection.immutable.List$.empty diff --git a/tests/run-macros/requiredSymbols/Macro_1.scala b/tests/run-macros/requiredSymbols/Macro_1.scala new file mode 100644 index 000000000000..ec92c43f2ffe --- /dev/null +++ b/tests/run-macros/requiredSymbols/Macro_1.scala @@ -0,0 +1,27 @@ +import scala.quoted._ + +object Macro { + inline def foo: String = ${ fooImpl } + def fooImpl(given qctx: QuoteContext): Expr[String] = { + import qctx.tasty.{given, _} + val list = List( + rootContext.requiredPackage("java"), + rootContext.requiredPackage("java.lang"), + rootContext.requiredPackage("scala"), + rootContext.requiredPackage("scala.collection"), + + rootContext.requiredClass("java.lang.Object"), + rootContext.requiredClass("scala.Any"), + rootContext.requiredClass("scala.AnyRef"), + rootContext.requiredClass("scala.AnyVal"), + rootContext.requiredClass("scala.Unit"), + rootContext.requiredClass("scala.Null"), + + rootContext.requiredModule("scala.None"), + rootContext.requiredModule("scala.Nil"), + + rootContext.requiredMethod("scala.List.empty"), + ) + Expr(list.map(_.fullName).mkString("\n")) + } +} diff --git a/tests/run-macros/requiredSymbols/Test_2.scala b/tests/run-macros/requiredSymbols/Test_2.scala new file mode 100644 index 000000000000..faf28953a857 --- /dev/null +++ b/tests/run-macros/requiredSymbols/Test_2.scala @@ -0,0 +1,6 @@ + +object Test { + def main(args: Array[String]): Unit = { + println(Macro.foo) + } +}