Skip to content

Commit

Permalink
refactor path aliasing constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
Linyxus committed Sep 14, 2022
1 parent a6f3a3a commit 25aa162
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 43 deletions.
46 changes: 19 additions & 27 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,11 @@ sealed abstract class GadtConstraint extends Showable {
/** Further constrain a path-dependent type already present in the constraint. */
def addBound(p: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean

/** Record the equality between two singleton types. */
def addEquality(p: PathType, q: PathType)(using Context): Unit
/** Record the aliasing relationship between two singleton types. */
def recordPathAliasing(p: PathType, q: PathType)(using Context): Unit

/** Check whether two singleton types are equivalent. */
def isEquivalent(p: PathType, q: PathType): Boolean

/** Query the representative member of a singleton type. */
def reprOf(p: PathType): PathType | Null
/** Check whether two paths are equivalent via path aliasing. */
def isAliasingPath(p: PathType, q: PathType): Boolean

/** Scrutinee path of the current pattern matching. */
def scrutineePath: TermRef | Null
Expand Down Expand Up @@ -124,7 +121,7 @@ final class ProperGadtConstraint private(
private var pathDepReverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef],
private var wasConstrained: Boolean,
private var myScrutineePath: TermRef | Null,
private var myUnionFind: SimpleIdentityMap[PathType, PathType],
private var pathAliasingMapping: SimpleIdentityMap[PathType, PathType],
private var myPatternSkolem: SkolemType | Null,
) extends GadtConstraint with ConstraintHandling {
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
Expand All @@ -137,7 +134,7 @@ final class ProperGadtConstraint private(
pathDepReverseMapping = SimpleIdentityMap.empty,
wasConstrained = false,
myScrutineePath = null,
myUnionFind = SimpleIdentityMap.empty,
pathAliasingMapping = SimpleIdentityMap.empty,
myPatternSkolem = null,
)

Expand Down Expand Up @@ -392,9 +389,6 @@ final class ProperGadtConstraint private(
buf ++= "}"
buf.result

/** Get the representative member of the path in the union find. */
override def reprOf(p: PathType): PathType | Null = lookupPath(p)

override def addToConstraint(params: List[Symbol])(using Context): Boolean = {
import NameKinds.DepParamName

Expand Down Expand Up @@ -512,15 +506,15 @@ final class ProperGadtConstraint private(
}

private def lookupPath(p: PathType): PathType | Null =
def recur(p: PathType): PathType | Null = myUnionFind(p) match
def recur(p: PathType): PathType | Null = pathAliasingMapping(p) match
case null => null
case q: PathType if q eq p => q
case q: PathType =>
recur(q)

recur(p)

override def addEquality(p: PathType, q: PathType)(using Context): Unit =
override def recordPathAliasing(p: PathType, q: PathType)(using Context): Unit =
val pRep: PathType | Null = lookupPath(p)
val qRep: PathType | Null = lookupPath(q)

Expand All @@ -529,13 +523,13 @@ final class ProperGadtConstraint private(
case (null, r: PathType) => r
case (r: PathType, null) => r
case (r1: PathType, r2: PathType) =>
myUnionFind = myUnionFind.updated(r2, r1)
pathAliasingMapping = pathAliasingMapping.updated(r2, r1)
r1

myUnionFind = myUnionFind.updated(p, newRep)
myUnionFind = myUnionFind.updated(q, newRep)
pathAliasingMapping = pathAliasingMapping.updated(p, newRep)
pathAliasingMapping = pathAliasingMapping.updated(q, newRep)

override def isEquivalent(p: PathType, q: PathType): Boolean =
override def isAliasingPath(p: PathType, q: PathType): Boolean =
lookupPath(p) match
case null => false
case p0: PathType => lookupPath(q) match
Expand Down Expand Up @@ -637,7 +631,7 @@ final class ProperGadtConstraint private(
pathDepReverseMapping,
wasConstrained,
myScrutineePath,
myUnionFind,
pathAliasingMapping,
myPatternSkolem,
)

Expand All @@ -650,7 +644,7 @@ final class ProperGadtConstraint private(
this.pathDepReverseMapping = other.pathDepReverseMapping
this.wasConstrained = other.wasConstrained
this.myScrutineePath = other.myScrutineePath
this.myUnionFind = other.myUnionFind
this.pathAliasingMapping = other.pathAliasingMapping
this.myPatternSkolem = other.myPatternSkolem
case _ => ;
}
Expand All @@ -675,10 +669,10 @@ final class ProperGadtConstraint private(
}

def updateUnionFind() =
myUnionFind(myPatternSkolem.nn) match {
pathAliasingMapping(myPatternSkolem.nn) match {
case null =>
case repr: PathType =>
myUnionFind = myUnionFind.updated(path, repr)
pathAliasingMapping = pathAliasingMapping.updated(path, repr)
}

updateMappings()
Expand Down Expand Up @@ -784,7 +778,7 @@ final class ProperGadtConstraint private(
}
}
sb ++= "\nSingleton equalities:\n"
myUnionFind foreachBinding { case (path, _) =>
pathAliasingMapping foreachBinding { case (path, _) =>
val repr = lookupPath(path)
repr match
case repr: PathType if repr ne path =>
Expand All @@ -805,8 +799,6 @@ final class ProperGadtConstraint private(
override def bounds(tp: TypeRef)(using Context): TypeBounds | Null = null
override def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null = null

override def reprOf(p: PathType): PathType | Null = null

override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess")
override def isLess(tp1: NamedType, tp2: NamedType)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess")

Expand All @@ -827,9 +819,9 @@ final class ProperGadtConstraint private(
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")
override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")

override def addEquality(p: PathType, q: PathType)(using Context) = ()
override def recordPathAliasing(p: PathType, q: PathType)(using Context) = ()

override def isEquivalent(p: PathType, q: PathType) = false
override def isAliasingPath(p: PathType, q: PathType) = false

override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
/** Reconstruct subtype from the cohabitation between the scrutinee and the
pattern. */
def constrainPattern: Boolean = {
ctx.gadt.addEquality(scrutineePath, patternPath)
ctx.gadt.recordPathAliasing(scrutineePath, patternPath)

(!registerPattern || reconstructSubTypeFor(patternPath, scrutineePath))
&& (!registerScrutinee || reconstructSubTypeFor(scrutineePath, patternPath))
Expand All @@ -253,7 +253,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
(!registerPtPath || reconstructSubTypeFor(ptPath, scrutineePath))
&& (!registerScrutinee || reconstructSubTypeFor(scrutineePath, ptPath))

ctx.gadt.addEquality(scrutineePath, ptPath)
ctx.gadt.recordPathAliasing(scrutineePath, ptPath)

result
case _ =>
Expand Down
15 changes: 1 addition & 14 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
def compareSingletonGADT: Boolean =
(tp1, tp2) match {
case (tp1: TermRef, tp2: TermRef) =>
ctx.gadt.isEquivalent(tp1, tp2) && { GADTused = true; true }
ctx.gadt.isAliasingPath(tp1, tp2) && { GADTused = true; true }
case _ => false
}

Expand Down Expand Up @@ -2360,19 +2360,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case Atoms.Range(lo2, hi2) =>
if hi1.subsetOf(lo2) then return tp2
if hi2.subsetOf(lo1) then return tp1

def getReprSet(ps: Set[Type]): Set[Type] =
ps.map { x =>
x match
case p: PathType =>
val rep = ctx.gadt.reprOf(p)
if rep == null then p else rep
case t => t
}
val (repLo1, repHi1, repLo2, repHi2) = (getReprSet(lo1), getReprSet(hi1), getReprSet(lo2), getReprSet(hi2))
if repHi2.subsetOf(repLo1) then return tp1
if repHi1.subsetOf(repLo2) then return tp2

if (hi1 & hi2).isEmpty then return orType(tp1, tp2)
case none =>
case none =>
Expand Down

0 comments on commit 25aa162

Please sign in to comment.