Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AssumeInfo #15928

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ object desugar {
// Propagate down the expected type to the leafs of the expression
case Block(stats, expr) =>
cpy.Block(tree)(stats, adaptToExpectedTpt(expr))
case AssumeInfo(sym, info, body) =>
cpy.AssumeInfo(tree)(sym, info, adaptToExpectedTpt(body))
case If(cond, thenp, elsep) =>
cpy.If(tree)(cond, adaptToExpectedTpt(thenp), adaptToExpectedTpt(elsep))
case untpd.Parens(expr) =>
Expand Down Expand Up @@ -1645,6 +1647,7 @@ object desugar {
case Tuple(trees) => (pats corresponds trees)(isIrrefutable)
case Parens(rhs1) => matchesTuple(pats, rhs1)
case Block(_, rhs1) => matchesTuple(pats, rhs1)
case AssumeInfo(_, _, rhs1) => matchesTuple(pats, rhs1)
case If(_, thenp, elsep) => matchesTuple(pats, thenp) && matchesTuple(pats, elsep)
case Match(_, cases) => cases forall (matchesTuple(pats, _))
case CaseDef(_, _, rhs1) => matchesTuple(pats, rhs1)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p)
case Match(_, cases) => cases forall (c => forallResults(c.body, p))
case Block(_, expr) => forallResults(expr, p)
case AssumeInfo(_, _, body) => forallResults(body, p)
case _ => p(tree)
}

