Skip to content

Avoid inference getting stuck when the expected type contains a union/intersection #8635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,10 @@ trait ConstraintHandling[AbstractContext] {
* recording an isLess relationship instead (even though this is not implied
* by the bound).
*
* Narrowing a constraint is better than widening it, because narrowing leads
* to incompleteness (which we face anyway, see for instance eitherIsSubType)
* but widening leads to unsoundness.
* Normally, narrowing a constraint is better than widening it, because
* narrowing leads to incompleteness (which we face anyway, see for
* instance `TypeComparer#either`) but widening leads to unsoundness,
* but note the special handling in `ConstrainResult` mode below.
*
* A test case that demonstrates the problem is i864.scala.
* Turn Config.checkConstraintsSeparated on to get an accurate diagnostic
Expand Down Expand Up @@ -544,10 +545,23 @@ trait ConstraintHandling[AbstractContext] {
case bound: TypeParamRef if constraint contains bound =>
addParamBound(bound)
case _ =>
val savedConstraint = constraint
val pbound = prune(bound)
pbound.exists
&& kindCompatible(param, pbound)
&& (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound))
val constraintsNarrowed = constraint ne savedConstraint

val res =
pbound.exists
&& kindCompatible(param, pbound)
&& (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound))
// If we're in `ConstrainResult` mode, we don't want to commit to a
// set of constraints that would later prevent us from typechecking
// arguments, so if `pruneParams` had to narrow the constraints, we
// simply do not record any new constraint.
// Unlike in `TypeComparer#either`, the same reasoning does not apply
// to GADT mode because this code is never run on GADT constraints.
Comment on lines +560 to +561
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AleksanderBG Can you confirm that this comment is correct ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smarter missed your comment! GADT code only goes through addToConstraint/addUpperBound/addLowerBound. AFAIT none of these invoke addBound.

if ctx.mode.is(Mode.ConstrainResult) && constraintsNarrowed then
constraint = savedConstraint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks good to me. It would be even nicer to have a prune that does not narrow. But let's save this for another day.

