|
| 1 | +package dotty.tools.dotc |
| 2 | +package quoted |
| 3 | + |
| 4 | +import scala.language.unsafeNulls |
| 5 | + |
| 6 | +import scala.collection.mutable |
| 7 | +import scala.reflect.ClassTag |
| 8 | + |
| 9 | +import java.io.{PrintWriter, StringWriter} |
| 10 | +import java.lang.reflect.{InvocationTargetException, Method => JLRMethod} |
| 11 | + |
| 12 | +import dotty.tools.dotc.ast.tpd |
| 13 | +import dotty.tools.dotc.ast.TreeMapWithImplicits |
| 14 | +import dotty.tools.dotc.core.Annotations._ |
| 15 | +import dotty.tools.dotc.core.Constants._ |
| 16 | +import dotty.tools.dotc.core.Contexts._ |
| 17 | +import dotty.tools.dotc.core.Decorators._ |
| 18 | +import dotty.tools.dotc.core.Denotations.staticRef |
| 19 | +import dotty.tools.dotc.core.Flags._ |
| 20 | +import dotty.tools.dotc.core.NameKinds.FlatName |
| 21 | +import dotty.tools.dotc.core.Names._ |
| 22 | +import dotty.tools.dotc.core.StagingContext._ |
| 23 | +import dotty.tools.dotc.core.StdNames._ |
| 24 | +import dotty.tools.dotc.core.Symbols._ |
| 25 | +import dotty.tools.dotc.core.TypeErasure |
| 26 | +import dotty.tools.dotc.core.Types._ |
| 27 | +import dotty.tools.dotc.quoted._ |
| 28 | +import dotty.tools.dotc.transform.TreeMapWithStages._ |
| 29 | +import dotty.tools.dotc.typer.ImportInfo.withRootImports |
| 30 | +import dotty.tools.dotc.util.SrcPos |
| 31 | +import dotty.tools.repl.AbstractFileClassLoader |
| 32 | + |
| 33 | +/** Tree interpreter for metaprogramming constructs */ |
| 34 | +abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context): |
| 35 | + import Interpreter._ |
| 36 | + import tpd._ |
| 37 | + |
| 38 | + type Env = Map[Symbol, Object] |
| 39 | + |
| 40 | + /** Returns the result of interpreting the code in the tree. |
| 41 | + * Return Some of the result or None if the result type is not consistent with the expected type. |
| 42 | + * Throws a StopInterpretation if the tree could not be interpreted or a runtime exception ocurred. |
| 43 | + */ |
| 44 | + final def interpret[T](tree: Tree)(implicit ct: ClassTag[T]): Option[T] = |
| 45 | + interpretTree(tree)(Map.empty) match { |
| 46 | + case obj: T => Some(obj) |
| 47 | + case obj => |
| 48 | + // TODO upgrade to a full type tag check or something similar |
| 49 | + report.error(s"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos) |
| 50 | + None |
| 51 | + } |
| 52 | + |
| 53 | + /** Returns the result of interpreting the code in the tree. |
| 54 | + * Throws a StopInterpretation if the tree could not be interpreted or a runtime exception ocurred. |
| 55 | + */ |
| 56 | + protected def interpretTree(tree: Tree)(implicit env: Env): Object = tree match { |
| 57 | + case Literal(Constant(value)) => |
| 58 | + interpretLiteral(value) |
| 59 | + |
| 60 | + case tree: Ident if tree.symbol.is(Inline, butNot = Method) => |
| 61 | + tree.tpe.widenTermRefExpr match |
| 62 | + case ConstantType(c) => c.value.asInstanceOf[Object] |
| 63 | + case _ => throw new StopInterpretation(em"${tree.symbol} could not be inlined", tree.srcPos) |
| 64 | + |
| 65 | + // TODO disallow interpreted method calls as arguments |
| 66 | + case Call(fn, args) => |
| 67 | + if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) |
| 68 | + interpretNew(fn.symbol, args.flatten.map(interpretTree)) |
| 69 | + else if (fn.symbol.is(Module)) |
| 70 | + interpretModuleAccess(fn.symbol) |
| 71 | + else if (fn.symbol.is(Method) && fn.symbol.isStatic) { |
| 72 | + val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol) |
| 73 | + staticMethodCall(interpretArgs(args, fn.symbol.info)) |
| 74 | + } |
| 75 | + else if fn.symbol.isStatic then |
| 76 | + assert(args.isEmpty) |
| 77 | + interpretedStaticFieldAccess(fn.symbol) |
| 78 | + else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) |
| 79 | + if (fn.name == nme.asInstanceOfPM) |
| 80 | + interpretModuleAccess(fn.qualifier.symbol) |
| 81 | + else { |
| 82 | + val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol) |
| 83 | + staticMethodCall(interpretArgs(args, fn.symbol.info)) |
| 84 | + } |
| 85 | + else if (env.contains(fn.symbol)) |
| 86 | + env(fn.symbol) |
| 87 | + else if (tree.symbol.is(InlineProxy)) |
| 88 | + interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs) |
| 89 | + else |
| 90 | + unexpectedTree(tree) |
| 91 | + |
| 92 | + case closureDef((ddef @ DefDef(_, ValDefs(arg :: Nil) :: Nil, _, _))) => |
| 93 | + (obj: AnyRef) => interpretTree(ddef.rhs)(using env.updated(arg.symbol, obj)) |
| 94 | + |
| 95 | + // Interpret `foo(j = x, i = y)` which it is expanded to |
| 96 | + // `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)` |
| 97 | + case Block(stats, expr) => interpretBlock(stats, expr) |
| 98 | + case NamedArg(_, arg) => interpretTree(arg) |
| 99 | + |
| 100 | + case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion) |
| 101 | + |
| 102 | + case Typed(expr, _) => |
| 103 | + interpretTree(expr) |
| 104 | + |
| 105 | + case SeqLiteral(elems, _) => |
| 106 | + interpretVarargs(elems.map(e => interpretTree(e))) |
| 107 | + |
| 108 | + case _ => |
| 109 | + unexpectedTree(tree) |
| 110 | + } |
| 111 | + |
| 112 | + private def interpretArgs(argss: List[List[Tree]], fnType: Type)(using Env): List[Object] = { |
| 113 | + def interpretArgsGroup(args: List[Tree], argTypes: List[Type]): List[Object] = |
| 114 | + assert(args.size == argTypes.size) |
| 115 | + val view = |
| 116 | + for (arg, info) <- args.lazyZip(argTypes) yield |
| 117 | + info match |
| 118 | + case _: ExprType => () => interpretTree(arg) // by-name argument |
| 119 | + case _ => interpretTree(arg) // by-value argument |
| 120 | + view.toList |
| 121 | + |
| 122 | + fnType.dealias match |
| 123 | + case fnType: MethodType if fnType.isErasedMethod => interpretArgs(argss, fnType.resType) |
| 124 | + case fnType: MethodType => |
| 125 | + val argTypes = fnType.paramInfos |
| 126 | + assert(argss.head.size == argTypes.size) |
| 127 | + interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, fnType.resType) |
| 128 | + case fnType: AppliedType if defn.isContextFunctionType(fnType) => |
| 129 | + val argTypes :+ resType = fnType.args: @unchecked |
| 130 | + interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, resType) |
| 131 | + case fnType: PolyType => interpretArgs(argss, fnType.resType) |
| 132 | + case fnType: ExprType => interpretArgs(argss, fnType.resType) |
| 133 | + case _ => |
| 134 | + assert(argss.isEmpty) |
| 135 | + Nil |
| 136 | + } |
| 137 | + |
| 138 | + private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = { |
| 139 | + var unexpected: Option[Object] = None |
| 140 | + val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match { |
| 141 | + case stat: ValDef => |
| 142 | + accEnv.updated(stat.symbol, interpretTree(stat.rhs)(accEnv)) |
| 143 | + case stat => |
| 144 | + if (unexpected.isEmpty) |
| 145 | + unexpected = Some(unexpectedTree(stat)) |
| 146 | + accEnv |
| 147 | + }) |
| 148 | + unexpected.getOrElse(interpretTree(expr)(newEnv)) |
| 149 | + } |
| 150 | + |
| 151 | + private def interpretLiteral(value: Any)(implicit env: Env): Object = |
| 152 | + value.asInstanceOf[Object] |
| 153 | + |
| 154 | + private def interpretVarargs(args: List[Object])(implicit env: Env): Object = |
| 155 | + args.toSeq |
| 156 | + |
| 157 | + private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = { |
| 158 | + val (inst, clazz) = |
| 159 | + try |
| 160 | + if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) |
| 161 | + (null, loadReplLineClass(moduleClass)) |
| 162 | + else { |
| 163 | + val inst = loadModule(moduleClass) |
| 164 | + (inst, inst.getClass) |
| 165 | + } |
| 166 | + catch |
| 167 | + case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => |
| 168 | + if (ctx.settings.XprintSuspension.value) |
| 169 | + report.echo(i"suspension triggered by a dependency on $sym", pos) |
| 170 | + ctx.compilationUnit.suspend() // this throws a SuspendException |
| 171 | + |
| 172 | + val name = fn.name.asTermName |
| 173 | + val method = getMethod(clazz, name, paramsSig(fn)) |
| 174 | + (args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*), method) |
| 175 | + } |
| 176 | + |
| 177 | + private def interpretedStaticFieldAccess(sym: Symbol)(implicit env: Env): Object = { |
| 178 | + val clazz = loadClass(sym.owner.fullName.toString) |
| 179 | + val field = clazz.getField(sym.name.toString) |
| 180 | + field.get(null) |
| 181 | + } |
| 182 | + |
| 183 | + private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object = |
| 184 | + loadModule(fn.moduleClass) |
| 185 | + |
| 186 | + private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = { |
| 187 | + val clazz = loadClass(fn.owner.fullName.toString) |
| 188 | + val constr = clazz.getConstructor(paramsSig(fn): _*) |
| 189 | + constr.newInstance(args: _*).asInstanceOf[Object] |
| 190 | + } |
| 191 | + |
| 192 | + private def unexpectedTree(tree: Tree)(implicit env: Env): Object = |
| 193 | + throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.srcPos) |
| 194 | + |
| 195 | + private def loadModule(sym: Symbol): Object = |
| 196 | + if (sym.owner.is(Package)) { |
| 197 | + // is top level object |
| 198 | + val moduleClass = loadClass(sym.fullName.toString) |
| 199 | + moduleClass.getField(str.MODULE_INSTANCE_FIELD).get(null) |
| 200 | + } |
| 201 | + else { |
| 202 | + // nested object in an object |
| 203 | + val className = { |
| 204 | + val pack = sym.topLevelClass.owner |
| 205 | + if (pack == defn.RootPackage || pack == defn.EmptyPackageClass) sym.flatName.toString |
| 206 | + else pack.showFullName + "." + sym.flatName |
| 207 | + } |
| 208 | + val clazz = loadClass(className) |
| 209 | + clazz.getConstructor().newInstance().asInstanceOf[Object] |
| 210 | + } |
| 211 | + |
| 212 | + private def loadReplLineClass(moduleClass: Symbol)(implicit env: Env): Class[?] = { |
| 213 | + val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader) |
| 214 | + lineClassloader.loadClass(moduleClass.name.firstPart.toString) |
| 215 | + } |
| 216 | + |
| 217 | + private def loadClass(name: String): Class[?] = |
| 218 | + try classLoader.loadClass(name) |
| 219 | + catch { |
| 220 | + case _: ClassNotFoundException if ctx.compilationUnit.isSuspendable => |
| 221 | + if (ctx.settings.XprintSuspension.value) |
| 222 | + report.echo(i"suspension triggered by a dependency on $name", pos) |
| 223 | + ctx.compilationUnit.suspend() |
| 224 | + case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => |
| 225 | + if (ctx.settings.XprintSuspension.value) |
| 226 | + report.echo(i"suspension triggered by a dependency on $sym", pos) |
| 227 | + ctx.compilationUnit.suspend() // this throws a SuspendException |
| 228 | + } |
| 229 | + |
| 230 | + private def getMethod(clazz: Class[?], name: Name, paramClasses: List[Class[?]]): JLRMethod = |
| 231 | + try clazz.getMethod(name.toString, paramClasses: _*) |
| 232 | + catch { |
| 233 | + case _: NoSuchMethodException => |
| 234 | + val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)" |
| 235 | + throw new StopInterpretation(msg, pos) |
| 236 | + case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => |
| 237 | + if (ctx.settings.XprintSuspension.value) |
| 238 | + report.echo(i"suspension triggered by a dependency on $sym", pos) |
| 239 | + ctx.compilationUnit.suspend() // this throws a SuspendException |
| 240 | + } |
| 241 | + |
| 242 | + private def stopIfRuntimeException[T](thunk: => T, method: JLRMethod): T = |
| 243 | + try thunk |
| 244 | + catch { |
| 245 | + case ex: RuntimeException => |
| 246 | + val sw = new StringWriter() |
| 247 | + sw.write("A runtime exception occurred while executing macro expansion\n") |
| 248 | + sw.write(ex.getMessage) |
| 249 | + sw.write("\n") |
| 250 | + ex.printStackTrace(new PrintWriter(sw)) |
| 251 | + sw.write("\n") |
| 252 | + throw new StopInterpretation(sw.toString, pos) |
| 253 | + case ex: InvocationTargetException => |
| 254 | + ex.getTargetException match { |
| 255 | + case ex: scala.quoted.runtime.StopMacroExpansion => |
| 256 | + throw ex |
| 257 | + case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => |
| 258 | + if (ctx.settings.XprintSuspension.value) |
| 259 | + report.echo(i"suspension triggered by a dependency on $sym", pos) |
| 260 | + ctx.compilationUnit.suspend() // this throws a SuspendException |
| 261 | + case targetException => |
| 262 | + val sw = new StringWriter() |
| 263 | + sw.write("Exception occurred while executing macro expansion.\n") |
| 264 | + if (!ctx.settings.Ydebug.value) { |
| 265 | + val end = targetException.getStackTrace.lastIndexWhere { x => |
| 266 | + x.getClassName == method.getDeclaringClass.getCanonicalName && x.getMethodName == method.getName |
| 267 | + } |
| 268 | + val shortStackTrace = targetException.getStackTrace.take(end + 1) |
| 269 | + targetException.setStackTrace(shortStackTrace) |
| 270 | + } |
| 271 | + targetException.printStackTrace(new PrintWriter(sw)) |
| 272 | + sw.write("\n") |
| 273 | + throw new StopInterpretation(sw.toString, pos) |
| 274 | + } |
| 275 | + } |
| 276 | + |
| 277 | + private object MissingClassDefinedInCurrentRun { |
| 278 | + def unapply(targetException: NoClassDefFoundError)(using Context): Option[Symbol] = { |
| 279 | + val className = targetException.getMessage |
| 280 | + if (className eq null) None |
| 281 | + else { |
| 282 | + val sym = staticRef(className.toTypeName).symbol |
| 283 | + if (sym.isDefinedInCurrentRun) Some(sym) else None |
| 284 | + } |
| 285 | + } |
| 286 | + } |
| 287 | + |
| 288 | + /** List of classes of the parameters of the signature of `sym` */ |
| 289 | + private def paramsSig(sym: Symbol): List[Class[?]] = { |
| 290 | + def paramClass(param: Type): Class[?] = { |
| 291 | + def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match { |
| 292 | + case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1) |
| 293 | + case _ => (tpe, depth) |
| 294 | + } |
| 295 | + def javaArraySig(tpe: Type): String = { |
| 296 | + val (elemType, depth) = arrayDepth(tpe, 0) |
| 297 | + val sym = elemType.classSymbol |
| 298 | + val suffix = |
| 299 | + if (sym == defn.BooleanClass) "Z" |
| 300 | + else if (sym == defn.ByteClass) "B" |
| 301 | + else if (sym == defn.ShortClass) "S" |
| 302 | + else if (sym == defn.IntClass) "I" |
| 303 | + else if (sym == defn.LongClass) "J" |
| 304 | + else if (sym == defn.FloatClass) "F" |
| 305 | + else if (sym == defn.DoubleClass) "D" |
| 306 | + else if (sym == defn.CharClass) "C" |
| 307 | + else "L" + javaSig(elemType) + ";" |
| 308 | + ("[" * depth) + suffix |
| 309 | + } |
| 310 | + def javaSig(tpe: Type): String = tpe match { |
| 311 | + case tpe: JavaArrayType => javaArraySig(tpe) |
| 312 | + case _ => |
| 313 | + // Take the flatten name of the class and the full package name |
| 314 | + val pack = tpe.classSymbol.topLevelClass.owner |
| 315 | + val packageName = if (pack == defn.EmptyPackageClass) "" else s"${pack.fullName}." |
| 316 | + packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString |
| 317 | + } |
| 318 | + |
| 319 | + val sym = param.classSymbol |
| 320 | + if (sym == defn.BooleanClass) classOf[Boolean] |
| 321 | + else if (sym == defn.ByteClass) classOf[Byte] |
| 322 | + else if (sym == defn.CharClass) classOf[Char] |
| 323 | + else if (sym == defn.ShortClass) classOf[Short] |
| 324 | + else if (sym == defn.IntClass) classOf[Int] |
| 325 | + else if (sym == defn.LongClass) classOf[Long] |
| 326 | + else if (sym == defn.FloatClass) classOf[Float] |
| 327 | + else if (sym == defn.DoubleClass) classOf[Double] |
| 328 | + else java.lang.Class.forName(javaSig(param), false, classLoader) |
| 329 | + } |
| 330 | + def getExtraParams(tp: Type): List[Type] = tp.widenDealias match { |
| 331 | + case tp: AppliedType if defn.isContextFunctionType(tp) => |
| 332 | + // Call context function type direct method |
| 333 | + tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last) |
| 334 | + case _ => Nil |
| 335 | + } |
| 336 | + val extraParams = getExtraParams(sym.info.finalResultType) |
| 337 | + val allParams = TypeErasure.erasure(sym.info) match { |
| 338 | + case meth: MethodType => meth.paramInfos ::: extraParams |
| 339 | + case _ => extraParams |
| 340 | + } |
| 341 | + allParams.map(paramClass) |
| 342 | + } |
| 343 | +end Interpreter |
| 344 | + |
| 345 | +object Interpreter: |
| 346 | + /** Exception that stops interpretation if some issue is found */ |
| 347 | + class StopInterpretation(val msg: String, val pos: SrcPos) extends Exception |
| 348 | + |
| 349 | + object Call: |
| 350 | + import tpd._ |
| 351 | + /** Matches an expression that is either a field access or an application |
| 352 | + * It retruns a TermRef containing field accessed or a method reference and the arguments passed to it. |
| 353 | + */ |
| 354 | + def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = |
| 355 | + Call0.unapply(arg).map((fn, args) => (fn, args.reverse)) |
| 356 | + |
| 357 | + private object Call0 { |
| 358 | + def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = arg match { |
| 359 | + case Select(Call0(fn, args), nme.apply) if defn.isContextFunctionType(fn.tpe.widenDealias.finalResultType) => |
| 360 | + Some((fn, args)) |
| 361 | + case fn: Ident => Some((tpd.desugarIdent(fn).withSpan(fn.span), Nil)) |
| 362 | + case fn: Select => Some((fn, Nil)) |
| 363 | + case Apply(f @ Call0(fn, args1), args2) => |
| 364 | + if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1)) |
| 365 | + else Some((fn, args2 :: args1)) |
| 366 | + case TypeApply(Call0(fn, args), _) => Some((fn, args)) |
| 367 | + case _ => None |
| 368 | + } |
| 369 | + } |
| 370 | + end Call |
0 commit comments