Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Help implement Metals' infer expected type feature #21390

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3268,9 +3268,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

/** The trace of comparison operations when performing `op` */
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean)(using Context): String =
val cmp = explainingTypeComparer(short)
inSubComparer(cmp)(op)
cmp.lastTrace(header)
explaining(cmp => { op(cmp); cmp.lastTrace(header) }, short)

def explaining[T](op: ExplainingTypeComparer => T, short: Boolean)(using Context): T =
inSubComparer(explainingTypeComparer(short))(op)

def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
inSubComparer(matchReducer)(op)
Expand Down Expand Up @@ -3440,6 +3441,9 @@ object TypeComparer {
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String =
comparing(_.explained(op, header, short))

def explaining[T](op: ExplainingTypeComparer => T, short: Boolean = false)(using Context): T =
comparing(_.explaining(op, short))

def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
comparing(_.reduceMatchWith(op))

Expand Down Expand Up @@ -3873,7 +3877,7 @@ class ExplainingTypeComparer(initctx: Context, short: Boolean) extends TypeCompa
override def recur(tp1: Type, tp2: Type): Boolean =
def moreInfo =
if Config.verboseExplainSubtype || ctx.settings.verbose.value
then s" ${tp1.getClass} ${tp2.getClass}"
then s" ${tp1.className} ${tp2.className}"
else ""
val approx = approxState
def approxStr = if short then "" else approx.show
Expand Down
19 changes: 15 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,22 @@ object TypeOps:
val hiBound = instantiate(bounds.hi, skolemizedArgTypes)
val loBound = instantiate(bounds.lo, skolemizedArgTypes)

def check(using Context) = {
if (!(lo <:< hiBound)) violations += ((arg, "upper", hiBound))
if (!(loBound <:< hi)) violations += ((arg, "lower", loBound))
def check(tp1: Type, tp2: Type, which: String, bound: Type)(using Context) = {
val isSub = TypeComparer.explaining { cmp =>
val isSub = cmp.isSubType(tp1, tp2)
if !isSub then
if !ctx.typerState.constraint.domainLambdas.isEmpty then
typr.println(i"${ctx.typerState.constraint}")
if !ctx.gadt.symbols.isEmpty then
typr.println(i"${ctx.gadt}")
typr.println(cmp.lastTrace(i"checkOverlapsBounds($lo, $hi, $arg, $bounds)($which)"))
//trace.dumpStack()
isSub
}//(using ctx.fresh.setSetting(ctx.settings.verbose, true)) // uncomment to enable moreInfo in ExplainingTypeComparer recur
if !isSub then violations += ((arg, which, bound))
}
check(using checkCtx)
check(lo, hiBound, "upper", hiBound)(using checkCtx)
check(loBound, hi, "lower", loBound)(using checkCtx)
}

def loop(args: List[Tree], boundss: List[TypeBounds]): Unit = args match
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/trace.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ object trace extends TraceSyntax:
object log extends TraceSyntax:
inline def isEnabled: true = true
protected val isForced = false

def dumpStack(limit: Int = -1): Unit = {
val out = Console.out
val exc = new Exception("Dump Stack")
var stack = exc.getStackTrace
.filter(e => !e.getClassName.startsWith("dotty.tools.dotc.reporting.TraceSyntax"))
.filter(e => !e.getClassName.startsWith("dotty.tools.dotc.reporting.trace"))
if limit >= 0 then
stack = stack.take(limit)
exc.setStackTrace(stack)
exc.printStackTrace(out)
}
end trace

/** This module is carefully optimized to give zero overhead if Config.tracingEnabled
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ trait Applications extends Compatibility {
fail(TypeMismatch(methType.resultType, resultType, None))

// match all arguments with corresponding formal parameters
matchArgs(orderedArgs, methType.paramInfos, 0)
if success then matchArgs(orderedArgs, methType.paramInfos, 0)
case _ =>
if (methType.isError) ok = false
else fail(em"$methString does not take parameters")
Expand Down Expand Up @@ -666,7 +666,7 @@ trait Applications extends Compatibility {
* @param n The position of the first parameter in formals in `methType`.
*/
def matchArgs(args: List[Arg], formals: List[Type], n: Int): Unit =
if (success) formals match {
formals match {
case formal :: formals1 =>

def checkNoVarArg(arg: Arg) =
Expand Down Expand Up @@ -878,7 +878,9 @@ trait Applications extends Compatibility {
init()

def addArg(arg: Tree, formal: Type): Unit =
typedArgBuf += adapt(arg, formal.widenExpr)
val typedArg = adapt(arg, formal.widenExpr)
typedArgBuf += typedArg
ok = ok & !typedArg.tpe.isError

def makeVarArg(n: Int, elemFormal: Type): Unit = {
val args = typedArgBuf.takeRight(n).toList
Expand Down Expand Up @@ -943,7 +945,7 @@ trait Applications extends Compatibility {
var typedArgs = typedArgBuf.toList
def app0 = cpy.Apply(app)(normalizedFun, typedArgs) // needs to be a `def` because typedArgs can change later
val app1 =
if (!success || typedArgs.exists(_.tpe.isError)) app0.withType(UnspecifiedErrorType)
if !success then app0.withType(UnspecifiedErrorType)
else {
if isJavaAnnotConstr(methRef.symbol) then
// #19951 Make sure all arguments are NamedArgs for Java annotations
Expand Down
55 changes: 35 additions & 20 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,25 +240,12 @@ object Inferencing {
&& {
var fail = false
var skip = false
val direction = instDirection(tvar.origin)
if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then
skip = instantiate(tvar, fromBelow = true)
else if direction >= 0 && tvar.hasUpperBound then
skip = instantiate(tvar, fromBelow = false)
// else hold off instantiating unbounded unconstrained variable
else if direction != 0 then
skip = instantiate(tvar, fromBelow = direction < 0)
else if variance >= 0 && tvar.hasLowerBound then
skip = instantiate(tvar, fromBelow = true)
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
&& force.ifBottom == IfBottom.ok
then // if variance == 0, prefer upper bound if one is given
skip = instantiate(tvar, fromBelow = true)
else if variance >= 0 && force.ifBottom == IfBottom.fail then
fail = true
else
toMaximize = tvar :: toMaximize
instDecision(tvar, variance, minimizeSelected, force.ifBottom) match
case Decision.Min => skip = instantiate(tvar, fromBelow = true)
case Decision.Max => skip = instantiate(tvar, fromBelow = false)
case Decision.Skip => // hold off instantiating unbounded unconstrained variable
case Decision.Fail => fail = true
case Decision.ToMax => toMaximize ::= tvar
!fail && (skip || foldOver(x, tvar))
}
case tp => foldOver(x, tp)
Expand Down Expand Up @@ -452,9 +439,32 @@ object Inferencing {
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
val approxAbove =
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
//println(i"instDirection($param) = $approxAbove - $approxBelow original=[$original] constrained=[$constrained]")
approxAbove - approxBelow
}

/** The instantiation decision for given poly param computed from the constraint. */
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }
private def instDecision(tvar: TypeVar, v: Int, minimizeSelected: Boolean, ifBottom: IfBottom)(using Context): Decision =
import Decision.*
val direction = instDirection(tvar.origin)
val dec = if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then Min
else if direction >= 0 && tvar.hasUpperBound then Max
else Skip
else if direction != 0 then if direction < 0 then Min else Max
else if tvar.hasLowerBound then if v >= 0 then Min else ToMax
else ifBottom match
// What's left are unconstrained tvars with at most a non-Any param upperbound:
// * IfBottom.flip will always maximise to the param upperbound, for all variances
// * IfBottom.fail will fail the IFD check, for covariant or invariant tvars, maximise contravariant tvars
// * IfBottom.ok will minimise to Nothing covariant and unbounded invariant tvars, and max to Any the others
case IfBottom.ok => if v > 0 || v == 0 && !tvar.hasUpperBound then Min else ToMax // prefer upper bound if one is given
case IfBottom.fail => if v >= 0 then Fail else ToMax
case ifBottom_flip => ToMax
//println(i"instDecision($tvar, v=v, minimizedSelected=$minimizeSelected, $ifBottom) dir=$direction = $dec")
dec

/** Following type aliases and stripping refinements and annotations, if one arrives at a
* class type reference where the class has a companion module, a reference to
* that companion module. Otherwise NoType
Expand Down Expand Up @@ -651,7 +661,7 @@ trait Inferencing { this: Typer =>

val ownedVars = state.ownedVars
if (ownedVars ne locked) && !ownedVars.isEmpty then
val qualifying = ownedVars -- locked
val qualifying = (ownedVars -- locked).toList
if (!qualifying.isEmpty) {
typr.println(i"interpolate $tree: ${tree.tpe.widen} in $state, pt = $pt, owned vars = ${state.ownedVars.toList}%, %, qualifying = ${qualifying.toList}%, %, previous = ${locked.toList}%, % / ${state.constraint}")
val resultAlreadyConstrained =
Expand Down Expand Up @@ -687,6 +697,10 @@ trait Inferencing { this: Typer =>

def constraint = state.constraint

trace(i"interpolateTypeVars($tree: ${tree.tpe}, $pt, $qualifying)", typr, (_: Any) => i"$qualifying\n$constraint\n${ctx.gadt}") {
//println(i"$constraint")
//println(i"${ctx.gadt}")

/** Values of this type report type variables to instantiate with variance indication:
* +1 variable appears covariantly, can be instantiated from lower bound
* -1 variable appears contravariantly, can be instantiated from upper bound
Expand Down Expand Up @@ -804,6 +818,7 @@ trait Inferencing { this: Typer =>
end doInstantiate

doInstantiate(filterByDeps(toInstantiate))
}
}
end if
tree
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ import config.Printers.typr
import Inferencing.*
import ErrorReporting.*
import util.SourceFile
import util.Spans.{NoSpan, Span}
import TypeComparer.necessarySubType
import reporting.*

import scala.annotation.internal.sharable
import dotty.tools.dotc.util.Spans.{NoSpan, Span}

object ProtoTypes {

Expand Down Expand Up @@ -83,6 +84,7 @@ object ProtoTypes {
* fits the given expected result type.
*/
def constrainResult(mt: Type, pt: Type)(using Context): Boolean =
trace(i"constrainResult($mt, $pt)", typr):
val savedConstraint = ctx.typerState.constraint
val res = pt.widenExpr match {
case pt: FunProto =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/util/Signatures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ object Signatures {
*
* @param err The error message to inspect.
* @param params The parameters that were given at the call site.
* @param alreadyCurried Index of paramss we are currently in.
* @param paramssIndex Index of paramss we are currently in.
*
* @return A pair composed of the index of the best alternative (0 if no alternatives
* were found), and the list of alternatives.
Expand Down
49 changes: 23 additions & 26 deletions compiler/test/dotty/tools/dotc/typer/InstantiateModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,16 @@ package typer

// Modelling the decision in IsFullyDefined
object InstantiateModel:
enum LB { case NN; case LL; case L1 }; import LB.*
enum UB { case AA; case UU; case U1 }; import UB.*
enum Var { case V; case NotV }; import Var.*
enum MSe { case M; case NotM }; import MSe.*
enum Bot { case Fail; case Ok; case Flip }; import Bot.*
enum Act { case Min; case Max; case ToMax; case Skip; case False }; import Act.*
enum LB { case NN; case LL; case L1 }; import LB.*
enum UB { case AA; case UU; case U1 }; import UB.*
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }; import Decision.*

// NN/AA = Nothing/Any
// LL/UU = the original bounds, on the type parameter
// L1/U1 = the constrained bounds, on the type variable
// V = variance >= 0 ("non-contravariant")
// MSe = minimisedSelected
// Bot = IfBottom
// ToMax = delayed maximisation, via addition to toMaximize
// Skip = minimisedSelected "hold off instantiating"
// False = return false
// Fail = IfBottom.fail's bail option

// there are 9 combinations:
// # | LB | UB | d | // d = direction
Expand All @@ -34,24 +28,27 @@ object InstantiateModel:
// 8 | NN | UU | 0 | T <: UU
// 9 | NN | AA | 0 | T

def decide(lb: LB, ub: UB, v: Var, bot: Bot, m: MSe): Act = (lb, ub) match
def instDecision(lb: LB, ub: UB, v: Int, ifBottom: IfBottom, min: Boolean) = (lb, ub) match
case (L1, AA) => Min
case (L1, UU) => Min
case (LL, U1) => Max
case (NN, U1) => Max

case (L1, U1) => if m==M || v==V then Min else ToMax
case (LL, UU) => if m==M || v==V then Min else ToMax
case (LL, AA) => if m==M || v==V then Min else ToMax

case (NN, UU) => bot match
case _ if m==M => Max
//case Ok if v==V => Min // removed, i14218 fix
case Fail if v==V => False
case _ => ToMax

case (NN, AA) => bot match
case _ if m==M => Skip
case Ok if v==V => Min
case Fail if v==V => False
case _ => ToMax
case (L1, U1) => if min then Min else pickVar(v, Min, Min, ToMax)
case (LL, UU) => if min then Min else pickVar(v, Min, Min, ToMax)
case (LL, AA) => if min then Min else pickVar(v, Min, Min, ToMax)

case (NN, UU) => ifBottom match
case _ if min => Max
case IfBottom.ok => pickVar(v, Min, ToMax, ToMax)
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
case IfBottom.flip => ToMax

case (NN, AA) => ifBottom match
case _ if min => Skip
case IfBottom.ok => pickVar(v, Min, Min, ToMax)
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
case IfBottom.flip => ToMax

def pickVar[A](v: Int, cov: A, inv: A, con: A) =
if v > 0 then cov else if v == 0 then inv else con
Loading
Loading