Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Union types: deduplication + Nothing type members absorption #16312

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
51 changes: 51 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import java.lang.ref.WeakReference
import compiletime.uninitialized
import cc.{CapturingType, CaptureSet, derivedCapturingType, isBoxedCapturing, EventuallyCapturingType, boxedUnlessFun}
import CaptureSet.{CompareResult, IdempotentCaptRefMap, IdentityCaptRefMap}
import scala.collection.mutable.ListBuffer
import dotty.tools.dotc.util._

import scala.annotation.internal.sharable
import scala.annotation.threadUnsafe
Expand Down Expand Up @@ -3482,6 +3484,55 @@ object Types {
case that: OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2) && isSoft == that.isSoft
case _ => false
}

/** Returns the set of non-union (leaf) types composing this union tree.
* For example:<br>
* `(A | B | A | B | (A & (B | C)))` returns `{A, B, (A & (B | C))}`.
*/
private def gatherTreeUniqueMembersAbsorbingNothingTypes(using Context): MutableSet[Type] = {

var unvisitedSubtrees = List(this)
val uniqueTreeMembers = new EqLinkedHashSet[Type]

while (unvisitedSubtrees.nonEmpty) {
unvisitedSubtrees match
case head :: tail =>
head match
case OrType(l: OrType, r: OrType) =>
unvisitedSubtrees = l :: r :: tail
case OrType(l, r: OrType) =>
unvisitedSubtrees = r :: tail
if !l.isNothingType then uniqueTreeMembers += l
case OrType(l: OrType, r) =>
unvisitedSubtrees = l :: tail
if !r.isNothingType then uniqueTreeMembers += r
case OrType(l, r) =>
unvisitedSubtrees = tail
uniqueTreeMembers += l
uniqueTreeMembers += r
case _ =>
}

uniqueTreeMembers
}

/** Returns an equivalent union tree without repeated members. Weaker than LUB.
*/
def deduplicatedAbsorbingNothingTypes(using Context): Type = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two problems with this approach

  • It looks inefficient. You need a fast mutable set based on identity not structural equality. Everything else would be needlessly slow.
  • Determinism. You need a set that has a reproducible order of elements between compiler runs. A LinkedHashSet would do, except that it's based on equality. A java.util.IdentityHashMap or a dotc.util.EqHashMap would be based on identity, but it's not deterministic. So maybe you have to write that set yourself. A combination of a ListBuffer for the elements and a EqHashMap could work, maybe.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✔️

if tp1.isInstanceOf[OrType] || tp2.isInstanceOf[OrType] then
val uniqueTreeMembers = this.gatherTreeUniqueMembersAbsorbingNothingTypes

uniqueTreeMembers.size match {
case 1 =>
uniqueTreeMembers.iterator.next()
case _ =>
val members = uniqueTreeMembers.iterator
val startingUnion = OrType(members.next(), members.next(), soft = true)
members.foldLeft(startingUnion)(OrType(_, _, soft = true))
}
else if tp1 eq tp2 then tp1
else this
}
}

