Skip to content
Draft
51 changes: 51 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import java.lang.ref.WeakReference
import compiletime.uninitialized
import cc.{CapturingType, CaptureSet, derivedCapturingType, isBoxedCapturing, EventuallyCapturingType, boxedUnlessFun}
import CaptureSet.{CompareResult, IdempotentCaptRefMap, IdentityCaptRefMap}
import scala.collection.mutable.ListBuffer
import dotty.tools.dotc.util._

import scala.annotation.internal.sharable
import scala.annotation.threadUnsafe
Expand Down Expand Up @@ -3482,6 +3484,55 @@ object Types {
case that: OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2) && isSoft == that.isSoft
case _ => false
}

/** Returns the set of non-union (leaf) types composing this union tree.
* For example:<br>
* `(A | B | A | B | (A & (B | C)))` returns `{A, B, (A & (B | C))}`.
*/
private def gatherTreeUniqueMembersAbsorbingNothingTypes(using Context): MutableSet[Type] = {

var unvisitedSubtrees = List(this)
val uniqueTreeMembers = new EqLinkedHashSet[Type]

while (unvisitedSubtrees.nonEmpty) {
unvisitedSubtrees match
case head :: tail =>
head match
case OrType(l: OrType, r: OrType) =>
unvisitedSubtrees = l :: r :: tail
case OrType(l, r: OrType) =>
unvisitedSubtrees = r :: tail
if !l.isNothingType then uniqueTreeMembers += l
case OrType(l: OrType, r) =>
unvisitedSubtrees = l :: tail
if !r.isNothingType then uniqueTreeMembers += r
case OrType(l, r) =>
unvisitedSubtrees = tail
uniqueTreeMembers += l
uniqueTreeMembers += r
case _ =>
}

uniqueTreeMembers
}

/** Returns an equivalent union tree without repeated members. Weaker than LUB.
*/
def deduplicatedAbsorbingNothingTypes(using Context): Type = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two problems with this approach

  • It looks inefficient. You need a fast mutable set based on identity not structural equality. Everything else would be needlessly slow.
  • Determinism. You need a set that has a reproducible order of elements between compiler runs. A LinkedHashSet would do, except that it's based on equality. A java.util.IdentityHashMap or a dotc.util.EqHashMap would be based on identity, but it's not deterministic. So maybe you have to write that set yourself. A combination of a ListBuffer for the elements and a EqHashMap could work, maybe.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✔️

if tp1.isInstanceOf[OrType] || tp2.isInstanceOf[OrType] then
val uniqueTreeMembers = this.gatherTreeUniqueMembersAbsorbingNothingTypes

uniqueTreeMembers.size match {
case 1 =>
uniqueTreeMembers.iterator.next()
case _ =>
val members = uniqueTreeMembers.iterator
val startingUnion = OrType(members.next(), members.next(), soft = true)
members.foldLeft(startingUnion)(OrType(_, _, soft = true))
}
else if tp1 eq tp2 then tp1
else this
}
}

