Skip to content

Fix #11178: remove unsound tweak for F-bounds in isInstanceOf check #11768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
* in which case the subtyping relationship "heals" the type.
*/
def constrainPatternType(pat: Type, scrut: Type): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {

def classesMayBeCompatible: Boolean = {
import Flags._
Expand Down Expand Up @@ -135,7 +135,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
case _ => NoType
}
if (upcasted.exists)
constrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted)
constrainSimplePatternType(pat, upcasted, widenParams) || constrainUpcasted(upcasted)
else true
}
}
Expand All @@ -155,7 +155,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
case pat: RefinedOrRecType =>
constrainPatternType(stripRefinement(pat), scrut)
case pat =>
constrainSimplePatternType(pat, scrut) || classesMayBeCompatible && constrainUpcasted(scrut)
constrainSimplePatternType(pat, scrut, widenParams) || classesMayBeCompatible && constrainUpcasted(scrut)
}
}
}
Expand Down Expand Up @@ -194,7 +194,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
* case classes without also appropriately extending the relevant case class
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
*/
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type): Boolean = {
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, widenParams: Boolean): Boolean = {
def refinementIsInvariant(tp: Type): Boolean = tp match {
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
case tp: TypeProxy => refinementIsInvariant(tp.underlying)
Expand All @@ -213,7 +213,8 @@ trait PatternTypeConstrainer { self: TypeComparer =>

val widePt =
if migrateTo3 || refinementIsInvariant(patternTp) then scrutineeTp
else widenVariantParams(scrutineeTp)
else if widenParams then widenVariantParams(scrutineeTp)
else scrutineeTp
val narrowTp = SkolemType(patternTp)
trace(i"constraining simple pattern type $narrowTp <:< $widePt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
isSubType(narrowTp, widePt)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2687,8 +2687,8 @@ object TypeComparer {
def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
comparing(_.dropTransparentTraits(tp, bound))

def constrainPatternType(pat: Type, scrut: Type)(using Context): Boolean =
comparing(_.constrainPatternType(pat, scrut))
def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true)(using Context): Boolean =
comparing(_.constrainPatternType(pat, scrut, widenParams))

def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:")(using Context): String =
comparing(_.explained(op, header))
Expand Down
59 changes: 23 additions & 36 deletions compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ object TypeTestsCasts {
*
* First do the following substitution:
* (a) replace `T @unchecked` and pattern binder types (e.g., `_$1`) in P with WildcardType
* (b) replace pattern binder types (e.g., `_$1`) in X:
* - variance = 1 : hiBound
* - variance = -1 : loBound
* - variance = 0 : OrType(Any, Nothing) // TODO: use original type param bounds
*
* Then check:
*
Expand Down Expand Up @@ -67,29 +63,6 @@ object TypeTestsCasts {
}
}.apply(tp)

def replaceX(tp: Type)(using Context) = new TypeMap {
def apply(tp: Type) = tp match {
case tref: TypeRef if tref.typeSymbol.isPatternBound =>
if (variance == 1) tref.info.hiBound
else if (variance == -1) tref.info.loBound
else OrType(defn.AnyType, defn.NothingType, soft = true) // TODO: what does this line do?
case _ => mapOver(tp)
}
}.apply(tp)

/** Approximate type parameters depending on variance */
def stripTypeParam(tp: Type)(using Context) = new ApproximatingTypeMap {
val boundTypeParams = util.HashMap[TypeRef, TypeVar]()
def apply(tp: Type): Type = tp match {
case _: MatchType =>
tp // break cycles
case tp: TypeRef if !tp.symbol.isClass =>
boundTypeParams.getOrElseUpdate(tp, newTypeVar(tp.underlying.toBounds))
case _ =>
mapOver(tp)
}
}.apply(tp)

/** Returns true if the type arguments of `P` can be determined from `X` */
def typeArgsTrivial(X: Type, P: AppliedType)(using Context) = inContext(ctx.fresh.setExploreTyperState().setFreshGADTBounds) {
val AppliedType(tycon, _) = P
Expand All @@ -102,27 +75,43 @@ object TypeTestsCasts {
val tvars = constrained(typeLambda, untpd.EmptyTree, alwaysAddTypeVars = true)._2.map(_.tpe)
val P1 = tycon.appliedTo(tvars)

debug.println("before " + ctx.typerState.constraint.show)
debug.println("P : " + P.show)
debug.println("P1 : " + P1.show)
debug.println("X : " + X.show)

// It does not matter if P1 is not a subtype of X.
// It does not matter whether P1 is a subtype of X or not.
// It just tries to infer type arguments of P1 from X if the value x
// conforms to the type skeleton pre.F[_]. Then it goes on to check
// if P1 <: P, which means the type arguments in P are trivial,
// thus no runtime checks are needed for them.
P1 <:< X
withMode(Mode.GadtConstraintInference) {
// Why not widen type arguments here? Given the following program
//
// trait Tree[-T] class Ident[-T] extends Tree[T] def foo1(tree:
// Tree[Int]) = tree.isInstanceOf[Ident[Int]]
//
// In checking whether the test tree.isInstanceOf[Ident[Int]]
// is realizable, we want to constrain Ident[X] <: Tree[Int],
// such that we can infer X = Int and Ident[X] <:< Ident[Int].
//
// If we perform widening, we will get X = Nothing, and we don't have
// Ident[X] <:< Ident[Int] any more.
TypeComparer.constrainPatternType(P1, X, widenParams = false)
debug.println(TypeComparer.explained(_.constrainPatternType(P1, X, widenParams = false)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand the intent here, but it's a bit surprising that we don't widen the params, given that doing so is necessary for soundness. What would be a motivating example for not widening them here?

Copy link
Contributor Author

@liufengyun liufengyun Mar 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation for not widening is something specific to the algorithm: here for P1 = pre.F[Xs], all Xs are type variables. We want the Xs to be constrained.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel very happy the work on the GADT makes the checking so elegant that we remove a lot of ad-hoc tweaks. The fact that the algorithm is simpler and it found more bugs in the compiler gives more confidence.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One example is the following:

  // tests/neg-custom-args/isInstanceOf/JavaSeqLiteral.scala
  trait Tree[-T]

  class JavaSeqLiteral[-T] extends Tree[T]

  trait Type

  class DummyTree extends JavaSeqLiteral[Any]

  def foo1(tree: Tree[Type]) =
    tree.isInstanceOf[JavaSeqLiteral[Type]]

In checking whether the test tree.isInstanceOf[JavaSeqLiteral[Type]] is realizable, we want to constrain JavaSeqLiteral[X] <: Tree[Type], such that we can infer X = Type and JavaSeqLiteral[X] <:< JavaSeqLiteral[Type].

If we perform widening, we will get X = Nothing, and we don't have JavaSeqLiteral[X] <:< JavaSeqLiteral[Type] any more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think what we're doing here is sound, but with small changes we could easily make it unsound. I'll need to add a comment explaining what's going on, and then we can merge the PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll add a comment explaining the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added in aa296ba

}

// Maximization of the type means we try to cover all possible values
// which conform to the skeleton pre.F[_] and X. Then we have to make
// sure all of them are actually of the type P, which implies that the
// type arguments in P are trivial (no runtime check needed).
maximizeType(P1, span, fromScala2x = false)

debug.println("after " + ctx.typerState.constraint.show)

val res = P1 <:< P

debug.println(TypeComparer.explained(_.isSubType(P1, P)))

debug.println("P1 : " + P1.show)
debug.println("P1 <:< P = " + res)

Expand All @@ -140,7 +129,7 @@ object TypeTestsCasts {
case _ => recur(defn.AnyType, tpT)
}
case tpe: AppliedType =>
X.widen match {
X.widenDealias match {
case OrType(tp1, tp2) =>
// This case is required to retrofit type inference,
// which cut constraints in the following two cases:
Expand All @@ -151,10 +140,8 @@ object TypeTestsCasts {
case _ =>
// always false test warnings are emitted elsewhere
X.classSymbol.exists && P.classSymbol.exists &&
!X.classSymbol.asClass.mayHaveCommonChild(P.classSymbol.asClass) ||
// first try without striping type parameters for performance
typeArgsTrivial(X, tpe) ||
typeArgsTrivial(stripTypeParam(X), tpe)
!X.classSymbol.asClass.mayHaveCommonChild(P.classSymbol.asClass)
|| typeArgsTrivial(X, tpe)
}
case AndType(tp1, tp2) => recur(X, tp1) && recur(X, tp2)
case OrType(tp1, tp2) => recur(X, tp1) && recur(X, tp2)
Expand All @@ -163,7 +150,7 @@ object TypeTestsCasts {
case _ => true
})

val res = recur(replaceX(X.widen), replaceP(P))
val res = recur(X.widen, replaceP(P))

debug.println(i"checking ${X.show} isInstanceOf ${P} = $res")

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
tpd.Closure(meth, tss => xCheckMacroedOwners(xCheckMacroValidExpr(rhsFn(meth, tss.head.map(withDefaultPos))), meth))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, TermParamClause(params) :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
case Block((ddef @ DefDef(_, tpd.ValDefs(params) :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
if ddef.symbol == meth.symbol =>
Some((params, body))
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class StringInterpolationPositionTest extends ParserTest {
def interpolationLiteralPosition: Unit = {
val t = parseText(program)
t match {
case PackageDef(_, List(TypeDef(_, Template(_, _, _, statements: List[Tree])))) => {
val interpolations = statements.collect{ case ValDef(_, _, InterpolatedString(_, int)) => int }
case PackageDef(_, List(TypeDef(_, tpl: Template))) => {
val interpolations = tpl.body.collect{ case ValDef(_, _, InterpolatedString(_, int)) => int }
val lits = interpolations.flatten.flatMap {
case l @ Literal(_) => List(l)
case Thicket(trees) => trees.collect { case l @ Literal(_) => l }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class DocRender(signatureRenderer: SignatureRenderer)(using DocContext):
case md: MdNode => renderMarkdown(md)
case Nil => raw("")
case Seq(elem: WikiDocElement) => renderElement(elem)
case list: Seq[WikiDocElement] => div(list.map(renderElement))
case list: Seq[WikiDocElement @unchecked] => div(list.map(renderElement))

private def renderMarkdown(el: MdNode): AppliedTag =
raw(DocFlexmarkRenderer.render(el)( (link,name) =>
Expand Down
18 changes: 18 additions & 0 deletions tests/neg-custom-args/isInstanceOf/html.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
object HTML:
type AttrArg = AppliedAttr | Seq[AppliedAttr]
opaque type AppliedAttr = String
opaque type AppliedTag = StringBuilder

case class Tag(name: String):
def apply(attrs: AttrArg*): AppliedTag = {
val sb = StringBuilder()
sb.append(s"<$name")
attrs.filter(_ != Nil).foreach{
case s: Seq[AppliedAttr] =>
s.foreach(sb.append(" ").append)
case s: Seq[Int] => // error
case e: AppliedAttr =>
sb.append(" ").append(e)
}
sb
}
39 changes: 39 additions & 0 deletions tests/neg-custom-args/isInstanceOf/i11178.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
trait Box[+T]
case class Foo[+S](s: S) extends Box[S]

def unwrap2[A](b: Box[A]): A =
b match
case _: Foo[Int] => 0 // error

object Test1 {
// Invariant case, OK
sealed trait Bar[A]

def test[A](bar: Bar[A]) =
bar match {
case _: Bar[Boolean] => ??? // error
case _ => ???
}
}

object Test2 {
// Covariant case
sealed trait Bar[+A]

def test[A](bar: Bar[A]) =
bar match {
case _: Bar[Boolean] => ??? // error
case _ => ???
}
}

object Test3 {
// Contravariant case
sealed trait Bar[-A]

def test[A](bar: Bar[A]) =
bar match {
case _: Bar[Boolean] => ??? // error
case _ => ???
}
}