final class CachedOrType(tp1: Type, tp2: Type, override val isSoft: Boolean) extends OrType(tp1, tp2)
Expand Down
35 changes: 32 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2422,7 +2422,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
completeAnnotations(vdef, sym)
if (sym.isOneOf(GivenOrImplicit)) checkImplicitConversionDefOK(sym)
if sym.is(Module) then checkNoModuleClash(sym)
val tpt1 = checkSimpleKinded(typedType(tpt))
val tpdType = typedType(tpt)
val tpt1: Tree = checkSimpleKinded(tpdType) match {
case inferred: InferredTypeTree =>
// println(i"inferred type = $inferred")
inferred.tpe match {
case or: OrType =>
inferred.overwriteType(or.deduplicatedAbsorbingNothingTypes)
case _ =>
}
inferred
case self => self
}
val rhs1 = vdef.rhs match {
case rhs @ Ident(nme.WILDCARD) => rhs withType tpt1.tpe
case rhs => typedExpr(rhs, tpt1.tpe.widenExpr)
Expand Down Expand Up @@ -3004,17 +3015,24 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedUnadapted(initTree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
record("typedUnadapted")
val xtree = expanded(initTree)
// println("> typing unadapted")
// println(i"initTree = $initTree")
// println(i"pt = $pt")

xtree.removeAttachment(TypedAhead) match {
case Some(ttree) => ttree
case none =>

def typedNamed(tree: untpd.NameTree, pt: Type)(using Context): Tree = {
val sym = retrieveSym(xtree)
// println(i"sym = $sym")
tree match {
case tree: untpd.Ident => typedIdent(tree, pt)
case tree: untpd.Select => typedSelect(tree, pt)
case tree: untpd.Bind => typedBind(tree, pt)
case tree: untpd.ValDef =>
// println("> typing val def")
// println(i"$tree")
if (tree.isEmpty) tpd.EmptyValDef
else typedValDef(tree, sym)(using ctx.localContext(tree, sym))
case tree: untpd.DefDef =>
Expand Down Expand Up @@ -3105,8 +3123,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
then
makeContextualFunction(xtree, ifpt)
else xtree match
case xtree: untpd.NameTree => typedNamed(xtree, pt)
case xtree => typedUnnamed(xtree)
case xtree: untpd.NameTree =>
// println(s"> typing named")
// println(i"xtree = $xtree")
// println(i"pt = $pt")
typedNamed(xtree, pt)
case xtree =>
// println(s"> typing unnamed")
// println(i"xtree = $xtree")
typedUnnamed(xtree)

simplify(result, pt, locked)
catch case ex: TypeError => errorTree(xtree, ex, xtree.srcPos.focus)
Expand Down Expand Up @@ -3169,6 +3194,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

/** Typecheck and adapt tree, returning a typed tree. Parameters as for `typedUnadapted` */
def typed(tree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree =
// println("> typing tree")
trace(i"typing $tree, pt = $pt", typr, show = true) {
record(s"typed $getClass")
record("typed total")
Expand Down Expand Up @@ -3297,6 +3323,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
withoutMode(Mode.PatternOrTypeBits)(typed(tree, pt))

def typedType(tree: untpd.Tree, pt: Type = WildcardType, mapPatternBounds: Boolean = false)(using Context): Tree =
// println("> typing type")
// println(i"tree = $tree")
// println(i"pt = $pt")
val tree1 = withMode(Mode.Type) { typed(tree, pt) }
if mapPatternBounds && ctx.mode.is(Mode.Pattern) && !ctx.isAfterTyper then
tree1 match
Expand Down
34 changes: 34 additions & 0 deletions compiler/src/dotty/tools/dotc/util/EqLinkedHashSet.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package dotty.tools.dotc.util

import scala.collection.mutable.ArrayBuffer

class EqLinkedHashSet[T](
initialCapacity: Int = 8, capacityMultiple: Int = 2
) extends MutableSet[T] {

private val map: MutableMap[T, Unit] = new EqHashMap(initialCapacity, capacityMultiple)
private val linkingArray: ArrayBuffer[T] = new ArrayBuffer(initialCapacity)

override def +=(x: T): Unit =
map.update(x, ())
if map.size != linkingArray.size then linkingArray += x

override def put(x: T): T =
this += x
x

override def -=(x: T): Unit =
map -= x
if map.size != linkingArray.size then linkingArray -= x

override def clear(resetToInitial: Boolean = true): Unit =
map.clear(resetToInitial)
linkingArray.clear()

override def lookup(x: T): T | Null = if map.contains(x) then x else null

override def size: Int = map.size

override def iterator: Iterator[T] = linkingArray.iterator

}
20 changes: 20 additions & 0 deletions tests/printing/i10693.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[[syntax trees at end of typer]] // local/i10693.scala
package <empty> {
final lazy module val i10693$package: i10693$package = new i10693$package()
final module class i10693$package() extends Object() {
this: i10693$package.type =>
def test[A >: Nothing <: Any, B >: Nothing <: Any](a: A, b: B): A | B = a
val v0: String | Int = test[String, Int]("string", 1)
val v1: Int | String = test[Int, String](1, "string")
val v2: String | Int = test[String | Int, Int | String](v0, v1)
val v3: Int | String = test[Int | String, String | Int](v1, v0)
val v4: String | Int =
test[String | Int | (Int | String), Int | String | (String | Int)](v2, v3)
val v5: Int | String =
test[Int | String | (String | Int), String | Int | (Int | String)](v3, v2)
val v6: String | Int =
test[String | Int | (Int | String) | (Int | String | (String | Int)),
Int | String | (String | Int) | (String | Int | (Int | String))](v4, v5)
}
}

8 changes: 8 additions & 0 deletions tests/printing/i10693.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def test[A, B](a: A, b: B): A | B = a
val v0 = test("string", 1)
val v1 = test(1, "string")
val v2 = test(v0, v1)
val v3 = test(v1, v0)
val v4 = test(v2, v3)
val v5 = test(v3, v2)
val v6 = test(v4, v5)