diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index f000fe53f239..2bdf83d55d3e 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -38,6 +38,8 @@ import java.lang.ref.WeakReference import compiletime.uninitialized import cc.{CapturingType, CaptureSet, derivedCapturingType, isBoxedCapturing, EventuallyCapturingType, boxedUnlessFun} import CaptureSet.{CompareResult, IdempotentCaptRefMap, IdentityCaptRefMap} +import scala.collection.mutable.ListBuffer +import dotty.tools.dotc.util._ import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -3482,6 +3484,55 @@ object Types { case that: OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2) && isSoft == that.isSoft case _ => false } + + /** Returns the set of non-union (leaf) types composing this union tree. + * For example:
+ * `(A | B | A | B | (A & (B | C)))` returns `{A, B, (A & (B | C))}`. + */ + private def gatherTreeUniqueMembersAbsorbingNothingTypes(using Context): MutableSet[Type] = { + + var unvisitedSubtrees = List(this) + val uniqueTreeMembers = new EqLinkedHashSet[Type] + + while (unvisitedSubtrees.nonEmpty) { + unvisitedSubtrees match + case head :: tail => + head match + case OrType(l: OrType, r: OrType) => + unvisitedSubtrees = l :: r :: tail + case OrType(l, r: OrType) => + unvisitedSubtrees = r :: tail + if !l.isNothingType then uniqueTreeMembers += l + case OrType(l: OrType, r) => + unvisitedSubtrees = l :: tail + if !r.isNothingType then uniqueTreeMembers += r + case OrType(l, r) => + unvisitedSubtrees = tail + uniqueTreeMembers += l + uniqueTreeMembers += r + case _ => + } + + uniqueTreeMembers + } + + /** Returns an equivalent union tree without repeated members. Weaker than LUB. + */ + def deduplicatedAbsorbingNothingTypes(using Context): Type = { + if tp1.isInstanceOf[OrType] || tp2.isInstanceOf[OrType] then + val uniqueTreeMembers = this.gatherTreeUniqueMembersAbsorbingNothingTypes + + uniqueTreeMembers.size match { + case 1 => + uniqueTreeMembers.iterator.next() + case _ => + val members = uniqueTreeMembers.iterator + val startingUnion = OrType(members.next(), members.next(), soft = true) + members.foldLeft(startingUnion)(OrType(_, _, soft = true)) + } + else if tp1 eq tp2 then tp1 + else this + } } final class CachedOrType(tp1: Type, tp2: Type, override val isSoft: Boolean) extends OrType(tp1, tp2) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 7eb8519739c6..58a2cb1a8b5e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2422,7 +2422,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer completeAnnotations(vdef, sym) if (sym.isOneOf(GivenOrImplicit)) checkImplicitConversionDefOK(sym) if sym.is(Module) then checkNoModuleClash(sym) - val tpt1 = checkSimpleKinded(typedType(tpt)) + val tpdType = typedType(tpt) + val tpt1: Tree = checkSimpleKinded(tpdType) match { + case inferred: InferredTypeTree => +// println(i"inferred type = $inferred") + inferred.tpe match { + case or: OrType => + inferred.overwriteType(or.deduplicatedAbsorbingNothingTypes) + case _ => + } + inferred + case self => self + } val rhs1 = vdef.rhs match { case rhs @ Ident(nme.WILDCARD) => rhs withType tpt1.tpe case rhs => typedExpr(rhs, tpt1.tpe.widenExpr) @@ -3004,17 +3015,24 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedUnadapted(initTree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree = { record("typedUnadapted") val xtree = expanded(initTree) +// println("> typing unadapted") +// println(i"initTree = $initTree") +// println(i"pt = $pt") + xtree.removeAttachment(TypedAhead) match { case Some(ttree) => ttree case none => def typedNamed(tree: untpd.NameTree, pt: Type)(using Context): Tree = { val sym = retrieveSym(xtree) +// println(i"sym = $sym") tree match { case tree: untpd.Ident => typedIdent(tree, pt) case tree: untpd.Select => typedSelect(tree, pt) case tree: untpd.Bind => typedBind(tree, pt) case tree: untpd.ValDef => +// println("> typing val def") +// println(i"$tree") if (tree.isEmpty) tpd.EmptyValDef else typedValDef(tree, sym)(using ctx.localContext(tree, sym)) case tree: untpd.DefDef => @@ -3105,8 +3123,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer then makeContextualFunction(xtree, ifpt) else xtree match - case xtree: untpd.NameTree => typedNamed(xtree, pt) - case xtree => typedUnnamed(xtree) + case xtree: untpd.NameTree => +// println(s"> typing named") +// println(i"xtree = $xtree") +// println(i"pt = $pt") + typedNamed(xtree, pt) + case xtree => +// println(s"> typing unnamed") +// println(i"xtree = $xtree") + typedUnnamed(xtree) simplify(result, pt, locked) catch case ex: TypeError => errorTree(xtree, ex, xtree.srcPos.focus) @@ -3169,6 +3194,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer /** Typecheck and adapt tree, returning a typed tree. Parameters as for `typedUnadapted` */ def typed(tree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree = +// println("> typing tree") trace(i"typing $tree, pt = $pt", typr, show = true) { record(s"typed $getClass") record("typed total") @@ -3297,6 +3323,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer withoutMode(Mode.PatternOrTypeBits)(typed(tree, pt)) def typedType(tree: untpd.Tree, pt: Type = WildcardType, mapPatternBounds: Boolean = false)(using Context): Tree = +// println("> typing type") +// println(i"tree = $tree") +// println(i"pt = $pt") val tree1 = withMode(Mode.Type) { typed(tree, pt) } if mapPatternBounds && ctx.mode.is(Mode.Pattern) && !ctx.isAfterTyper then tree1 match diff --git a/compiler/src/dotty/tools/dotc/util/EqLinkedHashSet.scala b/compiler/src/dotty/tools/dotc/util/EqLinkedHashSet.scala new file mode 100644 index 000000000000..3b5fe820da18 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/util/EqLinkedHashSet.scala @@ -0,0 +1,34 @@ +package dotty.tools.dotc.util + +import scala.collection.mutable.ArrayBuffer + +class EqLinkedHashSet[T]( + initialCapacity: Int = 8, capacityMultiple: Int = 2 +) extends MutableSet[T] { + + private val map: MutableMap[T, Unit] = new EqHashMap(initialCapacity, capacityMultiple) + private val linkingArray: ArrayBuffer[T] = new ArrayBuffer(initialCapacity) + + override def +=(x: T): Unit = + map.update(x, ()) + if map.size != linkingArray.size then linkingArray += x + + override def put(x: T): T = + this += x + x + + override def -=(x: T): Unit = + map -= x + if map.size != linkingArray.size then linkingArray -= x + + override def clear(resetToInitial: Boolean = true): Unit = + map.clear(resetToInitial) + linkingArray.clear() + + override def lookup(x: T): T | Null = if map.contains(x) then x else null + + override def size: Int = map.size + + override def iterator: Iterator[T] = linkingArray.iterator + +} diff --git a/tests/printing/i10693.check b/tests/printing/i10693.check new file mode 100644 index 000000000000..3e1f5858d6a5 --- /dev/null +++ b/tests/printing/i10693.check @@ -0,0 +1,20 @@ +[[syntax trees at end of typer]] // local/i10693.scala +package { + final lazy module val i10693$package: i10693$package = new i10693$package() + final module class i10693$package() extends Object() { + this: i10693$package.type => + def test[A >: Nothing <: Any, B >: Nothing <: Any](a: A, b: B): A | B = a + val v0: String | Int = test[String, Int]("string", 1) + val v1: Int | String = test[Int, String](1, "string") + val v2: String | Int = test[String | Int, Int | String](v0, v1) + val v3: Int | String = test[Int | String, String | Int](v1, v0) + val v4: String | Int = + test[String | Int | (Int | String), Int | String | (String | Int)](v2, v3) + val v5: Int | String = + test[Int | String | (String | Int), String | Int | (Int | String)](v3, v2) + val v6: String | Int = + test[String | Int | (Int | String) | (Int | String | (String | Int)), + Int | String | (String | Int) | (String | Int | (Int | String))](v4, v5) + } +} + diff --git a/tests/printing/i10693.scala b/tests/printing/i10693.scala new file mode 100644 index 000000000000..122984484658 --- /dev/null +++ b/tests/printing/i10693.scala @@ -0,0 +1,8 @@ +def test[A, B](a: A, b: B): A | B = a +val v0 = test("string", 1) +val v1 = test(1, "string") +val v2 = test(v0, v1) +val v3 = test(v1, v0) +val v4 = test(v2, v3) +val v5 = test(v3, v2) +val v6 = test(v4, v5)