Skip to content

Commit

Permalink
Freeze GADTs more when comparing type member infos
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Sep 23, 2022
1 parent 63344e7 commit 9084be9
Show file tree
Hide file tree
Showing 12 changed files with 217 additions and 5 deletions.
11 changes: 8 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1923,7 +1923,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
|| symInfo.isInstanceOf[MethodType]
&& symInfo.signature.consistentParams(info2.signature)

def tp1IsSingleton: Boolean = tp1.isInstanceOf[SingletonType]
def allowGadt: Boolean =
def rec(tp: Type): Boolean = tp match
case RefinedType(parent, name1, _) => name == name1 || rec(parent)
case tp: TypeRef => tp.symbol.isClass
case _ => false
!approx.low && rec(tp1)

// A relaxed version of isSubType, which compares method types
// under the standard arrow rule which is contravarient in the parameter types,
Expand All @@ -1939,8 +1944,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
matchingMethodParams(info1, info2, precise = false)
&& isSubInfo(info1.resultType, info2.resultType.subst(info2, info1), symInfo1.resultType)
&& sigsOK(symInfo1, info2)
case _ => inFrozenGadtIf(tp1IsSingleton) { isSubType(info1, info2) }
case _ => inFrozenGadtIf(tp1IsSingleton) { isSubType(info1, info2) }
case _ => inFrozenGadtIf(!allowGadt) { isSubType(info1, info2) }
case _ => inFrozenGadtIf(!allowGadt) { isSubType(info1, info2) }

def qualifies(m: SingleDenotation): Boolean =
val info1 = m.info.widenExpr
Expand Down
78 changes: 76 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,14 @@ object QuoteMatcher {
if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) &&
tpt2.tpe.derivesFrom(defn.RepeatedParamClass) =>
scrutinee match
case Typed(s, tpt1) if s.tpe <:< tpt.tpe => matched(scrutinee)
case Typed(s, tpt1) if patSub(s.tpe, tpt.tpe) => matched(scrutinee)
case _ => notMatched

/* Term hole */
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
case TypeApply(patternHole, tpt :: Nil)
if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) &&
scrutinee.tpe <:< tpt.tpe =>
patSub(scrutinee.tpe, tpt.tpe) =>
scrutinee match
case ClosedPatternTerm(scrutinee) => matched(scrutinee)
case _ => notMatched
Expand Down Expand Up @@ -480,4 +480,78 @@ object QuoteMatcher {

}

def patSub(scr: Type, pat: Type)(using Context): Boolean =
val scrCls = scr.classSymbol
val patCls = pat.classSymbol
val upcPat = patCls.derivesFrom(scrCls)
val upcScr = scrCls.derivesFrom(patCls)
val tp = if upcScr then scr.refinedBaseType(patCls) else scr
val pt = if upcPat then pat.refinedBaseType(scrCls) else pat
tp <:< pt

import dotty.tools.dotc.*, core.*, cc.*, reporting.*, Decorators.*, SymDenotations.*
extension (tp: Type) def refinedBaseType(base: Symbol)(using Context): Type = base.denot match
case classd: ClassDenotation => classd.refinedBaseTypeOf(tp)
case _ => NoType

