Skip to content

Commit

Permalink
Implement GadtExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Sep 5, 2022
1 parent 36c66e9 commit 768f067
Show file tree
Hide file tree
Showing 48 changed files with 386 additions and 201 deletions.
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ object desugar {
// Propagate down the expected type to the leafs of the expression
case Block(stats, expr) =>
cpy.Block(tree)(stats, adaptToExpectedTpt(expr))
case GadtExpr(gadt, expr) =>
cpy.GadtExpr(tree)(gadt, adaptToExpectedTpt(expr))
case If(cond, thenp, elsep) =>
cpy.If(tree)(cond, adaptToExpectedTpt(thenp), adaptToExpectedTpt(elsep))
case untpd.Parens(expr) =>
Expand Down Expand Up @@ -1630,6 +1632,7 @@ object desugar {
case Tuple(trees) => (pats corresponds trees)(isIrrefutable)
case Parens(rhs1) => matchesTuple(pats, rhs1)
case Block(_, rhs1) => matchesTuple(pats, rhs1)
case GadtExpr(_, rhs1) => matchesTuple(pats, rhs1)
case If(_, thenp, elsep) => matchesTuple(pats, thenp) && matchesTuple(pats, elsep)
case Match(_, cases) => cases forall (matchesTuple(pats, _))
case CaseDef(_, _, rhs1) => matchesTuple(pats, rhs1)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p)
case Match(_, cases) => cases forall (c => forallResults(c.body, p))
case Block(_, expr) => forallResults(expr, p)
case GadtExpr(_, expr) => forallResults(expr, p)
case _ => p(tree)
}
}
Expand Down Expand Up @@ -1039,6 +1040,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
case Typed(expr, _) => unapply(expr)
case Inlined(_, Nil, expr) => unapply(expr)
case Block(Nil, expr) => unapply(expr)
case GadtExpr(_, expr) => unapply(expr)
case _ =>
tree.tpe.widenTermRefExpr.normalized match
case ConstantType(Constant(x)) => Some(x)
Expand Down
37 changes: 37 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,20 @@ class TreeTypeMap(
cpy.Block(blk)(stats1, expr1)
case inlined: Inlined =>
transformInlined(inlined)
case GadtExpr(gadt, expr) =>
val tmap = withMappedSyms(gadt.symbols.diff(substFrom)) // CaseDef handles the patVars
val gadt1 = tmap.rebuild(gadt)
inContext(ctx.withGadt(gadt1))(cpy.GadtExpr(expr)(gadt1, tmap.transform(expr)))
case cdef @ CaseDef(pat, guard, expr @ GadtExpr(gadt, rhs)) =>
val patVars1 = patVars(pat)
val tmap = withMappedSyms(patVars1 ::: gadt.symbols.diff(patVars1))
val gadt1 = tmap.rebuild(gadt)
inContext(ctx.withGadt(gadt1)) {
val pat1 = tmap.transform(pat)
val guard1 = tmap.transform(guard)
val rhs1 = cpy.GadtExpr(expr)(gadt1, tmap.transform(rhs))
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
}
case cdef @ CaseDef(pat, guard, rhs) =>
val tmap = withMappedSyms(patVars(pat))
val pat1 = tmap.transform(pat)
Expand All @@ -146,6 +160,29 @@ class TreeTypeMap(
}
}

private def rebuild(gadt: GadtConstraint)(using Context): GadtConstraint =
val constraints = for sym <- gadt.symbols yield
val TypeBounds(lo, hi) = gadt.fullBounds(sym).nn
(sym, lo, hi)
val constraints1 = constraints.mapConserve { triple =>
val (sym, lo, hi) = triple
val sym1 = mapOwner(sym)
val lo1 = mapType(lo)
val hi1 = mapType(hi)
if (sym eq sym1) && (lo eq lo1) && (hi eq hi1)
then triple
else (sym1, lo1, hi1)
}
if constraints eq constraints1 then
gadt
else
val gadt = EmptyGadtConstraint.fresh
for (sym, lo, hi) <- constraints1 do
gadt.addToConstraint(sym)
gadt.addBound(sym, lo, false)
gadt.addBound(sym, hi, true)
gadt

override def transformStats(trees: List[tpd.Tree], exprOwner: Symbol)(using Context): List[Tree] =
transformDefs(trees)._2

Expand Down
18 changes: 18 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,12 @@ object Trees {
override def isTerm: Boolean = !isType // this will classify empty trees as terms, which is necessary
}

case class GadtExpr[-T >: Untyped] private[ast] (gadt: GadtConstraint, expr: Tree[T])(implicit @constructorOnly src: SourceFile)
extends ProxyTree[T] {
type ThisTree[-T >: Untyped] <: GadtExpr[T]
def forwardTo: Tree[T] = expr
}

