Skip to content

Commit

Permalink
Fix InferExpectedTypeSuite.map/flatMap
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Aug 17, 2024
1 parent a619128 commit 1d56a51
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 43 deletions.
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package cc
import core.*
import Phases.*, DenotTransformers.*, SymDenotations.*
import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
import Types.*, StdNames.*, Denotations.*
import Types.*, StdNames.*, Denotations.*, NamerOps.linkConstructorParams
import config.Printers.{capt, recheckr, noPrinter}
import config.{Config, Feature}
import ast.{tpd, untpd, Trees}
Expand Down Expand Up @@ -1552,7 +1552,8 @@ class CheckCaptures extends Recheck, SymTransformer:
val checker = new TreeTraverser:
def traverse(tree: Tree)(using Context): Unit =
val lctx = tree match
case _: DefTree | _: TypeDef if tree.symbol.exists => ctx.withOwner(tree.symbol)
case _: DefDef => linkConstructorParams(tree.symbol)(using ctx.withOwner(tree.symbol))
case _: DefTree if tree.symbol.exists => ctx.withOwner(tree.symbol)
case _ => ctx
trace(i"post check $tree"):
traverseChildren(tree)(using lctx)
Expand Down
14 changes: 12 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ object Contexts {
/** AbstractFile with given path, memoized */
def getFile(name: String): AbstractFile = getFile(name.toTermName)

private var related: SimpleIdentityMap[Phase | SourceFile, Context] | Null = null
private var related: SimpleIdentityMap[Phase | SourceFile | GadtState, Context] | Null = null

private def lookup(key: Phase | SourceFile): Context | Null =
private def lookup(key: Phase | SourceFile | GadtState): Context | Null =
util.Stats.record("Context.related.lookup")
if related == null then
related = SimpleIdentityMap.empty
Expand Down Expand Up @@ -326,6 +326,16 @@ object Contexts {
related = related.nn.updated(source, ctx2)
ctx1

final def withGadtState(gadtState: GadtState): Context =
if this.gadtState eq gadtState then
this
else
var ctx1 = lookup(gadtState)
if ctx1 == null then
ctx1 = fresh.setGadtState(gadtState)
related = related.nn.updated(gadtState, ctx1)
ctx1

// `creationTrace`-related code. To enable, uncomment the code below and the
// call to `setCreationTrace()` in this file.
/*
Expand Down
13 changes: 8 additions & 5 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -904,8 +904,9 @@ class TreeUnpickler(reader: TastyReader,

def DefDef(paramss: List[ParamClause], tpt: Tree) =
sym.setParamssFromDefs(paramss)
val rhsCtx = linkConstructorParams(sym)(using localCtx)
ta.assignType(
untpd.DefDef(sym.name.asTermName, paramss, tpt, readRhs(using localCtx)),
untpd.DefDef(sym.name.asTermName, paramss, tpt, readRhs(using rhsCtx)),
sym)

def TypeDef(rhs: Tree) =
Expand Down Expand Up @@ -1127,7 +1128,7 @@ class TreeUnpickler(reader: TastyReader,
val mappedParents: LazyTreeList =
if parents.exists(_.isInstanceOf[InferredTypeTree]) then
// parents were not read fully, will need to be read again later on demand
new LazyReader(parentReader, localDummy, ctx.mode, ctx.source,
new LazyReader(parentReader, localDummy, ctx.mode, ctx.source, ctx.gadtState,
_.readParents(withArgs = true)
.map(_.changeOwner(localDummy, constr.symbol)))
else parents
Expand Down Expand Up @@ -1748,7 +1749,8 @@ class TreeUnpickler(reader: TastyReader,
goto(end)
val mode = ctx.mode
val source = ctx.source
owner => new LazyReader(localReader, owner, mode, source, op)
val gadtState = ctx.gadtState
owner => new LazyReader(localReader, owner, mode, source, gadtState, op)
}

// ------ Setting positions ------------------------------------------------
Expand Down Expand Up @@ -1810,15 +1812,16 @@ class TreeUnpickler(reader: TastyReader,
}

class LazyReader[T <: AnyRef](
reader: TreeReader, owner: Symbol, mode: Mode, source: SourceFile,
reader: TreeReader, owner: Symbol, mode: Mode, source: SourceFile, gadtState: GadtState,
op: TreeReader => Context ?=> T) extends Trees.Lazy[T] {
def complete(using Context): T = {
pickling.println(i"starting to read at ${reader.reader.currentAddr} with owner $owner")
atPhaseBeforeTransforms {
op(reader)(using ctx
.withOwner(owner)
.withModeBits(mode)
.withSource(source))
.withSource(source)
.withGadtState(gadtState))
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import typer.ErrorReporting.errorTree
import Types.*, Contexts.*, Names.*, Flags.*, DenotTransformers.*, Phases.*
import SymDenotations.*, StdNames.*, Annotations.*, Trees.*, Scopes.*
import Decorators.*
import Symbols.*, NameOps.*
import Symbols.*, NameOps.*, NamerOps.linkConstructorParams
import ContextFunctionResults.annotateContextResults
import config.Printers.typr
import config.Feature
Expand Down Expand Up @@ -441,7 +441,10 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
Checking.checkPolyFunctionType(tree.tpt)
annotateContextResults(tree)
val tree1 = cpy.DefDef(tree)(tpt = makeOverrideTypeDeclared(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
processValOrDefDef(superAcc.wrapDefDef(tree1):
inContext(linkConstructorParams(tree1.symbol)):
super.transform(tree1).asInstanceOf[DefDef]
)
case tree: TypeDef =>
registerIfHasMacroAnnotations(tree)
val sym = tree.symbol
Expand Down
74 changes: 51 additions & 23 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,25 +240,12 @@ object Inferencing {
&& {
var fail = false
var skip = false
val direction = instDirection(tvar.origin)
if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then
skip = instantiate(tvar, fromBelow = true)
else if direction >= 0 && tvar.hasUpperBound then
skip = instantiate(tvar, fromBelow = false)
// else hold off instantiating unbounded unconstrained variable
else if direction != 0 then
skip = instantiate(tvar, fromBelow = direction < 0)
else if variance >= 0 && tvar.hasLowerBound then
skip = instantiate(tvar, fromBelow = true)
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
&& force.ifBottom == IfBottom.ok
then // if variance == 0, prefer upper bound if one is given
skip = instantiate(tvar, fromBelow = true)
else if variance >= 0 && force.ifBottom == IfBottom.fail then
fail = true
else
toMaximize = tvar :: toMaximize
instDecision(tvar, variance, minimizeSelected, force.ifBottom) match
case Decision.Min => skip = instantiate(tvar, fromBelow = true)
case Decision.Max => skip = instantiate(tvar, fromBelow = false)
case Decision.Skip => // hold off instantiating unbounded unconstrained variable
case Decision.Fail => fail = true
case Decision.ToMax => toMaximize ::= tvar
!fail && (skip || foldOver(x, tvar))
}
case tp => foldOver(x, tp)
Expand Down Expand Up @@ -455,6 +442,41 @@ object Inferencing {
approxAbove - approxBelow
}

/** The instantiation decision for given poly param computed from the constraint. */
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }
private def instDecision(tvar: TypeVar, v: Int, minimizeSelected: Boolean, ifBottom: IfBottom)(using Context): Decision =
import Decision.*
val direction = instDirection(tvar.origin)
val dec = if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then Min
else if direction >= 0 && tvar.hasUpperBound then Max
else Skip
else if direction != 0 then if direction < 0 then Min else Max
else if tvar.hasLowerBound then if v >= 0 then Min else ToMax
else ifBottom match
// What's left are unconstrained tvars with at most a non-Any param upperbound:
// * IfBottom.flip will always maximise to the param upperbound, for all variances
// * IfBottom.fail will fail the IFD check, for covariant or invariant tvars, maximise contravariant tvars
// * IfBottom.ok will minimise to Nothing covariant and unbounded invariant tvars, and max to Any the others
case IfBottom.ok => if v > 0 || v == 0 && !tvar.hasUpperBound then Min else ToMax // prefer upper bound if one is given
case IfBottom.fail => if v >= 0 then Fail else ToMax
case ifBottom_flip => ToMax
//println(i"instDecision($tvar, v=v, minimizedSelected=$minimizeSelected, $ifBottom) original=[$original] constrained=[$constrained] dir=$direction = $dec")
dec

private def interpDecision(tvar: TypeVar, v: Int)(using Context): Decision =
import Decision.*
val dec = instDecision(tvar, v, minimizeSelected = false, IfBottom.fail) match
case Min => Min
case Fail => if v > 0 then Min else Max
// like IfBottom.ok,
// but only minimise unconstrained covariant tvars,
// which means that unconstrained unbounded tvars
// will be maximised to Any rather than minimised to Nothing
case _ => Max
//println(i"interpDecision($var, v=$v) = $dec")
dec

/** Following type aliases and stripping refinements and annotations, if one arrives at a
* class type reference where the class has a companion module, a reference to
* that companion module. Otherwise NoType
Expand Down Expand Up @@ -651,12 +673,13 @@ trait Inferencing { this: Typer =>

val ownedVars = state.ownedVars
if (ownedVars ne locked) && !ownedVars.isEmpty then
val qualifying = ownedVars -- locked
val qualifying = (ownedVars -- locked).toList
if (!qualifying.isEmpty) {
typr.println(i"interpolate $tree: ${tree.tpe.widen} in $state, pt = $pt, owned vars = ${state.ownedVars.toList}%, %, qualifying = ${qualifying.toList}%, %, previous = ${locked.toList}%, % / ${state.constraint}")
val resultAlreadyConstrained =
tree.isInstanceOf[Apply] || tree.tpe.isInstanceOf[MethodOrPoly]
if (!resultAlreadyConstrained)
trace(i"constrainResult($tree ${tree.symbol}, ${tree.tpe}, $pt)"):
constrainResult(tree.symbol, tree.tpe, pt)
// This is needed because it could establish singleton type upper bounds. See i2998.scala.

Expand Down Expand Up @@ -687,6 +710,10 @@ trait Inferencing { this: Typer =>

def constraint = state.constraint

trace(i"interpolateTypeVars($tree: ${tree.tpe}, $pt, $qualifying)", typr, (_: Any) => i"$qualifying\n$constraint\n${ctx.gadt}") {
//println(i"$constraint")
//println(i"${ctx.gadt}")

/** Values of this type report type variables to instantiate with variance indication:
* +1 variable appears covariantly, can be instantiated from lower bound
* -1 variable appears contravariantly, can be instantiated from upper bound
Expand Down Expand Up @@ -782,12 +809,10 @@ trait Inferencing { this: Typer =>
/** Try to instantiate `tvs`, return any suspended type variables */
def tryInstantiate(tvs: ToInstantiate): ToInstantiate = tvs match
case (hd @ (tvar, v)) :: tvs1 =>
val fromBelow = v == 1 || (v == 0 && tvar.hasLowerBound)
typr.println(
i"interpolate${if v == 0 then " non-occurring" else ""} $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
if tvar.isInstantiated then
tryInstantiate(tvs1)
else
val fromBelow = interpDecision(tvar, v) == Decision.Min
val suspend = tvs1.exists{ (following, _) =>
if fromBelow
then constraint.isLess(following.origin, tvar.origin)
Expand All @@ -797,13 +822,16 @@ trait Inferencing { this: Typer =>
typr.println(i"suspended: $hd")
hd :: tryInstantiate(tvs1)
else
typr.println(
i"interpolate${if v == 0 then " non-occurring" else ""} $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
tvar.instantiate(fromBelow)
tryInstantiate(tvs1)
case Nil => Nil
if tvs.nonEmpty then doInstantiate(tryInstantiate(tvs))
end doInstantiate

doInstantiate(filterByDeps(toInstantiate))
}
}
end if
tree
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
|""".stripMargin
)

@Ignore("Generic functions are not handled correctly.")
@Test def flatmap =
check(
"""|val _ : List[Int] = List().flatMap(_ => @@)
Expand All @@ -230,7 +229,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
|""".stripMargin
)

@Ignore("Generic functions are not handled correctly.")
@Test def map =
check(
"""|val _ : List[Int] = List().map(_ => @@)
Expand All @@ -239,7 +237,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
|""".stripMargin
)

@Ignore("Generic functions are not handled correctly.")
@Test def `for-comprehension` =
check(
"""|val _ : List[Int] =
Expand Down
10 changes: 5 additions & 5 deletions tests/neg-deep-subtype/i5877.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
object Main {
object Main { // error // error
def main(a: Array[String]): Unit = {
println("you may not run `testHasThisType` - just check that it compiles")
// comment lines after "// this line of code makes" comments to make it compilable again
Expand All @@ -18,25 +18,25 @@ object Main {

// ---- ---- ---- ----

def testHasThisType(): Unit = {
def testHasThisType(): Unit = { // error // error
def testSelf[PThis <: HasThisType[_ <: PThis]](that: HasThisType[PThis]): Unit = {
val thatSelf = that.self()
// that.self().type <: that.This
assert(implicitly[thatSelf.type <:< that.This] != null)
}
val that: HasThisType[_] = Foo() // null.asInstanceOf
testSelf(that) // error: recursion limit exceeded
testSelf(that) // error: recursion limit exceeded // error
}


def testHasThisType2(): Unit = {
def testHasThisType2(): Unit = { // error // error
def testSelf[PThis <: HasThisType[_ <: PThis]](that: PThis with HasThisType[PThis]): Unit = {
// that.type <: that.This
assert(implicitly[that.type <:< that.This] != null)
}
val that: HasThisType[_] = Foo() // null.asInstanceOf
// this line of code makes Dotty compiler infinite recursion (stopped only by overflow) - comment it to make it compilable again
testSelf(that) // error: recursion limit exceeded
testSelf(that) // error: recursion limit exceeded // error
}

// ---- ---- ---- ----
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/recursive-lower-constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ class Bar extends Foo[Bar]

class A {
def foo[T <: Foo[T], U >: Foo[T] <: T](x: T): T = x
foo(new Bar) // error // error
foo(new Bar) // error
}
12 changes: 12 additions & 0 deletions tests/pos/i21390.TrieMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Minimised from scala.collection.concurrent.LNode
// Useful as a minimisation of how,
// If we were to change the type interpolation
// to minimise to the inferred "X" type,
// then this is a minimisation of how the (ab)use of
// GADT constraints to handle class type params
// can fail PostTyper, -Ytest-pickler, and probably others.

import scala.language.experimental.captureChecking

class Foo[X](xs: List[X]):
def this(a: X, b: X) = this(if (a == b) then a :: Nil else a :: b :: Nil)

0 comments on commit 1d56a51

Please sign in to comment.