extension (classd: ClassDenotation) def refinedBaseTypeOf(tp: Type)(using Context): Type =
val symbol = classd.symbol
def foldGlb(bt: Type, ps: List[Type]): Type = ps match
case p :: ps1 => foldGlb(bt & recur(p), ps1)
case _ => bt
def recur(tp: Type): Type = trace(i"($tp).rbt($symbol)", show = true) {
val normed = tp.tryNormalize
if normed.exists then recur(normed) else tp match
case tp @ TypeRef(prefix, _) =>
val tpSym = tp.symbol
tpSym.denot match
case clsd: ClassDenotation =>
def isOwnThis = prefix match
case prefix: ThisType => prefix.cls eq clsd.owner
case NoPrefix => true
case _ => false
if tpSym eq symbol then tp
else if isOwnThis then
if clsd.derivesFrom(symbol) then
val base =
if symbol.isStatic && symbol.typeParams.isEmpty then symbol.typeRef
else foldGlb(NoType, clsd.info.parents)
// change 1
if base.exists then
val custom = clsd.info.decls.filter(_.name.isTypeName)
custom.foldRight(base)((sym, base) => RefinedType(base, sym.name, sym.info))
else NoType
else NoType
else recur(clsd.typeRef).asSeenFrom(prefix, clsd.owner)
case _ => recur(tp.superTypeNormalized)
case tp @ AppliedType(tycon, args) =>
if tycon.typeSymbol eq symbol then tp
else (tycon.typeParams: @unchecked) match
case LambdaParam(_, _) :: _ => recur(tp.superTypeNormalized)
case tparams: List[Symbol @unchecked] => recur(tycon).substApprox(tparams, args)
case tp: TypeParamRef => recur(TypeComparer.bounds(tp).hi)
case CapturingType(parent, refs) => tp.derivedCapturingType(recur(parent), refs)
case tp @ RefinedType(parent, name, info) =>
// change 2
val parent1 = recur(parent)
if parent1.exists then tp.derivedRefinedType(parent1, name, info)
else NoType
case tp: TypeProxy => recur(tp.superTypeNormalized)
case tp: AndOrType =>
val tp1 = tp.tp1
val tp2 = tp.tp2
if !tp.isAnd && tp1.isBottomType && (tp1 frozen_<:< tp2) then recur(tp2)
else if !tp.isAnd && tp2.isBottomType && (tp2 frozen_<:< tp1) then recur(tp1)
else
val baseTp1 = recur(tp1)
val baseTp2 = recur(tp2)
val combined = if tp.isAnd then baseTp1 & baseTp2 else baseTp1 | baseTp2
combined match
case combined: AndOrType
if (combined.tp1 eq tp1) && (combined.tp2 eq tp2) && (combined.isAnd == tp.isAnd) => tp
case _ => combined
case JavaArrayType(_) if symbol == defn.ObjectClass => classd.typeRef
case _ => NoType
}
recur(tp)
}
31 changes: 31 additions & 0 deletions tests/neg/i15485.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
enum SUB[-L, +R]:
case Refl[C]() extends SUB[C, C]

trait Tag { type T }

def foo[A, B, X <: Tag { type T <: A } ](
e: SUB[X, Tag { type T <: B }],
x: A,
): B = e match {
case SUB.Refl() =>
// SUB.Refl.unapply[?C](e)
// ?C >: X => cstr: C = X..Any
// ?C <: Tag { T = Nothing..B } => cstr: C = X..Tag { T = Nothing..B }
// SUB[Tag { T = Nothing..Int }, Tag { T = Nothing..String }]
// A = Int
// B = String
// X = Tag { T = Nothing..Nothing }
// X <: Tag { T = Nothing..A }
// SUB[X, Tag { T = Nothing..B }]
// SUB[Tag { T = Nothing..A }, Tag { T = Nothing..B }], approxLHS
// Tag { T = Nothing..A } <: C <: Tag { T = Nothing..B }]
// Tag { T = Nothing..A } <: Tag { T = Nothing..B }
// A <: B
x // error: Found: (x: A) Required: B
}

def bad(x: Int): String =
foo[Int, String, Tag { type T = Nothing }](SUB.Refl(), x) // cast Int to String

object Test:
def main(args: Array[String]): Unit = bad(1) // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String
22 changes: 22 additions & 0 deletions tests/neg/i15485b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
enum SUB[-A, +B]:
case Refl[C]() extends SUB[C, C]

trait Tag { type T }

def foo[L, H, X <: Tag { type T >: L <: H }](
e: SUB[X, Tag { type T = Int }],
x: Int,
): L = e match {
case SUB.Refl() =>
// X <: C and C <: Tag { T = Int }
// X <: Tag { T = Int }
// Tag { T >: L <: H } <: Tag { T = Int }
// Int <: L and H <: Int
x // error
}

def bad(x: Int): String =
foo[Nothing, Any, Tag { type T = Int }](SUB.Refl(), x) // cast Int to String!

object Test:
def main(args: Array[String]): Unit = bad(1) // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String
31 changes: 31 additions & 0 deletions tests/neg/i15485c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
enum SUB[-A, +B]:
case Refl[C]() extends SUB[C, C]

trait Tag { type T }