/** if cond then thenp else elsep */
case class If[-T >: Untyped] private[ast] (cond: Tree[T], thenp: Tree[T], elsep: Tree[T])(implicit @constructorOnly src: SourceFile)
extends TermTree[T] {
Expand Down Expand Up @@ -1071,6 +1077,7 @@ object Trees {
type NamedArg = Trees.NamedArg[T]
type Assign = Trees.Assign[T]
type Block = Trees.Block[T]
type GadtExpr = Trees.GadtExpr[T]
type If = Trees.If[T]
type InlineIf = Trees.InlineIf[T]
type Closure = Trees.Closure[T]
Expand Down Expand Up @@ -1209,6 +1216,9 @@ object Trees {
case tree: Block if (stats eq tree.stats) && (expr eq tree.expr) => tree
case _ => finalize(tree, untpd.Block(stats, expr)(sourceFile(tree)))
}
def GadtExpr(tree: Tree)(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr = tree match
case tree: GadtExpr if (gadt eq tree.gadt) && (expr eq tree.expr) => tree
case _ => finalize(tree, untpd.GadtExpr(gadt, expr)(sourceFile(tree)))
def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = tree match {
case tree: If if (cond eq tree.cond) && (thenp eq tree.thenp) && (elsep eq tree.elsep) => tree
case tree: InlineIf => finalize(tree, untpd.InlineIf(cond, thenp, elsep)(sourceFile(tree)))
Expand Down Expand Up @@ -1430,6 +1440,10 @@ object Trees {
cpy.Closure(tree)(transform(env), transform(meth), transform(tpt))
case Match(selector, cases) =>
cpy.Match(tree)(transform(selector), transformSub(cases))
case GadtExpr(gadt, expr) =>
inContext(ctx.withGadt(gadt))(cpy.GadtExpr(tree)(gadt, transform(expr)))
case CaseDef(pat, guard, body @ GadtExpr(gadt, _)) =>
inContext(ctx.withGadt(gadt))(cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body)))
case CaseDef(pat, guard, body) =>
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
case Labeled(bind, expr) =>
Expand Down Expand Up @@ -1566,6 +1580,10 @@ object Trees {
this(this(this(x, env), meth), tpt)
case Match(selector, cases) =>
this(this(x, selector), cases)
case GadtExpr(gadt, expr) =>
inContext(ctx.withGadt(gadt))(this(x, expr))
case CaseDef(pat, guard, body @ GadtExpr(gadt, _)) =>
inContext(ctx.withGadt(gadt))(this(this(this(x, pat), guard), body))
case CaseDef(pat, guard, body) =>
this(this(this(x, pat), guard), body)
case Labeled(bind, expr) =>
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
Block(stats, expr)
}

def GadtExpr(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr =
ta.assignType(untpd.GadtExpr(gadt, expr), gadt, expr)

def If(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If =
ta.assignType(untpd.If(cond, thenp, elsep), thenp, elsep)

Expand Down Expand Up @@ -673,6 +676,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

override def GadtExpr(tree: Tree)(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr =
val tree1 = untpdCpy.GadtExpr(tree)(gadt, expr)
tree match
case tree: GadtExpr if expr.tpe eq tree.expr.tpe => tree1.withTypeUnchecked(tree.tpe)
case _ => ta.assignType(tree1, gadt, expr)

override def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = {
val tree1 = untpdCpy.If(tree)(cond, thenp, elsep)
tree match {
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def NamedArg(name: Name, arg: Tree)(implicit src: SourceFile): NamedArg = new NamedArg(name, arg)
def Assign(lhs: Tree, rhs: Tree)(implicit src: SourceFile): Assign = new Assign(lhs, rhs)
def Block(stats: List[Tree], expr: Tree)(implicit src: SourceFile): Block = new Block(stats, expr)
def GadtExpr(gadt: GadtConstraint, expr: Tree)(implicit src: SourceFile): GadtExpr = new GadtExpr(gadt, expr)
def If(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new If(cond, thenp, elsep)
def InlineIf(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new InlineIf(cond, thenp, elsep)
def Closure(env: List[Tree], meth: Tree, tpt: Tree)(implicit src: SourceFile): Closure = new Closure(env, meth, tpt)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,9 @@ object Contexts {
case None => fresh.dropProperty(key)
}

final def withGadt(gadt: GadtConstraint): Context =
if this.gadt eq gadt then this else fresh.setGadt(gadt)

def typer: Typer = this.typeAssigner match {
case typer: Typer => typer
case _ => new Typer
Expand Down
55 changes: 37 additions & 18 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ sealed abstract class GadtConstraint extends Showable {
*
* @see [[ConstraintHandling.addToConstraint]]
*/
def addToConstraint(syms: List[Symbol])(using Context): Boolean
def addToConstraint(syms: List[Symbol], nestingLevel: Int)(using Context): Boolean
def addToConstraint(syms: List[Symbol])(using Context): Boolean = addToConstraint(syms, ctx.nestingLevel)
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)

/** Further constrain a symbol already present in the constraint. */
Expand All @@ -49,14 +50,17 @@ sealed abstract class GadtConstraint extends Showable {
/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type

def remove(sym: Symbol)(using Context): Unit

def symbols: List[Symbol]
def inputs: List[(List[Symbol], Int)]

def fresh: GadtConstraint

/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
def restore(other: GadtConstraint): Unit

def debugBoundsDescription(using Context): String
def eql(that: GadtConstraint): Boolean
}

final class ProperGadtConstraint private(
Expand Down Expand Up @@ -88,7 +92,7 @@ final class ProperGadtConstraint private(
// the case where they're valid, so no approximating is needed.
rawBound

override def addToConstraint(params: List[Symbol])(using Context): Boolean = {
override def addToConstraint(params: List[Symbol], nestingLevel: Int)(using Context): Boolean = {
import NameKinds.DepParamName

val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })(
Expand Down Expand Up @@ -126,15 +130,15 @@ final class ProperGadtConstraint private(
)

val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) =>
val tv = TypeVar(paramRef, creatorState = null)
val tv = TypeVar(paramRef, creatorState = null, nestingLevel)
mapping = mapping.updated(sym, tv)
reverseMapping = reverseMapping.updated(tv.origin, sym)
tv
}

// The replaced symbols are picked up here.
addToConstraint(poly1, tvars)
.showing(i"added to constraint: [$poly1] $params%, %\n$debugBoundsDescription", gadts)
.showing(i"added to constraint: [$poly1] $params%, % gadt = $this", gadts)
}

override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = {
Expand Down Expand Up @@ -219,8 +223,22 @@ final class ProperGadtConstraint private(
res
}

override def remove(sym: Symbol)(using Context): Unit =
mapping(sym) match
case tv: TypeVar =>
mapping = mapping.remove(sym)
reverseMapping = reverseMapping.remove(tv.origin)
constraint = constraint.replace(tv.origin, sym.typeRef)
case null =>

override def symbols: List[Symbol] = mapping.keys

override def inputs: List[(List[Symbol], Int)] =
constraint.domainLambdas.flatMap { tl =>
val syms = tl.paramRefs.flatMap(reverseMapping(_).toOption)
syms.headOption.map(sym1 => (syms, mapping(sym1).nn.initNestingLevel))
}

override def fresh: GadtConstraint = new ProperGadtConstraint(
myConstraint,
mapping,
Expand Down Expand Up @@ -291,17 +309,15 @@ final class ProperGadtConstraint private(

override def constr = gadtsConstr

override def toText(printer: Printer): Texts.Text = constraint.toText(printer)
override def eql(that: GadtConstraint): Boolean = (this eq that) || that.match
case that: ProperGadtConstraint =>
myConstraint == that.myConstraint
&& mapping == that.mapping
&& reverseMapping == that.reverseMapping
&& wasConstrained == that.wasConstrained
case _ => false

override def debugBoundsDescription(using Context): String = {
val sb = new mutable.StringBuilder
sb ++= constraint.show
sb += '\n'
mapping.foreachBinding { case (sym, _) =>
sb ++= i"$sym: ${fullBounds(sym)}\n"
}
sb.result
}
override def toText(printer: Printer): Texts.Text = printer.toText(this)
}

@sharable object EmptyGadtConstraint extends GadtConstraint {
Expand All @@ -314,18 +330,21 @@ final class ProperGadtConstraint private(

override def contains(sym: Symbol)(using Context) = false

override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
override def addToConstraint(params: List[Symbol], nestingLevel: Int)(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")

override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")

override def remove(sym: Symbol)(using Context): Unit = ()

override def symbols: List[Symbol] = Nil
override def inputs: List[(List[Symbol], Int)] = Nil

override def fresh = new ProperGadtConstraint
override def restore(other: GadtConstraint): Unit =
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")

override def debugBoundsDescription(using Context): String = "EmptyGadtConstraint"
override def eql(that: GadtConstraint): Boolean = (this eq that) || that == EmptyGadtConstraint

override def toText(printer: Printer): Texts.Text = "EmptyGadtConstraint"
override def toText(printer: Printer): Texts.Text = printer.toText(this)
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
val assumeInvariantRefinement =
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)

trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res gadt = ${ctx.gadt}") {
(tp, pt) match {
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
val saved = state.nn.constraint
Expand Down
Loading

0 comments on commit 768f067

Please sign in to comment.