Skip to content

Commit

Permalink
Beta-reduce directly applied PolymorphicFunction (#16623)
Browse files Browse the repository at this point in the history
Beta-reduce directly applied PolymorphicFunction such as

```scala
([Z] => (arg: Z) => { def a: Z = arg; a }).apply[Int](2)
```
into
```scala
type Z = Int
val arg = 2
def a: Z = arg
a
```

Apply this beta reduction in the `BetaReduce` phase and
`Expr.betaReduce`. Also, refactor the beta-reduce logic to avoid code
duplication.

Fixes #15968
  • Loading branch information
odersky authored Feb 12, 2023
2 parents a2c89fb + db2d3eb commit c5c5aa6
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 91 deletions.
40 changes: 0 additions & 40 deletions compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName}
import config.Printers.inlining
import util.SimpleIdentityMap

import dotty.tools.dotc.transform.BetaReduce

import collection.mutable

/** A utility class offering methods for rewriting inlined code */
Expand Down Expand Up @@ -150,44 +148,6 @@ class InlineReducer(inliner: Inliner)(using Context):
binding1.withSpan(call.span)
}

/** Rewrite an application
*
* ((x1, ..., xn) => b)(e1, ..., en)
*
* to
*
* val/def x1 = e1; ...; val/def xn = en; b
*
* where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix
* refs among the ei's directly without creating an intermediate binding.
*
* This variant of beta-reduction preserves the integrity of `Inlined` tree nodes.
*/
def betaReduce(tree: Tree)(using Context): Tree = tree match {
case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) =>
val bindingsBuf = new mutable.ListBuffer[ValDef]
def recur(cl: Tree): Option[Tree] = cl match
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
ddef.tpe.widen match
case mt: MethodType if ddef.paramss.head.length == args.length =>
Some(BetaReduce.reduceApplication(ddef, args, bindingsBuf))
case _ => None
case Block(stats, expr) if stats.forall(isPureBinding) =>
recur(expr).map(cpy.Block(cl)(stats, _))
case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) =>
recur(expr).map(cpy.Inlined(cl)(call, bindings, _))
case Typed(expr, tpt) =>
recur(expr)
case _ => None
recur(cl) match
case Some(reduced) =>
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
case None =>
tree
case _ =>
tree
}

