Skip to content

Commit c1f35aa

Browse files
committed
Instantiate more type variables to hard unions
Fixes #14770
1 parent 63344e7 commit c1f35aa

File tree

9 files changed

+128
-38
lines changed

9 files changed

+128
-38
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import typer.ProtoTypes.{newTypeVar, representedParamRef}
1313
import UnificationDirection.*
1414
import NameKinds.AvoidNameKind
1515
import util.SimpleIdentitySet
16+
import NullOpsDecorator.stripNull
1617

1718
/** Methods for adding constraints and solving them.
1819
*
@@ -627,8 +628,11 @@ trait ConstraintHandling {
627628
* 1. If `inst` is a singleton type, or a union containing some singleton types,
628629
* widen (all) the singleton type(s), provided the result is a subtype of `bound`.
629630
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
630-
* 2. If `inst` is a union type, approximate the union type from above by an intersection
631-
* of all common base types, provided the result is a subtype of `bound`.
631+
* 2a. If `inst` is a union type and `widenUnions` is true, approximate the union type
632+
* from above by an intersection of all common base types, provided the result
633+
* is a subtype of `bound`.
634+
* 2b. If `inst` is a union type and `widenUnions` is false, turn it into a hard
635+
* union type (except for unions | Null, which are kept in the state they were).
632636
* 3. Widen some irreducible applications of higher-kinded types to wildcard arguments
633637
* (see @widenIrreducible).
634638
* 4. Drop transparent traits from intersections (see @dropTransparentTraits).
@@ -641,10 +645,12 @@ trait ConstraintHandling {
641645
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
642646
* as those could leak the annotation to users (see run/inferred-repeated-result).
643647
*/
644-
def widenInferred(inst: Type, bound: Type)(using Context): Type =
648+
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
645649
def widenOr(tp: Type) =
646-
val tpw = tp.widenUnion
647-
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
650+
if widenUnions then
651+
val tpw = tp.widenUnion
652+
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
653+
else tp.hardenUnions
648654

