Skip to content

Commit

Permalink
Merge pull request #15134 from dwijnand/local-classes-are-uncheckable
Browse files Browse the repository at this point in the history
Local classes are uncheckable (type tests)
  • Loading branch information
odersky authored Jul 6, 2022
2 parents 6a62bb7 + 3ca5b70 commit 9d07d52
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 15 deletions.
2 changes: 1 addition & 1 deletion community-build/community-projects/fs2
35 changes: 21 additions & 14 deletions compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object TypeTestsCasts {
import typer.Inferencing.maximizeType
import typer.ProtoTypes.constrained

/** Whether `(x:X).isInstanceOf[P]` can be checked at runtime?
/** Whether `(x: X).isInstanceOf[P]` can be checked at runtime?
*
* First do the following substitution:
* (a) replace `T @unchecked` and pattern binder types (e.g., `_$1`) in P with WildcardType
Expand All @@ -48,7 +48,8 @@ object TypeTestsCasts {
* (c) maximize `pre.F[Xs]` and check `pre.F[Xs] <:< P`
* 6. if `P = T1 | T2` or `P = T1 & T2`, checkable(X, T1) && checkable(X, T2).
* 7. if `P` is a refinement type, FALSE
* 8. otherwise, TRUE
* 8. if `P` is a local class which is not statically reachable from the scope where `X` is defined, FALSE
* 9. otherwise, TRUE
*/
def checkable(X: Type, P: Type, span: Span)(using Context): Boolean = atPhase(Phases.refchecksPhase.next) {
// Run just before ElimOpaque transform (which follows RefChecks)
Expand Down Expand Up @@ -152,10 +153,13 @@ object TypeTestsCasts {
case AnnotatedType(t, _) => recur(X, t)
case tp2: RefinedType => recur(X, tp2.parent) && TypeComparer.hasMatchingMember(tp2.refinedName, X, tp2)
case tp2: RecType => recur(X, tp2.parent)
case _
if P.classSymbol.isLocal && foundClasses(X).exists(P.classSymbol.isInaccessibleChildOf) => // 8
false
case _ => true
})

val res = recur(X.widen, replaceP(P))
val res = X.widenTermRefExpr.hasAnnotation(defn.UncheckedAnnot) || recur(X.widen, replaceP(P))

debug.println(i"checking ${X.show} isInstanceOf ${P} = $res")

Expand All @@ -174,15 +178,6 @@ object TypeTestsCasts {
def derivedTree(expr1: Tree, sym: Symbol, tp: Type) =
cpy.TypeApply(tree)(expr1.select(sym).withSpan(expr.span), List(TypeTree(tp)))

def effectiveClass(tp: Type): Symbol =
if tp.isRef(defn.PairClass) then effectiveClass(erasure(tp))
else if tp.isRef(defn.AnyValClass) then defn.AnyClass
else tp.classSymbol

def foundClasses(tp: Type, acc: List[Symbol]): List[Symbol] = tp.dealias match
case OrType(tp1, tp2) => foundClasses(tp2, foundClasses(tp1, acc))
case _ => effectiveClass(tp) :: acc

def inMatch =
tree.fun.symbol == defn.Any_typeTest || // new scheme
expr.symbol.is(Case) // old scheme
Expand Down Expand Up @@ -251,7 +246,7 @@ object TypeTestsCasts {
if expr.tpe.isBottomType then
report.warning(TypeTestAlwaysDiverges(expr.tpe, testType), tree.srcPos)
val nestedCtx = ctx.fresh.setNewTyperState()
val foundClsSyms = foundClasses(expr.tpe.widen, Nil)
val foundClsSyms = foundClasses(expr.tpe.widen)
val sensical = checkSensical(foundClsSyms)(using nestedCtx)
if (!sensical) {
nestedCtx.typerState.commit()
Expand All @@ -272,7 +267,7 @@ object TypeTestsCasts {
def transformAsInstanceOf(testType: Type): Tree = {
def testCls = effectiveClass(testType.widen)
def foundClsSymPrimitive = {
val foundClsSyms = foundClasses(expr.tpe.widen, Nil)
val foundClsSyms = foundClasses(expr.tpe.widen)
foundClsSyms.size == 1 && foundClsSyms.head.isPrimitiveValueClass
}
if (erasure(expr.tpe) <:< testType)
Expand Down Expand Up @@ -372,4 +367,16 @@ object TypeTestsCasts {
}
interceptWith(expr)
}

private def effectiveClass(tp: Type)(using Context): Symbol =
if tp.isRef(defn.PairClass) then effectiveClass(erasure(tp))
else if tp.isRef(defn.AnyValClass) then defn.AnyClass
else tp.classSymbol

private[transform] def foundClasses(tp: Type)(using Context): List[Symbol] =
def go(tp: Type, acc: List[Type])(using Context): List[Type] = tp.dealias match
case OrType(tp1, tp2) => go(tp2, go(tp1, acc))
case AndType(tp1, tp2) => (for t1 <- go(tp1, Nil); t2 <- go(tp2, Nil); yield AndType(t1, t2)) ::: acc
case _ => tp :: acc
go(tp, Nil).map(effectiveClass)
}
45 changes: 45 additions & 0 deletions compiler/test/dotty/tools/dotc/transform/TypeTestsCastsTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package dotty.tools
package dotc
package transform

import core.*
import Contexts.*, Decorators.*, Denotations.*, SymDenotations.*, Symbols.*, Types.*
import Annotations.*

import org.junit.Test
import org.junit.Assert.*

class TypeTestsCastsTest extends DottyTest:
val defn = ctx.definitions; import defn.*

@Test def orL = checkFound(List(StringType, LongType), OrType(LongType, StringType, false))
@Test def orR = checkFound(List(LongType, StringType), OrType(StringType, LongType, false))

@Test def annot = checkFound(List(StringType, LongType), AnnotatedType(OrType(LongType, StringType, false), Annotation(defn.UncheckedAnnot)))

@Test def andL = checkFound(List(StringType), AndType(StringType, AnyType))
@Test def andR = checkFound(List(StringType), AndType(AnyType, StringType))
@Test def andX = checkFound(List(NoType), AndType(StringType, BooleanType))

// (A | B) & C => {(A & B), (A & C)}
// A & (B | C) => {(A & B), (A & C)}
// (A | B) & (C | D) => {(A & C), (A & D), (B & C), (B & D)}
@Test def orInAndL = checkFound(List(StringType, LongType), AndType(OrType(LongType, StringType, false), AnyType))
@Test def orInAndR = checkFound(List(StringType, LongType), AndType(AnyType, OrType(LongType, StringType, false)))
@Test def orInAndZ =
// (Throwable | Exception) & (RuntimeException | Any) =
// Throwable & RuntimeException = RuntimeException
// Throwable & Any = Throwable
// Exception & RuntimeException = RuntimeException
// Exception & Any = Exception
val ExceptionType = defn.ExceptionClass.typeRef
val RuntimeExceptionType = defn.RuntimeExceptionClass.typeRef
val tp = AndType(OrType(ThrowableType, ExceptionType, false), OrType(RuntimeExceptionType, AnyType, false))
val exp = List(ExceptionType, RuntimeExceptionType, ThrowableType, RuntimeExceptionType)
checkFound(exp, tp)

def checkFound(found: List[Type], tp: Type) =
val expected = found.map(_.classSymbol)
val obtained = TypeTestsCasts.foundClasses(tp)
assertEquals(expected, obtained)
end TypeTestsCastsTest
28 changes: 28 additions & 0 deletions tests/neg/i4812.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-- Error: tests/neg/i4812.scala:8:11 -----------------------------------------------------------------------------------
8 | case prev: A => // error: the type test for A cannot be checked at runtime
| ^
| the type test for A cannot be checked at runtime
-- Error: tests/neg/i4812.scala:18:11 ----------------------------------------------------------------------------------
18 | case prev: A => // error: the type test for A cannot be checked at runtime
| ^
| the type test for A cannot be checked at runtime
-- Error: tests/neg/i4812.scala:28:11 ----------------------------------------------------------------------------------
28 | case prev: A => // error: the type test for A cannot be checked at runtime
| ^
| the type test for A cannot be checked at runtime
-- Error: tests/neg/i4812.scala:38:11 ----------------------------------------------------------------------------------
38 | case prev: A => // error: the type test for A cannot be checked at runtime
| ^
| the type test for A cannot be checked at runtime
-- Error: tests/neg/i4812.scala:50:13 ----------------------------------------------------------------------------------
50 | case prev: A => // error: the type test for A cannot be checked at runtime
| ^
| the type test for A cannot be checked at runtime
-- Error: tests/neg/i4812.scala:60:11 ----------------------------------------------------------------------------------
60 | case prev: A => // error: the type test for A cannot be checked at runtime
| ^
| the type test for A cannot be checked at runtime
-- Error: tests/neg/i4812.scala:96:11 ----------------------------------------------------------------------------------
96 | case x: B => // error: the type test for B cannot be checked at runtime
| ^
| the type test for B cannot be checked at runtime
119 changes: 119 additions & 0 deletions tests/neg/i4812.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// scalac: -Werror
object Test:
var prev: Any = _

def test[T](x: T): T =
class A(val elem: (T, Boolean))
prev match
case prev: A => // error: the type test for A cannot be checked at runtime
prev.elem._1
case _ =>
prev = new A((x, true))
x

def test2[T](x: T): T =
abstract class Parent(_elem: T) { def elem: T = _elem }
class A extends Parent(x)
prev match
case prev: A => // error: the type test for A cannot be checked at runtime
prev.elem
case _ =>
prev = new A
x

def test3[T](x: T): T =
class Holder(val elem: T)
class A(val holder: Holder)
prev match
case prev: A => // error: the type test for A cannot be checked at runtime
prev.holder.elem
case _ =>
prev = new A(new Holder(x))
x

def test4[T](x: T): T =
class Holder(val elem: (Int, (Unit, (T, Boolean))))
class A { var holder: Holder = null }
prev match
case prev: A => // error: the type test for A cannot be checked at runtime
prev.holder.elem._2._2._1
case _ =>
val a = new A
a.holder = new Holder((42, ((), (x, true))))
prev = a
x

class Foo[U]:
def test5(x: U): U =
class A(val elem: U)
prev match
case prev: A => // error: the type test for A cannot be checked at runtime
prev.elem
case _ =>
prev = new A(x)
x

def test6[T](x: T): T =
class A { var b: B = null }
class B { var a: A = null; var elem: T = _ }
prev match
case prev: A => // error: the type test for A cannot be checked at runtime
prev.b.elem
case _ =>
val a = new A
val b = new B
b.elem = x
a.b = b
prev = a
x

def test7[T](x: T): T =
class A(val elem: T)
prev match
case prev: A @unchecked => prev.elem
case _ => prev = new A(x); x

def test8[T](x: T): T =
class A(val elem: T)
val p = prev
(p: @unchecked) match
case prev: A => prev.elem
case _ => prev = new A(x); x

def test9 =
trait A
class B extends A
val x: A = new B
x match
case x: B => x

sealed class A
var prevA: A = _
def test10: A =
val methodCallId = System.nanoTime()
class B(val id: Long) extends A
prevA match
case x: B => // error: the type test for B cannot be checked at runtime
x.ensuring(x.id == methodCallId, s"Method call id $methodCallId != ${x.id}")
case _ =>
val x = new B(methodCallId)
prevA = x
x

def test11 =
trait A
trait B
class C extends A with B
val x: A = new C
x match
case x: B => x

def test12 =
class Foo
class Bar
val x: Foo | Bar = new Foo
x.isInstanceOf[Foo]

def main(args: Array[String]): Unit =
test(1)
val x: String = test("") // was: ClassCastException: java.lang.Integer cannot be cast to java.lang.String

0 comments on commit 9d07d52

Please sign in to comment.