res
}
finally addConstraintInvocations -= 1
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ object Mode {
*/
val Printing: Mode = newMode(10, "Printing")

/** We are constraining a method based on its expected type. */
val ConstrainResult: Mode = newMode(11, "ConstrainResult")

/** We are currently in a `viewExists` check. In that case, ambiguous
* implicits checks are disabled and we succeed with the first implicit
* found.
Expand Down
85 changes: 49 additions & 36 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1364,14 +1364,26 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w

/** Returns true iff the result of evaluating either `op1` or `op2` is true and approximates resulting constraints.
*
* If we're _not_ in GADTFlexible mode, we try to keep the smaller of the two constraints.
* If we're _in_ GADTFlexible mode, we keep the smaller constraint if any, or no constraint at all.
* If we're inferring GADT bounds or constraining a method based on its
* expected type, we infer only the _necessary_ constraints, this means we
* keep the smaller constraint if any, or no constraint at all. This is
* necessary for GADT bounds inference to be sound. When constraining a
* method, this avoid committing of constraints that would later prevent us
* from typechecking method arguments, see or-inf.scala and and-inf.scala for
* examples.
*
* Otherwise, we infer _sufficient_ constraints: we try to keep the smaller of
* the two constraints, but if never is smaller than the other, we just pick
* the first one.
*
* @see [[necessaryEither]] for the GADT / result type case
* @see [[sufficientEither]] for the normal case
* @see [[necessaryEither]] for the GADTFlexible case
*/
protected def either(op1: => Boolean, op2: => Boolean): Boolean =
if (ctx.mode.is(Mode.GadtConstraintInference)) necessaryEither(op1, op2) else sufficientEither(op1, op2)
if ctx.mode.is(Mode.GadtConstraintInference) || ctx.mode.is(Mode.ConstrainResult) then
necessaryEither(op1, op2)
else
sufficientEither(op1, op2)

/** Returns true iff the result of evaluating either `op1` or `op2` is true,
* trying at the same time to keep the constraint as wide as possible.
Expand Down Expand Up @@ -1438,8 +1450,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
* T1 & T2 <:< T3
* T1 <:< T2 | T3
*
* Unlike [[sufficientEither]], this method is used in GADTFlexible mode, when we are attempting to infer GADT
* constraints that necessarily follow from the subtyping relationship. For instance, if we have
* Unlike [[sufficientEither]], this method is used in GADTConstraintInference mode, when we are attempting
* to infer GADT constraints that necessarily follow from the subtyping relationship. For instance, if we have
*
* enum Expr[T] {
* case IntExpr(i: Int) extends Expr[Int]
Expand All @@ -1466,48 +1478,49 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
*
* then the necessary constraint is { A = Int }, but correctly inferring that is, as far as we know, too expensive.
*
* This method is also used in ConstrainResult mode
* to avoid inference getting stuck due to lack of backtracking,
* see or-inf.scala and and-inf.scala for examples.
*
* Method name comes from the notion that we are keeping the constraint which is necessary to satisfy both
* subtyping relationships.
*/
private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = {
private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean =
val preConstraint = constraint

val preGadt = ctx.gadt.fresh
// if GADTflexible mode is on, we expect to always have a ProperGadtConstraint
val pre = preGadt.asInstanceOf[ProperGadtConstraint]
if (op1) {
val leftConstraint = constraint
val leftGadt = ctx.gadt.fresh

def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean =
subsumes(left, right, preConstraint) && preGadt.match
case preGadt: ProperGadtConstraint =>
preGadt.subsumes(leftGadt, rightGadt, preGadt)
case _ =>
true

if op1 then
val op1Constraint = constraint
val op1Gadt = ctx.gadt.fresh
constraint = preConstraint
ctx.gadt.restore(preGadt)
if (op2)
if (pre.subsumes(leftGadt, ctx.gadt, preGadt) && subsumes(leftConstraint, constraint, preConstraint)) {
gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $leftGadt")
constr.println(i"CUT - prefer $constraint over $leftConstraint")
true
}
else if (pre.subsumes(ctx.gadt, leftGadt, preGadt) && subsumes(constraint, leftConstraint, preConstraint)) {
gadts.println(i"GADT CUT - prefer $leftGadt over ${ctx.gadt}")
constr.println(i"CUT - prefer $leftConstraint over $constraint")
constraint = leftConstraint
ctx.gadt.restore(leftGadt)
true
}
else {
if op2 then
if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then
gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt")
constr.println(i"CUT - prefer $constraint over $op1Constraint")
else if allSubsumes(ctx.gadt, op1Gadt, constraint, op1Constraint) then
gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}")
constr.println(i"CUT - prefer $op1Constraint over $constraint")
constraint = op1Constraint
ctx.gadt.restore(op1Gadt)
else
gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt")
constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint")
constraint = preConstraint
ctx.gadt.restore(preGadt)
true
}
else {
constraint = leftConstraint
ctx.gadt.restore(leftGadt)
true
}
}
else
constraint = op1Constraint
ctx.gadt.restore(op1Gadt)
true
else op2
}
end necessaryEither