Expand Down Expand Up @@ -1088,6 +1089,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
case Typed(expr, _) => unapply(expr)
case Inlined(_, Nil, expr) => unapply(expr)
case Block(Nil, expr) => unapply(expr)
case AssumeInfo(_, _, body) => unapply(body)
case _ =>
tree.tpe.widenTermRefExpr.dealias.normalized match
case ConstantType(Constant(x)) => Some(x)
Expand Down
14 changes: 14 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ class TreeTypeMap(
cpy.LambdaTypeTree(tdef)(tparams1, tmap1.transform(body))
case inlined: Inlined =>
transformInlined(inlined)
case tree: AssumeInfo =>
def mapBody(body: Tree) = body match
case tree @ AssumeInfo(_, _, _) =>
val tree1 = treeMap(tree)
tree1.withType(mapType(tree1.tpe))
case _ => body
tree.fold(transform, mapBody) { case (assumeInfo @ AssumeInfo(sym, info, _), body) =>
mapType(sym.typeRef) match
case tp: TypeRef if tp eq sym.typeRef =>
val sym1 = sym.subst(substFrom, substTo)
val info1 = mapType(info)
cpy.AssumeInfo(assumeInfo)(sym = sym1, info = info1, body = body)
case _ => body // if the AssumeInfo symbol maps (as a type) to another type, we lose the associated info
}
case cdef @ CaseDef(pat, guard, rhs) =>
val tmap = withMappedSyms(patVars(pat))
val pat1 = tmap.transform(pat)
Expand Down
26 changes: 26 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,20 @@ object Trees {
override def isTerm: Boolean = !isType // this will classify empty trees as terms, which is necessary
}

case class AssumeInfo[+T <: Untyped] private[ast] (sym: Symbol, info: Type, body: Tree[T])(implicit @constructorOnly src: SourceFile)
extends ProxyTree[T] {
type ThisTree[+T <: Untyped] <: AssumeInfo[T]
def forwardTo: Tree[T] = body

def fold[U >: T <: Untyped, A](
start: Context ?=> Tree[U] => A, mapBody: Tree[U] => Tree[U] = (body: Tree[U]) => body,
)(combine: Context ?=> (AssumeInfo[U], A) => A)(using Context): A =
val body1 = mapBody(body)
inContext(ctx.withAssumeInfo(ctx.assumeInfo.add(sym, info))) {
combine(this, start(body1))
}
}

/** if cond then thenp else elsep */
case class If[+T <: Untyped] private[ast] (cond: Tree[T], thenp: Tree[T], elsep: Tree[T])(implicit @constructorOnly src: SourceFile)
extends TermTree[T] {
Expand Down Expand Up @@ -1074,6 +1088,7 @@ object Trees {
type NamedArg = Trees.NamedArg[T]
type Assign = Trees.Assign[T]
type Block = Trees.Block[T]
type AssumeInfo = Trees.AssumeInfo[T]
type If = Trees.If[T]
type InlineIf = Trees.InlineIf[T]
type Closure = Trees.Closure[T]
Expand Down Expand Up @@ -1212,6 +1227,9 @@ object Trees {
case tree: Block if (stats eq tree.stats) && (expr eq tree.expr) => tree
case _ => finalize(tree, untpd.Block(stats, expr)(sourceFile(tree)))
}
def AssumeInfo(tree: Tree)(sym: Symbol, info: Type, body: Tree)(using Context): AssumeInfo = tree match
case tree: AssumeInfo if (sym eq tree.sym) && (info eq tree.info) && (body eq tree.body) => tree
case _ => finalize(tree, untpd.AssumeInfo(sym, info, body)(sourceFile(tree)))
def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = tree match {
case tree: If if (cond eq tree.cond) && (thenp eq tree.thenp) && (elsep eq tree.elsep) => tree
case tree: InlineIf => finalize(tree, untpd.InlineIf(cond, thenp, elsep)(sourceFile(tree)))
Expand Down Expand Up @@ -1344,6 +1362,8 @@ object Trees {

// Copier methods with default arguments; these demand that the original tree
// is of the same class as the copy. We only include trees with more than 2 elements here.
def AssumeInfo(tree: AssumeInfo)(sym: Symbol = tree.sym, info: Type = tree.info, body: Tree = tree.body)(using Context): AssumeInfo =
AssumeInfo(tree: Tree)(sym, info, body)
def If(tree: If)(cond: Tree = tree.cond, thenp: Tree = tree.thenp, elsep: Tree = tree.elsep)(using Context): If =
If(tree: Tree)(cond, thenp, elsep)
def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =
Expand Down Expand Up @@ -1433,6 +1453,10 @@ object Trees {
cpy.Closure(tree)(transform(env), transform(meth), transform(tpt))
case Match(selector, cases) =>
cpy.Match(tree)(transform(selector), transformSub(cases))
case tree @ AssumeInfo(sym, info, body) =>
tree.fold(transform) { (assumeInfo, body) =>
cpy.AssumeInfo(assumeInfo)(body = body)
}
case CaseDef(pat, guard, body) =>
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
case Labeled(bind, expr) =>
Expand Down Expand Up @@ -1569,6 +1593,8 @@ object Trees {
this(this(this(x, env), meth), tpt)
case Match(selector, cases) =>
this(this(x, selector), cases)
case tree @ AssumeInfo(sym, info, body) =>
tree.fold(this(x, _))((_, x) => x)
case CaseDef(pat, guard, body) =>
this(this(this(x, pat), guard), body)
case Labeled(bind, expr) =>
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
Block(stats, expr)
}

def AssumeInfo(sym: Symbol, info: Type, body: Tree)(using Context): AssumeInfo =
ta.assignType(untpd.AssumeInfo(sym, info, body), body)

def If(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If =
ta.assignType(untpd.If(cond, thenp, elsep), thenp, elsep)

Expand Down Expand Up @@ -683,6 +686,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

override def AssumeInfo(tree: Tree)(sym: Symbol, info: Type, body: Tree)(using Context): AssumeInfo =
val tree1 = untpdCpy.AssumeInfo(tree)(sym, info, body)
tree match
case tree: AssumeInfo if body.tpe eq tree.body.tpe => tree1.withTypeUnchecked(tree.tpe)
case _ => ta.assignType(tree1, body)

override def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = {
val tree1 = untpdCpy.If(tree)(cond, thenp, elsep)
tree match {
Expand Down Expand Up @@ -767,6 +776,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

override def AssumeInfo(tree: AssumeInfo)(sym: Symbol = tree.sym, info: Type = tree.info, body: Tree = tree.body)(using Context): AssumeInfo =
AssumeInfo(tree: Tree)(sym, info, body)
override def If(tree: If)(cond: Tree = tree.cond, thenp: Tree = tree.thenp, elsep: Tree = tree.elsep)(using Context): If =
If(tree: Tree)(cond, thenp, elsep)
override def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def NamedArg(name: Name, arg: Tree)(implicit src: SourceFile): NamedArg = new NamedArg(name, arg)
def Assign(lhs: Tree, rhs: Tree)(implicit src: SourceFile): Assign = new Assign(lhs, rhs)
def Block(stats: List[Tree], expr: Tree)(implicit src: SourceFile): Block = new Block(stats, expr)
def AssumeInfo(sym: Symbol, info: Type, body: Tree)(implicit src: SourceFile): AssumeInfo = new AssumeInfo(sym, info, body)
def If(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new If(cond, thenp, elsep)
def InlineIf(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new InlineIf(cond, thenp, elsep)
def Closure(env: List[Tree], meth: Tree, tpt: Tree)(implicit src: SourceFile): Closure = new Closure(env, meth, tpt)
Expand Down
28 changes: 28 additions & 0 deletions compiler/src/dotty/tools/dotc/core/AssumeInfoMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dotty.tools
package dotc
package core

import Contexts.*, Decorators.*, NameKinds.*, Symbols.*, Types.*
import ast.*, Trees.*
import printing.*, Texts.*

import scala.annotation.internal.sharable
import util.{SimpleIdentitySet, SimpleIdentityMap}

object AssumeInfoMap:
@sharable val empty: AssumeInfoMap = AssumeInfoMap(SimpleIdentityMap.empty)

class AssumeInfoMap private (
private val map: SimpleIdentityMap[Symbol, Type],
) extends Showable:
def info(sym: Symbol)(using Context): Type | Null = map(sym)

def add(sym: Symbol, info: Type) = new AssumeInfoMap(map.updated(sym, info))

override def toText(p: Printer): Text =
given Context = p match
case p: PlainPrinter => p.printerContext
case _ => Contexts.NoContext
val deps = for (sym, info) <- map.toList yield
(p.toText(sym.typeRef) ~ p.toText(info)).close
("AssumeInfo(" ~ Text(deps, ", ") ~ ")").close
19 changes: 19 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ object Contexts {
def typerState: TyperState
def gadt: GadtConstraint = gadtState.gadt
def gadtState: GadtState
def assumeInfo: AssumeInfoMap
def searchHistory: SearchHistory
def source: SourceFile

Expand Down Expand Up @@ -470,6 +471,15 @@ object Contexts {
case None => fresh.dropProperty(key)
}

final def withGadt(gadt: GadtConstraint): Context =
if this.gadt eq gadt then this else fresh.setGadtState(GadtState(gadt))

final def withGadtState(gadt: GadtState): Context =
if this.gadtState eq gadt then this else fresh.setGadtState(gadt)

final def withAssumeInfo(assumeInfo: AssumeInfoMap): Context =
if this.assumeInfo eq assumeInfo then this else fresh.setAssumeInfo(assumeInfo)

def typer: Typer = this.typeAssigner match {
case typer: Typer => typer
case _ => new Typer
Expand Down Expand Up @@ -545,6 +555,9 @@ object Contexts {
private var _gadtState: GadtState = uninitialized
final def gadtState: GadtState = _gadtState

private var _assumeInfo: AssumeInfoMap = uninitialized
final def assumeInfo: AssumeInfoMap = _assumeInfo

private var _searchHistory: SearchHistory = uninitialized
final def searchHistory: SearchHistory = _searchHistory

Expand All @@ -569,6 +582,7 @@ object Contexts {
_tree = origin.tree
_scope = origin.scope
_gadtState = origin.gadtState
_assumeInfo = origin.assumeInfo
_searchHistory = origin.searchHistory
_source = origin.source
_moreProperties = origin.moreProperties
Expand Down Expand Up @@ -632,6 +646,10 @@ object Contexts {
def setFreshGADTBounds: this.type =
setGadtState(gadtState.fresh)

def setAssumeInfo(assumeInfo: AssumeInfoMap): this.type =
this._assumeInfo= assumeInfo
this

def setSearchHistory(searchHistory: SearchHistory): this.type =
util.Stats.record("Context.setSearchHistory")
this._searchHistory = searchHistory
Expand Down Expand Up @@ -723,6 +741,7 @@ object Contexts {
.updated(compilationUnitLoc, NoCompilationUnit)
c._searchHistory = new SearchRoot
c._gadtState = GadtState(GadtConstraint.empty)
c._assumeInfo = AssumeInfoMap.empty
c
end FreshContext

Expand Down
13 changes: 12 additions & 1 deletion compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package core

import Contexts.*, Decorators.*, Symbols.*, Types.*
import NameKinds.UniqueName
import ast.*, Trees.*
import config.Printers.{gadts, gadtsConstr}
import util.{SimpleIdentitySet, SimpleIdentityMap}
import printing._
Expand All @@ -27,6 +28,7 @@ class GadtConstraint private (
def symbols: List[Symbol] = mapping.keys
def withConstraint(c: Constraint) = copy(myConstraint = c)
def withWasConstrained = copy(wasConstrained = true)
def isEmpty: Boolean = mapping.isEmpty

def add(sym: Symbol, tv: TypeVar): GadtConstraint = copy(
mapping = mapping.updated(sym, tv),
Expand Down Expand Up @@ -136,6 +138,13 @@ class GadtConstraint private (

override def toText(printer: Printer): Texts.Text = printer.toText(this)

def eql(that: GadtConstraint): Boolean = (this eq that) || {
myConstraint == that.myConstraint
&& mapping == that.mapping
&& reverseMapping == that.reverseMapping
&& wasConstrained == that.wasConstrained
}

/** Provides more information than toText, by showing the underlying Constraint details. */
def debugBoundsDescription(using Context): String = i"$this\n$constraint"

Expand Down Expand Up @@ -201,7 +210,7 @@ sealed trait GadtState {
)

val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) =>
val tv = TypeVar(paramRef, creatorState = null)
val tv = TypeVar(paramRef, creatorState = null, ctx.nestingLevel)
gadt = gadt.add(sym, tv)
tv
}
Expand Down Expand Up @@ -277,6 +286,8 @@ sealed trait GadtState {
override def fullLowerBound(param: TypeParamRef)(using Context): Type = gadt.fullLowerBound(param)
override def fullUpperBound(param: TypeParamRef)(using Context): Type = gadt.fullUpperBound(param)

def symbols: List[Symbol] = gadt.symbols

// ---- Debug ------------------------------------------------------------

override def constr = gadtsConstr
Expand Down
Loading