649655
def widenSingle(tp: Type) =
650656
val tpw = tp.widenSingletons
@@ -664,6 +670,23 @@ trait ConstraintHandling {
664670
wideInst.dropRepeatedAnnot
665671
end widenInferred
666672

673+
/** Convert all toplevel union types in `tp` to hard unions */
674+
extension (tp: Type) private def hardenUnions(using Context): Type = tp.widen match
675+
case tp: AndType =>
676+
tp.derivedAndType(tp.tp1.hardenUnions, tp.tp2.hardenUnions)
677+
case tp: RefinedType =>
678+
tp.derivedRefinedType(tp.parent.hardenUnions, tp.refinedName, tp.refinedInfo)
679+
case tp: RecType =>
680+
tp.rebind(tp.parent.hardenUnions)
681+
case tp: HKTypeLambda =>
682+
tp.derivedLambdaType(resType = tp.resType.hardenUnions)
683+
case tp: OrType =>
684+
val tp1 = tp.stripNull
685+
if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType)
686+
else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false)
687+
case _ =>
688+
tp
689+
667690
/** The instance type of `param` in the current constraint (which contains `param`).
668691
* If `fromBelow` is true, the instance type is the lub of the parameter's
669692
* lower bounds; otherwise it is the glb of its upper bounds. However,
@@ -672,18 +695,18 @@ trait ConstraintHandling {
672695
* The instance type is not allowed to contain references to types nested deeper
673696
* than `maxLevel`.
674697
*/
675-
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
698+
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int)(using Context): Type = {
676699
val approx = approximation(param, fromBelow, maxLevel).simplified
677700
if fromBelow then
678-
val widened = widenInferred(approx, param)
701+
val widened = widenInferred(approx, param, widenUnions)
679702
// Widening can add extra constraints, in particular the widened type might
680703
// be a type variable which is now instantiated to `param`, and therefore
681704
// cannot be used as an instantiation of `param` without creating a loop.
682705
// If that happens, we run `instanceType` again to find a new instantation.
683706
// (we do not check for non-toplevel occurences: those should never occur
684707
// since `addOneBound` disallows recursive lower bounds).
685708
if constraint.occursAtToplevel(param, widened) then
686-
instanceType(param, fromBelow, maxLevel)
709+
instanceType(param, fromBelow, widenUnions, maxLevel)
687710
else
688711
widened
689712
else

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -487,31 +487,54 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
487487

488488
// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
489489
// before splitting the LHS into its constituents. That way, the RHS variables are
490-
// constraint by the hard union and can be instantiated to it. If we just split and add
490+
// constrained by the hard union and can be instantiated to it. If we just split and add
491491
// the two parts of the LHS separately to the constraint, the lower bound would become
492492
// a soft union.
493493
def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match
494494
case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
495495
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
496496
case _ => true
497497

498-
widenOK
499-
|| joinOK
500-
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
501-
|| containsAnd(tp1)
502-
&& !joined
503-
&& {
504-
joined = true
505-
try inFrozenGadt(recur(tp1.join, tp2))
506-
finally joined = false
507-
}
508-
// An & on the left side loses information. We compensate by also trying the join.
509-
// This is less ad-hoc than it looks since we produce joins in type inference,
510-
// and then need to check that they are indeed supertypes of the original types
511-
// under -Ycheck. Test case is i7965.scala.
512-
// On the other hand, we could get a combinatorial explosion by applying such joins
513-
// recursively, so we do it only once. See i14870.scala as a test case, which would
514-
// loop for a very long time without the recursion brake.
498+
/** Mark toplevel type vars in `tp2` as hard in the current typerState */
499+
def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match
500+
case tvar: TypeVar if constraint.contains(tvar.origin) =>
501+
state.hardVars += tvar
502+
case tp2: TypeParamRef if constraint.contains(tp2) =>
503+
hardenTypeVars(constraint.typeVarOfParam(tp2))
504+
case tp2: AndOrType =>
505+
hardenTypeVars(tp2.tp1)
506+
hardenTypeVars(tp2.tp2)
507+
case _ =>
508+
509+
val res = widenOK
510+
|| joinOK
511+
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
512+
|| containsAnd(tp1)
513+
&& !joined
514+
&& {
515+
joined = true
516+
try inFrozenGadt(recur(tp1.join, tp2))
517+
finally joined = false
518+
}
519+
// An & on the left side loses information. We compensate by also trying the join.
520+
// This is less ad-hoc than it looks since we produce joins in type inference,
521+
// and then need to check that they are indeed supertypes of the original types
522+
// under -Ycheck. Test case is i7965.scala.
523+
// On the other hand, we could get a combinatorial explosion by applying such joins
524+
// recursively, so we do it only once. See i14870.scala as a test case, which would
525+
// loop for a very long time without the recursion brake.
526+
527+
if res && !tp1.isSoft then
528+
// We use a heuristic here where every toplevel type variable on the right hand side
529+
// is marked so that it converts all soft unions in its lower bound to hard unions
530+
// before it is instantiated. The reason is that the union might have come from
531+
// (decomposed and reconstituted) `tp1`. But of course there might be false positives
532+
// where we also treat unions that come from elsewhere as hard unions. Or the constraint
533+
// that created the union is ultimately thrown away, but the type variable will
534+
// stay marked. So it is a coarse measure to take. But it works in the obvious cases.
535+
hardenTypeVars(tp2)
536+
537+
res
515538

516539
case CapturingType(parent1, refs1) =>
517540
if subCaptures(refs1, tp2.captureSet, frozenConstraint).isOK && sameBoxed(tp1, tp2, refs1)
@@ -2960,8 +2983,8 @@ object TypeComparer {
29602983
def subtypeCheckInProgress(using Context): Boolean =
29612984
comparing(_.subtypeCheckInProgress)
29622985

2963-
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
2964-
comparing(_.instanceType(param, fromBelow, maxLevel))
2986+
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
2987+
comparing(_.instanceType(param, fromBelow, widenUnions, maxLevel))
29652988

29662989
def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
29672990
comparing(_.approximation(param, fromBelow, maxLevel))
@@ -2981,8 +3004,8 @@ object TypeComparer {
29813004
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean =
29823005
comparing(_.addToConstraint(tl, tvars))
29833006

2984-
def widenInferred(inst: Type, bound: Type)(using Context): Type =
2985-
comparing(_.widenInferred(inst, bound))
3007+
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
3008+
comparing(_.widenInferred(inst, bound, widenUnions))
29863009

29873010
def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
29883011
comparing(_.dropTransparentTraits(tp, bound))

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,9 @@ object TypeOps:
537537
override def apply(tp: Type): Type = tp match
538538
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
539539
val lo = TypeComparer.instanceType(
540-
tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx)
540+
tp.origin,
541+
fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound,
542+
widenUnions = tp.widenUnions)(using mapCtx)
541543
val lo1 = apply(lo)
542544
if (lo1 ne lo) lo1 else tp
543545
case _ =>

compiler/src/dotty/tools/dotc/core/TyperState.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,20 @@ object TyperState {
2525

2626
type LevelMap = SimpleIdentityMap[TypeVar, Integer]
2727

28-
opaque type Snapshot = (Constraint, TypeVars, LevelMap)
28+
opaque type Snapshot = (Constraint, TypeVars, TypeVars, LevelMap)
2929

3030
extension (ts: TyperState)
3131
def snapshot()(using Context): Snapshot =
32-
(ts.constraint, ts.ownedVars, ts.upLevels)
32+
(ts.constraint, ts.ownedVars, ts.hardVars, ts.upLevels)
3333

3434
def resetTo(state: Snapshot)(using Context): Unit =
35-
val (constraint, ownedVars, upLevels) = state
35+
val (constraint, ownedVars, hardVars, upLevels) = state
3636
for tv <- ownedVars do
3737
if !ts.ownedVars.contains(tv) then // tv has been instantiated
3838
tv.resetInst(ts)
3939
ts.constraint = constraint
4040
ts.ownedVars = ownedVars
41+
ts.hardVars = hardVars
4142
ts.upLevels = upLevels
4243
}
4344

@@ -91,6 +92,14 @@ class TyperState() {
9192
def ownedVars: TypeVars = myOwnedVars
9293
def ownedVars_=(vs: TypeVars): Unit = myOwnedVars = vs
9394

95+
/** The set of type variables `tv` such that, if `tv` is instantiated to
96+
* its lower bound, top-level soft unions in the instance type are converted
97+
* to hard unions instead of being widened in `widenOr`.
98+
*/
99+
private var myHardVars: TypeVars = _
100+
def hardVars: TypeVars = myHardVars
101+
def hardVars_=(tvs: TypeVars): Unit = myHardVars = tvs
102+
94103
private var upLevels: LevelMap = _
95104

96105
/** Initializes all fields except reporter, isCommittable, which need to be
@@ -103,6 +112,7 @@ class TyperState() {
103112
this.myConstraint = constraint
104113
this.previousConstraint = constraint
105114
this.myOwnedVars = SimpleIdentitySet.empty
115+
this.myHardVars = SimpleIdentitySet.empty
106116
this.upLevels = SimpleIdentityMap.empty
107117
this.isCommitted = false
108118
this
@@ -114,6 +124,7 @@ class TyperState() {
114124
val ts = TyperState().init(this, this.constraint)
115125
.setReporter(reporter)
116126
.setCommittable(committable)
127+
ts.hardVars = this.hardVars
117128
ts.upLevels = upLevels
118129
ts
119130

@@ -180,6 +191,7 @@ class TyperState() {
180191
constr.println(i"committing $this to $targetState, fromConstr = $constraint, toConstr = ${targetState.constraint}")
181192
if targetState.constraint eq previousConstraint then
182193
targetState.constraint = constraint
194+
targetState.hardVars = hardVars
183195
if !ownedVars.isEmpty then ownedVars.foreach(targetState.includeVar)
184196
else
185197
targetState.mergeConstraintWith(this)
@@ -238,6 +250,7 @@ class TyperState() {
238250
val otherLos = other.lower(p)
239251
val otherHis = other.upper(p)
240252
val otherEntry = other.entry(p)
253+
if that.hardVars.contains(tv) then this.myHardVars += tv
241254
( (otherLos eq constraint.lower(p)) || otherLos.forall(_ <:< p)) &&
242255
( (otherHis eq constraint.upper(p)) || otherHis.forall(p <:< _)) &&
243256
((otherEntry eq constraint.entry(p)) || otherEntry.match

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4714,12 +4714,16 @@ object Types {
47144714
* is also a singleton type.
47154715
*/
47164716
def instantiate(fromBelow: Boolean)(using Context): Type =
4717-
val tp = TypeComparer.instanceType(origin, fromBelow, nestingLevel)
4717+
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
47184718
if myInst.exists then // The line above might have triggered instantiation of the current type variable
47194719
myInst
47204720
else
47214721
instantiateWith(tp)
47224722

4723+
/** Widen unions when instantiating this variable in the current context? */
4724+
def widenUnions(using Context): Boolean =
4725+
!ctx.typerState.hardVars.contains(this)
4726+
47234727
/** For uninstantiated type variables: the entry in the constraint (either bounds or
47244728
* provisional instance value)
47254729
*/

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1888,7 +1888,7 @@ class Namer { typer: Typer =>
18881888
TypeOps.simplify(tp.widenTermRefExpr,
18891889
if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match
18901890
case ctp: ConstantType if sym.isInlineVal => ctp
1891-
case tp => TypeComparer.widenInferred(tp, pt)
1891+
case tp => TypeComparer.widenInferred(tp, pt, widenUnions = true)
18921892

18931893
// Replace aliases to Unit by Unit itself. If we leave the alias in
18941894
// it would be erased to BoxedUnit.

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
514514
val tparams = poly.paramRefs
515515
val variances = childClass.typeParams.map(_.paramVarianceSign)
516516
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
517-
TypeComparer.instanceType(tparam, fromBelow = variance < 0)
517+
TypeComparer.instanceType(tparam, fromBelow = variance < 0, widenUnions = true)
518518
)
519519
val instanceType = resType.substParams(poly, instanceTypes)
520520
// this is broken in tests/run/i13332intersection.scala,

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2847,7 +2847,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28472847
if (ctx.mode.is(Mode.Pattern)) app1
28482848
else {
28492849
val elemTpes = elems.lazyZip(pts).map((elem, pt) =>
2850-
TypeComparer.widenInferred(elem.tpe, pt))
2850+
TypeComparer.widenInferred(elem.tpe, pt, widenUnions = true))
28512851
val resTpe = TypeOps.nestedPairs(elemTpes)
28522852
app1.cast(resTpe)
28532853
}

tests/pos/i14770.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
type UndefOr[A] = A | Unit
2+
3+
extension [A](maybe: UndefOr[A])
4+
def foreach(f: A => Unit): Unit =
5+
maybe match
6+
case () => ()
7+
case a: A => f(a)
8+
9+
trait Foo
10+
trait Bar
11+
12+
object Baz:
13+
var booBap: Foo | Bar = _
14+
15+
def z: UndefOr[Foo | Bar] = ???
16+
17+
@main
18+
def main =
19+
z.foreach(x => Baz.booBap = x)
20+
21+
def test[A](v: A | Unit): A | Unit = v
22+
val x1 = test(5: Int | Unit)
23+
val x2 = test(5: String | Int | Unit)
24+
val _: Int | Unit = x1
25+
val _: String | Int | Unit = x2

0 commit comments

Comments
 (0)