diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 1392ed688959..7488e9e35471 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -235,8 +235,7 @@ object Inferencing { * approx, see gadt-approximation-interaction.scala). */ def apply(tp: Type): Type = tp.dealias match { - case tp @ TypeRef(qual, nme) if (qual eq NoPrefix) - && variance != 0 + case tp @ TypeRef(qual, nme) if variance != 0 && ctx.gadt.contains(tp.symbol) => val sym = tp.symbol diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index b9eff97f1a5c..45f806d47d7e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2030,6 +2030,15 @@ class Typer extends Namer val tparamss = paramss1.collect { case untpd.TypeDefs(tparams) => tparams } + + // Register GADT constraint for class type parameters from outer to inner class definition. (Useful when nested classes exist.) But do not cross a function definition. + if sym.flags.is(Method) then + rhsCtx.setFreshGADTBounds + ctx.outer.outersIterator.takeWhile(!_.owner.is(Method)) + .filter(ctx => ctx.owner.isClass && ctx.owner.typeParams.nonEmpty) + .toList.reverse + .foreach(ctx => rhsCtx.gadt.addToConstraint(ctx.owner.typeParams)) + if tparamss.nonEmpty then rhsCtx.setFreshGADTBounds val tparamSyms = tparamss.flatten.map(_.symbol) diff --git a/tests/pos/class-gadt/basic.scala b/tests/pos/class-gadt/basic.scala new file mode 100644 index 000000000000..620144c4fd91 --- /dev/null +++ b/tests/pos/class-gadt/basic.scala @@ -0,0 +1,14 @@ +object basic { + enum Expr[A] { + case IntExpr(value: Int) extends Expr[Int] + case Other[T](value: T) extends Expr[T] + } + + class C[A] { + def eval(e: Expr[A]): A = + e match { + case Expr.IntExpr(i) => i + 2 + case Expr.Other(v) => v + } + } +} diff --git a/tests/pos/class-gadt/inheritance.scala b/tests/pos/class-gadt/inheritance.scala new file mode 100644 index 000000000000..0c6639abfa3a --- /dev/null +++ b/tests/pos/class-gadt/inheritance.scala @@ -0,0 +1,15 @@ +object inheritance{ + enum SUB[-A, +B]: + case Refl[S]() extends SUB[S, S] + + class A[T](val v: T) { + val foo1: T = v + } + + class C[T](val v1: T) extends A[T](v1) { + def eval1(t: T, e: SUB[T, Int]): Int = + e match { + case SUB.Refl() => foo1 + } + } +} \ No newline at end of file diff --git a/tests/pos/class-gadt/member.scala b/tests/pos/class-gadt/member.scala new file mode 100644 index 000000000000..98156d21d000 --- /dev/null +++ b/tests/pos/class-gadt/member.scala @@ -0,0 +1,11 @@ +object member{ + enum SUB[-A, +B]: + case Refl[S]() extends SUB[S, S] + + class C[T] { + def eval1(t: T, e: SUB[T, Int]): Int = + e match { + case SUB.Refl() => t + 2 + } + } +} \ No newline at end of file diff --git a/tests/pos/class-gadt/nestedClass.scala b/tests/pos/class-gadt/nestedClass.scala new file mode 100644 index 000000000000..e7513cf444fa --- /dev/null +++ b/tests/pos/class-gadt/nestedClass.scala @@ -0,0 +1,59 @@ +// Nested class for GADT constraining +object nestedClass{ + enum Expr[A] { + case IntExpr(value: Int) extends Expr[Int] + case Other[T](value: T) extends Expr[T] + } + + class Outer1[C] { + class Inner1 { + def eval(e: Expr[C]): C = + e match { + case Expr.IntExpr(i) => i + 2 + case Expr.Other(v) => v + } + } + } + + def foo2[C](): Unit = + class Outer2 { + class Inner2 { + def eval(e: Expr[C]): C = + e match { + case Expr.IntExpr(i) => i + 2 + case Expr.Other(v) => v + } + } + } + + class Outer3[C] { + def foo3(): Unit = + class Inner3 { + def eval(e: Expr[C]): C = + e match { + case Expr.IntExpr(i) => i + 2 + case Expr.Other(v) => v + } + } + } + + trait Outer4[C] { + class Inner4 { + def eval(e: Expr[C]): C = + e match { + case Expr.IntExpr(i) => i + 2 + case Expr.Other(v) => v + } + } + } + + class Outer5[C] { + object Inner5 { + def eval(e: Expr[C]): C = + e match { + case Expr.IntExpr(i) => i + 2 + case Expr.Other(v) => v + } + } + } +} \ No newline at end of file diff --git a/tests/pos/class-gadt/variance.scala b/tests/pos/class-gadt/variance.scala new file mode 100644 index 000000000000..0c4dbb7563a0 --- /dev/null +++ b/tests/pos/class-gadt/variance.scala @@ -0,0 +1,77 @@ +object variance { + enum SUB[-A, +B]: + case Refl[S]() extends SUB[S, S] + // Covariant + class C1[+T](v: T){ + def foo(ev: T SUB Int): Int = + ev match { + case SUB.Refl() => v + } + } + + // Contravariant + class B2[-T](v: T){} + + class C2[-T](v: T){ + def foo(ev: Int SUB T): B2[T] = + ev match { + case SUB.Refl() => new B2(v) + } + } + + // Variance with inheritance + + // superclass covariant and subclass covariant + + class A3[+T](v: T) { + val value = v + } + + class C3[+T](v: T) extends A3[T](v){ + def foo(ev: T SUB Int): Int = + ev match { + case SUB.Refl() => value + } + } + + + // superclass covariant and subclass invariant + class A4[+T](v: T) { + val value = v + } + + class C4[T](v: T) extends A4[T](v){ + def foo(ev: T SUB Int): Int = + ev match { + case SUB.Refl() => value + } + } + + // superclass contravariant and subclass contravariant + class B5[-T](v: T){} + + class A5[-T](v: T) { + val value = new B5(v) + } + + class C5[-T](v: T) extends A5[T](v){ + def foo(ev: Int SUB T): B5[T] = + ev match { + case SUB.Refl() => value + } + } + + // superclass contravariant and subclass invariant + class B6[-T](v: T){} + + class A6[-T](v: T) { + val value = new B6(v) + } + + class C6[-T](v: T) extends A6[T](v){ + def foo(ev: Int SUB T): B6[T] = + ev match { + case SUB.Refl() => value + } + } +} \ No newline at end of file diff --git a/tests/pos/i4345.scala b/tests/pos/i4345.scala.ignore similarity index 100% rename from tests/pos/i4345.scala rename to tests/pos/i4345.scala.ignore diff --git a/tests/pos/i5735.scala b/tests/pos/i5735.scala.ignore similarity index 100% rename from tests/pos/i5735.scala rename to tests/pos/i5735.scala.ignore