diff --git a/compiler/src/dotty/tools/dotc/core/Constraint.scala b/compiler/src/dotty/tools/dotc/core/Constraint.scala index b849c7aa7093..c634f847e510 100644 --- a/compiler/src/dotty/tools/dotc/core/Constraint.scala +++ b/compiler/src/dotty/tools/dotc/core/Constraint.scala @@ -71,6 +71,9 @@ abstract class Constraint extends Showable { */ def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds + /** The current bounds of type parameter `param` */ + def bounds(param: TypeParamRef)(using Context): TypeBounds + /** A new constraint which is derived from this constraint by adding * entries for all type parameters of `poly`. * @param tvars A list of type variables associated with the params, diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index 6207e0a3d728..9ffe2bda73cb 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -749,16 +749,7 @@ trait ConstraintHandling { } /** The current bounds of type parameter `param` */ - def bounds(param: TypeParamRef)(using Context): TypeBounds = { - val e = constraint.entry(param) - if (e.exists) e.bounds - else { - // TODO: should we change the type of paramInfos to nullable? - val pinfos: List[param.binder.PInfo] | Null = param.binder.paramInfos - if (pinfos != null) pinfos(param.paramNum) // pinfos == null happens in pos/i536.scala - else TypeBounds.empty - } - } + def bounds(param: TypeParamRef)(using Context): TypeBounds = constraint.bounds(param) /** Add type lambda `tl`, possibly with type variables `tvars`, to current constraint * and propagate all bounds. diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 5f82e8c8b6ce..2f28975dd066 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -141,7 +141,8 @@ object Contexts { def tree: Tree[?] def scope: Scope def typerState: TyperState - def gadt: GadtConstraint + def gadt: GadtConstraint = gadtState.gadt + def gadtState: GadtState def searchHistory: SearchHistory def source: SourceFile @@ -410,7 +411,7 @@ object Contexts { val constrCtx = outersIterator.dropWhile(_.outer.owner == owner).next() superOrThisCallContext(owner, constrCtx.scope) .setTyperState(typerState) - .setGadt(gadt) + .setGadtState(gadtState) .fresh .setScope(this.scope) } @@ -541,8 +542,8 @@ object Contexts { private var _typerState: TyperState = uninitialized final def typerState: TyperState = _typerState - private var _gadt: GadtConstraint = uninitialized - final def gadt: GadtConstraint = _gadt + private var _gadtState: GadtState = uninitialized + final def gadtState: GadtState = _gadtState private var _searchHistory: SearchHistory = uninitialized final def searchHistory: SearchHistory = _searchHistory @@ -567,7 +568,7 @@ object Contexts { _owner = origin.owner _tree = origin.tree _scope = origin.scope - _gadt = origin.gadt + _gadtState = origin.gadtState _searchHistory = origin.searchHistory _source = origin.source _moreProperties = origin.moreProperties @@ -624,12 +625,12 @@ object Contexts { this._scope = typer.scope setTypeAssigner(typer) - def setGadt(gadt: GadtConstraint): this.type = - util.Stats.record("Context.setGadt") - this._gadt = gadt + def setGadtState(gadtState: GadtState): this.type = + util.Stats.record("Context.setGadtState") + this._gadtState = gadtState this def setFreshGADTBounds: this.type = - setGadt(gadt.fresh) + setGadtState(gadtState.fresh) def setSearchHistory(searchHistory: SearchHistory): this.type = util.Stats.record("Context.setSearchHistory") @@ -721,7 +722,7 @@ object Contexts { .updated(notNullInfosLoc, Nil) .updated(compilationUnitLoc, NoCompilationUnit) c._searchHistory = new SearchRoot - c._gadt = GadtConstraint.empty + c._gadtState = GadtState(GadtConstraint.empty) c end FreshContext diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 7515898a36df..a863a982a44d 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -2,41 +2,146 @@ package dotty.tools package dotc package core -import Decorators._ -import Contexts._ -import Types._ -import Symbols._ +import Contexts.*, Decorators.*, Symbols.*, Types.* +import config.Printers.{gadts, gadtsConstr} import util.{SimpleIdentitySet, SimpleIdentityMap} -import collection.mutable import printing._ +import scala.annotation.tailrec +import scala.annotation.internal.sharable +import scala.collection.mutable + object GadtConstraint: - def apply(): GadtConstraint = empty - def empty: GadtConstraint = - new ProperGadtConstraint(OrderingConstraint.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, false) + @sharable val empty: GadtConstraint = + GadtConstraint(OrderingConstraint.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, false) /** Represents GADT constraints currently in scope */ -sealed trait GadtConstraint ( - private var myConstraint: Constraint, - private var mapping: SimpleIdentityMap[Symbol, TypeVar], - private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], - private var wasConstrained: Boolean -) extends Showable { - this: ConstraintHandling => - - import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} - - private[core] def getConstraint: Constraint = constraint - private[core] def getMapping: SimpleIdentityMap[Symbol, TypeVar] = mapping - private[core] def getReverseMapping: SimpleIdentityMap[TypeParamRef, Symbol] = reverseMapping - private[core] def getWasConstrained: Boolean = wasConstrained - - /** Exposes ConstraintHandling.subsumes */ - def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { - def extractConstraint(g: GadtConstraint) = g.constraint - subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) +class GadtConstraint private ( + private val myConstraint: Constraint, + private val mapping: SimpleIdentityMap[Symbol, TypeVar], + private val reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], + private val wasConstrained: Boolean, +) extends Showable: + def constraint: Constraint = myConstraint + def symbols: List[Symbol] = mapping.keys + def withConstraint(c: Constraint) = copy(myConstraint = c) + def withWasConstrained = copy(wasConstrained = true) + + def add(sym: Symbol, tv: TypeVar): GadtConstraint = copy( + mapping = mapping.updated(sym, tv), + reverseMapping = reverseMapping.updated(tv.origin, sym), + ) + + /** Is `sym1` ordered to be less than `sym2`? */ + def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = + constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) + + /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. + * + * @note this performs subtype checks between ordered symbols. + * Using this in isSubType can lead to infinite recursion. Consider `bounds` instead. + */ + def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = mapping(sym) match + case null => null + case tv: TypeVar => fullBounds(tv.origin) // .ensuring(containsNoInternalTypes(_)) + + /** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */ + def bounds(sym: Symbol)(using Context): TypeBounds | Null = + mapping(sym) match + case null => null + case tv: TypeVar => + def retrieveBounds: TypeBounds = externalize(constraint.bounds(tv.origin)).bounds + retrieveBounds + //.showing(i"gadt bounds $sym: $result", gadts) + //.ensuring(containsNoInternalTypes(_)) + + /** Is the symbol registered in the constraint? + * + * @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]]. + */ + def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null + + /** GADT constraint narrows bounds of at least one variable */ + def isNarrowing: Boolean = wasConstrained + + def fullBounds(param: TypeParamRef)(using Context): TypeBounds = + nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param)) + + def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = + externalize(constraint.nonParamBounds(param)).bounds + + def fullLowerBound(param: TypeParamRef)(using Context): Type = + constraint.minLower(param).foldLeft(nonParamBounds(param).lo) { + (t, u) => t | externalize(u) + } + + def fullUpperBound(param: TypeParamRef)(using Context): Type = + constraint.minUpper(param).foldLeft(nonParamBounds(param).hi) { (t, u) => + val eu = externalize(u) + // Any as the upper bound means "no bound", but if F is higher-kinded, + // Any & F = F[_]; this is wrong for us so we need to short-circuit + if t.isAny then eu else t & eu + } + + def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match + case param: TypeParamRef => reverseMapping(param) match + case sym: Symbol => sym.typeRef + case null => param + case tp: TypeAlias => tp.derivedAlias(externalize(tp.alias, theMap)) + case tp => (if theMap == null then ExternalizeMap() else theMap).mapOver(tp) + + private class ExternalizeMap(using Context) extends TypeMap: + def apply(tp: Type): Type = externalize(tp, this)(using mapCtx) + + def tvarOrError(sym: Symbol)(using Context): TypeVar = + mapping(sym).ensuring(_ != null, i"not a constrainable symbol: $sym").uncheckedNN + + @tailrec final def stripInternalTypeVar(tp: Type): Type = tp match + case tv: TypeVar => + val inst = constraint.instType(tv) + if inst.exists then stripInternalTypeVar(inst) else tv + case _ => tp + + def internalize(tp: Type)(using Context): Type = tp match + case nt: NamedType => + val ntTvar = mapping(nt.symbol) + if ntTvar == null then tp + else ntTvar + case _ => tp + + private def containsNoInternalTypes(tp: Type, theAcc: TypeAccumulator[Boolean] | Null = null)(using Context): Boolean = tp match { + case tpr: TypeParamRef => !reverseMapping.contains(tpr) + case tv: TypeVar => !reverseMapping.contains(tv.origin) + case tp => + (if (theAcc != null) theAcc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp) + } + + private class ContainsNoInternalTypesAccumulator(using Context) extends TypeAccumulator[Boolean] { + override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp, this) } + override def toText(printer: Printer): Texts.Text = printer.toText(this) + + /** Provides more information than toText, by showing the underlying Constraint details. */ + def debugBoundsDescription(using Context): String = i"$this\n$constraint" + + private def copy( + myConstraint: Constraint = myConstraint, + mapping: SimpleIdentityMap[Symbol, TypeVar] = mapping, + reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol] = reverseMapping, + wasConstrained: Boolean = wasConstrained, + ): GadtConstraint = GadtConstraint(myConstraint, mapping, reverseMapping, wasConstrained) +end GadtConstraint + +object GadtState: + def apply(gadt: GadtConstraint): GadtState = ProperGadtState(gadt) + +sealed trait GadtState { + this: ConstraintHandling => // Hide ConstraintHandling within GadtConstraintHandling + + def gadt: GadtConstraint + def gadt_=(g: GadtConstraint): Unit + override protected def legalBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Type = // GADT constraints never involve wildcards and are not propagated outside // the case where they're valid, so no approximating is needed. @@ -57,22 +162,19 @@ sealed trait GadtConstraint ( // and used as orderings. def substDependentSyms(tp: Type, isUpper: Boolean)(using Context): Type = { def loop(tp: Type) = substDependentSyms(tp, isUpper) - tp match { + tp match case tp @ AndType(tp1, tp2) if !isUpper => tp.derivedAndType(loop(tp1), loop(tp2)) case tp @ OrType(tp1, tp2) if isUpper => tp.derivedOrType(loop(tp1), loop(tp2)) case tp: NamedType => - params.indexOf(tp.symbol) match { + params.indexOf(tp.symbol) match case -1 => - mapping(tp.symbol) match { + gadt.internalize(tp) match case tv: TypeVar => tv.origin - case null => tp - } + case _ => tp case i => pt.paramRefs(i) - } case tp => tp - } } val tb = param.info.bounds @@ -86,205 +188,87 @@ sealed trait GadtConstraint ( val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) => val tv = TypeVar(paramRef, creatorState = null) - mapping = mapping.updated(sym, tv) - reverseMapping = reverseMapping.updated(tv.origin, sym) + gadt = gadt.add(sym, tv) tv } // The replaced symbols are picked up here. addToConstraint(poly1, tvars) - .showing(i"added to constraint: [$poly1] $params%, % gadt = $this", gadts) + .showing(i"added to constraint: [$poly1] $params%, % gadt = $gadt", gadts) } /** Further constrain a symbol already present in the constraint. */ def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { - @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { - case tv: TypeVar => - val inst = constraint.instType(tv) - if (inst.exists) stripInternalTypeVar(inst) else tv - case _ => tp - } - - val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(sym)) match { + val symTvar: TypeVar = gadt.stripInternalTypeVar(gadt.tvarOrError(sym)) match case tv: TypeVar => tv case inst => gadts.println(i"instantiated: $sym -> $inst") - return if (isUpper) isSub(inst, bound) else isSub(bound, inst) - } + return if isUpper then isSub(inst, bound) else isSub(bound, inst) - val internalizedBound = bound match { - case nt: NamedType => - val ntTvar = mapping(nt.symbol) - if (ntTvar != null) stripInternalTypeVar(ntTvar) else bound - case _ => bound - } + val internalizedBound = gadt.stripInternalTypeVar(gadt.internalize(bound)) val saved = constraint val result = internalizedBound match case boundTvar: TypeVar => - if (boundTvar eq symTvar) true - else if (isUpper) addLess(symTvar.origin, boundTvar.origin) + if boundTvar eq symTvar then true + else if isUpper + then addLess(symTvar.origin, boundTvar.origin) else addLess(boundTvar.origin, symTvar.origin) case bound => addBoundTransitively(symTvar.origin, bound, isUpper) gadts.println { - val descr = if (isUpper) "upper" else "lower" - val op = if (isUpper) "<:" else ">:" + val descr = if isUpper then "upper" else "lower" + val op = if isUpper then "<:" else ">:" i"adding $descr bound $sym $op $bound = $result" } - if constraint ne saved then wasConstrained = true + if constraint ne saved then gadt = gadt.withWasConstrained result } - /** Is `sym1` ordered to be less than `sym2`? */ - def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = - constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) - - /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. - * - * @note this performs subtype checks between ordered symbols. - * Using this in isSubType can lead to infinite recursion. Consider `bounds` instead. - */ - def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = - mapping(sym) match { - case null => null - // TODO: Improve flow typing so that ascription becomes redundant - case tv: TypeVar => - fullBounds(tv.origin) - // .ensuring(containsNoInternalTypes(_)) - } - - /** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */ - def bounds(sym: Symbol)(using Context): TypeBounds | Null = - mapping(sym) match { - case null => null - // TODO: Improve flow typing so that ascription becomes redundant - case tv: TypeVar => - def retrieveBounds: TypeBounds = externalize(bounds(tv.origin)).bounds - retrieveBounds - //.showing(i"gadt bounds $sym: $result", gadts) - //.ensuring(containsNoInternalTypes(_)) - } - - /** Is the symbol registered in the constraint? - * - * @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]]. - */ - def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null - - /** GADT constraint narrows bounds of at least one variable */ - def isNarrowing: Boolean = wasConstrained - /** See [[ConstraintHandling.approximation]] */ def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type = { - val res = - approximation(tvarOrError(sym).origin, fromBelow, maxLevel) match - case tpr: TypeParamRef => - // Here we do externalization when the returned type is a TypeParamRef, - // b/c ConstraintHandling.approximation may return internal types when - // the type variable is instantiated. See #15531. - externalize(tpr) - case tp => tp - - gadts.println(i"approximating $sym ~> $res") - res + approximation(gadt.tvarOrError(sym).origin, fromBelow, maxLevel).match + case tpr: TypeParamRef => + // Here we do externalization when the returned type is a TypeParamRef, + // b/c ConstraintHandling.approximation may return internal types when + // the type variable is instantiated. See #15531. + gadt.externalize(tpr) + case tp => tp + .showing(i"approximating $sym ~> $result", gadts) } - def symbols: List[Symbol] = mapping.keys + def fresh: GadtState = GadtState(gadt) - def fresh: GadtConstraint = new ProperGadtConstraint(myConstraint, mapping, reverseMapping, wasConstrained) - - /** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */ - def restore(other: GadtConstraint): Unit = - this.myConstraint = other.myConstraint - this.mapping = other.mapping - this.reverseMapping = other.reverseMapping - this.wasConstrained = other.wasConstrained - - def restore(constr: Constraint, mapping: SimpleIdentityMap[Symbol, TypeVar], revMapping: SimpleIdentityMap[TypeParamRef, Symbol], wasConstrained: Boolean): Unit = - this.myConstraint = constr - this.mapping = mapping - this.reverseMapping = revMapping - this.wasConstrained = wasConstrained + /** Restore the GadtConstraint state. */ + def restore(gadt: GadtConstraint): Unit = this.gadt = gadt inline def rollbackGadtUnless(inline op: Boolean): Boolean = - val savedConstr = myConstraint - val savedMapping = mapping - val savedReverseMapping = reverseMapping - val savedWasConstrained = wasConstrained + val saved = gadt var result = false - try - result = op - finally - if !result then - restore(savedConstr, savedMapping, savedReverseMapping, savedWasConstrained) + try result = op + finally if !result then restore(saved) result // ---- Protected/internal ----------------------------------------------- - override protected def constraint = myConstraint - override protected def constraint_=(c: Constraint) = myConstraint = c + override protected def constraint = gadt.constraint + override protected def constraint_=(c: Constraint) = gadt = gadt.withConstraint(c) override protected def isSub(tp1: Type, tp2: Type)(using Context): Boolean = TypeComparer.isSubType(tp1, tp2) override protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean = TypeComparer.isSameType(tp1, tp2) - override def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = - externalize(constraint.nonParamBounds(param)).bounds - - override def fullLowerBound(param: TypeParamRef)(using Context): Type = - constraint.minLower(param).foldLeft(nonParamBounds(param).lo) { - (t, u) => t | externalize(u) - } - - override def fullUpperBound(param: TypeParamRef)(using Context): Type = - constraint.minUpper(param).foldLeft(nonParamBounds(param).hi) { (t, u) => - val eu = externalize(u) - // Any as the upper bound means "no bound", but if F is higher-kinded, - // Any & F = F[_]; this is wrong for us so we need to short-circuit - if t.isAny then eu else t & eu - } - - // ---- Private ---------------------------------------------------------- - - private def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match - case param: TypeParamRef => reverseMapping(param) match - case sym: Symbol => sym.typeRef - case null => param - case tp: TypeAlias => tp.derivedAlias(externalize(tp.alias, theMap)) - case tp => (if theMap == null then ExternalizeMap() else theMap).mapOver(tp) - - private class ExternalizeMap(using Context) extends TypeMap: - def apply(tp: Type): Type = externalize(tp, this)(using mapCtx) - - private def tvarOrError(sym: Symbol)(using Context): TypeVar = - mapping(sym).ensuring(_ != null, i"not a constrainable symbol: $sym").uncheckedNN - - private def containsNoInternalTypes(tp: Type, theAcc: TypeAccumulator[Boolean] | Null = null)(using Context): Boolean = tp match { - case tpr: TypeParamRef => !reverseMapping.contains(tpr) - case tv: TypeVar => !reverseMapping.contains(tv.origin) - case tp => - (if (theAcc != null) theAcc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp) - } - - private class ContainsNoInternalTypesAccumulator(using Context) extends TypeAccumulator[Boolean] { - override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp, this) - } + override def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = gadt.nonParamBounds(param) + override def fullLowerBound(param: TypeParamRef)(using Context): Type = gadt.fullLowerBound(param) + override def fullUpperBound(param: TypeParamRef)(using Context): Type = gadt.fullUpperBound(param) // ---- Debug ------------------------------------------------------------ override def constr = gadtsConstr - - override def toText(printer: Printer): Texts.Text = printer.toText(this) - - /** Provides more information than toText, by showing the underlying Constraint details. */ - def debugBoundsDescription(using Context): String = i"$this\n$constraint" } -private class ProperGadtConstraint ( - myConstraint: Constraint, - mapping: SimpleIdentityMap[Symbol, TypeVar], - reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], - wasConstrained: Boolean, -) extends ConstraintHandling with GadtConstraint(myConstraint, mapping, reverseMapping, wasConstrained) +// Hide ConstraintHandling within GadtState +private class ProperGadtState(private var myGadt: GadtConstraint) extends ConstraintHandling with GadtState: + def gadt: GadtConstraint = myGadt + def gadt_=(gadt: GadtConstraint): Unit = myGadt = gadt diff --git a/compiler/src/dotty/tools/dotc/core/NamerOps.scala b/compiler/src/dotty/tools/dotc/core/NamerOps.scala index 66912537dbce..db6f72590818 100644 --- a/compiler/src/dotty/tools/dotc/core/NamerOps.scala +++ b/compiler/src/dotty/tools/dotc/core/NamerOps.scala @@ -212,11 +212,11 @@ object NamerOps: * by (ab?)-using GADT constraints. See pos/i941.scala. */ def linkConstructorParams(sym: Symbol, tparams: List[Symbol], rhsCtx: Context)(using Context): Unit = - rhsCtx.gadt.addToConstraint(tparams) + rhsCtx.gadtState.addToConstraint(tparams) tparams.lazyZip(sym.owner.typeParams).foreach { (psym, tparam) => val tr = tparam.typeRef - rhsCtx.gadt.addBound(psym, tr, isUpper = false) - rhsCtx.gadt.addBound(psym, tr, isUpper = true) + rhsCtx.gadtState.addBound(psym, tr, isUpper = false) + rhsCtx.gadtState.addBound(psym, tr, isUpper = true) } end NamerOps diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 212b70336f4b..faea30390d2b 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -224,6 +224,17 @@ class OrderingConstraint(private val boundsMap: ParamBounds, def exclusiveUpper(param: TypeParamRef, butNot: TypeParamRef): List[TypeParamRef] = upper(param).filterNot(isLess(butNot, _)) + def bounds(param: TypeParamRef)(using Context): TypeBounds = { + val e = entry(param) + if (e.exists) e.bounds + else { + // TODO: should we change the type of paramInfos to nullable? + val pinfos: List[param.binder.PInfo] | Null = param.binder.paramInfos + if (pinfos != null) pinfos(param.paramNum) // pinfos == null happens in pos/i536.scala + else TypeBounds.empty + } + } + // ---------- Info related to TypeParamRefs ------------------------------------------- def isLess(param1: TypeParamRef, param2: TypeParamRef): Boolean = diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index e7f54d088c09..5e8a960608e6 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -266,7 +266,7 @@ trait PatternTypeConstrainer { self: TypeComparer => case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) => val saved = state.nn.constraint val result = - ctx.gadt.rollbackGadtUnless { + ctx.gadtState.rollbackGadtUnless { tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) => val variance = param.paramVarianceSign if variance == 0 || assumeInvariantRefinement || diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index d14be2b0dfb9..aa3ae0c3c513 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -686,7 +686,7 @@ object Symbols { addToGadt: Boolean = true, flags: FlagSet = EmptyFlags)(using Context): Symbol = { val sym = newSymbol(ctx.owner, name, Case | flags, info, coord = span) - if (addToGadt && name.isTypeName) ctx.gadt.addToConstraint(sym) + if (addToGadt && name.isTypeName) ctx.gadtState.addToConstraint(sym) sym } diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index cf2507aa1724..2a0072590550 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -116,7 +116,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private def isBottom(tp: Type) = tp.widen.isRef(NothingClass) protected def gadtBounds(sym: Symbol)(using Context) = ctx.gadt.bounds(sym) - protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadt.addBound(sym, b, isUpper) + protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadtState.addBound(sym, b, isUpper) protected def typeVarInstance(tvar: TypeVar)(using Context): Type = tvar.underlying @@ -1445,14 +1445,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling if tp2 eq NoType then false else if tp1 eq tp2 then true else - val saved = constraint - val savedGadtConstr = ctx.gadt.getConstraint - val savedMapping = ctx.gadt.getMapping - val savedReverseMapping = ctx.gadt.getReverseMapping - val savedWasConstrained = ctx.gadt.getWasConstrained + val savedCstr = constraint + val savedGadt = ctx.gadt inline def restore() = - state.constraint = saved - ctx.gadt.restore(savedGadtConstr, savedMapping, savedReverseMapping, savedWasConstrained) + state.constraint = savedCstr + ctx.gadtState.restore(savedGadt) val savedSuccessCount = successCount try recCount += 1 @@ -1858,16 +1855,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling */ private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = val preConstraint = constraint - val preGadt = ctx.gadt.fresh + val preGadt = ctx.gadt def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean = - subsumes(left, right, preConstraint) && preGadt.subsumes(leftGadt, rightGadt, preGadt) + subsumes(left, right, preConstraint) + && subsumes(leftGadt.constraint, rightGadt.constraint, preGadt.constraint) if op1 then val op1Constraint = constraint - val op1Gadt = ctx.gadt.fresh + val op1Gadt = ctx.gadt constraint = preConstraint - ctx.gadt.restore(preGadt) + ctx.gadtState.restore(preGadt) if op2 then if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt") @@ -1876,15 +1874,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}") constr.println(i"CUT - prefer $op1Constraint over $constraint") constraint = op1Constraint - ctx.gadt.restore(op1Gadt) + ctx.gadtState.restore(op1Gadt) else gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt") constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint") constraint = preConstraint - ctx.gadt.restore(preGadt) + ctx.gadtState.restore(preGadt) else constraint = op1Constraint - ctx.gadt.restore(op1Gadt) + ctx.gadtState.restore(op1Gadt) true else op2 end necessaryEither @@ -2056,7 +2054,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") if (bound.isRef(tparam)) false else - ctx.gadt.rollbackGadtUnless(gadtAddBound(tparam, bound, isUpper)) + ctx.gadtState.rollbackGadtUnless(gadtAddBound(tparam, bound, isUpper)) } } diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index ea8dcee5fca5..d9da11c561e8 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -687,8 +687,8 @@ object TypeOps: val bound1 = massage(bound) if (bound1 ne bound) { if (checkCtx eq ctx) checkCtx = ctx.fresh.setFreshGADTBounds - if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addToConstraint(sym) - checkCtx.gadt.addBound(sym, bound1, fromBelow) + if (!checkCtx.gadt.contains(sym)) checkCtx.gadtState.addToConstraint(sym) + checkCtx.gadtState.addBound(sym, bound1, fromBelow) typr.println("install GADT bound $bound1 for when checking F-bounded $sym") } } @@ -872,7 +872,7 @@ object TypeOps: case tp: TypeRef if tp.symbol.exists && !tp.symbol.isClass => foldOver(tp.symbol :: xs, tp) case tp => foldOver(xs, tp) val syms2 = getAbstractSymbols(Nil, tp2).reverse - if syms2.nonEmpty then ctx.gadt.addToConstraint(syms2) + if syms2.nonEmpty then ctx.gadtState.addToConstraint(syms2) // If parent contains a reference to an abstract type, then we should // refine subtype checking to eliminate abstract types according to diff --git a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala index 42e86b71eff8..e1b2aaa02866 100644 --- a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala +++ b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala @@ -311,11 +311,11 @@ class InlineReducer(inliner: Inliner)(using Context): def addTypeBindings(typeBinds: TypeBindsMap)(using Context): Unit = typeBinds.foreachBinding { case (sym, shouldBeMinimized) => newTypeBinding(sym, - ctx.gadt.approximation(sym, fromBelow = shouldBeMinimized, maxLevel = Int.MaxValue)) + ctx.gadtState.approximation(sym, fromBelow = shouldBeMinimized, maxLevel = Int.MaxValue)) } def registerAsGadtSyms(typeBinds: TypeBindsMap)(using Context): Unit = - if (typeBinds.size > 0) ctx.gadt.addToConstraint(typeBinds.keys) + if (typeBinds.size > 0) ctx.gadtState.addToConstraint(typeBinds.keys) pat match { case Typed(pat1, tpt) => diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 5abb32b15d57..2039a8f19558 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -269,7 +269,7 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase case CaseDef(pat, _, _) => val gadtCtx = pat.removeAttachment(typer.Typer.InferredGadtConstraints) match - case Some(gadt) => ctx.fresh.setGadt(gadt) + case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt)) case None => ctx super.transform(tree)(using gadtCtx) diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 3e0e7dd5879d..03d3011b4bcd 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -1031,7 +1031,7 @@ trait Implicits: if result.tstate ne ctx.typerState then result.tstate.commit() if result.gstate ne ctx.gadt then - ctx.gadt.restore(result.gstate) + ctx.gadtState.restore(result.gstate) if hasSkolem(false, result.tree) then report.error(SkolemInInferred(result.tree, pt, argument), ctx.source.atSpan(span)) implicits.println(i"success: $result") diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 2aef3433228b..3442207653d4 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -262,7 +262,7 @@ object Inferencing { && ctx.gadt.contains(tp.symbol) => val sym = tp.symbol - val res = ctx.gadt.approximation(sym, fromBelow = variance < 0) + val res = ctx.gadtState.approximation(sym, fromBelow = variance < 0) gadts.println(i"approximated $tp ~~ $res") res @@ -432,7 +432,7 @@ object Inferencing { } // We add the created symbols to GADT constraint here. - if (res.nonEmpty) ctx.gadt.addToConstraint(res) + if (res.nonEmpty) ctx.gadtState.addToConstraint(res) res } diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 6cdd0150518b..6f85efb0fc8a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1864,7 +1864,7 @@ class Namer { typer: Typer => // so we must allow constraining its type parameters // compare with typedDefDef, see tests/pos/gadt-inference.scala rhsCtx.setFreshGADTBounds - rhsCtx.gadt.addToConstraint(typeParams) + rhsCtx.gadtState.addToConstraint(typeParams) } def typedAheadRhs(pt: Type) = diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 1a24a94e527e..eb09d30e60f3 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2362,7 +2362,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer ctx.outer.outersIterator.takeWhile(!_.owner.is(Method)) .filter(ctx => ctx.owner.isClass && ctx.owner.typeParams.nonEmpty) .toList.reverse - .foreach(ctx => rhsCtx.gadt.addToConstraint(ctx.owner.typeParams)) + .foreach(ctx => rhsCtx.gadtState.addToConstraint(ctx.owner.typeParams)) if tparamss.nonEmpty then rhsCtx.setFreshGADTBounds @@ -2371,7 +2371,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // we're typing a polymorphic definition's body, // so we allow constraining all of its type parameters // constructors are an exception as we don't allow constraining type params of classes - rhsCtx.gadt.addToConstraint(tparamSyms) + rhsCtx.gadtState.addToConstraint(tparamSyms) else if !sym.isPrimaryConstructor then linkConstructorParams(sym, tparamSyms, rhsCtx) @@ -3835,7 +3835,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer adaptToSubType(wtp) case CompareResult.OKwithGADTUsed if pt.isValueType - && !inContext(ctx.fresh.setGadt(GadtConstraint.empty)) { + && !inContext(ctx.fresh.setGadtState(GadtState(GadtConstraint.empty))) { val res = (tree.tpe.widenExpr frozen_<:< pt) if res then // we overshot; a cast is not needed, after all. diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index dd6471a882bd..4d08e0582d1d 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -3130,7 +3130,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler if typeHoles.isEmpty then ctx else val ctx1 = ctx.fresh.setFreshGADTBounds.addMode(dotc.core.Mode.GadtConstraintInference) - ctx1.gadt.addToConstraint(typeHoles) + ctx1.gadtState.addToConstraint(typeHoles) ctx1 val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1)