Skip to content

Commit 02e3805

Browse files
committed
Preserve singletons in unions when they're explicitly written in the code
There are two cases where we should not widen singletons in unions: - When we explicitly write the type, like `val x: 1 | 2` - When pattern matching binds an alternative, like `case x @ (1 | 2) =>` Fixes #829
1 parent b927f66 commit 02e3805

File tree

6 files changed

+35
-16
lines changed

6 files changed

+35
-16
lines changed

src/dotty/tools/dotc/core/TypeComparer.scala

+9-9
Original file line numberDiff line numberDiff line change
@@ -790,9 +790,9 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
790790
(defn.AnyType /: tps)(glb)
791791

792792
/** The least upper bound of two types
793-
* @note We do not admit singleton types in or-types as lubs.
793+
* @param keepSingletons If true, do not widen singletons when forming an OrType
794794
*/
795-
def lub(tp1: Type, tp2: Type): Type = /*>|>*/ ctx.traceIndented(s"lub(${tp1.show}, ${tp2.show})", subtyping, show = true) /*<|<*/ {
795+
def lub(tp1: Type, tp2: Type, keepSingletons: Boolean = false): Type = /*>|>*/ ctx.traceIndented(s"lub(${tp1.show}, ${tp2.show}, $keepSingletons)", subtyping, show = true) /*<|<*/ {
796796
if (tp1 eq tp2) tp1
797797
else if (!tp1.exists) tp1
798798
else if (!tp2.exists) tp2
@@ -805,8 +805,8 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
805805
val t2 = mergeIfSuper(tp2, tp1)
806806
if (t2.exists) t2
807807
else {
808-
val tp1w = tp1.widen
809-
val tp2w = tp2.widen
808+
val tp1w = if (keepSingletons) tp1.widenExpr else tp1.widen
809+
val tp2w = if (keepSingletons) tp2.widenExpr else tp2.widen
810810
if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w)
811811
else orType(tp1w, tp2w) // no need to check subtypes again
812812
}
@@ -815,8 +815,8 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
815815
}
816816

817817
/** The least upper bound of a list of types */
818-
final def lub(tps: List[Type]): Type =
819-
(defn.NothingType /: tps)(lub)
818+
final def lubList(tps: List[Type], keepSingletons: Boolean = false): Type =
819+
(defn.NothingType /: tps)(lub(_, _, keepSingletons))
820820

821821
/** Merge `t1` into `tp2` if t1 is a subtype of some &-summand of tp2.
822822
*/
@@ -1207,9 +1207,9 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
12071207
super.hasMatchingMember(name, tp1, tp2)
12081208
}
12091209

1210-
override def lub(tp1: Type, tp2: Type) =
1211-
traceIndented(s"lub(${show(tp1)}, ${show(tp2)})") {
1212-
super.lub(tp1, tp2)
1210+
override def lub(tp1: Type, tp2: Type, canWiden: Boolean = true) =
1211+
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, $canWiden)") {
1212+
super.lub(tp1, tp2, canWiden)
12131213
}
12141214

12151215
override def glb(tp1: Type, tp2: Type) =

src/dotty/tools/dotc/core/TypeOps.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ trait TypeOps { this: Context => // TODO: Make standalone object.
246246
case AndType(l, r) =>
247247
simplify(l, theMap) & simplify(r, theMap)
248248
case OrType(l, r) =>
249-
simplify(l, theMap) | simplify(r, theMap)
249+
ctx.typeComparer.lub(simplify(l, theMap), simplify(r, theMap), keepSingletons = true)
250250
case _ =>
251251
(if (theMap != null) theMap else new SimplifyMap).mapOver(tp)
252252
}

src/dotty/tools/dotc/typer/TypeAssigner.scala