/** The result type of reducing a match. It consists optionally of a list of bindings
* for the pattern-bound variables and the RHS of the selected case.
* Returns `None` if no case was selected.
Expand Down
7 changes: 4 additions & 3 deletions compiler/src/dotty/tools/dotc/inlines/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import collection.mutable
import reporting.trace
import util.Spans.Span
import dotty.tools.dotc.transform.Splicer
import dotty.tools.dotc.transform.BetaReduce
import quoted.QuoteUtils
import scala.annotation.constructorOnly

Expand Down Expand Up @@ -811,7 +812,7 @@ class Inliner(val call: tpd.Tree)(using Context):
case Quoted(Spliced(inner)) => inner
case _ => tree
val locked = ctx.typerState.ownedVars
val res = cancelQuotes(constToLiteral(betaReduce(super.typedApply(tree, pt)))) match {
val res = cancelQuotes(constToLiteral(BetaReduce(super.typedApply(tree, pt)))) match {
case res: Apply if res.symbol == defn.QuotedRuntime_exprSplice
&& StagingContext.level == 0
&& !hasInliningErrors =>
Expand All @@ -825,7 +826,7 @@ class Inliner(val call: tpd.Tree)(using Context):

override def typedTypeApply(tree: untpd.TypeApply, pt: Type)(using Context): Tree =
val locked = ctx.typerState.ownedVars
val tree1 = inlineIfNeeded(constToLiteral(betaReduce(super.typedTypeApply(tree, pt))), pt, locked)
val tree1 = inlineIfNeeded(constToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked)
if tree1.symbol.isQuote then
ctx.compilationUnit.needsStaging = true
tree1
Expand Down Expand Up @@ -1006,7 +1007,7 @@ class Inliner(val call: tpd.Tree)(using Context):
super.transform(t1)
case t: Apply =>
val t1 = super.transform(t)
if (t1 `eq` t) t else reducer.betaReduce(t1)
if (t1 `eq` t) t else BetaReduce(t1)
case Block(Nil, expr) =>
super.transform(expr)
case _ =>
Expand Down
116 changes: 77 additions & 39 deletions compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ import scala.collection.mutable.ListBuffer

/** Rewrite an application
*
* (((x1, ..., xn) => b): T)(y1, ..., yn)
* (([X1, ..., Xm] => (x1, ..., xn) => b): T)[T1, ..., Tm](y1, ..., yn)
*
* where
*
* - all yi are pure references without a prefix
* - the closure can also be contextual or erased, but cannot be a SAM type
* _ the type ascription ...: T is optional
* - the type parameters Xi and type arguments Ti are optional
* - the type ascription ...: T is optional
*
* to
*
Expand All @@ -38,51 +39,88 @@ class BetaReduce extends MiniPhase:

override def description: String = BetaReduce.description

override def transformApply(app: Apply)(using Context): Tree = app.fun match
case Select(fn, nme.apply) if defn.isFunctionType(fn.tpe) =>
val app1 = BetaReduce(app, fn, app.args)
if app1 ne app then report.log(i"beta reduce $app -> $app1")
app1
case _ =>
app

override def transformApply(app: Apply)(using Context): Tree =
val app1 = BetaReduce(app)
if app1 ne app then report.log(i"beta reduce $app -> $app1")
app1

object BetaReduce:
import ast.tpd._

val name: String = "betaReduce"
val description: String = "reduce closure applications"

/** Beta-reduces a call to `fn` with arguments `argSyms` or returns `tree` */
def apply(original: Tree, fn: Tree, args: List[Tree])(using Context): Tree =
fn match
case Typed(expr, _) =>
BetaReduce(original, expr, args)
case Block((anonFun: DefDef) :: Nil, closure: Closure) =>
BetaReduce(anonFun, args)
case Block(stats, expr) =>
val tree = BetaReduce(original, expr, args)
if tree eq original then original
else cpy.Block(fn)(stats, tree)
case Inlined(call, bindings, expr) =>
val tree = BetaReduce(original, expr, args)
if tree eq original then original
else cpy.Inlined(fn)(call, bindings, tree)
/** Rewrite an application
*
* ((x1, ..., xn) => b)(e1, ..., en)
*
* to
*
* val/def x1 = e1; ...; val/def xn = en; b
*
* where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix
* refs among the ei's directly without creating an intermediate binding.
*
* Similarly, rewrites type applications
*
* ([X1, ..., Xm] => (x1, ..., xn) => b).apply[T1, .., Tm](e1, ..., en)
*
* to
*
* type X1 = T1; ...; type Xm = Tm;val/def x1 = e1; ...; val/def xn = en; b
*
* This beta-reduction preserves the integrity of `Inlined` tree nodes.
*/
def apply(tree: Tree)(using Context): Tree =
val bindingsBuf = new ListBuffer[DefTree]
def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
Some(reduceApplication(ddef, argss, bindingsBuf))
case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty =>
template.body match
case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf))
case _ => None
case Block(stats, expr) if stats.forall(isPureBinding) =>
recur(expr, argss).map(cpy.Block(fn)(stats, _))
case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) =>
recur(expr, argss).map(cpy.Inlined(fn)(call, bindings, _))
case Typed(expr, tpt) =>
recur(expr, argss)
case TypeApply(Select(expr, nme.asInstanceOfPM), List(tpt)) =>
recur(expr, argss)
case _ => None
tree match
case Apply(Select(fn, nme.apply), args) if defn.isFunctionType(fn.tpe) =>
recur(fn, List(args)) match
case Some(reduced) =>
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
case None =>
tree
case Apply(TypeApply(Select(fn, nme.apply), targs), args) if fn.tpe.typeSymbol eq dotc.core.Symbols.defn.PolyFunctionClass =>
recur(fn, List(targs, args)) match
case Some(reduced) =>
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
case None =>
tree
case _ =>
original
end apply

/** Beta-reduces a call to `ddef` with arguments `args` */
def apply(ddef: DefDef, args: List[Tree])(using Context) =
val bindings = new ListBuffer[ValDef]()
val expansion1 = reduceApplication(ddef, args, bindings)
val bindings1 = bindings.result()
seq(bindings1, expansion1)
tree

/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
def reduceApplication(ddef: DefDef, args: List[Tree], bindings: ListBuffer[ValDef])(using Context): Tree =
val vparams = ddef.termParamss.iterator.flatten.toList
assert(args.hasSameLengthAs(vparams))
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree =
val (targs, args) = argss.flatten.partition(_.isType)
val tparams = ddef.leadingTypeParams
val vparams = ddef.termParamss.flatten

val targSyms =
for (targ, tparam) <- targs.zip(tparams) yield
targ.tpe.dealias match
case ref @ TypeRef(NoPrefix, _) =>
ref.symbol
case _ =>
val binding = TypeDef(newSymbol(ctx.owner, tparam.name, EmptyFlags, targ.tpe, coord = targ.span)).withSpan(targ.span)
bindings += binding
binding.symbol

val argSyms =
for (arg, param) <- args.zip(vparams) yield
arg.tpe.dealias match
Expand All @@ -99,8 +137,8 @@ object BetaReduce:
val expansion = TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = vparams.map(_.symbol),
substTo = argSyms
substFrom = (tparams ::: vparams).map(_.symbol),
substTo = targSyms ::: argSyms
).transform(ddef.rhs)

