From 8ef9cde33fc7eb98214b538bd0e9a3ff6226d6bd Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Tue, 14 Nov 2023 15:25:07 +0100 Subject: [PATCH] Patch `underlyingArgument` to avoid mapping into modules Fixes #18911 [Cherry-picked 9bfeb1c43e7f54ea36b53183726d6b83489d355a] --- compiler/src/dotty/tools/dotc/ast/tpd.scala | 2 +- tests/pos-macros/i18911/Macros_1.scala | 91 +++++++++++++++++++++ tests/pos-macros/i18911/Test_2.scala | 5 ++ 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 tests/pos-macros/i18911/Macros_1.scala create mode 100644 tests/pos-macros/i18911/Test_2.scala diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index d207e133f958..a1066a6131fc 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1253,7 +1253,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { */ private class MapToUnderlying extends TreeMap { override def transform(tree: Tree)(using Context): Tree = tree match { - case tree: Ident if isBinding(tree.symbol) && skipLocal(tree.symbol) => + case tree: Ident if isBinding(tree.symbol) && skipLocal(tree.symbol) && !tree.symbol.is(Module) => tree.symbol.defTree match { case defTree: ValOrDefDef => val rhs = defTree.rhs diff --git a/tests/pos-macros/i18911/Macros_1.scala b/tests/pos-macros/i18911/Macros_1.scala new file mode 100644 index 000000000000..677610fd9536 --- /dev/null +++ b/tests/pos-macros/i18911/Macros_1.scala @@ -0,0 +1,91 @@ +import scala.quoted._ +import scala.compiletime.testing.{typeChecks, typeCheckErrors} + +trait Assertion +trait Bool { + def value: Boolean +} +class SimpleMacroBool(expression: Boolean) extends Bool { + override def value: Boolean = expression +} +class BinaryMacroBool(left: Any, operator: String, right: Any, expression: Boolean) extends Bool { + override def value: Boolean = expression +} +object Bool { + def simpleMacroBool(expression: Boolean): Bool = new SimpleMacroBool(expression) + def binaryMacroBool(left: Any, operator: String, right: Any, expression: Boolean): Bool = + new BinaryMacroBool(left, operator, right, expression) + def binaryMacroBool(left: Any, operator: String, right: Any, bool: Bool): Bool = + new BinaryMacroBool(left, operator, right, bool.value) +} + +object Assertions { + inline def assert(inline condition: Boolean): Assertion = + ${ AssertionsMacro.assert('{ condition }) } +} + +object AssertionsMacro { + def assert(condition: Expr[Boolean])(using Quotes): Expr[Assertion] = + transform(condition) + + def transform( + condition: Expr[Boolean] + )(using Quotes): Expr[Assertion] = { + val bool = BooleanMacro.parse(condition) + '{ + new Assertion { + val condition = $bool + } + } + } +} + +object BooleanMacro { + private val supportedBinaryOperations = + Set("!=", "==") + + def parse(condition: Expr[Boolean])(using Quotes): Expr[Bool] = { + import quotes.reflect._ + import quotes.reflect.ValDef.let + import util._ + + def exprStr: String = condition.show + def defaultCase = '{ Bool.simpleMacroBool($condition) } + + def isByNameMethodType(tp: TypeRepr): Boolean = tp.widen match { + case MethodType(_, ByNameType(_) :: Nil, _) => true + case _ => false + } + + condition.asTerm.underlyingArgument match { // WARNING: unsound use of `underlyingArgument` + case Apply(sel @ Select(lhs, op), rhs :: Nil) => + def binaryDefault = + if (isByNameMethodType(sel.tpe)) defaultCase + else if (supportedBinaryOperations.contains(op)) { + let(Symbol.spliceOwner, lhs) { left => + let(Symbol.spliceOwner, rhs) { right => + val app = left.select(sel.symbol).appliedTo(right) + let(Symbol.spliceOwner, app) { result => + val l = left.asExpr + val r = right.asExpr + val b = result.asExprOf[Boolean] + val code = '{ Bool.binaryMacroBool($l, ${ Expr(op) }, $r, $b) } + code.asTerm + } + } + }.asExprOf[Bool] + } else defaultCase + + op match { + case "==" => binaryDefault + case _ => binaryDefault + } + + case Literal(_) => + '{ Bool.simpleMacroBool($condition) } + + case _ => + defaultCase + } + } +} diff --git a/tests/pos-macros/i18911/Test_2.scala b/tests/pos-macros/i18911/Test_2.scala new file mode 100644 index 000000000000..5c253d5e8e1b --- /dev/null +++ b/tests/pos-macros/i18911/Test_2.scala @@ -0,0 +1,5 @@ +@main def Test = { + case class Document() + val expected: Document = ??? + Assertions.assert( expected == Document()) // error +}