Skip to content

Add requiredXYZ symbols to TASTy reflect context #7903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type =
self.gadt.approximation(sym, fromBelow)

def Context_requiredPackage(self: Context)(path: String): Symbol = self.requiredPackage(path)
def Context_requiredClass(self: Context)(path: String): Symbol = self.requiredClass(path)
def Context_requiredModule(self: Context)(path: String): Symbol = self.requiredModule(path)
def Context_requiredMethod(self: Context)(path: String): Symbol = self.requiredMethod(path)

//
// REPORTING
//
Expand Down
2 changes: 1 addition & 1 deletion library/src/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ private[quoted] object Matcher {
*/
private def patternsMatches(scrutinee: Tree, pattern: Tree)(given Context, Env): (Env, Matching) = (scrutinee, pattern) match {
case (v1: Term, Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
if patternHole.symbol.owner == summon[Context].requiredModule("scala.runtime.quoted.Matcher") =>
(summon[Env], matched(v1.seal))

case (Ident("_"), Ident("_")) =>
Expand Down
12 changes: 12 additions & 0 deletions library/src/scala/tasty/reflect/CompilerInterface.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@ trait CompilerInterface {
def Context_GADT_addToConstraint(self: Context)(syms: List[Symbol]): Boolean
def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type

/** Get package symbol if package is either defined in current compilation run or present on classpath. */
def Context_requiredPackage(self: Context)(path: String): Symbol

/** Get class symbol if class is either defined in current compilation run or present on classpath. */
def Context_requiredClass(self: Context)(path: String): Symbol

/** Get module symbol if module is either defined in current compilation run or present on classpath. */
def Context_requiredModule(self: Context)(path: String): Symbol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The meta-programmer may not know what's a module. Maybe rename to requiredObject?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to keep names consistent. It will be simpler for documentation. We also have moduleClass and companionModule that use the same naming scheme.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree, for advanced meta-programming, keeping the concepts closer to the compiler will help both meta-programmers who want to get into the compiler and maintenance of the framework.


/** Get method symbol if method is either defined in current compilation run or present on classpath. Throws if the method has an overload. */
def Context_requiredMethod(self: Context)(path: String): Symbol

//
// REPORTING
//
Expand Down
12 changes: 12 additions & 0 deletions library/src/scala/tasty/reflect/ContextOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ trait ContextOps extends Core {
/** Returns the source file being compiled. The path is relative to the current working directory. */
def source: java.nio.file.Path = internal.Context_source(self)

/** Get package symbol if package is either defined in current compilation run or present on classpath. */
def requiredPackage(path: String): Symbol = internal.Context_requiredPackage(self)(path)

/** Get class symbol if class is either defined in current compilation run or present on classpath. */
def requiredClass(path: String): Symbol = internal.Context_requiredClass(self)(path)

/** Get module symbol if module is either defined in current compilation run or present on classpath. */
def requiredModule(path: String): Symbol = internal.Context_requiredModule(self)(path)

/** Get method symbol if method is either defined in current compilation run or present on classpath. Throws if the method has an overload. */
def requiredMethod(path: String): Symbol = internal.Context_requiredMethod(self)(path)

}

/** Context of the macro expansion */
Expand Down
56 changes: 15 additions & 41 deletions library/src/scala/tasty/reflect/SourceCodePrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
}

val parents1 = parents.filter {
case Apply(Select(New(tpt), _), _) => !Types.JavaLangObject.unapply(tpt.tpe)
case Apply(Select(New(tpt), _), _) => tpt.tpe.typeSymbol != ctx.requiredClass("java.lang.Object")
case TypeSelect(Select(Ident("_root_"), "scala"), "Product") => false
case TypeSelect(Select(Ident("_root_"), "scala"), "Serializable") => false
case _ => true
Expand Down Expand Up @@ -356,7 +356,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
this += "throw "
printTree(expr)

case Apply(fn, args) if fn.symbol.fullName == "scala.internal.Quoted$.exprQuote" =>
case Apply(fn, args) if fn.symbol == ctx.requiredMethod("scala.internal.Quoted.exprQuote") =>
args.head match {
case Block(stats, expr) =>
this += "'{"
Expand All @@ -371,12 +371,12 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
this += "}"
}

case TypeApply(fn, args) if fn.symbol.fullName == "scala.internal.Quoted$.typeQuote" =>
case TypeApply(fn, args) if fn.symbol == ctx.requiredMethod("scala.internal.Quoted.typeQuote") =>
this += "'["
printTypeTree(args.head)
this += "]"

case Apply(fn, arg :: Nil) if fn.symbol.fullName == "scala.internal.Quoted$.exprSplice" =>
case Apply(fn, arg :: Nil) if fn.symbol == ctx.requiredMethod("scala.internal.Quoted.exprSplice") =>
this += "${"
printTree(arg)
this += "}"
Expand Down Expand Up @@ -573,7 +573,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
def printFlatBlock(stats: List[Statement], expr: Term)(given elideThis: Option[Symbol]): Buffer = {
val (stats1, expr1) = flatBlock(stats, expr)
val stats2 = stats1.filter {
case tree: TypeDef => !tree.symbol.annots.exists(_.symbol.owner.fullName == "scala.internal.Quoted$.quoteTypeTag")
case tree: TypeDef => !tree.symbol.annots.exists(_.symbol.owner == ctx.requiredClass("scala.internal.Quoted.quoteTypeTag"))
case _ => true
}
if (stats2.isEmpty) {
Expand Down Expand Up @@ -971,7 +971,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
printTypeAndAnnots(tp)
this += " "
printAnnotation(annot)
case tpe: TypeRef if tpe.typeSymbol.fullName == "scala.runtime.Null$" || tpe.typeSymbol.fullName == "scala.runtime.Nothing$" =>
case tpe: TypeRef if tpe.typeSymbol == ctx.requiredClass("scala.runtime.Null$") || tpe.typeSymbol == ctx.requiredClass("scala.runtime.Nothing$") =>
// scala.runtime.Null$ and scala.runtime.Nothing$ are not modules, those are their actual names
printType(tpe)
case tpe: TermRef if tpe.termSymbol.isClassDef && tpe.termSymbol.name.endsWith("$") =>
Expand Down Expand Up @@ -1014,7 +1014,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
case Annotated(tpt, annot) =>
val Annotation(ref, args) = annot
ref.tpe match {
case Types.RepeatedAnnotation() =>
case tpe: TypeRef if tpe.typeSymbol == ctx.requiredClass("scala.annotation.internal.Repeated") =>
val Types.Sequence(tp) = tpt.tpe
printType(tp)
this += highlightTypeDef("*")
Expand Down Expand Up @@ -1115,7 +1115,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
tp match {
case tp: TypeLambda =>
printType(tpe.dealias)
case TypeRef(Types.ScalaPackage(), "<repeated>") =>
case tp: TypeRef if tp.typeSymbol == ctx.requiredClass("scala.<repeated>") =>
this += "_*"
case _ =>
printType(tp)
Expand Down Expand Up @@ -1228,7 +1228,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig

def printAnnotation(annot: Term)(given elideThis: Option[Symbol]): Buffer = {
val Annotation(ref, args) = annot
if (annot.symbol.maybeOwner.fullName == "scala.internal.quoted.showName") this
if (annot.symbol.maybeOwner == ctx.requiredClass("scala.internal.quoted.showName")) this
else {
this += "@"
printTypeTree(ref)
Expand All @@ -1242,12 +1242,9 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
def printDefAnnotations(definition: Definition)(given elideThis: Option[Symbol]): Buffer = {
val annots = definition.symbol.annots.filter {
case Annotation(annot, _) =>
annot.tpe match {
case TypeRef(prefix: TermRef, _) if prefix.termSymbol.fullName == "scala.annotation.internal" => false
case TypeRef(prefix: TypeRef, _) if prefix.typeSymbol.fullName == "scala.annotation.internal" => false
case TypeRef(Types.ScalaPackage(), "forceInline") => false
case _ => true
}
val sym = annot.tpe.typeSymbol
sym != ctx.requiredClass("scala.forceInline") &&
sym.maybeOwner != ctx.requiredPackage("scala.annotation.internal")
case x => throw new MatchError(x.showExtractors)
}
printAnnotations(annots)
Expand Down Expand Up @@ -1404,7 +1401,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
}

private def splicedName(sym: Symbol)(given ctx: Context): Option[String] = {
sym.annots.find(_.symbol.owner.fullName == "scala.internal.quoted.showName").flatMap {
sym.annots.find(_.symbol.owner == ctx.requiredClass("scala.internal.quoted.showName")).flatMap {
case Apply(_, Literal(Constant(c: String)) :: Nil) => Some(c)
case Apply(_, Inlined(_, _, Literal(Constant(c: String))) :: Nil) => Some(c)
case annot => None
Expand Down Expand Up @@ -1435,43 +1432,20 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
// TODO Provide some of these in scala.tasty.Reflection.scala and implement them using checks on symbols for performance
private object Types {

object JavaLangObject {
def unapply(tpe: Type)(given ctx: Context): Boolean = tpe match {
case TypeRef(prefix: TermRef, "Object") => prefix.typeSymbol.fullName == "java.lang"
case _ => false
}
}

object Sequence {
def unapply(tpe: Type)(given ctx: Context): Option[Type] = tpe match {
case AppliedType(TypeRef(prefix: TermRef, "Seq"), (tp: Type) :: Nil) if prefix.termSymbol.fullName == "scala.collection" => Some(tp)
case AppliedType(TypeRef(prefix: TypeRef, "Seq"), (tp: Type) :: Nil) if prefix.typeSymbol.fullName == "scala.collection" => Some(tp)
case AppliedType(seq, (tp: Type) :: Nil) if seq.typeSymbol == ctx.requiredClass("scala.collection.Seq") => Some(tp)
case _ => None
}
}

object RepeatedAnnotation {
def unapply(tpe: Type)(given ctx: Context): Boolean = tpe match {
case TypeRef(prefix: TermRef, "Repeated") => prefix.termSymbol.fullName == "scala.annotation.internal"
case TypeRef(prefix: TypeRef, "Repeated") => prefix.typeSymbol.fullName == "scala.annotation.internal"
case _ => false
}
}

object Repeated {
def unapply(tpe: Type)(given ctx: Context): Option[Type] = tpe match {
case AppliedType(TypeRef(ScalaPackage(), "<repeated>"), (tp: Type) :: Nil) => Some(tp)
case AppliedType(rep, (tp: Type) :: Nil) if rep.typeSymbol == ctx.requiredClass("scala.<repeated>") => Some(tp)
case _ => None
}
}

object ScalaPackage {
def unapply(tpe: TypeOrBounds)(given ctx: Context): Boolean = tpe match {
case tpe: Type => tpe.termSymbol == defn.ScalaPackage
case _ => false
}
}

}

object PackageObject {
Expand Down
13 changes: 13 additions & 0 deletions tests/run-macros/requiredSymbols.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
java
java.lang
scala
scala.collection
java.lang.Object
scala.Any
scala.Any
scala.AnyVal
scala.Unit
scala.Null
scala.None
scala.package$.Nil
scala.collection.immutable.List$.empty
27 changes: 27 additions & 0 deletions tests/run-macros/requiredSymbols/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import scala.quoted._

object Macro {
inline def foo: String = ${ fooImpl }
def fooImpl(given qctx: QuoteContext): Expr[String] = {
import qctx.tasty.{given, _}
val list = List(
rootContext.requiredPackage("java"),
rootContext.requiredPackage("java.lang"),
rootContext.requiredPackage("scala"),
rootContext.requiredPackage("scala.collection"),

rootContext.requiredClass("java.lang.Object"),
rootContext.requiredClass("scala.Any"),
rootContext.requiredClass("scala.AnyRef"),
rootContext.requiredClass("scala.AnyVal"),
rootContext.requiredClass("scala.Unit"),
rootContext.requiredClass("scala.Null"),

rootContext.requiredModule("scala.None"),
rootContext.requiredModule("scala.Nil"),

rootContext.requiredMethod("scala.List.empty"),
)
Expr(list.map(_.fullName).mkString("\n"))
}
}
6 changes: 6 additions & 0 deletions tests/run-macros/requiredSymbols/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

object Test {
def main(args: Array[String]): Unit = {
println(Macro.foo)
}
}