Skip to content

Commit

Permalink
Forward-port changes from Linker rewrite rules.
Browse files Browse the repository at this point in the history
Add support for conditional rewrites, simple purity checking as well as
constant-checking.
  • Loading branch information
DarkDimius committed May 9, 2016
1 parent 2283964 commit d7cb648
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 58 deletions.
26 changes: 26 additions & 0 deletions src/dotty/linker/rewrites.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,29 @@ import scala.annotation.Annotation

/** An annotation used to mark classes containing rewriting rules. */
class rewrites extends Annotation {}

sealed class Rewrite[T] private(from: T, to: T)
object Rewrite {
def apply[T](from: T, to: T): Rewrite[T] = new Rewrite(from, to)
}

sealed class Warning[T] private (pattern: T, msgfun: /*List[String] => */String)
object Warning {
// def apply[T](pattern: T, msgfun: List[String] => String): Warning[T] = new Warning[T](pattern, msgfun) // will work only on second compilation
def apply[T](pattern: T, msg: String) : Warning[T] = new Warning[T](pattern, msg)
}

object Error {
// def apply[T](pattern: T, msgfun: List[String] => String): Warning[T] = new Warning[T](pattern, msgfun) // will work only on second compilation
def apply[T](pattern: T, msg: String) : Warning[T] = Warning.apply[T](pattern, msg)
}

final class IsLiteral[T] private ()
final class Source[T] private()
final class IsPure[T] private ()


abstract class MetaRewrite[T, Context, Tree] private(body: Context => Tree) {}
object MetaRewrite {
def apply[T](body: Any => Any) = ???
}
4 changes: 3 additions & 1 deletion src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
| Super(_, _)
| Literal(_) =>
Pure
case Block(List(anon: DefDef), cl: Closure) =>
minOf(exprPurity(anon.rhs), cl.env.map(exprPurity))
case Ident(_) =>
refPurity(tree)
case Select(qual, _) =>
Expand All @@ -332,7 +334,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
// Note: After uncurry, field accesses are represented as Apply(getter, Nil),
// so an Apply can also be pure.
if (args.isEmpty && fn.symbol.is(Stable)) exprPurity(fn)
else if (tree.tpe.isInstanceOf[ConstantType] && isKnownPureOp(tree.symbol))
else if (/*tree.tpe.isInstanceOf[ConstantType] &&*/ isKnownPureOp(tree.symbol))
// A constant expression with pure arguments is pure.
minOf(exprPurity(fn), args.map(exprPurity))
else Impure
Expand Down
176 changes: 132 additions & 44 deletions src/dotty/tools/dotc/transform/linker/Rewrites.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import ast.Trees._
import dotty.tools.dotc.ast.tpd
import util.Positions._
import Names._
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.transform.TreeTransforms
import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo, TreeTransform}

