Skip to content

Commit 4ae4dc8

Browse files
authored
Refactor splice interpreter (#16280)
Extract the general purpose logic from the splice interpreter. The new interpreter class will be the basis for the macro annotation interpreter. The splice interpreter only keeps logic related with level -1 quote and type evaluation. Part of https://github.com/dotty-staging/dotty/tree/design-macro-annotations
2 parents 1fccf39 + ee51648 commit 4ae4dc8

File tree

2 files changed

+385
-333
lines changed

2 files changed

+385
-333
lines changed
Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
package dotty.tools.dotc
2+
package quoted
3+
4+
import scala.language.unsafeNulls
5+
6+
import scala.collection.mutable
7+
import scala.reflect.ClassTag
8+
9+
import java.io.{PrintWriter, StringWriter}
10+
import java.lang.reflect.{InvocationTargetException, Method => JLRMethod}
11+
12+
import dotty.tools.dotc.ast.tpd
13+
import dotty.tools.dotc.ast.TreeMapWithImplicits
14+
import dotty.tools.dotc.core.Annotations._
15+
import dotty.tools.dotc.core.Constants._
16+
import dotty.tools.dotc.core.Contexts._
17+
import dotty.tools.dotc.core.Decorators._
18+
import dotty.tools.dotc.core.Denotations.staticRef
19+
import dotty.tools.dotc.core.Flags._
20+
import dotty.tools.dotc.core.NameKinds.FlatName
21+
import dotty.tools.dotc.core.Names._
22+
import dotty.tools.dotc.core.StagingContext._
23+
import dotty.tools.dotc.core.StdNames._
24+
import dotty.tools.dotc.core.Symbols._
25+
import dotty.tools.dotc.core.TypeErasure
26+
import dotty.tools.dotc.core.Types._
27+
import dotty.tools.dotc.quoted._
28+
import dotty.tools.dotc.transform.TreeMapWithStages._
29+
import dotty.tools.dotc.typer.ImportInfo.withRootImports
30+
import dotty.tools.dotc.util.SrcPos
31+
import dotty.tools.repl.AbstractFileClassLoader
32+
33+
/** Tree interpreter for metaprogramming constructs */
34+
abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context):
35+
import Interpreter._
36+
import tpd._
37+
38+
type Env = Map[Symbol, Object]
39+
40+
/** Returns the result of interpreting the code in the tree.
41+
* Return Some of the result or None if the result type is not consistent with the expected type.
42+
* Throws a StopInterpretation if the tree could not be interpreted or a runtime exception ocurred.
43+
*/
44+
final def interpret[T](tree: Tree)(implicit ct: ClassTag[T]): Option[T] =
45+
interpretTree(tree)(Map.empty) match {
46+
case obj: T => Some(obj)
47+
case obj =>
48+
// TODO upgrade to a full type tag check or something similar
49+
report.error(s"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos)
50+
None
51+
}
52+
53+
/** Returns the result of interpreting the code in the tree.
54+
* Throws a StopInterpretation if the tree could not be interpreted or a runtime exception ocurred.
55+
*/
56+
protected def interpretTree(tree: Tree)(implicit env: Env): Object = tree match {
57+
case Literal(Constant(value)) =>
58+
interpretLiteral(value)
59+
60+
case tree: Ident if tree.symbol.is(Inline, butNot = Method) =>
61+
tree.tpe.widenTermRefExpr match
62+
case ConstantType(c) => c.value.asInstanceOf[Object]
63+
case _ => throw new StopInterpretation(em"${tree.symbol} could not be inlined", tree.srcPos)
64+
65+
// TODO disallow interpreted method calls as arguments
66+
case Call(fn, args) =>
67+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package))
68+
interpretNew(fn.symbol, args.flatten.map(interpretTree))
69+
else if (fn.symbol.is(Module))
70+
interpretModuleAccess(fn.symbol)
71+
else if (fn.symbol.is(Method) && fn.symbol.isStatic) {
72+
val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol)
73+
staticMethodCall(interpretArgs(args, fn.symbol.info))
74+
}
75+
else if fn.symbol.isStatic then
76+
assert(args.isEmpty)
77+
interpretedStaticFieldAccess(fn.symbol)
78+
else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic)
79+
if (fn.name == nme.asInstanceOfPM)
80+
interpretModuleAccess(fn.qualifier.symbol)
81+
else {
82+
val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol)
83+
staticMethodCall(interpretArgs(args, fn.symbol.info))
84+
}
85+
else if (env.contains(fn.symbol))
86+
env(fn.symbol)
87+
else if (tree.symbol.is(InlineProxy))
88+
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
89+
else
90+
unexpectedTree(tree)
91+
92+
case closureDef((ddef @ DefDef(_, ValDefs(arg :: Nil) :: Nil, _, _))) =>
93+
(obj: AnyRef) => interpretTree(ddef.rhs)(using env.updated(arg.symbol, obj))
94+
95+
// Interpret `foo(j = x, i = y)` which it is expanded to
96+
// `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
97+
case Block(stats, expr) => interpretBlock(stats, expr)
98+
case NamedArg(_, arg) => interpretTree(arg)
99+
100+
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
101+
102+
case Typed(expr, _) =>
103+
interpretTree(expr)
104+
105+
case SeqLiteral(elems, _) =>
106+
interpretVarargs(elems.map(e => interpretTree(e)))
107+
108+
case _ =>
109+
unexpectedTree(tree)
110+
}
111+
112+
private def interpretArgs(argss: List[List[Tree]], fnType: Type)(using Env): List[Object] = {
113+
def interpretArgsGroup(args: List[Tree], argTypes: List[Type]): List[Object] =
114+
assert(args.size == argTypes.size)
115+
val view =
116+
for (arg, info) <- args.lazyZip(argTypes) yield
117+
info match
118+
case _: ExprType => () => interpretTree(arg) // by-name argument
119+
case _ => interpretTree(arg) // by-value argument
120+
view.toList
121+
122+
fnType.dealias match
123+
case fnType: MethodType if fnType.isErasedMethod => interpretArgs(argss, fnType.resType)
124+
case fnType: MethodType =>
125+
val argTypes = fnType.paramInfos
126+
assert(argss.head.size == argTypes.size)
127+
interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, fnType.resType)
128+
case fnType: AppliedType if defn.isContextFunctionType(fnType) =>
129+
val argTypes :+ resType = fnType.args: @unchecked
130+
interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, resType)
131+
case fnType: PolyType => interpretArgs(argss, fnType.resType)
132+
case fnType: ExprType => interpretArgs(argss, fnType.resType)
133+
case _ =>
134+
assert(argss.isEmpty)
135+
Nil
136+
}
137+
138+
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
139+
var unexpected: Option[Object] = None
140+
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
141+
case stat: ValDef =>
142+
accEnv.updated(stat.symbol, interpretTree(stat.rhs)(accEnv))
143+
case stat =>
144+
if (unexpected.isEmpty)
145+
unexpected = Some(unexpectedTree(stat))
146+
accEnv
147+
})
148+
unexpected.getOrElse(interpretTree(expr)(newEnv))
149+
}
150+
151+
private def interpretLiteral(value: Any)(implicit env: Env): Object =
152+
value.asInstanceOf[Object]
153+
154+
private def interpretVarargs(args: List[Object])(implicit env: Env): Object =
155+
args.toSeq
156+
157+
private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = {
158+
val (inst, clazz) =
159+
try
160+
if (moduleClass.name.startsWith(str.REPL_SESSION_LINE))
161+
(null, loadReplLineClass(moduleClass))
162+
else {
163+
val inst = loadModule(moduleClass)
164+
(inst, inst.getClass)
165+
}
166+
catch
167+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
168+
if (ctx.settings.XprintSuspension.value)
169+
report.echo(i"suspension triggered by a dependency on $sym", pos)
170+
ctx.compilationUnit.suspend() // this throws a SuspendException
171+
172+
val name = fn.name.asTermName
173+
val method = getMethod(clazz, name, paramsSig(fn))
174+
(args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*), method)
175+
}
176+
177+
private def interpretedStaticFieldAccess(sym: Symbol)(implicit env: Env): Object = {
178+
val clazz = loadClass(sym.owner.fullName.toString)
179+
val field = clazz.getField(sym.name.toString)
180+
field.get(null)
181+
}
182+
183+
private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
184+
loadModule(fn.moduleClass)
185+
186+
private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
187+
val clazz = loadClass(fn.owner.fullName.toString)
188+
val constr = clazz.getConstructor(paramsSig(fn): _*)
189+
constr.newInstance(args: _*).asInstanceOf[Object]
190+
}
191+
192+
private def unexpectedTree(tree: Tree)(implicit env: Env): Object =
193+
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.srcPos)
194+
195+
private def loadModule(sym: Symbol): Object =
196+
if (sym.owner.is(Package)) {
197+
// is top level object
198+
val moduleClass = loadClass(sym.fullName.toString)
199+
moduleClass.getField(str.MODULE_INSTANCE_FIELD).get(null)
200+
}
201+
else {
202+
// nested object in an object
203+
val className = {
204+
val pack = sym.topLevelClass.owner
205+
if (pack == defn.RootPackage || pack == defn.EmptyPackageClass) sym.flatName.toString
206+
else pack.showFullName + "." + sym.flatName
207+
}
208+
val clazz = loadClass(className)
209+
clazz.getConstructor().newInstance().asInstanceOf[Object]
210+
}
211+
212+
private def loadReplLineClass(moduleClass: Symbol)(implicit env: Env): Class[?] = {
213+
val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader)
214+
lineClassloader.loadClass(moduleClass.name.firstPart.toString)
215+
}
216+
217+
private def loadClass(name: String): Class[?] =
218+
try classLoader.loadClass(name)
219+
catch {
220+
case _: ClassNotFoundException if ctx.compilationUnit.isSuspendable =>
221+
if (ctx.settings.XprintSuspension.value)
222+
report.echo(i"suspension triggered by a dependency on $name", pos)
223+
ctx.compilationUnit.suspend()
224+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
225+
if (ctx.settings.XprintSuspension.value)
226+
report.echo(i"suspension triggered by a dependency on $sym", pos)
227+
ctx.compilationUnit.suspend() // this throws a SuspendException
228+
}
229+
230+
private def getMethod(clazz: Class[?], name: Name, paramClasses: List[Class[?]]): JLRMethod =
231+
try clazz.getMethod(name.toString, paramClasses: _*)
232+
catch {
233+
case _: NoSuchMethodException =>
234+
val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)"
235+
throw new StopInterpretation(msg, pos)
236+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
237+
if (ctx.settings.XprintSuspension.value)
238+
report.echo(i"suspension triggered by a dependency on $sym", pos)
239+
ctx.compilationUnit.suspend() // this throws a SuspendException
240+
}
241+
242+
private def stopIfRuntimeException[T](thunk: => T, method: JLRMethod): T =
243+
try thunk
244+
catch {
245+
case ex: RuntimeException =>
246+
val sw = new StringWriter()
247+
sw.write("A runtime exception occurred while executing macro expansion\n")
248+
sw.write(ex.getMessage)
249+
sw.write("\n")
250+
ex.printStackTrace(new PrintWriter(sw))
251+
sw.write("\n")
252+
throw new StopInterpretation(sw.toString, pos)
253+
case ex: InvocationTargetException =>
254+
ex.getTargetException match {
255+
case ex: scala.quoted.runtime.StopMacroExpansion =>
256+
throw ex
257+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
258+
if (ctx.settings.XprintSuspension.value)
259+
report.echo(i"suspension triggered by a dependency on $sym", pos)
260+
ctx.compilationUnit.suspend() // this throws a SuspendException
261+
case targetException =>
262+
val sw = new StringWriter()
263+
sw.write("Exception occurred while executing macro expansion.\n")
264+
if (!ctx.settings.Ydebug.value) {
265+
val end = targetException.getStackTrace.lastIndexWhere { x =>
266+
x.getClassName == method.getDeclaringClass.getCanonicalName && x.getMethodName == method.getName
267+
}
268+
val shortStackTrace = targetException.getStackTrace.take(end + 1)
269+
targetException.setStackTrace(shortStackTrace)
270+
}
271+
targetException.printStackTrace(new PrintWriter(sw))
272+
sw.write("\n")
273+
throw new StopInterpretation(sw.toString, pos)
274+
}
275+
}
276+
277+
private object MissingClassDefinedInCurrentRun {
278+
def unapply(targetException: NoClassDefFoundError)(using Context): Option[Symbol] = {
279+
val className = targetException.getMessage
280+
if (className eq null) None
281+
else {
282+
val sym = staticRef(className.toTypeName).symbol
283+
if (sym.isDefinedInCurrentRun) Some(sym) else None
284+
}
285+
}
286+
}
287+
288+
/** List of classes of the parameters of the signature of `sym` */
289+
private def paramsSig(sym: Symbol): List[Class[?]] = {
290+
def paramClass(param: Type): Class[?] = {
291+
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
292+
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
293+
case _ => (tpe, depth)
294+
}
295+
def javaArraySig(tpe: Type): String = {
296+
val (elemType, depth) = arrayDepth(tpe, 0)
297+
val sym = elemType.classSymbol
298+
val suffix =
299+
if (sym == defn.BooleanClass) "Z"
300+
else if (sym == defn.ByteClass) "B"
301+
else if (sym == defn.ShortClass) "S"
302+
else if (sym == defn.IntClass) "I"
303+
else if (sym == defn.LongClass) "J"
304+
else if (sym == defn.FloatClass) "F"
305+
else if (sym == defn.DoubleClass) "D"
306+
else if (sym == defn.CharClass) "C"
307+
else "L" + javaSig(elemType) + ";"
308+
("[" * depth) + suffix
309+
}
310+
def javaSig(tpe: Type): String = tpe match {
311+
case tpe: JavaArrayType => javaArraySig(tpe)
312+
case _ =>
313+
// Take the flatten name of the class and the full package name
314+
val pack = tpe.classSymbol.topLevelClass.owner
315+
val packageName = if (pack == defn.EmptyPackageClass) "" else s"${pack.fullName}."
316+
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
317+
}
318+
319+
val sym = param.classSymbol
320+
if (sym == defn.BooleanClass) classOf[Boolean]
321+
else if (sym == defn.ByteClass) classOf[Byte]
322+
else if (sym == defn.CharClass) classOf[Char]
323+
else if (sym == defn.ShortClass) classOf[Short]
324+
else if (sym == defn.IntClass) classOf[Int]
325+
else if (sym == defn.LongClass) classOf[Long]
326+
else if (sym == defn.FloatClass) classOf[Float]
327+
else if (sym == defn.DoubleClass) classOf[Double]
328+
else java.lang.Class.forName(javaSig(param), false, classLoader)
329+
}
330+
def getExtraParams(tp: Type): List[Type] = tp.widenDealias match {
331+
case tp: AppliedType if defn.isContextFunctionType(tp) =>
332+
// Call context function type direct method
333+
tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last)
334+
case _ => Nil
335+
}
336+
val extraParams = getExtraParams(sym.info.finalResultType)
337+
val allParams = TypeErasure.erasure(sym.info) match {
338+
case meth: MethodType => meth.paramInfos ::: extraParams
339+
case _ => extraParams
340+
}
341+
allParams.map(paramClass)
342+
}
343+
end Interpreter
344+
345+
object Interpreter:
346+
/** Exception that stops interpretation if some issue is found */
347+
class StopInterpretation(val msg: String, val pos: SrcPos) extends Exception
348+
349+
object Call:
350+
import tpd._
351+
/** Matches an expression that is either a field access or an application
352+
* It retruns a TermRef containing field accessed or a method reference and the arguments passed to it.
353+
*/
354+
def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] =
355+
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
356+
357+
private object Call0 {
358+
def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = arg match {
359+
case Select(Call0(fn, args), nme.apply) if defn.isContextFunctionType(fn.tpe.widenDealias.finalResultType) =>
360+
Some((fn, args))
361+
case fn: Ident => Some((tpd.desugarIdent(fn).withSpan(fn.span), Nil))
362+
case fn: Select => Some((fn, Nil))
363+
case Apply(f @ Call0(fn, args1), args2) =>
364+
if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1))
365+
else Some((fn, args2 :: args1))
366+
case TypeApply(Call0(fn, args), _) => Some((fn, args))
367+
case _ => None
368+
}
369+
}
370+
end Call

0 commit comments

Comments
 (0)