+6-6
Original file line numberDiff line numberDiff line change
@@ -339,23 +339,23 @@ trait TypeAssigner {
339339
tree.withType(body.tpe)
340340

341341
def assignType(tree: untpd.Match, cases: List[CaseDef])(implicit ctx: Context) =
342-
tree.withType(ctx.typeComparer.lub(cases.tpes))
342+
tree.withType(ctx.typeComparer.lubList(cases.tpes))
343343

344344
def assignType(tree: untpd.Return)(implicit ctx: Context) =
345345
tree.withType(defn.NothingType)
346346

347347
def assignType(tree: untpd.Try, expr: Tree, cases: List[CaseDef])(implicit ctx: Context) = {
348348
if (cases.isEmpty) tree.withType(expr.tpe)
349-
else tree.withType(ctx.typeComparer.lub(expr.tpe :: cases.tpes))
349+
else tree.withType(ctx.typeComparer.lubList(expr.tpe :: cases.tpes))
350350
}
351351

352352
def assignType(tree: untpd.SeqLiteral, elems: List[Tree])(implicit ctx: Context) = tree match {
353353
case tree: JavaSeqLiteral =>
354-
tree.withType(defn.ArrayType(ctx.typeComparer.lub(elems.tpes).widen))
354+
tree.withType(defn.ArrayType(ctx.typeComparer.lubList(elems.tpes).widen))
355355
case _ =>
356356
val ownType =
357357
if (ctx.erasedTypes) defn.SeqType
358-
else defn.SeqType.appliedTo(ctx.typeComparer.lub(elems.tpes).widen)
358+
else defn.SeqType.appliedTo(ctx.typeComparer.lubList(elems.tpes).widen)
359359
tree.withType(ownType)
360360
}
361361

@@ -366,7 +366,7 @@ trait TypeAssigner {
366366
tree.withType(left.tpe & right.tpe)
367367

368368
def assignType(tree: untpd.OrTypeTree, left: Tree, right: Tree)(implicit ctx: Context) =
369-
tree.withType(left.tpe | right.tpe)
369+
tree.withType(ctx.typeComparer.lub(left.tpe, right.tpe, keepSingletons = true))
370370

371371
// RefinedTypeTree is missing, handled specially in Typer and Unpickler.
372372

@@ -388,7 +388,7 @@ trait TypeAssigner {
388388
tree.withType(NamedType.withFixedSym(NoPrefix, sym))
389389

390390
def assignType(tree: untpd.Alternative, trees: List[Tree])(implicit ctx: Context) =
391-
tree.withType(ctx.typeComparer.lub(trees.tpes))
391+
tree.withType(ctx.typeComparer.lubList(trees.tpes, keepSingletons = true))
392392

393393
def assignType(tree: untpd.UnApply, proto: Type)(implicit ctx: Context) =
394394
tree.withType(proto)

test/dotc/tests.scala

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class tests extends CompilerTest {
160160
@Test def neg_validate = compileFile(negDir, "validate", xerrors = 18)
161161
@Test def neg_validateParsing = compileFile(negDir, "validate-parsing", xerrors = 7)
162162
@Test def neg_validateRefchecks = compileFile(negDir, "validate-refchecks", xerrors = 2)
163+
@Test def neg_singletonsLubs = compileFile(negDir, "singletons-lubs", xerrors = 2)
163164

164165
@Test def run_all = runFiles(runDir)
165166

tests/neg/singletons-lubs.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
object Test {
2+
def oneOrTwo(x: 1 | 2): 1 | 2 = x
3+
def test: Unit = {
4+
val foo: 3 | 4 = 1 // error
5+
oneOrTwo(foo) // error
6+
}
7+
}

tests/pos/singletons-lubs.scala

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
object Test {
2+
def oneOrTwo(x: 1 | 2): 1 | 2 = x
3+
def test: Unit = {
4+
val foo: 1 | 2 = 1
5+
oneOrTwo(oneOrTwo(foo))
6+
1 match {
7+
case x: (1 | 2) => oneOrTwo(x)
8+
case x @ (1 | 2) => oneOrTwo(x)
9+
}
10+
}
11+
}

0 commit comments

Comments
 (0)