Skip to content

Commit

Permalink
Rename GadtConstraintHandling to GadtState
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Jan 24, 2023
1 parent 8d1e5df commit b1a035a
Show file tree
Hide file tree
Showing 14 changed files with 63 additions and 73 deletions.
21 changes: 11 additions & 10 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ object Contexts {
def tree: Tree[?]
def scope: Scope
def typerState: TyperState
def gadt: GadtConstraintHandling
def gadt: GadtConstraint = gadtState.gadt
def gadtState: GadtState
def searchHistory: SearchHistory
def source: SourceFile

Expand Down Expand Up @@ -410,7 +411,7 @@ object Contexts {
val constrCtx = outersIterator.dropWhile(_.outer.owner == owner).next()
superOrThisCallContext(owner, constrCtx.scope)
.setTyperState(typerState)
.setGadt(gadt)
.setGadtState(gadtState)
.fresh
.setScope(this.scope)
}
Expand Down Expand Up @@ -541,8 +542,8 @@ object Contexts {
private var _typerState: TyperState = uninitialized
final def typerState: TyperState = _typerState

private var _gadt: GadtConstraintHandling = uninitialized
final def gadt: GadtConstraintHandling = _gadt
private var _gadtState: GadtState = uninitialized
final def gadtState: GadtState = _gadtState

private var _searchHistory: SearchHistory = uninitialized
final def searchHistory: SearchHistory = _searchHistory
Expand All @@ -567,7 +568,7 @@ object Contexts {
_owner = origin.owner
_tree = origin.tree
_scope = origin.scope
_gadt = origin.gadt
_gadtState = origin.gadtState
_searchHistory = origin.searchHistory
_source = origin.source
_moreProperties = origin.moreProperties
Expand Down Expand Up @@ -624,12 +625,12 @@ object Contexts {
this._scope = typer.scope
setTypeAssigner(typer)

def setGadt(gadt: GadtConstraintHandling): this.type =
util.Stats.record("Context.setGadt")
this._gadt = gadt
def setGadtState(gadtState: GadtState): this.type =
util.Stats.record("Context.setGadtState")
this._gadtState = gadtState
this
def setFreshGADTBounds: this.type =
setGadt(gadt.fresh)
setGadtState(gadtState.fresh)

def setSearchHistory(searchHistory: SearchHistory): this.type =
util.Stats.record("Context.setSearchHistory")
Expand Down Expand Up @@ -721,7 +722,7 @@ object Contexts {
.updated(notNullInfosLoc, Nil)
.updated(compilationUnitLoc, NoCompilationUnit)
c._searchHistory = new SearchRoot
c._gadt = GadtConstraintHandling(GadtConstraint.empty)
c._gadtState = GadtState(GadtConstraint.empty)
c
end FreshContext

Expand Down
33 changes: 11 additions & 22 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,14 @@ class GadtConstraint private (
): GadtConstraint = GadtConstraint(myConstraint, mapping, reverseMapping, wasConstrained)
end GadtConstraint

object GadtConstraintHandling:
def apply(gadt: GadtConstraint): GadtConstraintHandling = new ProperGadtConstraintHandling(gadt)
object GadtState:
def apply(gadt: GadtConstraint): GadtState = ProperGadtState(gadt)

sealed trait GadtConstraintHandling(private var myGadt: GadtConstraint) {
this: ConstraintHandling =>
sealed trait GadtState {
this: ConstraintHandling => // Hide ConstraintHandling within GadtConstraintHandling

def gadt: GadtConstraint = myGadt
private def gadt_=(g: GadtConstraint) = myGadt = g

/** Exposes ConstraintHandling.subsumes */
def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = {
def extractConstraint(g: GadtConstraint) = g.constraint
subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre))
}
def gadt: GadtConstraint
def gadt_=(g: GadtConstraint): Unit

override protected def legalBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Type =
// GADT constraints never involve wildcards and are not propagated outside
Expand Down Expand Up @@ -233,13 +227,6 @@ sealed trait GadtConstraintHandling(private var myGadt: GadtConstraint) {
result
}

def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = gadt.isLess(sym1, sym2)
def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = gadt.fullBounds(sym)
def bounds(sym: Symbol)(using Context): TypeBounds | Null = gadt.bounds(sym)
def contains(sym: Symbol)(using Context): Boolean = gadt.contains(sym)
def isNarrowing: Boolean = gadt.isNarrowing
def symbols: List[Symbol] = gadt.symbols

/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type = {
approximation(gadt.tvarOrError(sym).origin, fromBelow, maxLevel).match
Expand All @@ -252,7 +239,7 @@ sealed trait GadtConstraintHandling(private var myGadt: GadtConstraint) {
.showing(i"approximating $sym ~> $result", gadts)
}

def fresh: GadtConstraintHandling = GadtConstraintHandling(gadt)
def fresh: GadtState = GadtState(gadt)

/** Restore the GadtConstraint state. */
def restore(gadt: GadtConstraint): Unit = this.gadt = gadt
Expand Down Expand Up @@ -281,5 +268,7 @@ sealed trait GadtConstraintHandling(private var myGadt: GadtConstraint) {
override def constr = gadtsConstr
}

// Hide ConstraintHandling within GadtConstraintHandling
private class ProperGadtConstraintHandling(gadt: GadtConstraint) extends ConstraintHandling with GadtConstraintHandling(gadt)
// Hide ConstraintHandling within GadtState
private class ProperGadtState(private var myGadt: GadtConstraint) extends ConstraintHandling with GadtState:
def gadt: GadtConstraint = myGadt
def gadt_=(gadt: GadtConstraint): Unit = myGadt = gadt
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/NamerOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ object NamerOps:
* by (ab?)-using GADT constraints. See pos/i941.scala.
*/
def linkConstructorParams(sym: Symbol, tparams: List[Symbol], rhsCtx: Context)(using Context): Unit =
rhsCtx.gadt.addToConstraint(tparams)
rhsCtx.gadtState.addToConstraint(tparams)
tparams.lazyZip(sym.owner.typeParams).foreach { (psym, tparam) =>
val tr = tparam.typeRef
rhsCtx.gadt.addBound(psym, tr, isUpper = false)
rhsCtx.gadt.addBound(psym, tr, isUpper = true)
rhsCtx.gadtState.addBound(psym, tr, isUpper = false)
rhsCtx.gadtState.addBound(psym, tr, isUpper = true)
}

end NamerOps
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ trait PatternTypeConstrainer { self: TypeComparer =>
val assumeInvariantRefinement =
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)

trace(i"constraining simple pattern type $tp >:< $pt", gadts, (res: Boolean) => i"$res gadt = ${ctx.gadt.gadt}") {
trace(i"constraining simple pattern type $tp >:< $pt", gadts, (res: Boolean) => i"$res gadt = ${ctx.gadt}") {
(tp, pt) match {
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
val saved = state.nn.constraint
val result =
ctx.gadt.rollbackGadtUnless {
ctx.gadtState.rollbackGadtUnless {
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
val variance = param.paramVarianceSign
if variance == 0 || assumeInvariantRefinement ||
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ object Symbols {
addToGadt: Boolean = true,
flags: FlagSet = EmptyFlags)(using Context): Symbol = {
val sym = newSymbol(ctx.owner, name, Case | flags, info, coord = span)
if (addToGadt && name.isTypeName) ctx.gadt.addToConstraint(sym)
if (addToGadt && name.isTypeName) ctx.gadtState.addToConstraint(sym)
sym
}

Expand Down
32 changes: 16 additions & 16 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
private def isBottom(tp: Type) = tp.widen.isRef(NothingClass)

protected def gadtBounds(sym: Symbol)(using Context) = ctx.gadt.bounds(sym)
protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadt.addBound(sym, b, isUpper)
protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadtState.addBound(sym, b, isUpper)

protected def typeVarInstance(tvar: TypeVar)(using Context): Type = tvar.underlying

Expand Down Expand Up @@ -1446,10 +1446,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
else if tp1 eq tp2 then true
else
val savedCstr = constraint
val savedGadt = ctx.gadt.gadt
val savedGadt = ctx.gadt
inline def restore() =
state.constraint = savedCstr
ctx.gadt.restore(savedGadt)
ctx.gadtState.restore(savedGadt)
val savedSuccessCount = successCount
try
recCount += 1
Expand Down Expand Up @@ -1855,34 +1855,34 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
*/
private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean =
val preConstraint = constraint
val preGadtHandling = ctx.gadt.fresh
val preGadt = preGadtHandling.gadt
val preGadt = ctx.gadt

def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean =
subsumes(left, right, preConstraint) && preGadtHandling.subsumes(leftGadt, rightGadt, preGadt)
subsumes(left, right, preConstraint)
&& subsumes(leftGadt.constraint, rightGadt.constraint, preGadt.constraint)

if op1 then
val op1Constraint = constraint
val op1Gadt = ctx.gadt.gadt
val op1Gadt = ctx.gadt
constraint = preConstraint
ctx.gadt.restore(preGadt)
ctx.gadtState.restore(preGadt)
if op2 then
if allSubsumes(op1Gadt, ctx.gadt.gadt, op1Constraint, constraint) then
gadts.println(i"GADT CUT - prefer ${ctx.gadt.gadt} over $op1Gadt")
if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then
gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt")
constr.println(i"CUT - prefer $constraint over $op1Constraint")
else if allSubsumes(ctx.gadt.gadt, op1Gadt, constraint, op1Constraint) then
gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt.gadt}")
else if allSubsumes(ctx.gadt, op1Gadt, constraint, op1Constraint) then
gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}")
constr.println(i"CUT - prefer $op1Constraint over $constraint")
constraint = op1Constraint
ctx.gadt.restore(op1Gadt)
ctx.gadtState.restore(op1Gadt)
else
gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt")
constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint")
constraint = preConstraint
ctx.gadt.restore(preGadt)
ctx.gadtState.restore(preGadt)
else
constraint = op1Constraint
ctx.gadt.restore(op1Gadt)
ctx.gadtState.restore(op1Gadt)
true
else op2
end necessaryEither
Expand Down Expand Up @@ -2054,7 +2054,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
if (bound.isRef(tparam)) false
else
ctx.gadt.rollbackGadtUnless(gadtAddBound(tparam, bound, isUpper))
ctx.gadtState.rollbackGadtUnless(gadtAddBound(tparam, bound, isUpper))
}
}

Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,8 @@ object TypeOps:
val bound1 = massage(bound)
if (bound1 ne bound) {
if (checkCtx eq ctx) checkCtx = ctx.fresh.setFreshGADTBounds
if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addToConstraint(sym)
checkCtx.gadt.addBound(sym, bound1, fromBelow)
if (!checkCtx.gadt.contains(sym)) checkCtx.gadtState.addToConstraint(sym)
checkCtx.gadtState.addBound(sym, bound1, fromBelow)
typr.println("install GADT bound $bound1 for when checking F-bounded $sym")
}
}
Expand Down Expand Up @@ -872,7 +872,7 @@ object TypeOps:
case tp: TypeRef if tp.symbol.exists && !tp.symbol.isClass => foldOver(tp.symbol :: xs, tp)
case tp => foldOver(xs, tp)
val syms2 = getAbstractSymbols(Nil, tp2).reverse
if syms2.nonEmpty then ctx.gadt.addToConstraint(syms2)
if syms2.nonEmpty then ctx.gadtState.addToConstraint(syms2)

// If parent contains a reference to an abstract type, then we should
// refine subtype checking to eliminate abstract types according to
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,11 @@ class InlineReducer(inliner: Inliner)(using Context):
def addTypeBindings(typeBinds: TypeBindsMap)(using Context): Unit =
typeBinds.foreachBinding { case (sym, shouldBeMinimized) =>
newTypeBinding(sym,
ctx.gadt.approximation(sym, fromBelow = shouldBeMinimized, maxLevel = Int.MaxValue))
ctx.gadtState.approximation(sym, fromBelow = shouldBeMinimized, maxLevel = Int.MaxValue))
}

def registerAsGadtSyms(typeBinds: TypeBindsMap)(using Context): Unit =
if (typeBinds.size > 0) ctx.gadt.addToConstraint(typeBinds.keys)
if (typeBinds.size > 0) ctx.gadtState.addToConstraint(typeBinds.keys)

pat match {
case Typed(pat1, tpt) =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
case CaseDef(pat, _, _) =>
val gadtCtx =
pat.removeAttachment(typer.Typer.InferredGadtConstraints) match
case Some(gadt) => ctx.fresh.setGadt(GadtConstraintHandling(gadt))
case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt))
case None =>
ctx
super.transform(tree)(using gadtCtx)
Expand Down
10 changes: 5 additions & 5 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1030,8 +1030,8 @@ trait Implicits:
case result: SearchSuccess =>
if result.tstate ne ctx.typerState then
result.tstate.commit()
if result.gstate ne ctx.gadt.gadt then
ctx.gadt.restore(result.gstate)
if result.gstate ne ctx.gadt then
ctx.gadtState.restore(result.gstate)
if hasSkolem(false, result.tree) then
report.error(SkolemInInferred(result.tree, pt, argument), ctx.source.atSpan(span))
implicits.println(i"success: $result")
Expand Down Expand Up @@ -1145,7 +1145,7 @@ trait Implicits:
SearchFailure(adapted.withType(new MismatchedImplicit(ref, pt, argument)))
}
else
SearchSuccess(adapted, ref, cand.level, cand.isExtension)(ctx.typerState, ctx.gadt.gadt)
SearchSuccess(adapted, ref, cand.level, cand.isExtension)(ctx.typerState, ctx.gadt)
}

/** An implicit search; parameters as in `inferImplicit` */
Expand Down Expand Up @@ -1343,7 +1343,7 @@ trait Implicits:
case _: SearchFailure =>
SearchSuccess(ref(defn.NotGiven_value), defn.NotGiven_value.termRef, 0)(
ctx.typerState.fresh().setCommittable(true),
ctx.gadt.gadt
ctx.gadt
)
case _: SearchSuccess =>
NoMatchingImplicitsFailure
Expand Down Expand Up @@ -1526,7 +1526,7 @@ trait Implicits:
// other candidates need to be considered.
recursiveRef match
case ref: TermRef =>
SearchSuccess(tpd.ref(ref).withSpan(span.startPos), ref, 0)(ctx.typerState, ctx.gadt.gadt)
SearchSuccess(tpd.ref(ref).withSpan(span.startPos), ref, 0)(ctx.typerState, ctx.gadt)
case _ =>
searchImplicit(contextual = true)
end bestImplicit
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ object Inferencing {
&& ctx.gadt.contains(tp.symbol)
=>
val sym = tp.symbol
val res = ctx.gadt.approximation(sym, fromBelow = variance < 0)
val res = ctx.gadtState.approximation(sym, fromBelow = variance < 0)
gadts.println(i"approximated $tp ~~ $res")
res

Expand Down Expand Up @@ -432,7 +432,7 @@ object Inferencing {
}

// We add the created symbols to GADT constraint here.
if (res.nonEmpty) ctx.gadt.addToConstraint(res)
if (res.nonEmpty) ctx.gadtState.addToConstraint(res)
res
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1864,7 +1864,7 @@ class Namer { typer: Typer =>
// so we must allow constraining its type parameters
// compare with typedDefDef, see tests/pos/gadt-inference.scala
rhsCtx.setFreshGADTBounds
rhsCtx.gadt.addToConstraint(typeParams)
rhsCtx.gadtState.addToConstraint(typeParams)
}

def typedAheadRhs(pt: Type) =
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1782,7 +1782,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// see tests/pos/i12226 and issue #12226. It might be possible that this
// will end up taking too much memory. If it does, we should just limit
// how much GADT constraints we infer - it's always sound to infer less.
pat1.putAttachment(InferredGadtConstraints, ctx.gadt.gadt)
pat1.putAttachment(InferredGadtConstraints, ctx.gadt)
if (pt1.isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
body1 = body1.ensureConforms(pt1)(using originalCtx)
assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1)
Expand Down Expand Up @@ -2362,7 +2362,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
ctx.outer.outersIterator.takeWhile(!_.owner.is(Method))
.filter(ctx => ctx.owner.isClass && ctx.owner.typeParams.nonEmpty)
.toList.reverse
.foreach(ctx => rhsCtx.gadt.addToConstraint(ctx.owner.typeParams))
.foreach(ctx => rhsCtx.gadtState.addToConstraint(ctx.owner.typeParams))

if tparamss.nonEmpty then
rhsCtx.setFreshGADTBounds
Expand All @@ -2371,7 +2371,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// we're typing a polymorphic definition's body,
// so we allow constraining all of its type parameters
// constructors are an exception as we don't allow constraining type params of classes
rhsCtx.gadt.addToConstraint(tparamSyms)
rhsCtx.gadtState.addToConstraint(tparamSyms)
else if !sym.isPrimaryConstructor then
linkConstructorParams(sym, tparamSyms, rhsCtx)

Expand Down Expand Up @@ -3835,7 +3835,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
adaptToSubType(wtp)
case CompareResult.OKwithGADTUsed
if pt.isValueType
&& !inContext(ctx.fresh.setGadt(GadtConstraintHandling(GadtConstraint.empty))) {
&& !inContext(ctx.fresh.setGadtState(GadtState(GadtConstraint.empty))) {
val res = (tree.tpe.widenExpr frozen_<:< pt)
if res then
// we overshot; a cast is not needed, after all.
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3130,7 +3130,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
if typeHoles.isEmpty then ctx
else
val ctx1 = ctx.fresh.setFreshGADTBounds.addMode(dotc.core.Mode.GadtConstraintInference)
ctx1.gadt.addToConstraint(typeHoles)
ctx1.gadtState.addToConstraint(typeHoles)
ctx1

val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1)
Expand Down

0 comments on commit b1a035a

Please sign in to comment.