def foo[L](g: Tag { type T >: L <: Int })(
e: SUB[g.type, Tag { type T = Int }],
x: Int,
): L = e match {
case SUB.Refl() =>
// L = Nothing
// C = t
// g := Tag { T = Int..Int }
// g <: Tag { T = Nothing..Int }
// SUB[g, Tag { T = Int..Int }]
// SUB[Tag { T = Nothing..Int }, Tag { T = Int..Int }]
// SUB[Tag { T = L..Int }, Tag { T = Int..Int }] <:< SUB[C, C]
// Tag { T = L..Int } <: C <: Tag { T = Int..Int }]
// Tag { T = L..Int } <: Tag { T = Int..Int }
// Int <: L
x // error
}

def bad(x: Int): String =
val s: Tag { type T = Int } = new Tag { type T = Int }
val t: Tag { type T >: Nothing <: Int } & s.type = s
val e: SUB[t.type, Tag { type T = Int }] = SUB.Refl[t.type]()
foo[Nothing](t)(e, x) // cast Int to String!

object Test:
def main(args: Array[String]): Unit = bad(1) // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String
12 changes: 12 additions & 0 deletions tests/pos-macros/i15485.fallout2-monocle/Derivation_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Another minimisation (after tests/run-macros/i15485.fallout-monocle)
// of monocle's GenIsoSpec.scala
// which broke when fixing soundness in infering GADT constraints on refined types
class Can[T]
object Can:
import scala.deriving.*, scala.quoted.*

inline given derived[T](using inline m: Mirror.Of[T]): Can[T] = ${ impl('m) }

private def impl[T](m: Expr[Mirror.Of[T]])(using Quotes, Type[T]): Expr[Can[T]] = m match
case '{ $_ : Mirror.Sum { type MirroredElemTypes = met } } => '{ new Can[T] }
case '{ $_ : Mirror.Product { type MirroredElemTypes = met } } => '{ new Can[T] }
3 changes: 3 additions & 0 deletions tests/pos-macros/i15485.fallout2-monocle/Lib_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Test:
def test =
Can.derived[EmptyTuple]
10 changes: 10 additions & 0 deletions tests/pos-macros/i15485.fallout3-monocle/Derivation_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import scala.deriving.*, scala.quoted.*

object Iso:
transparent inline def fields[S <: Product](using m: Mirror.ProductOf[S]): Int = ${ Impl.apply[S]('m) }

object Impl:
def apply[S <: Product](m: Expr[Mirror.ProductOf[S]])(using Quotes, Type[S]): Expr[Int] =
import quotes.reflect.*
m match
case '{ type a <: Tuple; $m: Mirror.ProductOf[S] { type MirroredElemTypes = `a` } } => '{ 1 }
4 changes: 4 additions & 0 deletions tests/pos-macros/i15485.fallout3-monocle/Lib_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Test:
def test =
case object Foo
val iso = Iso.fields[Foo.type]
16 changes: 16 additions & 0 deletions tests/run-macros/i15485.fallout-monocle/Derivation_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import scala.deriving.*, scala.quoted.*

trait Foo[T]:
def foo: Int

// A minimisation of monocle's GenIsoSpec.scala
// which broke when fixing soundness in infering GADT constraints on refined types
object Foo:
inline given derived[T](using inline m: Mirror.Of[T]): Foo[T] = ${ impl('m) }

private def impl[T](m: Expr[Mirror.Of[T]])(using qctx: Quotes, tpe: Type[T]): Expr[Foo[T]] = m match
case '{ $m : Mirror.Product { type MirroredElemTypes = EmptyTuple } } => '{ FooN[T](0) }
case '{ $m : Mirror.Product { type MirroredElemTypes = a *: EmptyTuple } } => '{ FooN[T](1) }
case '{ $m : Mirror.Product { type MirroredElemTypes = mirroredElemTypes } } => '{ FooN[T](9) }

class FooN[T](val foo: Int) extends Foo[T]
1 change: 1 addition & 0 deletions tests/run-macros/i15485.fallout-monocle/Lib_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
final case class Box(value: Int) derives Foo
3 changes: 3 additions & 0 deletions tests/run-macros/i15485.fallout-monocle/Test_3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
@main def Test =
val foo = summon[Foo[Box]].foo
assert(foo == 1, foo)

0 comments on commit 9084be9

Please sign in to comment.