Skip to content

Commit 2ea400a

Browse files
authored
Merge pull request #15544 from dwijnand/gadt/unsound-cast
Use GADT constraints in maximiseType
2 parents 6efd92d + 32826d8 commit 2ea400a

File tree

9 files changed

+103
-32
lines changed

9 files changed

+103
-32
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ sealed abstract class GadtConstraint extends Showable {
4949
/** See [[ConstraintHandling.approximation]] */
5050
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type
5151

52+
def symbols: List[Symbol]
53+
5254
def fresh: GadtConstraint
5355

5456
/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
@@ -193,12 +195,7 @@ final class ProperGadtConstraint private(
193195
case null => null
194196
// TODO: Improve flow typing so that ascription becomes redundant
195197
case tv: TypeVar =>
196-
def retrieveBounds: TypeBounds =
197-
bounds(tv.origin) match {
198-
case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) =>
199-
TypeAlias(reverseMapping(tpr).nn.typeRef)
200-
case tb => tb
201-
}
198+
def retrieveBounds: TypeBounds = externalize(bounds(tv.origin)).bounds
202199
retrieveBounds
203200
//.showing(i"gadt bounds $sym: $result", gadts)
204201
//.ensuring(containsNoInternalTypes(_))
@@ -222,6 +219,8 @@ final class ProperGadtConstraint private(
222219
res
223220
}
224221

222+
override def symbols: List[Symbol] = mapping.keys
223+
225224
override def fresh: GadtConstraint = new ProperGadtConstraint(
226225
myConstraint,
227226
mapping,
@@ -247,13 +246,7 @@ final class ProperGadtConstraint private(
247246
override protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean = TypeComparer.isSameType(tp1, tp2)
248247

249248
override def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds =
250-
val externalizeMap = new TypeMap {
251-
def apply(tp: Type): Type = tp match {
252-
case tpr: TypeParamRef => externalize(tpr)
253-
case tp => mapOver(tp)
254-
}
255-
}
256-
externalizeMap(constraint.nonParamBounds(param)).bounds
249+
externalize(constraint.nonParamBounds(param)).bounds
257250

258251
override def fullLowerBound(param: TypeParamRef)(using Context): Type =
259252
constraint.minLower(param).foldLeft(nonParamBounds(param).lo) {
@@ -270,27 +263,28 @@ final class ProperGadtConstraint private(
270263

271264
// ---- Private ----------------------------------------------------------
272265

273-
private def externalize(param: TypeParamRef)(using Context): Type =
274-
reverseMapping(param) match {
266+
private def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match
267+
case param: TypeParamRef => reverseMapping(param) match
275268
case sym: Symbol => sym.typeRef
276-
case null => param
277-
}
269+
case null => param
270+
case tp: TypeAlias => tp.derivedAlias(externalize(tp.alias, theMap))
271+
case tp => (if theMap == null then ExternalizeMap() else theMap).mapOver(tp)
272+
273+
private class ExternalizeMap(using Context) extends TypeMap:
274+
def apply(tp: Type): Type = externalize(tp, this)(using mapCtx)
278275

279276
private def tvarOrError(sym: Symbol)(using Context): TypeVar =
280277
mapping(sym).ensuring(_ != null, i"not a constrainable symbol: $sym").uncheckedNN
281278

282-
private def containsNoInternalTypes(
283-
tp: Type,
284-
acc: TypeAccumulator[Boolean] | Null = null
285-
)(using Context): Boolean = tp match {
279+
private def containsNoInternalTypes(tp: Type, theAcc: TypeAccumulator[Boolean] | Null = null)(using Context): Boolean = tp match {
286280
case tpr: TypeParamRef => !reverseMapping.contains(tpr)
287281
case tv: TypeVar => !reverseMapping.contains(tv.origin)
288282
case tp =>
289-
(if (acc != null) acc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp)
283+
(if (theAcc != null) theAcc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp)
290284
}
291285

292286
private class ContainsNoInternalTypesAccumulator(using Context) extends TypeAccumulator[Boolean] {
293-
override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp)
287+
override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp, this)
294288
}
295289

296290
// ---- Debug ------------------------------------------------------------
@@ -325,6 +319,8 @@ final class ProperGadtConstraint private(
325319

326320
override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
327321

322+
override def symbols: List[Symbol] = Nil
323+
328324
override def fresh = new ProperGadtConstraint
329325
override def restore(other: GadtConstraint): Unit =
330326
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,7 @@ trait Applications extends Compatibility {
13541354
// Constraining only fails if the pattern cannot possibly match,
13551355
// but useless pattern checks detect more such cases, so we simply rely on them instead.
13561356
withMode(Mode.GadtConstraintInference)(TypeComparer.constrainPatternType(unapplyArgType, selType))
1357-
val patternBound = maximizeType(unapplyArgType, tree.span)
1357+
val patternBound = maximizeType(unapplyArgType, unapplyFn.span.endPos)
13581358
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
13591359
unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}")
13601360
unapplyArgType

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ import collection.mutable
1616

1717
import scala.annotation.internal.sharable
1818

19-
import config.Printers.gadts
20-
2119
object Inferencing {
2220

2321
import tpd._
@@ -408,10 +406,15 @@ object Inferencing {
408406
Stats.record("maximizeType")
409407
val vs = variances(tp)
410408
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
409+
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
411410
vs foreachBinding { (tvar, v) =>
412411
if !tvar.isInstantiated then
413-
if (v == 1) tvar.instantiate(fromBelow = false)
414-
else if (v == -1) tvar.instantiate(fromBelow = true)
412+
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
413+
// then it is safe to instantiate if it doesn't occur in any of the GADT bounds.
414+
// Eg neg/i14983 the C in Node[+C] occurs in GADT bound X >: List[C] so maximising to Node[Any] is unsound
415+
// Eg pos/precise-pattern-type the T in Tree[-T] doesn't occur in any GADT bound so can maximise to Tree[Type]
416+
val safeToInstantiate = v != 0 && gadtBounds.forall(!tvar.occursIn(_))
417+
if safeToInstantiate then tvar.instantiate(fromBelow = v == -1)
415418
else {
416419
val bounds = TypeComparer.fullBounds(tvar.origin)
417420
if bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) then

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3764,9 +3764,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
37643764
res
37653765
} =>
37663766
// Insert an explicit cast, so that -Ycheck in later phases succeeds.
3767-
// I suspect, but am not 100% sure that this might affect inferred types,
3768-
// if the expected type is a supertype of the GADT bound. It would be good to come
3769-
// up with a test case for this.
3767+
// The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts.
37703768
val target =
37713769
if tree.tpe.isSingleton then
37723770
val conj = AndType(tree.tpe, pt)

tests/neg/i14983.co-contra.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
case class Showing[-C](show: C => String)
2+
3+
sealed trait Tree[+A]
4+
final case class Leaf[+B](b: B) extends Tree[B]
5+
final case class Node[-C](l: Showing[C]) extends Tree[Showing[C]]
6+
7+
object Test:
8+
def meth[X](tree: Tree[X]): X = tree match
9+
case Leaf(v) => v
10+
case Node(x) =>
11+
// tree: Tree[X] vs Node[C] aka Tree[Showing[C]]
12+
// PTC: X >: Showing[C]
13+
// max: Node[C] to Node[Nothing], instantiating C := Nothing, which makes X >: Showing[Nothing]
14+
// adapt: Showing[String] <: X = OKwithGADTUsed; insert GADT cast asInstanceOf[X]
15+
Showing[String](_ + " boom!") // error: Found: Showing[String] Required: X where: X is a type in method meth with bounds >: Showing[C$1]
16+
// after fix:
17+
// max: Node[C] => Node[C$1], instantiating C := C$1, a new symbol, so X >: Showing[C$1]
18+
// adapt: Showing[String] <: X = Fail, because String !<: C$1
19+
20+
def main(args: Array[String]): Unit =
21+
val tree = Node(Showing[Int](_.toString))
22+
val res = meth(tree)
23+
println(res.show(42)) // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String

tests/neg/i14983.contra.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
sealed trait Show[-A]
2+
final case class Pure[-B](showB: B => String) extends Show[B]
3+
final case class Many[-C](showL: List[C] => String) extends Show[List[C]]
4+
5+
object Test:
6+
def meth[X](show: Show[X]): X => String = show match
7+
case Pure(showB) => showB
8+
case Many(showL) =>
9+
val res = (xs: List[String]) => xs.head.length.toString
10+
res // error: Found: List[String] => String Required: X => String where: X is a type in method meth with bounds <: List[C$1]
11+
12+
def main(args: Array[String]): Unit =
13+
val show = Many((is: List[Int]) => (is.head + 1).toString)
14+
val fn = meth(show)
15+
assert(fn(List(42)) == "43") // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String

tests/neg/i14983.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
sealed trait Tree[+A]
2+
final case class Leaf[+B](b: B) extends Tree[B]
3+
final case class Node[+C](l: List[C]) extends Tree[List[C]]
4+
5+
// The original test case, minimised.
6+
object Test:
7+
def meth[X](tree: Tree[X]): X = tree match
8+
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X
9+
case Node(x) =>
10+
// tree: Tree[X] vs Node[C] aka Tree[List[C]]
11+
// PTC: X >: List[C]
12+
// max: Node[C] => Node[Any], instantiating C := Any, which makes X >: List[Any]
13+
// adapt: List[String] <: X = OKwithGADTUsed; insert GADT cast asInstanceOf[X]
14+
List("boom") // error: Found: List[String] Required: X where: X is a type in method meth with bounds >: List[C$1]
15+
// after fix:
16+
// max: Node[C] => Node[C$1], instantiating C := C$1, a new symbol, so X >: List[C$1]
17+
// adapt: List[String] <: X = Fail, because String !<: C$1
18+
19+
def main(args: Array[String]): Unit =
20+
val tree = Node(List(42))
21+
val res = meth(tree)
22+
assert(res.head == 42) // was: ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer

tests/run/i14983.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
sealed trait Tree[+A]
2+
final case class Leaf[+B](b: B) extends Tree[B]
3+
final case class Node[+C](l: List[C]) extends Tree[List[C]]
4+
5+
// A version of the original test case that is sound so should typecheck.
6+
object Test:
7+
def meth[X](tree: Tree[X]): X = tree match
8+
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X <:< X
9+
case Node(x) => x // ok: Tree[X] vs Node[C], PTC: X >: List[C], max: Node[C] => Node[C$1], x: C$1 <:< X, w/ GADT cast
10+
11+
def main(args: Array[String]): Unit =
12+
val tree = Node(List(42))
13+
val res = meth(tree)
14+
assert(res.head == 42) // ok

0 commit comments

Comments
 (0)