Skip to content

Commit

Permalink
Switch to AssumeInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Oct 5, 2022
1 parent ec5a86c commit 0acaa36
Show file tree
Hide file tree
Showing 31 changed files with 221 additions and 188 deletions.
4 changes: 0 additions & 4 deletions compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,6 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
genBlockTo(blck, expectedType, dest)
generatedDest = dest

case GadtExpr(_, expr) =>
genLoadTo(expr, expectedType, dest)
generatedDest = dest

case Typed(Super(_, _), _) =>
genLoadTo(tpd.This(claszSymbol.asClass), expectedType, dest)
generatedDest = dest
Expand Down
3 changes: 0 additions & 3 deletions compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1851,9 +1851,6 @@ class JSCodeGen()(using genCtx: Context) {
}
js.Block(genStatsAndExpr)

case GadtExpr(_, expr) =>
genStatOrExpr(expr, isStat)

case Typed(expr, _) =>
expr match {
case _: Super => genThis()
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ object desugar {
// Propagate down the expected type to the leafs of the expression
case Block(stats, expr) =>
cpy.Block(tree)(stats, adaptToExpectedTpt(expr))
case GadtExpr(gadt, expr) =>
cpy.GadtExpr(tree)(gadt, adaptToExpectedTpt(expr))
case AssumeInfo(sym, nestingLevel, info, body) =>
cpy.AssumeInfo(tree)(sym, nestingLevel, info, adaptToExpectedTpt(body))
case If(cond, thenp, elsep) =>
cpy.If(tree)(cond, adaptToExpectedTpt(thenp), adaptToExpectedTpt(elsep))
case untpd.Parens(expr) =>
Expand Down Expand Up @@ -1632,7 +1632,7 @@ object desugar {
case Tuple(trees) => (pats corresponds trees)(isIrrefutable)
case Parens(rhs1) => matchesTuple(pats, rhs1)
case Block(_, rhs1) => matchesTuple(pats, rhs1)
case GadtExpr(_, rhs1) => matchesTuple(pats, rhs1)
case AssumeInfo(_, _, _, rhs1) => matchesTuple(pats, rhs1)
case If(_, thenp, elsep) => matchesTuple(pats, thenp) && matchesTuple(pats, elsep)
case Match(_, cases) => cases forall (matchesTuple(pats, _))
case CaseDef(_, _, rhs1) => matchesTuple(pats, rhs1)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p)
case Match(_, cases) => cases forall (c => forallResults(c.body, p))
case Block(_, expr) => forallResults(expr, p)
case GadtExpr(_, expr) => forallResults(expr, p)
case AssumeInfo(_, _, _, body) => forallResults(body, p)
case _ => p(tree)
}
}
Expand Down Expand Up @@ -1040,7 +1040,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
case Typed(expr, _) => unapply(expr)
case Inlined(_, Nil, expr) => unapply(expr)
case Block(Nil, expr) => unapply(expr)
case GadtExpr(_, expr) => unapply(expr)
case AssumeInfo(_, _, _, body) => unapply(body)
case _ =>
tree.tpe.widenTermRefExpr.normalized match
case ConstantType(Constant(x)) => Some(x)
Expand Down
48 changes: 14 additions & 34 deletions compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,20 @@ class TreeTypeMap(
cpy.Block(blk)(stats1, expr1)
case inlined: Inlined =>
transformInlined(inlined)
case GadtExpr(gadt, expr) =>
cpy.GadtExpr(expr)(gadt, transform(expr))
case cdef @ CaseDef(pat, guard, expr @ GadtExpr(gadt, rhs)) =>
val patVars1 = patVars(pat)
val tmap = withMappedSyms(patVars1 ::: gadt.symbols.diff(substFrom).diff(patVars1))
val gadt1 = tmap.rebuild(gadt)
inContext(ctx.withGadt(gadt1)) {
val pat1 = tmap.transform(pat)
val guard1 = tmap.transform(guard)
val rhs1 = cpy.GadtExpr(expr)(gadt1, tmap.transform(rhs))
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
case tree1: AssumeInfo =>
def mapBody(body: Tree) = body match
case tree @ AssumeInfo(_, _, _, _) =>
val tree1 = treeMap(tree)
tree1.withType(mapType(tree1.tpe))
case _ => body
ctx.gadt.withAssumeInfosIn(tree1, mapBody)(transform) { (assumeInfo, body) =>
val AssumeInfo(sym, _, info, _) = assumeInfo
mapType(sym.typeRef) match
case tp: TypeRef if tp eq sym.typeRef =>
val sym1 = sym.subst(substFrom, substTo)
val info1 = mapType(info)
cpy.AssumeInfo(assumeInfo)(sym = sym1, info = info1, body = body)
case _ => body // if the AssumeInfo symbol maps (as a type) to another type, we lose the associated info
}
case cdef @ CaseDef(pat, guard, rhs) =>
val tmap = withMappedSyms(patVars(pat))
Expand All @@ -158,29 +161,6 @@ class TreeTypeMap(
}
}

private def rebuild(gadt: GadtConstraint)(using Context): GadtConstraint =
val constraints = for sym <- gadt.symbols yield
val TypeBounds(lo, hi) = gadt.fullBounds(sym).nn
(sym, lo, hi)
val constraints1 = constraints.mapConserve { triple =>
val (sym, lo, hi) = triple
val sym1 = mapOwner(sym)
val lo1 = mapType(lo)
val hi1 = mapType(hi)
if (sym eq sym1) && (lo eq lo1) && (hi eq hi1)
then triple
else (sym1, lo1, hi1)
}
if constraints eq constraints1 then
gadt
else
val gadt = EmptyGadtConstraint.fresh
for (sym, lo, hi) <- constraints1 do
gadt.addToConstraint(sym)
gadt.addBound(sym, lo, false)
gadt.addBound(sym, hi, true)
gadt

override def transformStats(trees: List[tpd.Tree], exprOwner: Symbol)(using Context): List[Tree] =
transformDefs(trees)._2

Expand Down
35 changes: 24 additions & 11 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -569,10 +569,19 @@ object Trees {
override def isTerm: Boolean = !isType // this will classify empty trees as terms, which is necessary
}

case class GadtExpr[-T >: Untyped] private[ast] (gadt: GadtConstraint, expr: Tree[T])(implicit @constructorOnly src: SourceFile)
case class AssumeInfo[-T >: Untyped] private[ast] (sym: Symbol, nestingLevel: Int, info: Type, body: Tree[T])(implicit @constructorOnly src: SourceFile)
extends ProxyTree[T] {
type ThisTree[-T >: Untyped] <: GadtExpr[T]
def forwardTo: Tree[T] = expr
type ThisTree[-T >: Untyped] <: AssumeInfo[T]
def forwardTo: Tree[T] = body

/** Un-nests AssumeInfo trees, returning them in a list, along with the last non-AssumeInfo body.
* On each recursion, the body is first mapped with `mapBody`. */
def unnest[U <: T](mapBody: Tree[U] => Tree[U] = (body: Tree[U]) => body): (List[AssumeInfo[U]], Tree[U]) =
val acc = scala.collection.mutable.ListBuffer.empty[AssumeInfo[U]]
def rec(tree: Tree[U]): (List[AssumeInfo[U]], Tree[U]) = tree match
case tree @ AssumeInfo(_, _, _, body) => acc += tree; rec(mapBody(body))
case _ => (acc.toList, tree)
rec(this)
}

/** if cond then thenp else elsep */
Expand Down Expand Up @@ -1077,7 +1086,7 @@ object Trees {
type NamedArg = Trees.NamedArg[T]
type Assign = Trees.Assign[T]
type Block = Trees.Block[T]
type GadtExpr = Trees.GadtExpr[T]
type AssumeInfo = Trees.AssumeInfo[T]
type If = Trees.If[T]
type InlineIf = Trees.InlineIf[T]
type Closure = Trees.Closure[T]
Expand Down Expand Up @@ -1216,9 +1225,9 @@ object Trees {
case tree: Block if (stats eq tree.stats) && (expr eq tree.expr) => tree
case _ => finalize(tree, untpd.Block(stats, expr)(sourceFile(tree)))
}
def GadtExpr(tree: Tree)(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr = tree match
case tree: GadtExpr if (gadt eq tree.gadt) && (expr eq tree.expr) => tree
case _ => finalize(tree, untpd.GadtExpr(gadt, expr)(sourceFile(tree)))
def AssumeInfo(tree: Tree)(sym: Symbol, nestingLevel: Int, info: Type, body: Tree)(using Context): AssumeInfo = tree match
case tree: AssumeInfo if (sym eq tree.sym) && (info eq tree.info) && (body eq tree.body) && (nestingLevel == tree.nestingLevel) => tree
case _ => finalize(tree, untpd.AssumeInfo(sym, nestingLevel, info, body)(sourceFile(tree)))
def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = tree match {
case tree: If if (cond eq tree.cond) && (thenp eq tree.thenp) && (elsep eq tree.elsep) => tree
case tree: InlineIf => finalize(tree, untpd.InlineIf(cond, thenp, elsep)(sourceFile(tree)))
Expand Down Expand Up @@ -1351,6 +1360,8 @@ object Trees {

// Copier methods with default arguments; these demand that the original tree
// is of the same class as the copy. We only include trees with more than 2 elements here.
def AssumeInfo(tree: AssumeInfo)(sym: Symbol = tree.sym, nestingLevel: Int = tree.nestingLevel, info: Type = tree.info, body: Tree = tree.body)(using Context): AssumeInfo =
AssumeInfo(tree: Tree)(sym, nestingLevel, info, body)
def If(tree: If)(cond: Tree = tree.cond, thenp: Tree = tree.thenp, elsep: Tree = tree.elsep)(using Context): If =
If(tree: Tree)(cond, thenp, elsep)
def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =
Expand Down Expand Up @@ -1440,8 +1451,10 @@ object Trees {
cpy.Closure(tree)(transform(env), transform(meth), transform(tpt))
case Match(selector, cases) =>
cpy.Match(tree)(transform(selector), transformSub(cases))
case GadtExpr(gadt, expr) =>
inContext(ctx.withGadt(gadt))(cpy.GadtExpr(tree)(gadt, transform(expr)))
case tree @ AssumeInfo(sym, nestingLevel, info, body) =>
ctx.gadt.withAssumeInfosIn(tree)(transform) { (assumeInfo, body) =>
cpy.AssumeInfo(assumeInfo)(body = body)
}
case CaseDef(pat, guard, body) =>
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
case Labeled(bind, expr) =>
Expand Down Expand Up @@ -1578,8 +1591,8 @@ object Trees {
this(this(this(x, env), meth), tpt)
case Match(selector, cases) =>
this(this(x, selector), cases)
case GadtExpr(gadt, expr) =>
inContext(ctx.withGadt(gadt))(this(x, expr))
case tree @ AssumeInfo(sym, _, info, body) =>
ctx.gadt.withAssumeInfosIn(tree)(this(x, _))((_, x) => x)
case CaseDef(pat, guard, body) =>
this(this(this(x, pat), guard), body)
case Labeled(bind, expr) =>
Expand Down
16 changes: 9 additions & 7 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
Block(stats, expr)
}

def GadtExpr(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr =
ta.assignType(untpd.GadtExpr(gadt, expr), gadt, expr)
def AssumeInfo(sym: Symbol, nestingLevel: Int, info: Type, body: Tree)(using Context): AssumeInfo =
ta.assignType(untpd.AssumeInfo(sym, nestingLevel, info, body), body)

def If(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If =
ta.assignType(untpd.If(cond, thenp, elsep), thenp, elsep)
Expand Down Expand Up @@ -676,11 +676,11 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

override def GadtExpr(tree: Tree)(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr =
val tree1 = untpdCpy.GadtExpr(tree)(gadt, expr)
override def AssumeInfo(tree: Tree)(sym: Symbol, nestingLevel: Int, info: Type, body: Tree)(using Context): AssumeInfo =
val tree1 = untpdCpy.AssumeInfo(tree)(sym, nestingLevel, info, body)
tree match
case tree: GadtExpr if expr.tpe eq tree.expr.tpe => tree1.withTypeUnchecked(tree.tpe)
case _ => ta.assignType(tree1, gadt, expr)
case tree: AssumeInfo if body.tpe eq tree.body.tpe => tree1.withTypeUnchecked(tree.tpe)
case _ => ta.assignType(tree1, body)

override def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = {
val tree1 = untpdCpy.If(tree)(cond, thenp, elsep)
Expand Down Expand Up @@ -766,6 +766,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

override def AssumeInfo(tree: AssumeInfo)(sym: Symbol = tree.sym, nestingLevel: Int = tree.nestingLevel, info: Type = tree.info, body: Tree = tree.body)(using Context): AssumeInfo =
AssumeInfo(tree: Tree)(sym, nestingLevel, info, body)
override def If(tree: If)(cond: Tree = tree.cond, thenp: Tree = tree.thenp, elsep: Tree = tree.elsep)(using Context): If =
If(tree: Tree)(cond, thenp, elsep)
override def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =
Expand Down Expand Up @@ -1305,7 +1307,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
else if (tree.tpe.widen isRef numericCls)
tree
else {
report.warning(i"conversion from ${tree.tpe.widen} to ${numericCls.typeRef} will always fail at runtime.", tree.srcPos)
report.warning(i"conversion from ${tree.tpe.widen} to ${numericCls.typeRef} will always fail at runtime.")
Throw(New(defn.ClassCastExceptionClass.typeRef, Nil)).withSpan(tree.span)
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def NamedArg(name: Name, arg: Tree)(implicit src: SourceFile): NamedArg = new NamedArg(name, arg)
def Assign(lhs: Tree, rhs: Tree)(implicit src: SourceFile): Assign = new Assign(lhs, rhs)
def Block(stats: List[Tree], expr: Tree)(implicit src: SourceFile): Block = new Block(stats, expr)
def GadtExpr(gadt: GadtConstraint, expr: Tree)(implicit src: SourceFile): GadtExpr = new GadtExpr(gadt, expr)
def AssumeInfo(sym: Symbol, nestingLevel: Int, info: Type, body: Tree)(implicit src: SourceFile): AssumeInfo = new AssumeInfo(sym, nestingLevel, info, body)
def If(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new If(cond, thenp, elsep)
def InlineIf(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new InlineIf(cond, thenp, elsep)
def Closure(env: List[Tree], meth: Tree, tpt: Tree)(implicit src: SourceFile): Closure = new Closure(env, meth, tpt)
Expand Down
44 changes: 36 additions & 8 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Decorators._
import Contexts._
import Types._
import Symbols._
import ast.*, Trees.*
import util.{SimpleIdentitySet, SimpleIdentityMap}
import collection.mutable
import printing._
Expand Down Expand Up @@ -38,6 +39,38 @@ sealed abstract class GadtConstraint extends Showable {
/** Further constrain a symbol already present in the constraint. */
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean

/** Rebuild and return a GadtConstraint using the given AssumeInfo trees. */
def withAssumeInfos[T >: Untyped](assumeInfos: List[AssumeInfo[T]])(using Context): GadtConstraint =
val assumeInfos1 = assumeInfos.filterConserve(t => isBounds(t.info))
if assumeInfos1.isEmpty then this
else
val gadt = fresh

for case AssumeInfo(sym, nestingLevel, _, _) <- assumeInfos1 do
gadt.addToConstraint(List(sym), nestingLevel)

for case AssumeInfo(sym, _, TypeBounds(lo, hi), _) <- assumeInfos1 do
if (sym.typeRef <:< lo)(using ctx.withGadt(gadt)) then
// add in reverse order so that unification runs in the right direction (keep sym)
// for a counter-example: say the symbol is c: b and the bound is b
// if we add c >: b it will unify to b: c not c: b
gadt.addBound(sym, hi, isUpper = true)
gadt.addBound(sym, lo, isUpper = false)
else
gadt.addBound(sym, lo, isUpper = false)
gadt.addBound(sym, hi, isUpper = true)
gadt

/** Un-nest AssumeInfo trees, build a GadtConstraint, and fold everything, starting from the last body.
* Reuses `AssumeInfo#unnest` and `withAssumeInfos`. */
def withAssumeInfosIn[T >: Untyped <: Type, A](
tree: AssumeInfo[T], mapBody: Tree[T] => Tree[T] = (body: Tree[T]) => body,
)(start: Context ?=> Tree[T] => A)(combine: Context ?=> (AssumeInfo[T], A) => A)(using Context): A =
val (assumeInfos, body) = tree.unnest(mapBody)
inContext(ctx.withGadt(withAssumeInfos(assumeInfos))) {
assumeInfos.foldRight(start(body))(combine)
}

/** Is the symbol registered in the constraint?
*
* @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]].
Expand All @@ -51,7 +84,7 @@ sealed abstract class GadtConstraint extends Showable {
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type

def symbols: List[Symbol]
def inputs: List[(List[Symbol], Int)]
def inputs: List[(Symbol, Int)]

def fresh: GadtConstraint

Expand Down Expand Up @@ -225,12 +258,7 @@ final class ProperGadtConstraint private(
}

override def symbols: List[Symbol] = mapping.keys

override def inputs: List[(List[Symbol], Int)] =
constraint.domainLambdas.flatMap { tl =>
val syms = tl.paramRefs.flatMap(reverseMapping(_).toOption)
syms.headOption.map(sym1 => (syms, mapping(sym1).nn.initNestingLevel))
}
override def inputs: List[(Symbol, Int)] = mapping.map2((sym, tvar) => (sym, tvar.initNestingLevel))

override def fresh: GadtConstraint = new ProperGadtConstraint(
myConstraint,
Expand Down Expand Up @@ -331,7 +359,7 @@ final class ProperGadtConstraint private(
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")

override def symbols: List[Symbol] = Nil
override def inputs: List[(List[Symbol], Int)] = Nil
override def inputs: List[(Symbol, Int)] = Nil

override def fresh = new ProperGadtConstraint
override def restore(other: GadtConstraint): Unit =
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/tasty/TastyPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ class TastyPrinter(bytes: Array[Byte]) {
printTrees()
case PARAMtype =>
printNat(); printNat()
case CONSTRAINT =>
printInt(); until(end) { printNat(); printTree(); printTree() }
case ASSUMEINFO =>
until(end) { printNat(); printTree(); printInt(); printTree() }
case _ =>
printTrees()
}
Expand Down
19 changes: 6 additions & 13 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -472,20 +472,13 @@ class TreePickler(pickler: TastyPickler) {
writeByte(BLOCK)
stats.foreach(preRegister)
withLength { pickleTree(expr); stats.foreach(pickleTree) }
case GadtExpr(gadt, expr) =>
writeByte(GADTEXPR)
case AssumeInfo(sym, nestingLevel, info, body) =>
writeByte(ASSUMEINFO)
withLength {
for (symbols, nestingLevel) <- gadt.inputs do
writeByte(CONSTRAINT)
withLength {
writeInt(nestingLevel)
for sym <- symbols do
val TypeBounds(lo, hi) = gadt.fullBounds(sym).nn
pickleSymRef(sym)
pickleType(lo)
pickleType(hi)
}
pickleTree(expr)
pickleSymRef(sym)
writeInt(nestingLevel)
pickleType(info)
pickleTree(body)
}
case tree @ If(cond, thenp, elsep) =>
writeByte(IF)
Expand Down
Loading

0 comments on commit 0acaa36

Please sign in to comment.