val expansion1 = new TreeMap {
Expand Down
13 changes: 10 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import Symbols._, Contexts._, Types._, Decorators._
import NameOps._
import Names._

import scala.collection.mutable.ListBuffer

/** Rewrite an application
*
* {new { def unapply(x0: X0)(x1: X1,..., xn: Xn) = b }}.unapply(y0)(y1, ..., yn)
Expand Down Expand Up @@ -38,7 +40,7 @@ class InlinePatterns extends MiniPhase:
if app.symbol.name.isUnapplyName && !app.tpe.isInstanceOf[MethodicType] then
app match
case App(Select(fn, name), argss) =>
val app1 = betaReduce(app, fn, name, argss.flatten)
val app1 = betaReduce(app, fn, name, argss)
if app1 ne app then report.log(i"beta reduce $app -> $app1")
app1
case _ =>
Expand All @@ -51,11 +53,16 @@ class InlinePatterns extends MiniPhase:
case Apply(App(fn, argss), args) => (fn, argss :+ args)
case _ => (app, Nil)

private def betaReduce(tree: Apply, fn: Tree, name: Name, args: List[Tree])(using Context): Tree =
// TODO merge with BetaReduce.scala
private def betaReduce(tree: Apply, fn: Tree, name: Name, argss: List[List[Tree]])(using Context): Tree =
fn match
case Block(TypeDef(_, template: Template) :: Nil, Apply(Select(New(_),_), Nil)) if template.constr.rhs.isEmpty =>
template.body match
case List(ddef @ DefDef(`name`, _, _, _)) => BetaReduce(ddef, args)
case List(ddef @ DefDef(`name`, _, _, _)) =>
val bindings = new ListBuffer[DefTree]()
val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings)
val bindings1 = bindings.result()
seq(bindings1, expansion1)
case _ => tree
case _ => tree

Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ object PickleQuotes {
}
val Block(List(ddef: DefDef), _) = splice: @unchecked
// TODO: beta reduce inner closure? Or wait until BetaReduce phase?
BetaReduce(ddef, spliceArgs).select(nme.apply).appliedTo(args(2).asInstance(quotesType))
BetaReduce(
splice
.select(nme.apply).appliedToArgs(spliceArgs))
.select(nme.apply).appliedTo(args(2).asInstance(quotesType))
}
CaseDef(Literal(Constant(idx)), EmptyTree, rhs)
}
Expand Down
9 changes: 4 additions & 5 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,15 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
object Term extends TermModule:
def betaReduce(tree: Term): Option[Term] =
tree match
case app @ tpd.Apply(tpd.Select(fn, nme.apply), args) if dotc.core.Symbols.defn.isFunctionType(fn.tpe) =>
val app1 = dotc.transform.BetaReduce(app, fn, args)
if app1 eq app then None
else Some(app1.withSpan(tree.span))
case tpd.Block(Nil, expr) =>
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
case tpd.Inlined(_, Nil, expr) =>
betaReduce(expr)
case _ =>
None
val tree1 = dotc.transform.BetaReduce(tree)
if tree1 eq tree then None
else Some(tree1.withSpan(tree.span))

end Term

given TermMethods: TermMethods with
Expand Down
57 changes: 57 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,63 @@ class InlineBytecodeTests extends DottyBytecodeTest {
}
}

@Test def beta_reduce_polymorphic_function = {
val source = """class Test:
| def test =
| ([Z] => (arg: Z) => { val a: Z = arg; a }).apply[Int](2)
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)

val fun = getMethod(clsNode, "test")
val instructions = instructionsFromMethod(fun)
val expected =
List(
Op(ICONST_2),
VarOp(ISTORE, 1),
VarOp(ILOAD, 1),
Op(IRETURN)
)

assert(instructions == expected,
"`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected))

}
}

@Test def beta_reduce_function_of_opaque_types = {
val source = """object foo:
| opaque type T = Int
| inline def apply(inline op: T => T): T = op(2)
|
|class Test:
| def test = foo { n => n }
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)

val fun = getMethod(clsNode, "test")
val instructions = instructionsFromMethod(fun)
val expected =
List(
Field(GETSTATIC, "foo$", "MODULE$", "Lfoo$;"),
VarOp(ASTORE, 1),
VarOp(ALOAD, 1),
VarOp(ASTORE, 2),
Op(ICONST_2),
Op(IRETURN),
)

assert(instructions == expected,
"`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected))

}
}

@Test def i9456 = {
val source = """class Foo {
| def test: Int = inline2(inline1(2.+))
Expand Down
5 changes: 5 additions & 0 deletions tests/run-macros/i15968.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
type Z = java.lang.String
"foo".toString()
}
"foo".toString()
15 changes: 15 additions & 0 deletions tests/run-macros/i15968/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import scala.quoted.*

inline def macroPolyFun[A](inline arg: A, inline f: [Z] => Z => String): String =
${ macroPolyFunImpl[A]('arg, 'f) }

private def macroPolyFunImpl[A: Type](arg: Expr[A], f: Expr[[Z] => Z => String])(using Quotes): Expr[String] =
Expr(Expr.betaReduce('{ $f($arg) }).show)


inline def macroFun[A](inline arg: A, inline f: A => String): String =
${ macroFunImpl[A]('arg, 'f) }

private def macroFunImpl[A: Type](arg: Expr[A], f: Expr[A => String])(using Quotes): Expr[String] =
Expr(Expr.betaReduce('{ $f($arg) }).show)

Loading

0 comments on commit c5c5aa6

Please sign in to comment.