final class CachedOrType(tp1: Type, tp2: Type, override val isSoft: Boolean) extends OrType(tp1, tp2)
Expand Down
35 changes: 32 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2422,7 +2422,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
completeAnnotations(vdef, sym)
if (sym.isOneOf(GivenOrImplicit)) checkImplicitConversionDefOK(sym)
if sym.is(Module) then checkNoModuleClash(sym)
val tpt1 = checkSimpleKinded(typedType(tpt))
val tpdType = typedType(tpt)
val tpt1: Tree = checkSimpleKinded(tpdType) match {
case inferred: InferredTypeTree =>
// println(i"inferred type = $inferred")
inferred.tpe match {
case or: OrType =>
inferred.overwriteType(or.deduplicatedAbsorbingNothingTypes)
case _ =>
}
inferred
case self => self
}
val rhs1 = vdef.rhs match {
case rhs @ Ident(nme.WILDCARD) => rhs withType tpt1.tpe
case rhs => typedExpr(rhs, tpt1.tpe.widenExpr)
Expand Down Expand Up @@ -3004,17 +3015,24 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedUnadapted(initTree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
record("typedUnadapted")
val xtree = expanded(initTree)
// println("> typing unadapted")
// println(i"initTree = $initTree")
// println(i"pt = $pt")

xtree.removeAttachment(TypedAhead) match {
case Some(ttree) => ttree
case none =>

def typedNamed(tree: untpd.NameTree, pt: Type)(using Context): Tree = {
val sym = retrieveSym(xtree)
// println(i"sym = $sym")
tree match {
case tree: untpd.Ident => typedIdent(tree, pt)
case tree: untpd.Select => typedSelect(tree, pt)
case tree: untpd.Bind => typedBind(tree, pt)
case tree: untpd.ValDef =>
// println("> typing val def")
// println(i"$tree")
if (tree.isEmpty) tpd.EmptyValDef
else typedValDef(tree, sym)(using ctx.localContext(tree, sym))
case tree: untpd.DefDef =>
Expand Down Expand Up @@ -3105,8 +3123,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
then
makeContextualFunction(xtree, ifpt)
else xtree match
case xtree: untpd.NameTree => typedNamed(xtree, pt)
case xtree => typedUnnamed(xtree)
case xtree: untpd.NameTree =>
// println(s"> typing named")
// println(i"xtree = $xtree")
// println(i"pt = $pt")
typedNamed(xtree, pt)
case xtree =>
// println(s"> typing unnamed")
// println(i"xtree = $xtree")
typedUnnamed(xtree)

simplify(result, pt, locked)
catch case ex: TypeError => errorTree(xtree, ex, xtree.srcPos.focus)
Expand Down Expand Up @@ -3169,6 +3194,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

/** Typecheck and adapt tree, returning a typed tree. Parameters as for `typedUnadapted` */
def typed(tree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree =
// println("> typing tree")
trace(i"typing $tree, pt = $pt", typr, show = true) {
record(s"typed $getClass")
record("typed total")
Expand Down Expand Up @@ -3297,6 +3323,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
withoutMode(Mode.PatternOrTypeBits)(typed(tree, pt))

def typedType(tree: untpd.Tree, pt: Type = WildcardType, mapPatternBounds: Boolean = false)(using Context): Tree =
// println("> typing type")
// println(i"tree = $tree")
// println(i"pt = $pt")
val tree1 = withMode(Mode.Type) { typed(tree, pt) }
if mapPatternBounds && ctx.mode.is(Mode.Pattern) && !ctx.isAfterTyper then
tree1 match
Expand Down
34 changes: 34 additions & 0 deletions compiler/src/dotty/tools/dotc/util/EqLinkedHashSet.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package dotty.tools.dotc.util

import scala.collection.mutable.ArrayBuffer

class EqLinkedHashSet[T](
initialCapacity: Int = 8, capacityMultiple: Int = 2
) extends MutableSet[T] {

private val map: MutableMap[T, Unit] = new EqHashMap(initialCapacity, capacityMultiple)
private val linkingArray: ArrayBuffer[T] = new ArrayBuffer(initialCapacity)

override def +=(x: T): Unit =
map.update(x, ())
if map.size != linkingArray.size then linkingArray += x

override def put(x: T): T =
this += x
x

override def -=(x: T): Unit =
map -= x
if map.size != linkingArray.size then linkingArray -= x

override def clear(resetToInitial: Boolean = true): Unit =
map.clear(resetToInitial)
linkingArray.clear()

override def lookup(x: T): T | Null = if map.contains(x) then x else null

override def size: Int = map.size

override def iterator: Iterator[T] = linkingArray.iterator

}
20 changes: 20 additions & 0 deletions tests/printing/i10693.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[[syntax trees at end of typer]] // local/i10693.scala
package <empty> {
final lazy module val i10693$package: i10693$package = new i10693$package()
final module class i10693$package() extends Object() {
this: i10693$package.type =>
def test[A >: Nothing <: Any, B >: Nothing <: Any](a: A, b: B): A | B = a
val v0: String | Int = test[String, Int]("string", 1)
val v1: Int | String = test[Int, String](1, "string")
val v2: String | Int = test[String | Int, Int | String](v0, v1)
val v3: Int | String = test[Int | String, String | Int](v1, v0)
val v4: String | Int =
test[String | Int | (Int | String), Int | String | (String | Int)](v2, v3)
val v5: Int | String =
test[Int | String | (String | Int), String | Int | (Int | String)](v3, v2)
val v6: String | Int =
test[String | Int | (Int | String) | (Int | String | (String | Int)),
Int | String | (String | Int) | (String | Int | (Int | String))](v4, v5)
}
}

8 changes: 8 additions & 0 deletions tests/printing/i10693.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def test[A, B](a: A, b: B): A | B = a
val v0 = test("string", 1)
val v1 = test(1, "string")
val v2 = test(v0, v1)
val v3 = test(v1, v0)
val v4 = test(v2, v3)
val v5 = test(v3, v2)
val v6 = test(v4, v5)