import collection.mutable
Expand All @@ -31,26 +33,39 @@ class Rewrites extends MiniPhaseTransform { thisTransform =>
def isOptVersion(n: Name) = n.endsWith(namePattern)
def dropOpt(n: Name)(implicit ctx: Context) = if (isOptVersion(n)) n.dropRight(namePattern.length) else n

private var annot: Symbol = null
private var pairs: List[(DefDef, DefDef)] = null
private var annot : Symbol = null
private var rewriteClass : Symbol = null
private var warningClass : Symbol = null
private var errorClass : Symbol = null
private var rewriteCompanion : Symbol = null
private var rewriteApply : Symbol = null
private var sourceArgument : Symbol = null
private var isPureArgument : Symbol = null
private var isLiteralArgument : Symbol = null
private var supportedArguments: Set[Symbol] = null
private var pairs : List[(DefDef, DefDef)] = null


def collectPatters(tree: tpd.Tree)(implicit ctx: Context) = {
val collector = new TreeAccumulator[List[(DefDef, DefDef)]] {
def apply(x: List[(tpd.DefDef, tpd.DefDef)], tree: tpd.Tree)(implicit ctx: Context): List[(tpd.DefDef, tpd.DefDef)] = {
tree match {
case t: tpd.TypeDef if t.isClassDef && t.symbol.hasAnnotation(annot) =>
val stats = t.rhs.asInstanceOf[Template].body
val defsByName = stats.filter(x => x.isInstanceOf[DefDef]).asInstanceOf[List[DefDef]].map(x => (x.symbol.name, x)).toMap
val pairs = defsByName.groupBy(x => dropOpt(x._1)).filter(x => x._2.size > 1)
val (errors, realPairs) = pairs.partition(x => x._2.size > 2)

errors.foreach(x => ctx.error("overloads are not supported", x._2.head._2.pos))
val prepend = realPairs.map(x=>
(x._2.values.find(x => !isOptVersion(x.symbol.name)).get,
x._2.values.find(x => isOptVersion(x.symbol.name)).get)
).toList


val prepend = stats.flatMap{x => x match {
case defdef: DefDef if defdef.symbol.info.finalResultType.derivesFrom(rewriteClass) || defdef.symbol.info.finalResultType.derivesFrom(warningClass)=>
seb(defdef.rhs) match {
case t: Apply if t.symbol eq rewriteApply =>
(cpy.DefDef(defdef)(rhs = t.args.head), cpy.DefDef(defdef)(rhs = t.args.tail.head)) :: Nil
case t: Apply if t.tpe.derivesFrom(warningClass) =>
ctx.error("warning are not implemented yet", t.pos)
Nil
case _ =>
ctx.error("tree not supported", defdef.pos)
Nil
}
case _ => Nil
}}
foldOver(prepend ::: x, t)
case _ => foldOver(x, tree)
}
Expand All @@ -61,6 +76,17 @@ class Rewrites extends MiniPhaseTransform { thisTransform =>

override def prepareForUnit(tree: tpd.Tree)(implicit ctx: Context): TreeTransform = {
annot = ctx.requiredClass("dotty.linker.rewrites")
rewriteClass = ctx.requiredClass("dotty.linker.Rewrite")
warningClass = ctx.requiredClass("dotty.linker.Warning")
errorClass = ctx.requiredClass("dotty.linker.Error")

isPureArgument = ctx.requiredClass("dotty.linker.IsPure")
sourceArgument = ctx.requiredClass("dotty.linker.Source")
isLiteralArgument = ctx.requiredClass("dotty.linker.IsLiteral")

supportedArguments = Set(isLiteralArgument, sourceArgument, isPureArgument)
rewriteCompanion = rewriteClass.companionModule
rewriteApply = rewriteCompanion.requiredMethod(nme.apply)
pairs = collectPatters(tree)
ctx.warning(s"found rewriting rules: ${pairs.map(x=> x._1.symbol.showFullName).mkString(", ")}")
pairs.foreach(checkSupported)
Expand All @@ -71,10 +97,28 @@ class Rewrites extends MiniPhaseTransform { thisTransform =>
val pattern = pair._1
val rewrite = pair._2
def unsupported(reason: String) = ctx.error("Unsupported pattern: " + reason, pattern.pos)
if (pattern.symbol.signature != rewrite.symbol.signature)
unsupported("signatures do not match")
if (pattern.vparamss.flatten.map(_.name) != rewrite.vparamss.flatten.map(_.name))
unsupported("arguments should have same names")
pattern.symbol.info.resultType match {
case t: ImplicitMethodType =>
val ptypes = t.paramTypes
def checkValidCondition(t: Type) = t match{
case t: RefinedType if supportedArguments.contains(t.typeSymbol) =>
val info = t.refinedInfo
info match {
case alias: TypeAlias =>
alias.underlying match {
case t: MethodParam =>
case _ =>
ctx.error(i"Unsupported condition $t", pattern.pos)
}
}
case _ =>
ctx.error(i"Unsupported condition $t", pattern.pos)
}
ptypes.foreach(checkValidCondition)
case t: MethodType =>
ctx.error("multiple argument strings not supported", pattern.pos)
case _ =>
}
new TreeTraverser {
def traverse(tree: tpd.Tree)(implicit ctx: Context): Unit =
tree match {
Expand All @@ -98,32 +142,77 @@ class Rewrites extends MiniPhaseTransform { thisTransform =>
}


override def prepareForTypeDef(tree: tpd.TypeDef)(implicit ctx: Context): TreeTransform = {
if (tree.symbol.hasAnnotation(annot)) TreeTransforms.NoTransform
else this
}

type Substitution = Map[Name, (Tree, Symbol /* owner */)]

private def filtersByPattern(filters: List[ValDef])(implicit ctx: Context): Substitution => Option[Substitution] = {
if (filters.isEmpty) {x: Substitution => Some(x)}
else {
val headFilter = filters.head.symbol.info.typeSymbol
val methodParamName = filters.head.symbol.info.asInstanceOf[RefinedType].
refinedInfo.asInstanceOf[TypeAlias].underlying.asInstanceOf[TermRef].name
if (headFilter eq isLiteralArgument) { x: Substitution =>
x.get(methodParamName) match {
case Some((t: Literal, _)) =>
filtersByPattern(filters.tail)(ctx)(x)
case _ =>
None
}
} else if (headFilter eq sourceArgument) { x: Substitution =>
val sourceText = Literal(Constant(x(methodParamName)._1.show))
val newSubstitution : Substitution = x + (filters.head.name -> (sourceText, NoSymbol))
filtersByPattern(filters.tail)(ctx)(newSubstitution)
} else if (headFilter eq isPureArgument) {x: Substitution =>
if (isPureTree(x(methodParamName)._1)) filtersByPattern(filters.tail)(ctx)(x)
else None
} else ???
}
}

private def isPureTree(t: Tree)(implicit ctx: Context) = {
tpd.isPureExpr(t)
}

override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
if (tree.symbol.ownersIterator.findSymbol(x => x.hasAnnotation(annot)).exists) tree
else {
val mapping = new TreeMap() {
override def transform(tree: tpd.Tree)(implicit ctx: Context):tpd.Tree = tree match {
case tree: DefTree =>
super.transform(tree)(ctx.withOwner(tree.symbol))
case tree =>
case _ =>
val scanner = pairs.iterator.map(x => (x, isSimilar(tree, x._1))).find(x => x._2.nonEmpty)
if (scanner.nonEmpty) {
val ((patern, rewrite), Some(binding)) = scanner.get
ctx.warning(s"Applying rule ${dropOpt(rewrite.symbol.name)}, substitution: ${binding.map(x => s"${x._1} -> ${x._2._1.show}").mkString(", ")}")
val subByName = binding.map(x => (x._1.name, x._2)).toMap[Name, (Tree, Symbol)]
val substitution = new TreeMap() {
override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
case tree: DefTree =>
super.transform(tree)(ctx.withOwner(tree.symbol))
case _ =>
if (tree.symbol.maybeOwner == rewrite.symbol && tree.symbol.is(Flags.Param)) {
val (oldTree, oldOwner) = subByName(tree.symbol.name)
oldTree.changeOwner(oldOwner, ctx.owner)
}
else super.transform(tree)
val fistSubstitution: Substitution = binding.map(x => (x._1.name, x._2)).toMap
val filters =
if (patern.symbol.info.resultType.isInstanceOf[ImplicitMethodType])
patern.vparamss.tail.head
else Nil
val secondSubstitution = filtersByPattern(filters)(ctx)(fistSubstitution)

if (secondSubstitution.nonEmpty) {
val finalSubstitution = secondSubstitution.get
ctx.warning(s"Applying rule ${dropOpt(rewrite.symbol.name)}, substitution: ${finalSubstitution.map(x => s"${x._1} -> ${x._2._1.show}").mkString(", ")}")
val substitution = new TreeMap() {
override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
case tree: DefTree =>
super.transform(tree)(ctx.withOwner(tree.symbol))
case _ =>
if (tree.symbol.maybeOwner == rewrite.symbol && tree.symbol.is(Flags.Param)) {
val (oldTree, oldOwner) = finalSubstitution(tree.symbol.name)
if (oldOwner.exists) oldTree.changeOwner(oldOwner, ctx.owner)
else oldTree
}
else super.transform(tree)
}
}
}
super.transform(substitution.transform(rewrite.rhs)).changeOwner(rewrite.symbol, ctx.owner)
super.transform(substitution.transform(rewrite.rhs)).changeOwner(rewrite.symbol, ctx.owner)
} else super.transform(tree)
} else super.transform(tree)
}
}
Expand All @@ -141,18 +230,14 @@ class Rewrites extends MiniPhaseTransform { thisTransform =>


private def isSimilar(tree: Tree, pattern: DefDef)(implicit ctx: Context): Option[Map[Symbol, (Tree, Symbol)]] = {
var currentMapping = new mutable.HashMap[Symbol, (Tree, Symbol)]()
def seb/*skipEmptyBlocks*/(x: Tree) = x match {
case Block(Nil, t) => t
case _ => x
}
def abort = currentMapping = null
try {
val currentMapping = new mutable.HashMap[Symbol, (Tree, Symbol)]()
var aborted = false
def abort = aborted = true
def bind(sym: Symbol, tree: Tree, oldOwner: Symbol) = {
if (currentMapping.contains(sym)) abort
else currentMapping.put(sym, (tree, oldOwner))
}
def loop(subtree: Tree, subpat: Tree)(implicit ctx: Context): Unit = seb(subtree) match {
def loop(subtree: Tree, subpat: Tree)(implicit ctx: Context): Unit = if (!aborted) seb(subtree) match {
case Apply(sel, args) => seb(subpat) match {
case Apply(selpat, selargs) if selargs.hasSameLengthAs(args) =>
loop(sel, selpat)
Expand Down Expand Up @@ -189,10 +274,13 @@ class Rewrites extends MiniPhaseTransform { thisTransform =>
case _ => abort
}
loop(tree, pattern.rhs)
} catch {
case e: NullPointerException =>
}
if ((currentMapping eq null) || currentMapping.size != pattern.vparamss.flatten.size) None

if (aborted || currentMapping.size != pattern.vparamss.head.size) None
else Some(currentMapping.toMap)
}

private def seb/*skipEmptyBlocks*/(x: Tree) = x match {
case Block(Nil, t) => t
case _ => x
}
}
45 changes: 32 additions & 13 deletions tests/linker/rewrites/ColRewrites.scala
Original file line number Diff line number Diff line change
@@ -1,25 +1,44 @@
import dotty.linker.rewrites
import dotty.linker._

@rewrites
object rules{
def twoDropsOnes(x: List[Int]) =
x.drop(1).drop(1)
def twoDropsOnes$opt(x: List[Int]) =
x.drop(2)
object rules/* extends NeedsMeta */ {
def twoDropsOnes(x: List[Int]) =
Rewrite(x.drop(1).drop(1), x.drop(2))
def twoDropRights(x: List[Int], a: Int, b: Int) =
x.dropRight(a).dropRight(b)
def twoDropRights$opt(x: List[Int], a: Int, b: Int) =
x.dropRight(a + b)
def twoFilters(x: List[Int], a: Int => Boolean, b: Int => Boolean) =
x.filter(a).filter(b)
def twoFilters$opt(x: List[Int], a: Int => Boolean, b: Int => Boolean) =
x.filter(x => a(x) && b(x))
Rewrite(x.dropRight(a).dropRight(b), x.dropRight(a + b))

/************************************/
def twoFilters(x: List[Int], a: Int => Boolean, b: Int => Boolean)(implicit apure: IsPure[a.type]) =
Rewrite(from = x.filter(a).filter(b),
to = x.filter(x => a(x) && b(x)))


def prettyPrint(x: Any)(implicit source: Source[x.type]) =
Rewrite(Test.myPrettyPrint(x), println(source + " = " + x))

/*
def twoFilters(x: List[Int], a: Int => Boolean, b: Int => Boolean) =
Rewrite(filter = meta.isPure(a), // restriction: the expressions that are used here are very limited
from = x.filter(a).filter(b),
to = x.filter(x => a(x) && b(x)))
def twoFilters(x: List[Int], a: Int => Boolean, b: Int => Boolean)(implicit meta: Semantics) =
if (meta.isPure(a))
Rewrite(x.filter(a).filter(b), x.filter(x => a(x) && b(x)))
def twoFilters(x: List[Int], a: Int => Boolean, b: Int => Boolean) =
RewriteMeta(meta{ /* full blown meta */ })
*/
}

object Test{
def myPrettyPrint(a: Any) = ???
def main(args: Array[String]): Unit = {
List(1,2,3).drop(1).drop(1)
List(1,2,3).dropRight(1).dropRight(1)
List(1,2,3).filter(_ > 2).filter(_ > 1)
myPrettyPrint(args.length)
}
}

0 comments on commit d7cb648

Please sign in to comment.