/** Does type `tp1` have a member with name `name` whose normalized type is a subtype of
* the normalized type of the refinement `tp2`?
Expand Down
11 changes: 4 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,14 @@ object ProtoTypes {
else ctx.test(testCompat)
}

private def disregardProto(pt: Type)(implicit ctx: Context): Boolean = pt.dealias match {
case _: OrType => true
// Don't constrain results with union types, since comparison with a union
// type on the right might commit too early into one side.
case pt => pt.isRef(defn.UnitClass)
}
private def disregardProto(pt: Type)(implicit ctx: Context): Boolean =
pt.dealias.isRef(defn.UnitClass)

/** Check that the result type of the current method
* fits the given expected result type.
*/
def constrainResult(mt: Type, pt: Type)(implicit ctx: Context): Boolean = {
def constrainResult(mt: Type, pt: Type)(implicit parentCtx: Context): Boolean = {
given ctx as Context = parentCtx.addMode(Mode.ConstrainResult)
val savedConstraint = ctx.typerState.constraint
val res = pt.widenExpr match {
case pt: FunProto =>
Expand Down
1 change: 1 addition & 0 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class CompilationTests extends ParallelTesting {
compileFile("tests/neg-custom-args/i3882.scala", allowDeepSubtypes),
compileFile("tests/neg-custom-args/i4372.scala", allowDeepSubtypes),
compileFile("tests/neg-custom-args/i1754.scala", allowDeepSubtypes),
compileFile("tests/neg-custom-args/interop-polytypes.scala", allowDeepSubtypes.and("-Yexplicit-nulls")),
compileFile("tests/neg-custom-args/conditionalWarnings.scala", allowDeepSubtypes.and("-deprecation").and("-Xfatal-warnings")),
compileFilesInDir("tests/neg-custom-args/isInstanceOf", allowDeepSubtypes and "-Xfatal-warnings"),
compileFile("tests/neg-custom-args/i3627.scala", allowDeepSubtypes),
Expand Down
4 changes: 2 additions & 2 deletions tests/neg/i6565.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ def (o: Lifted[O]) flatMap [O, U] (f: O => Lifted[U]): Lifted[U] = ???
val error: Err = Err()

lazy val ok: Lifted[String] = { // ok despite map returning a union
point("a").map(_ => if true then "foo" else error) // error
point("a").map(_ => if true then "foo" else error) // ok
}

lazy val bad: Lifted[String] = { // found Lifted[Object]
point("a").flatMap(_ => point("b").map(_ => if true then "foo" else error)) // error
}
}
2 changes: 1 addition & 1 deletion tests/neg/union.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ object O {

val x: A = f(new A { }, new A)

val y1: A | B = f(new A { }, new B) // error
val y1: A | B = f(new A { }, new B) // ok
val y2: A | B = f[A | B](new A { }, new B) // ok

val z = if (???) new A{} else new B
Expand Down
13 changes: 13 additions & 0 deletions tests/pos/and-inf.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class A
class B

class Inv[T]
class Contra[-T]

class Test {
def foo[T, S](x: T, y: S): Contra[Inv[T] & Inv[S]] = ???
val a: A = new A
val b: B = new B

val x: Contra[Inv[A] & Inv[B]] = foo(a, b)
}
27 changes: 27 additions & 0 deletions tests/pos/i7829.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class X
class Y

object Test {
type Id[T] = T

val a: 1 = identity(1)
val b: Id[1] = identity(1)

val c: X | Y = identity(if (true) new X else new Y)
val d: Id[X | Y] = identity(if (true) new X else new Y)

def impUnion: Unit = {
class Base
class A extends Base
class B extends Base
class Inv[T]

implicit def invBase: Inv[Base] = new Inv[Base]

def getInv[T](x: T)(implicit inv: Inv[T]): Int = 1

val a: Int = getInv(if (true) new A else new B)
// If we keep unions when doing the implicit search, this would give us: "no implicit argument of type Inv[X | Y]"
val b: Int | Any = getInv(if (true) new A else new B)
}
}
17 changes: 17 additions & 0 deletions tests/pos/i8378.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
trait Has[A]

trait A
trait B
trait C

trait ZLayer[-RIn, +E, +ROut]

object ZLayer {
def fromServices[A0, A1, B](f: (A0, A1) => B): ZLayer[Has[A0] with Has[A1], Nothing, Has[B]] =
???
}

val live: ZLayer[Has[A] & Has[B], Nothing, Has[C]] =
ZLayer.fromServices { (a: A, b: B) =>
new C {}
}
14 changes: 14 additions & 0 deletions tests/pos/or-inf.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
object Test {

def a(lis: Set[Int] | Set[String]) = {}
a(Set(1))
a(Set(""))

def b(lis: List[Set[Int] | Set[String]]) = {}
b(List(Set(1)))
b(List(Set("")))

def c(x: Set[Any] | Array[Any]) = {}
c(Set(1))
c(Array(1))
}
6 changes: 0 additions & 6 deletions tests/pos/orinf.scala

This file was deleted.